6

Как сохранить и восстановить модель после обучения?

8

После тренировки модели в TensorFlow возникли следующие вопросы:

  1. Как сохранить обученную модель?
  2. Как позже восстановить сохраненную модель?

Надеюсь на помощь сообщества в решении этих вопросов.

5 ответ(ов)

1

В версиях TensorFlow начиная с 0.11.0RC1 вы можете сохранять и восстанавливать свою модель непосредственно, используя функции tf.train.export_meta_graph и tf.train.import_meta_graph, как описано в документации.

Сохранение модели

w1 = tf.Variable(tf.truncated_normal(shape=[10]), name='w1')
w2 = tf.Variable(tf.truncated_normal(shape=[20]), name='w2')
tf.add_to_collection('vars', w1)
tf.add_to_collection('vars', w2)
saver = tf.train.Saver()
sess = tf.Session()
sess.run(tf.global_variables_initializer())
saver.save(sess, 'my-model')
# Метод `save` косвенно вызовет `export_meta_graph`.
# В итоге вы получите файлы графа: my-model.meta

Восстановление модели

sess = tf.Session()
new_saver = tf.train.import_meta_graph('my-model.meta')
new_saver.restore(sess, tf.train.latest_checkpoint('./'))
all_vars = tf.get_collection('vars')
for v in all_vars:
    v_ = sess.run(v)
    print(v_)

Этот код демонстрирует, как сохранить значения переменных в модели и затем восстановить их. Убедитесь, что у вас открылся правильный контекст сессии при сохранении и восстановлении.

0

Модель состоит из двух частей: определения модели, которое сохраняется классом Supervisor в файле graph.pbtxt в директории модели, и числовых значений тензоров, которые сохраняются в контрольных точках, таких как model.ckpt-1003418.

Определение модели можно восстановить с помощью функции tf.import_graph_def, а веса восстанавливаются с помощью класса Saver.

Однако Saver использует специальную коллекцию, содержащую список переменных, привязанных к графу модели, и эта коллекция не инициализируется при использовании import_graph_def. Поэтому в данный момент вы не можете использовать оба этих метода вместе (это в планах на будущее). На данный момент вам нужно следовать подходу Райана Сепасси: вручную создать граф с идентичными именами узлов и использовать Saver, чтобы загрузить веса в этот граф.

В качестве альтернативы, вы можете попробовать "хак": использовать import_graph_def, вручную создать переменные и добавлять их в коллекцию с помощью tf.add_to_collection(tf.GraphKeys.VARIABLES, variable) для каждой переменной, а затем использовать Saver.

0

Вы можете воспользоваться более простым способом.

Шаг 1: инициализируйте все свои переменные

W1 = tf.Variable(tf.truncated_normal([6, 6, 1, K], stddev=0.1), name="W1")
B1 = tf.Variable(tf.constant(0.1, tf.float32, [K]), name="B1")

Аналогично, создайте W2, B2, W3 и т.д.

Шаг 2: сохраните сессию внутри объекта Saver и сохраните её

model_saver = tf.train.Saver()

# Обучите модель и сохраните её в конце
model_saver.save(session, "saved_models/CNN_New.ckpt")

Шаг 3: восстановите модель

with tf.Session(graph=graph_cnn) as session:
    model_saver.restore(session, "saved_models/CNN_New.ckpt")
    print("Модель восстановлена.") 
    print('Инициализировано')

Шаг 4: проверьте вашу переменную

W1 = session.run(W1)
print(W1)

Если вы запускаете это в другом экземпляре Python, используйте

with tf.Session() as sess:
    # Восстановите последнюю контрольную точку
    saver.restore(sess, tf.train.latest_checkpoint('saved_model/.'))

    # Инициализируйте переменные
    sess.run(tf.global_variables_initializer())

    # Получите граф по умолчанию (передайте свой кастомный граф, если он у вас есть)
    graph = tf.get_default_graph()

    # Это даст объект тензора
    W1 = graph.get_tensor_by_name('W1:0')

    # Чтобы получить значение (numpy массив)
    W1_value = session.run(W1)
0

Как отметил Ярослав, вы можете «взломать» восстановление из graph_def и контрольной точки, импортировав граф, вручную создав переменные и затем используя Saver.

Я реализовал это для личного использования и решил поделиться кодом здесь.

Ссылка: https://gist.github.com/nikitakit/6ef3b72be67b86cb7868

(Это, конечно, является хаком, и нет никаких гарантий, что модели, сохранённые таким образом, останутся читаемыми в будущих версиях TensorFlow.)

0

Вот моё простое решение для двух основных случаев, которые отличаются тем, хотите ли вы загрузить граф из файла или построить его во время выполнения.

Этот ответ актуален для TensorFlow 0.12+ (включая 1.0).

Воссоздание графа в коде

Сохранение

graph = ...  # создайте граф
saver = tf.train.Saver()  # создайте saver после создания графа
with ... as sess:  # ваш объект сессии
    saver.save(sess, 'my-model')

Загрузка

graph = ...  # создайте граф
saver = tf.train.Saver()  # создайте saver после создания графа
with ... as sess:  # ваш объект сессии
    saver.restore(sess, tf.train.latest_checkpoint('./'))
    # теперь вы можете использовать граф, продолжать обучение или что-то ещё

Загрузка также графа из файла

При использовании этой техники убедитесь, что все ваши слои/переменные имеют явно установленные уникальные имена. В противном случае TensorFlow сделает имена уникальными самостоятельно, и они будут отличаться от имен, хранящихся в файле. В предыдущей технике это не проблема, поскольку имена "перемешиваются" одинаковым образом как при загрузке, так и при сохранении.

Сохранение

graph = ...  # создайте граф

for op in [ ... ]:  # операторы, которые вы хотите использовать после восстановления модели
    tf.add_to_collection('ops_to_restore', op)

saver = tf.train.Saver()  # создайте saver после создания графа
with ... as sess:  # ваш объект сессии
    saver.save(sess, 'my-model')

Загрузка

with ... as sess:  # ваш объект сессии
    saver = tf.train.import_meta_graph('my-model.meta')
    saver.restore(sess, tf.train.latest_checkpoint('./'))
    ops = tf.get_collection('ops_to_restore')  # вот ваши операторы в том же порядке, в котором вы сохранили их в коллекцию
Чтобы ответить на вопрос, пожалуйста, войдите или зарегистрируйтесь