-
Notifications
You must be signed in to change notification settings - Fork 1
/
train.py
27 lines (20 loc) · 797 Bytes
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
# coding: utf-8
import tensorflow as tf
from os import path
cell = tf.contrib.rnn.BasicLSTMCell(num_units=64)
outputs, _ = tf.nn.dynamic_rnn(
cell=cell,
dtype=tf.float32,
sequence_length=tf.constant([3, 2]),
inputs=tf.constant([[[1.,1.,1.]], [[1.,1.,0.]]]))
outputs = tf.identity(outputs, name="y")
saver = tf.train.Saver(tf.global_variables())
with tf.Session() as sess:
with open(path.join("graph-def", "graph.pb"), "wb") as fout:
fout.write(sess.graph.as_graph_def().SerializeToString())
print("Written graph.")
sess.run(tf.global_variables_initializer())
sess.run(outputs)
# Save variables in checkpoint
save_path = saver.save(sess, path.join("checkpoints", "model.checkpoint"))
print("Saved checkpoint to {}.".format(save_path))