TensorFlow中那些鲜为人知却又极其实用的知识
一. GraphDef才是正确地模型保存的方法
大部分用户保存TensorFlow模型的方法是tf.train.Saver.save,这是众多科研代码中用来保存模型的方法,保存之后的模型如下图所示。
实际上这种保存的方法,是给模型训练做checkpoint用的,也就是说为了让你能够随时保存实验过程,随时恢复实验用的(防止断电、死机导致实验丢失)。
如果你希望为TensorFlow保存一个能够用于产品用的模型,并且这个模型能够被C/C++/Java/NodeJS等调用(类似Caffe模型),你需要了解GraphDef。用GraphDef方式保存的模型是一个独立地Protobuf文件,看一下维基百科对Protobuf的解释:
Protocol Buffers是一种序列化数据结构的协议。对于透过管线(pipeline)或存储数据进行通信的程序开发上是很有用的。这个方法包含一个接口描述语言,描述一些数据结构,并提供程序工具根据这些描述产生代码,用于将这些数据结构产生或解析数据流。
也就是说Protobuf文件是一种无视语种的数据描述文件,存成Protobuf文件,模型可以被Protobuf支持的各大语种(C/C++/Java/NodeJS等)读取。
TensorFlow模型的正确保存方式如下:
#coding=utf-8 import tensorflow as tf # 定义图 x = tf.placeholder(tf.float32, name="x") y = tf.get_variable("y", initializer=10.0) z = tf.log(x + y, name="z") with tf.Session() as sess: sess.run(tf.global_variables_initializer()) # 进行一些训练代码,此处省略 # xxxxxxxxxxxx # 显示图中的节点 frozen_graph_def = tf.graph_util. convert_variables_to_constants( sess, sess.graph_def, output_node_names=["z"]) print(frozen_graph_def) # 保存图为pb文件 with open('model.pb', 'wb') as f: f.write(frozen_graph_def.SerializeToString())
最终,我们只会得到一个model.pb文件:
model.pb存储的是压缩版的frozen_graph_def,上面我们用print函数将frozen_graph_def 输出的结果如下,这可以看到,这是一个标准的图结构的数据(也就是静态图),不仅包含了节点,还包含了节点中的数据。
node { name: "x" op: "Placeholder" attr { key: "dtype" value { type: DT_FLOAT } } attr { key: "shape" value { shape { unknown_rank: true } } } } node { name: "y" op: "Const" attr { key: "dtype" value { type: DT_FLOAT } } attr { key: "value" value { tensor { dtype: DT_FLOAT tensor_shape { } float_val: 10.0 } } } } node { name: "y/read" op: "Identity" input: "y" attr { key: "T" value { type: DT_FLOAT } } attr { key: "_class" value { list { s: "loc:@y" } } } } node { name: "add" op: "Add" input: "x" input: "y/read" attr { key: "T" value { type: DT_FLOAT } } } node { name: "z" op: "Log" input: "add" attr { key: "T" value { type: DT_FLOAT } } } library { }
为什么在保存GraphDef前要调用tf.graph_util.convert_variables_to_constants方法,我们发现在调用tf.graph_util.convert_variables_to_constants方法时,程序有一行输出:
Converted 1 variables to const ops.
其实默认状态下,静态图的数据是被同时保存在GraphDef和Session中的,图结构、常量的值等被存储在GraphDef中,而变量的值被存储在Session中,这也是为什么每次用静态图都要在Session中使用的原因。
tf.graph_util.convert_variables_to_constants方法将Session中的变量转换到GraphDef中以常量形式存储,由于没有了变量,得到的GraphDef中包含了静态图的所有信息,即包含了整个模型,保存GraphDef即保存了整个模型。
现在我们可以用C/C++/Java/NodeJS等来读取并执行保存的GraphDef文件,以Java为例(需要Maven导入java版tensorflow api),整个流程和Python API很像,读取图,开启Session,并将读取的图放入Session,指定输入,获取输出:
import org.apache.commons.io.IOUtils; import org.tensorflow.Graph; import org.tensorflow.Session; import org.tensorflow.Tensor; import java.io.FileInputStream; import java.io.IOException; public class DemoImportGraph { public static void main(String[] args) throws IOException { try (Graph graph = new Graph()) { //导入图 byte[] graphBytes = IOUtils.toByteArray(new FileInputStream("model.pb")); graph.importGraphDef(graphBytes); //根据图建立Session try(Session session = new Session(graph)){ //相当于TensorFlow Python中的sess.run(z, feed_dict = {'x': 10.0}) float z = session.runner() .feed("x", Tensor.create(10.0f)) .fetch("z").run().get(0).floatValue(); System.out.println(z); } } } }
所以,TensorFlow模型并非只能被Python调用。按照GraphDef方式保存为Protobuf模型后,可以被任何TensorFlow提供了API的语种调用。
二. 可以在Keras中使用TensorFlow,也可以在TensorFlow中使用Keras
TensorFlow是最终要的内核之一,在默认的使用TensorFlow作为内核的情况下,Keras的各种层、包括模型的执行,都是依赖TensorFlow的各种操作、Session等去完成的,在Keras中使用TensorFlow是众所周知的,然而在TensorFlow中使用Keras确是一个不常见的情况。其实Keras早就进入了TensorFlow的核心库(tf.keras),而且成为了官方较为推荐使用tf.keras进行模型的构建,看一下TensorFlow 1.9官网教程首页的示例代码,
import tensorflow as tf mnist = tf.keras.datasets.mnist (x_train, y_train),(x_test, y_test) = mnist.load_data() x_train, x_test = x_train / 255.0, x_test / 255.0 model = tf.keras.models.Sequential([ tf.keras.layers.Flatten(), tf.keras.layers.Dense(512, activation=tf.nn.relu), tf.keras.layers.Dropout(0.2), tf.keras.layers.Dense(10, activation=tf.nn.softmax) ]) model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy']) model.fit(x_train, y_train, epochs=5) model.evaluate(x_test, y_test)
原先在TensorFlow需要几十行才能构建的模型和流程,用tf.keras模块十几行就可以搞定了。
三. TensorFlow Hub中有许多可以直接使用的模型
TensorFlow Hub是TensorFlow官方提供的用于模型发布、复用的工具。例如下面的代码可以获取句子的Embedding,我们只需要给出TensorFlow Hub模型发布的url以及输入,通过简单的几行调用即可完成原先需要数百还才能完成的工作。另外,指定url的方式相比于自己下载模型的方式便利了许多。
import tensorflow as tf import tensorflow_hub as hub with tf.Graph().as_default(): module_url = "https://tfhub.dev/google/nnlm-en-dim128-with-normalization/1" embed = hub.Module(module_url) embeddings = embed(["A long sentence.", "single-word", "http://example.com"]) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) sess.run(tf.tables_initializer()) print(sess.run(embeddings))
四. 在静态图中也可以像动态图那样写条件判断语句
原先在静态图中是无法使用Python的if语句来为静态图定义条件判断结构的,需要使用特殊的tf.cond操作来定义一个条件判断节点,非常的麻烦,近期TensorFlow新出的AutoGraph功能可以让用户按照Python的if语句来定义结构,然后利用AutoGraph注解将其转换为相应的静态图结构,这样可以大幅度降低静态图构建的难度:
@autograph.convert() def fizzbuzz(num): if num % 3 == 0 and num % 5 == 0: print('FizzBuzz') elif num % 3 == 0: print('Fizz') elif num % 5 == 0: print('Buzz') else: print(num) return num with tf.Graph().as_default(): # The result works like a regular op: takes tensors in, returns tensors. # You can inspect the graph using tf.get_default_graph().as_graph_def() num = tf.placeholder(tf.int32) result = fizzbuzz(num) with tf.Session() as sess: for n in range(10,16): sess.run(result, feed_dict={num:n})
低调大师中文资讯倾力打造互联网数据资讯、行业资源、电子商务、移动互联网、网络营销平台。
持续更新报道IT业界、互联网、市场资讯、驱动更新,是最及时权威的产业资讯及硬件资讯报道平台。
转载内容版权归作者及来源网站所有,本站原创内容转载请注明来源。
- 上一篇
「大学生学编程系列」第二篇:如何选择第一门编程语言?
第一篇讲述了为什么要选择做一名程序员,从源头上讲述要想成为一名程序员需要很强的驱动力,因为编程相对而言算是比较难入门的一个职业。在入门之前必须有克服困难的勇气,有成为一名程序员的决心 有了决心和信心了,剩下的就是加足马力开干了,问题又来了怎么干,选什么样子的编程语言适合自学入手?要根据自身的实际情况出发选择编程语言切入。完全的零基础学习编程相对来讲要费劲很多,如果要学习建议先从计算机组成原理开始入手学习,对于零基础的来说,开始学习一般来讲都会信心百倍,要懂得保持住这份信念,所以上来不要把这份信心给打没了,先让自己缓冲一段时间,先从简单的入手,学习一段时间慢慢培养计算机语感,如同学习英文也需要培养语感是一样的,了解计算机基本的框架结构,进制之间是如何转化的,cpu和内存以及硬盘之间是如何关联的。不但涨了知识还能进一步培养自己的自信心。 如何选择第一门编程语言? 选择编程语言主要从以下几点入手: 1.第一优先级选择自己喜欢的编程语言,兴趣才是第一老师,这个可能和编程语言的难易程度以及是不是很好找工作多少有点冲突,因为喜欢就会舍得下功夫去钻研学习,人有时候就怕较真,一旦较真就没有干不成的事情...
- 下一篇
特朗普真的是笨蛋吗?至少搜索引擎是这么想的!
今天分享给大家一片来自“差评”的文章,比较有意思~请往下看: 假如你整打算找一些和笨蛋( idiot )有关系的图片,弹出来的结果是: 咦明明搜索的关键词是笨蛋才对啊,为什么谷歌图片返回的结果全都是美国现任总统川普啊。。 你感到一头雾水,思考笨蛋这个词到底和川普产生了什么样千丝万缕的关系,是谷歌工程师对川普不满开的玩笑?还是谷歌暗中开发了神奇的人工智能自动给川普贴上了标签?还是。。 实际上,这只不过是谷歌的图片搜索算法自然生成的结果。。 大家应该记得就在上个星期,美国总统川普访问了英国。 但是很明显有一大堆英国人民对这位我行我素的总统感到非常不满,他们甚至众筹了 18000 英镑做了一个 “ 川普宝宝 ” 的气球,在川普访问期间放飞,对川普的孩子气表示抗议。。 但是同样愤怒的另一帮英国人民为了表达自己的情绪,则是选择把绿日( Green Day )的著名歌曲《 美式笨蛋 》( American Idiot )顶上音乐榜前十,让大家都能看到这位 “ 来自美国的笨蛋 ”。。。 他们竟然还真成功了。。。 《 美式笨蛋 》这首歌在英国顺利成为英国摇滚榜,亚马逊音乐,谷歌音乐的第一名,英国下载榜...
相关文章
文章评论
共有0条评论来说两句吧...
文章二维码
点击排行
推荐阅读
最新文章
- SpringBoot2整合MyBatis,连接MySql数据库做增删改查操作
- SpringBoot2整合Redis,开启缓存,提高访问速度
- SpringBoot2全家桶,快速入门学习开发网站教程
- CentOS8安装MyCat,轻松搞定数据库的读写分离、垂直分库、水平分库
- Hadoop3单机部署,实现最简伪集群
- Jdk安装(Linux,MacOS,Windows),包含三大操作系统的最全安装
- CentOS7编译安装Cmake3.16.3,解决mysql等软件编译问题
- SpringBoot2更换Tomcat为Jetty,小型站点的福音
- CentOS7设置SWAP分区,小内存服务器的救世主
- Docker快速安装Oracle11G,搭建oracle11g学习环境