TensorFlow模型的保存及模型的应用

本文详细介绍了TensorFlow中的三种模型保存格式:ckpt、SavedModel和FrozenGraphDef。ckpt适用于模型迭代,SavedModel用于在线服务部署,FrozenGraphDef则适合移动端应用。每种格式的保存、加载方式及适用场景均有阐述。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

tensorflow的模型保存形式?

1.ckpt格式

就是通过如下几个函数实现的tensorflow模型保存的模型,是ckpt格式的模型。

saver = tf.train.Saver()
...
saver.save(sess, saveFile)

就可以保存出如下文件:

checkpoint
model-450.data-00000-of-00001
model-450.index
model-450.meta

具体说明:
checkpoint
Checkpoint 文件会记录最近一次的断点文件(Checkpoint File) 的前缀,根据前缀可以找对对应的索引和数据文件。当调用tf.train.latest_checkpoint,可以快速找到最近一次的断点文件。
ckp.data-00000-of-00001
数据(data) 文件记录了所有变量(Variable) 的值。当restore 某个变量时,首先从索引文件中找到相应变量在哪个数据文件,然后根据索引直接获取变量的值,从而实现变量数据的恢复。
ckp.index
索引(index)文件,保存了一个不可变表的数据。其中,关键字为Tensor 的名称,其值描述该Tensor 的元数据信息,包括该Tensor 存储在哪个数据(data) 文件中,及其在该数据文件中的偏移,及其校验和等信息。
ckp.meta
元文件(meta) 中保存了MetaGraphDef 的持久化数据,即模型数据,它包括GraphDef, SaverDef 等元数据。

通俗的讲,这种模型是全部保存的,即模型的框架,参数及其他的信息。这种模型重载之后,是可以继续训练的,即可以pre-train或fine-tune。
个人理解,这种形式适合模型迭代需要,但不会应用于生产或者应用。

这种模型在加载的时候直接使用**saver.restore(sess, ckpt_file)**就可以了。具体的代码网上很多,这里就不赘述了。

2.SavedModel 格式

这种保存形式的常规代码形式如下:

builder = tf.saved_model.builder.SavedModelBuilder("./model")

signature = predict_signature_def(inputs={'myInput': x},
                                  outputs={'myOutput': y})
builder.add_meta_graph_and_variables(sess=sess,
                                     tags=[tag_constants.SERVING],
                                     signature_def_map={'predict'})
builder.save()

简单的保存形式如下:

tf.saved_model.simple_save(sess,
            "./model",
            inputs={"myInput": x},
            outputs={"myOutput": y})

具体代码大家可以网上细研究哈。
这种形式保存的文件是什么样的呢?
类似这个样子:

variables/
   variables.data-*****-of-*****
   variables.index
model.pb

其中:model.pb是二进制模型文件,也就是图,variables路径下的是变量参数等。
这种模型的加载方式,大概如下:

with tf.Session(graph=tf.Graph()) as sess:
  tf.saved_model.loader.load(sess, ["test"], "./model")
  graph = tf.get_default_graph()

  input = ...
  x = sess.graph.get_tensor_by_name('input:0')
  y = sess.graph.get_tensor_by_name('output:0')
  result = sess.run(y,
           feed_dict={x: input})

具体代码和使用大家可以看tensorflow的手册或者源码。

这种模型适合怎么应用?如果是部署在线服务(Serving)时,官方推荐使用 SavedModel 格式。

3.FrozenGraphDef 格式

具体保存代码形式:

 frozen_graph_def = tf.graph_util.convert_variables_to_constants(
     sess,
     sess.graph_def,
     output_node_names=["predict"])
     
 with open('./model2/model_' + self.timestamp + '.pb', 'wb') as f:
     f.write(graph_def.SerializeToString())

或者

 frozen_graph_def = tf.graph_util.convert_variables_to_constants(
     sess,
     sess.graph_def,
     output_node_names=["predict"])

with tf.gfile.FastGFile('./model2/_model_' + self.timestamp + '.pb', mode='wb') as f:
    f.write(graph_def.SerializeToString())

二者差不多哈,就是保存时候有点差异。这种方式保存的模型文件是什么?

model.pb

python上的模型加载形式如下:

    output_graph_path = './model/model_1572338162.pb'
    with tf.Session() as sess:

        with gfile.FastGFile(output_graph_path, 'rb') as f:
            graph_def = tf.GraphDef()
            graph_def.ParseFromString(f.read())
            sess.graph.as_default()
            tf.import_graph_def(graph_def, name='')
            sess.run(tf.global_variables_initializer())
            input_x = sess.graph.get_tensor_by_name("word_ids:0")
            sequence_lengths = sess.graph.get_tensor_by_name("sequence_lengths:0")
            dropout = sess.graph.get_tensor_by_name("dropout:0")
            output = sess.graph.get_tensor_by_name("proj/predict:0")
            logit = sess.run(output, feed_dict={input_x: sent_, sequence_lengths:[len(sent_[0])], dropout:[1]})

对的,就这一个文件。这种方式保存的模型是序列化的模型,二进制文件。这个模型只保留必要的从输入到输出的一条路径的图,其他的不需要都不会保存,所以这个模型不能pre_train了。
但是应用场景,比较适合手机端等app调用,模型比较小。当然如果还想更小?就需要研究模型压缩方法了。此处不讨论。

以上python上的演示代码都是从不同项目上粘贴来的,不一致,大家可以自己体会哈。摘要性的介绍完毕。

下一篇介绍一下,java调用tensorflow模型进行使用,及在分布式上调用的问题。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值