UFO ET IT

Tensorflow의 그래프에있는 텐서 이름 목록

ufoet 2021. 1. 13. 07:27
반응형

Tensorflow의 그래프에있는 텐서 이름 목록


Tensorflow의 그래프 개체에는 "get_tensor_by_name (name)"이라는 메서드가 있습니다. 어쨌든 유효한 텐서 이름 목록을 얻을 수 있습니까?

그렇지 않은 경우 여기에서 사전 훈련 된 모델 inception-v3의 유효한 이름을 아는 사람이 있습니까? 그들의 예에서 pool_3은 하나의 유효한 텐서이지만 모두 목록이 좋을 것입니다. 참조 된 논문을 살펴 보았는데 일부 레이어가 표 1의 크기와 일치하는 것 같지만 모두는 아닙니다.


종이가 모델을 정확하게 반영하지 않습니다. arxiv에서 소스를 다운로드하면 model.txt와 같은 정확한 모델 설명이 있으며 여기에있는 이름은 릴리스 된 모델의 이름과 밀접한 관련이 있습니다.

첫 번째 질문에 답하기 위해 sess.graph.get_operations()작업 목록을 제공합니다. 연산의 op.name경우 이름을 op.values()제공하고 생성하는 텐서 목록을 제공합니다 (inception-v3 모델에서 모든 텐서 이름은 ": 0"이 추가 된 연산 이름이므로 pool_3:0최종적으로 생성 된 텐서도 마찬가지 입니다. 풀링 op.)


그래프에서 연산을 보려면 (여러분은 많은 것을 보게 될 것이므로 여기서는 첫 번째 문자열 만 제공했습니다).

sess = tf.Session()
op = sess.graph.get_operations()
[m.values() for m in op][1]

out:
(<tf.Tensor 'conv1/weights:0' shape=(4, 4, 3, 32) dtype=float32_ref>,)

위의 답변이 정확합니다. 위의 작업에 대해 이해하기 쉽고 간단한 코드를 발견했습니다. 그래서 여기에 공유 :-

import tensorflow as tf

def printTensors(pb_file):

    # read pb into graph_def
    with tf.gfile.GFile(pb_file, "rb") as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())

    # import graph_def
    with tf.Graph().as_default() as graph:
        tf.import_graph_def(graph_def)

    # print operations
    for op in graph.get_operations():
        print(op.name)


printTensors("path-to-my-pbfile.pb")

그래프에서 모든 작업 이름의 이름을보기 위해 세션을 만들 필요조차 없습니다. 이렇게하려면 기본 그래프를 가져 tf.get_default_graph()와서 모든 작업을 추출하면됩니다 .get_operations.. 각 작업에는 많은 필드있으며 필요한 것은 이름입니다.

다음은 코드입니다.

import tensorflow as tf
a = tf.Variable(5)
b = tf.Variable(6)
c = tf.Variable(7)
d = (a + b) * c

for i in tf.get_default_graph().get_operations():
    print i.name

중첩 된 목록 이해 :

tensor_names = [t.name for op in tf.get_default_graph().get_operations() for t in op.values()]

그래프에서 텐서의 이름을 가져 오는 함수 (기본값은 기본 그래프) :

def get_names(graph=tf.get_default_graph()):
    return [t.name for op in graph.get_operations() for t in op.values()]

그래프에서 텐서를 가져 오는 함수 (기본값은 기본 그래프) :

def get_tensors(graph=tf.get_default_graph()):
    return [t for op in graph.get_operations() for t in op.values()]

참조 URL : https://stackoverflow.com/questions/35336648/list-of-tensor-names-in-graph-in-tensorflow

반응형