Как сохранить и восстановить модель после обучения?
После тренировки модели в TensorFlow возникли следующие вопросы:
- Как сохранить обученную модель?
- Как позже восстановить сохраненную модель?
Надеюсь на помощь сообщества в решении этих вопросов.
5 ответ(ов)
В версиях TensorFlow начиная с 0.11.0RC1 вы можете сохранять и восстанавливать свою модель непосредственно, используя функции tf.train.export_meta_graph
и tf.train.import_meta_graph
, как описано в документации.
Сохранение модели
w1 = tf.Variable(tf.truncated_normal(shape=[10]), name='w1')
w2 = tf.Variable(tf.truncated_normal(shape=[20]), name='w2')
tf.add_to_collection('vars', w1)
tf.add_to_collection('vars', w2)
saver = tf.train.Saver()
sess = tf.Session()
sess.run(tf.global_variables_initializer())
saver.save(sess, 'my-model')
# Метод `save` косвенно вызовет `export_meta_graph`.
# В итоге вы получите файлы графа: my-model.meta
Восстановление модели
sess = tf.Session()
new_saver = tf.train.import_meta_graph('my-model.meta')
new_saver.restore(sess, tf.train.latest_checkpoint('./'))
all_vars = tf.get_collection('vars')
for v in all_vars:
v_ = sess.run(v)
print(v_)
Этот код демонстрирует, как сохранить значения переменных в модели и затем восстановить их. Убедитесь, что у вас открылся правильный контекст сессии при сохранении и восстановлении.
Модель состоит из двух частей: определения модели, которое сохраняется классом Supervisor
в файле graph.pbtxt
в директории модели, и числовых значений тензоров, которые сохраняются в контрольных точках, таких как model.ckpt-1003418
.
Определение модели можно восстановить с помощью функции tf.import_graph_def
, а веса восстанавливаются с помощью класса Saver
.
Однако Saver
использует специальную коллекцию, содержащую список переменных, привязанных к графу модели, и эта коллекция не инициализируется при использовании import_graph_def
. Поэтому в данный момент вы не можете использовать оба этих метода вместе (это в планах на будущее). На данный момент вам нужно следовать подходу Райана Сепасси: вручную создать граф с идентичными именами узлов и использовать Saver
, чтобы загрузить веса в этот граф.
В качестве альтернативы, вы можете попробовать "хак": использовать import_graph_def
, вручную создать переменные и добавлять их в коллекцию с помощью tf.add_to_collection(tf.GraphKeys.VARIABLES, variable)
для каждой переменной, а затем использовать Saver
.
Вы можете воспользоваться более простым способом.
Шаг 1: инициализируйте все свои переменные
W1 = tf.Variable(tf.truncated_normal([6, 6, 1, K], stddev=0.1), name="W1")
B1 = tf.Variable(tf.constant(0.1, tf.float32, [K]), name="B1")
Аналогично, создайте W2, B2, W3 и т.д.
Шаг 2: сохраните сессию внутри объекта Saver
и сохраните её
model_saver = tf.train.Saver()
# Обучите модель и сохраните её в конце
model_saver.save(session, "saved_models/CNN_New.ckpt")
Шаг 3: восстановите модель
with tf.Session(graph=graph_cnn) as session:
model_saver.restore(session, "saved_models/CNN_New.ckpt")
print("Модель восстановлена.")
print('Инициализировано')
Шаг 4: проверьте вашу переменную
W1 = session.run(W1)
print(W1)
Если вы запускаете это в другом экземпляре Python, используйте
with tf.Session() as sess:
# Восстановите последнюю контрольную точку
saver.restore(sess, tf.train.latest_checkpoint('saved_model/.'))
# Инициализируйте переменные
sess.run(tf.global_variables_initializer())
# Получите граф по умолчанию (передайте свой кастомный граф, если он у вас есть)
graph = tf.get_default_graph()
# Это даст объект тензора
W1 = graph.get_tensor_by_name('W1:0')
# Чтобы получить значение (numpy массив)
W1_value = session.run(W1)
Как отметил Ярослав, вы можете «взломать» восстановление из graph_def
и контрольной точки, импортировав граф, вручную создав переменные и затем используя Saver
.
Я реализовал это для личного использования и решил поделиться кодом здесь.
Ссылка: https://gist.github.com/nikitakit/6ef3b72be67b86cb7868
(Это, конечно, является хаком, и нет никаких гарантий, что модели, сохранённые таким образом, останутся читаемыми в будущих версиях TensorFlow.)
Вот моё простое решение для двух основных случаев, которые отличаются тем, хотите ли вы загрузить граф из файла или построить его во время выполнения.
Этот ответ актуален для TensorFlow 0.12+ (включая 1.0).
Воссоздание графа в коде
Сохранение
graph = ... # создайте граф
saver = tf.train.Saver() # создайте saver после создания графа
with ... as sess: # ваш объект сессии
saver.save(sess, 'my-model')
Загрузка
graph = ... # создайте граф
saver = tf.train.Saver() # создайте saver после создания графа
with ... as sess: # ваш объект сессии
saver.restore(sess, tf.train.latest_checkpoint('./'))
# теперь вы можете использовать граф, продолжать обучение или что-то ещё
Загрузка также графа из файла
При использовании этой техники убедитесь, что все ваши слои/переменные имеют явно установленные уникальные имена. В противном случае TensorFlow сделает имена уникальными самостоятельно, и они будут отличаться от имен, хранящихся в файле. В предыдущей технике это не проблема, поскольку имена "перемешиваются" одинаковым образом как при загрузке, так и при сохранении.
Сохранение
graph = ... # создайте граф
for op in [ ... ]: # операторы, которые вы хотите использовать после восстановления модели
tf.add_to_collection('ops_to_restore', op)
saver = tf.train.Saver() # создайте saver после создания графа
with ... as sess: # ваш объект сессии
saver.save(sess, 'my-model')
Загрузка
with ... as sess: # ваш объект сессии
saver = tf.train.import_meta_graph('my-model.meta')
saver.restore(sess, tf.train.latest_checkpoint('./'))
ops = tf.get_collection('ops_to_restore') # вот ваши операторы в том же порядке, в котором вы сохранили их в коллекцию
Как клонировать список, чтобы он не изменялся неожиданно после присваивания?
Преобразование списка словарей в DataFrame pandas
Как отсортировать список/кортеж списков/кортежей по элементу на заданном индексе
Как отменить последнюю миграцию?
Как явно освободить память в Python?