您现在的位置是:首页 > 文章详情

使用Java部署TensorFlow和Keras训练好的深度学习模型的几种方法

日期:2019-07-17点击:416

使用Java部署TensorFlow和Keras训练好的深度学习模型的几种方法
写在前面
最近在一个自然语言处理方面的项目,选用的深度学习模型有两个,一个是CNN+LSTM模型,一个是GRU模型,这两个模型在GPU服务器上训练好了,然后需要使用Java调用这两个模型,CNN+LSTM使用TensorFlow写的,GRU是用Keras写的,所以需要用Java部署TensorFlow和Keras训练好的深度学习模型。关于这方面的内容网上并不是很多,我也是费了很多周折才完成任务的,这里来总结一下具体有哪些方法可以用,这些方法又有哪些缺陷,以供大家学习交流。

一、使用Java深度学习框架直接部署
(1)使用TensorFlow Java API部署TensorFlow模型
如果我们使用的是TensorFlow训练的模型,那么我们就可以直接使用Java中的TensorFlow API调用模型。这里需要注意的是我们得把训练好的模型保存为.pb格式的文件。具体代码如下:

constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, output_node_names=["quest_out"])

写入序列化的 PB 文件

with tf.gfile.FastGFile('/home/amax/zth/qa/new_model2_cpu.pb', mode='wb') as f:

f.write(constant_graph.SerializeToString())

然后我们需要在Java使用这个保存好的模型,在pom.xml中引入TensorFlow的依赖

 <groupId>org.tensorflow</groupId> <artifactId>tensorflow</artifactId> <version>1.11.0</version> </dependency> <!-- https://mvnrepository.com/artifact/org.tensorflow/libtensorflow_jni_gpu --> <dependency> <groupId>org.tensorflow</groupId> <artifactId>libtensorflow_jni_gpu</artifactId> <version>1.11.0</version> </dependency> 

导包成功后,在Java中调用模型

graphDef = readAllBytes(new FileInputStream(new_model2_cpu.pb));
graph = new Graph();
graph.importGraphDef(graphDef);
session = new Session(graph);
Tensor result = session.runner()

 .feed("ori_quest_embedding", Tensor.create(wordVecInputSentence))//输入你自己的数据 .feed("dropout", Tensor.create(1.0F)) .fetch("quest_out") //和上面python保存模型时的output_node_names对应 .run().get(0);

//这样就能得到模型的输出结果了

(2)使用Deeplearning4J Java API部署Keras模型
如果我们使用的是Keras训练的模型,那么你就可以选择Deeplearning4J 这个框架来调用模型。
第一步同样是使用Keras保存训练好的模型

filepath = "query_models"

checkpoint = ModelCheckpoint(filepath, monitor='val_acc', verbose=1, save_best_only=True)

callback_lists = [checkpoint]

model.fit(x, y, epochs=1,validation_split=0.2,callbacks=callback_lists)

然后同样是Java项目中pom.xml导入Deeplearning4J 依赖

<groupId>org.deeplearning4j</groupId> <artifactId>deeplearning4j-core</artifactId> <version>1.0.0-beta2</version> 


<groupId>org.deeplearning4j</groupId> <artifactId>deeplearning4j-modelimport</artifactId> <version>1.0.0-beta2</version> 

库导入成功后,直接使用Java调用保存好的模型

MultiLayerNetwork model = KerasModelImport.importKerasSequentialModelAndWeights(“query_models”);

这样模型就部署成功了,然后关于怎么使用模型这里就不多说了。
注意:这里需要注意的是Deeplearning4J 只支持部分深度学习模型,有些模型是不支持的,譬如我这里使用的GRU模型就不支持,运行上面代码会出现以下错误。明确指明不支持GRU模型

去Deeplearning4J 官网查询发现确实现在不支持GRU模型,以下是官网截图

所以如果你想使用Deeplearning4J 来部署训练好的模型,请先查看下是否支持你所使用的模型。

二、使用Python编写服务端
(1)使用socket实现进程间的通信
用python构建服务端,然后通过Java向服务端发送请求调用模型,第一种是使用socket实现进程中的通信,代码如下:

import socket
import sys
import threading
import json
import numpy as np
import jieba
import os
import numpy as np
import nltk
import keras
from keras.models import Sequential
from keras.layers import Dense
from keras.layers import Dropout
from keras.layers import LSTM,GRU,TimeDistributed
from keras.callbacks import ModelCheckpoint
from keras.utils import np_utils
from gensim.models.word2vec import Word2Vec
from keras.optimizers import Adam
from keras.models import load_model
import pickle

nn=network.getNetWork()

cnn = conv.main(False)

深度学习训练的神经网络,使用TensorFlow训练的神经网络模型,保存在文件中

w2v_model = Word2Vec.load("word2vec.w2v").wv
UNK = pickle.load(open('unk.pkl','rb'))
model = load_model('query_models')
a = np.zeros((1, 223,200))
model.predict(a)

def test_init(string):

