tensorflow的模型保存形式?
1.ckpt格式
就是通过如下几个函数实现的tensorflow模型保存的模型,是ckpt格式的模型。
saver = tf.train.Saver()
...
saver.save(sess, saveFile)
就可以保存出如下文件:
checkpoint
model-450.data-00000-of-00001
model-450.index
model-450.meta
具体说明:
checkpoint
Checkpoint 文件会记录最近一次的断点文件(Checkpoint File) 的前缀,根据前缀可以找对对应的索引和数据文件。当调用tf.train.latest_checkpoint,可以快速找到最近一次的断点文件。
ckp.data-00000-of-00001
数据(data) 文件记录了所有变量(Variable) 的值。当restore 某个变量时,首先从索引文件中找到相应变量在哪个数据文件,然后根据索引直接获取变量的值,从而实现变量数据的恢复。
ckp.index
索引(index)文件,保存了一个不可变表的数据。其中,关键字为Tensor 的名称,其值描述该Tensor 的元数据信息,包括该Tensor 存储在哪个数据(data) 文件中,及其在该数据文件中的偏移,及其校验和等信息。
ckp.meta
元文件(meta) 中保存了MetaGraphDef 的持久化数据,即模型数据,它包括GraphDef, SaverDef 等元数据。
通俗的讲,这种模型是全部保存的,即模型的框架,参数及其他的信息。这种模型重载之后,是可以继续训练的,即可以pre-train或fine-tune。
个人理解,这种形式适合模型迭代需要,但不会应用于生产或者应用。
这种模型在加载的时候直接使用**saver.restore(sess, ckpt_file)**就可以了。具体的代码网上很多,这里就不赘述了。
2.SavedModel 格式
这种保存形式的常规代码形式如下:
builder = tf.saved_model.builder.SavedModelBuilder("./model")
signature = predict_signature_def(inputs={'myInput': x},
outputs={'myOutput': y})
builder.add_meta_graph_and_variables(sess=sess,
tags=[tag_constants.SERVING],
signature_def_map={'predict'})
builder.save()
简单的保存形式如下:
tf.saved_model.simple_save(sess,
"./model",
inputs={"myInput": x},
outputs={"myOutput": y})
具体代码大家可以网上细研究哈。
这种形式保存的文件是什么样的呢?
类似这个样子:
variables/
variables.data-*****-of-*****
variables.index
model.pb
其中:model.pb是二进制模型文件,也就是图,variables路径下的是变量参数等。
这种模型的加载方式,大概如下:
with tf.Session(graph=tf.Graph()) as sess:
tf.saved_model.loader.load(sess, ["test"], "./model")
graph = tf.get_default_graph()
input = ...
x = sess.graph.get_tensor_by_name('input:0')
y = sess.graph.get_tensor_by_name('output:0')
result = sess.run(y,
feed_dict={x: input})
具体代码和使用大家可以看tensorflow的手册或者源码。
这种模型适合怎么应用?如果是部署在线服务(Serving)时,官方推荐使用 SavedModel 格式。
3.FrozenGraphDef 格式
具体保存代码形式:
frozen_graph_def = tf.graph_util.convert_variables_to_constants(
sess,
sess.graph_def,
output_node_names=["predict"])
with open('./model2/model_' + self.timestamp + '.pb', 'wb') as f:
f.write(graph_def.SerializeToString())
或者
frozen_graph_def = tf.graph_util.convert_variables_to_constants(
sess,
sess.graph_def,
output_node_names=["predict"])
with tf.gfile.FastGFile('./model2/_model_' + self.timestamp + '.pb', mode='wb') as f:
f.write(graph_def.SerializeToString())
二者差不多哈,就是保存时候有点差异。这种方式保存的模型文件是什么?
model.pb
python上的模型加载形式如下:
output_graph_path = './model/model_1572338162.pb'
with tf.Session() as sess:
with gfile.FastGFile(output_graph_path, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
sess.graph.as_default()
tf.import_graph_def(graph_def, name='')
sess.run(tf.global_variables_initializer())
input_x = sess.graph.get_tensor_by_name("word_ids:0")
sequence_lengths = sess.graph.get_tensor_by_name("sequence_lengths:0")
dropout = sess.graph.get_tensor_by_name("dropout:0")
output = sess.graph.get_tensor_by_name("proj/predict:0")
logit = sess.run(output, feed_dict={input_x: sent_, sequence_lengths:[len(sent_[0])], dropout:[1]})
对的,就这一个文件。这种方式保存的模型是序列化的模型,二进制文件。这个模型只保留必要的从输入到输出的一条路径的图,其他的不需要都不会保存,所以这个模型不能pre_train了。
但是应用场景,比较适合手机端等app调用,模型比较小。当然如果还想更小?就需要研究模型压缩方法了。此处不讨论。
以上python上的演示代码都是从不同项目上粘贴来的,不一致,大家可以自己体会哈。摘要性的介绍完毕。
下一篇介绍一下,java调用tensorflow模型进行使用,及在分布式上调用的问题。