반응형
TensorFlow에서 학습 한 모델에서 일부 가중치 값 가져 오기
TensorFlow를 사용하여 ConvNet 모델을 훈련 시켰으며 레이어에서 특정 가중치를 얻고 싶습니다. 예를 들어 torch7에서는 단순히 model.modules[2].weights
. 레이어 2의 가중치를 가져옵니다. TensorFlow에서 동일한 작업을 어떻게 수행합니까?
TensorFlow에서 훈련 된 가중치는 tf.Variable
객체 로 표현됩니다 . -예를 들어 -yourself tf.Variable
라는 이름으로 만든 경우 (where is a )를 v
호출하여 NumPy 배열로 값을 가져올 수 있습니다 .sess.run(v)
sess
tf.Session
현재에 대한 포인터가없는 경우를 tf.Variable
호출하여 현재 그래프에서 훈련 가능한 변수 목록을 가져올 수 있습니다 tf.trainable_variables()
. 이 함수는 tf.Variable
현재 그래프에서 학습 가능한 모든 객체 의 목록을 반환하며 v.name
속성 을 일치시켜 원하는 객체를 선택할 수 있습니다 . 예를 들면 :
# Desired variable is called "tower_2/filter:0".
var = [v for v in tf.trainable_variables() if v.name == "tower_2/filter:0"][0]
따라서이 코드를 단계별로 진행하면 먼저 사용 / 학습 가능한 변수 목록이 표시됩니다. 그런 다음 가중치 행렬 / 목록을 변수 이름으로 정렬하는 목록에서 정렬 할 수 있습니다. 예를 들어 해당 정보를 처리하는 방법과 같습니다.
vars = tf.trainable_variables()
print(vars) #some infos about variables...
vars_vals = sess.run(vars)
for var, val in zip(vars, vars_vals):
print("var: {}, value: {}".format(var.name, val)) #...or sort it in a list....
반응형
'UFO ET IT' 카테고리의 다른 글
uint8_t ≠ unsigned char은 언제입니까? (0) | 2020.12.12 |
---|---|
Bearer 토큰은 Web API 2에서 서버 측에 어떻게 저장됩니까? (0) | 2020.12.12 |
모노 이드 동형은 정확히 무엇입니까? (0) | 2020.12.12 |
기호를 확인할 수 없음 : FusedLocationProviderClient. (0) | 2020.12.11 |
문서 디렉토리에 대한 시스템 정의 환경 변수가 있습니까? (0) | 2020.12.11 |