cut_list = jieba.lcut(string) x_test = np.zeros((1, 223,200)) for i in range(223): x_test[0,i,:] = UNK for i in range(len(cut_list)): if cut_list[i] in w2v_model: x_test[0,i,:] = w2v_model.wv[cut_list[i]] return x_test,len(cut_list) 

string_list = list()
def query_complet(string):

x_test,length = test_init(string) y = model.predict(x_test) if length>8: return word1 = w2v_model.wv.most_similar(positive=[y[0][length-1]], topn=2)[0][0] word2 = w2v_model.wv.most_similar(positive=[y[0][length-1]], topn=2)[1][0] if word1 == '?' or word1 == '?': string_list.append(string) else: new_str = string+word1 query_complet(new_str) if word2 == '?' or word2 == '?': string_list.append(string) else: new_str = string+word2 query_complet(new_str) 

def new_query_complet(string):

query_complet(string) return string_list 

def main():

# 创建服务器套接字 serversocket = socket.socket(socket.AF_INET,socket.SOCK_STREAM) # 设置一个端口 port = 12345 # 将套接字与本地主机和端口绑定 serversocket.bind(("172.17.169.232",port)) # 设置监听最大连接数 serversocket.listen(5) # 获取本地服务器的连接信息 myaddr = serversocket.getsockname() print("服务器地址:%s"%str(myaddr)) # 循环等待接受客户端信息 while True: # 获取一个客户端连接 clientsocket,addr = serversocket.accept() print("连接地址:%s" % str(addr)) try: t = ServerThreading(clientsocket)#为每一个请求开启一个处理线程 t.start() pass except Exception as identifier: print(identifier) pass pass serversocket.close() pass 

class ServerThreading(threading.Thread):

# words = text2vec.load_lexicon() def __init__(self,clientsocket,recvsize=1024*1024,encoding="utf-8"): threading.Thread.__init__(self) self._socket = clientsocket self._recvsize = recvsize self._encoding = encoding pass def run(self): print("开启线程.....") try: #接受数据 msg = '' while True: # 读取recvsize个字节 rec = self._socket.recv(self._recvsize) # 解码 msg += rec.decode(self._encoding) # 文本接受是否完毕,因为python socket不能自己判断接收数据是否完毕, # 所以需要自定义协议标志数据接受完毕 if msg.strip().endswith('over'): msg=msg[:-4] break # 发送数据 self._socket.send("啦啦啦啦".encode(self._encoding)) pass except Exception as identifier: self._socket.send("500".encode(self._encoding)) print(identifier) pass finally: self._socket.close() print("任务结束.....") pass 

//启动服务
main()

Java客户端代码如下:

public void test2() throws IOException { JSONObject jsonObject = new JSONObject(); String content = "医疗保险缴费需要"; jsonObject.put("content", content); String str = jsonObject.toJSONString(); // 访问服务进程的套接字 Socket socket = null;

// List questions = new ArrayList<>();
// log.info("调用远程接口:host=>"+HOST+",port=>"+PORT);

 try { // 初始化套接字,设置访问服务的主机和进程端口号,HOST是访问python进程的主机名称,可以是IP地址或者域名,PORT是python进程绑定的端口号 socket = new Socket("172.17.169.232",12345); // 获取输出流对象 OutputStream os = socket.getOutputStream(); PrintStream out = new PrintStream(os); // 发送内容 out.print(str); // 告诉服务进程,内容发送完毕,可以开始处理 out.print("over"); // 获取服务进程的输入流 InputStream is = socket.getInputStream(); String text = IOUtils.toString(is); System.out.println(text); } catch (IOException e) { e.printStackTrace(); } finally { try {if(socket!=null) socket.close();} catch (IOException e) {} System.out.println("远程接口调用结束."); } } 

socket实现Python服务端确实比较简单,但是代码量比较大,没有前面Java直接部署训练好的模型简单。

(2)使用Python的Flask框架
Flask框架实现服务端,这个框架我是听我同学说的,因为他们公司就是使用这种方法部署深度学习模型的,不过我们项目当中没有用到,有兴趣的同学可以自己去了解一下这个Flask框架,这里不累述了。

总结

常用的方法基本上就上面这些了,以上方法各有各的优缺点,大家可以根据自己的项目需求自行选择合适的方法来部署训练好的深度学习模型,希望这篇博客可以帮到你们。

作者:中二小苇
来源:CSDN
原文:https://blog.csdn.net/u012350430/article/details/96272968
版权声明:本文为博主原创文章,转载请附上博文链接!

原文链接:https://yq.aliyun.com/articles/709726
关注公众号

低调大师中文资讯倾力打造互联网数据资讯、行业资源、电子商务、移动互联网、网络营销平台。

持续更新报道IT业界、互联网、市场资讯、驱动更新,是最及时权威的产业资讯及硬件资讯报道平台。

转载内容版权归作者及来源网站所有,本站原创内容转载请注明来源。

文章评论

共有0条评论来说两句吧...

文章二维码

扫描即可查看该文章

点击排行

推荐阅读

最新文章