tensorflow serving java案例
背景介绍
这篇文章是tensorflow serving java api使用的参考案例,基本上把TFS的核心API的用法都介绍清楚。案例主要分为三部分:
- 动态更新模型:用于在TFS处于runtime时候动态加载模型。
- 获取模型状态:用于获取加载的模型的基本信息。
- 在线模型预测:进行在线预测,分类等操作,着重介绍在线预测。
因为模型的预测需要参考模型内部变量,所以可以先行通过TFS的REST接口获取TF模型的元数据然后才能构建TFS的RPC请求对象。
TFS 使用入门
模型源数据获取
curl http://host:port/v1/models/${MODEL_NAME}[/versions/${MODEL_VERSION}]/metadata
说明:
- 参考TFS REST API
- 返回结果参考TF模型结构。
public static void getModelStatus() { // 1、设置访问的RPC协议的host和port ManagedChannel channel = ManagedChannelBuilder.forAddress(host, port).usePlaintext().build(); // 2、构建PredictionServiceBlockingStub对象 PredictionServiceGrpc.PredictionServiceBlockingStub predictionServiceBlockingStub = PredictionServiceGrpc.newBlockingStub(channel); // 3、设置待获取的模型 Model.ModelSpec modelSpec = Model.ModelSpec.newBuilder() .setName("wdl_model").build(); // 4、构建获取元数据的请求 GetModelMetadata.GetModelMetadataRequest modelMetadataRequest = GetModelMetadata.GetModelMetadataRequest.newBuilder() .setModelSpec(modelSpec) .addAllMetadataField(Arrays.asList("signature_def")) .build(); // 5、获取元数据 GetModelMetadata.GetModelMetadataResponse getModelMetadataResponse = predictionServiceBlockingStub.getModelMetadata(modelMetadataRequest); channel.shutdownNow(); }
说明:
- Model.ModelSpec.newBuilder绑定需要访问的模型的名字。
- GetModelMetadataRequest中addAllMetadataField绑定curl命令返回的metadata当中的
signature_def
字段。
动态更新模型
public static void addNewModel() { // 1、构建动态更新模型1 ModelServerConfigOuterClass.ModelConfig modelConfig1 = ModelServerConfigOuterClass.ModelConfig.newBuilder() .setBasePath("/models/new_model") .setName("new_model") .setModelType(ModelServerConfigOuterClass.ModelType.TENSORFLOW) .build(); // 2、构建动态更新模型2 ModelServerConfigOuterClass.ModelConfig modelConfig2 = ModelServerConfigOuterClass.ModelConfig.newBuilder() .setBasePath("/models/wdl_model") .setName("wdl_model") .setModelType(ModelServerConfigOuterClass.ModelType.TENSORFLOW) .build(); // 3、合并动态更新模型到ModelConfigList对象中 ModelServerConfigOuterClass.ModelConfigList modelConfigList = ModelServerConfigOuterClass.ModelConfigList.newBuilder() .addConfig(modelConfig1) .addConfig(modelConfig2) .build(); // 4、添加到ModelConfigList到ModelServerConfig对象当中 ModelServerConfigOuterClass.ModelServerConfig modelServerConfig = ModelServerConfigOuterClass.ModelServerConfig.newBuilder() .setModelConfigList(modelConfigList) .build(); // 5、构建ReloadConfigRequest并绑定ModelServerConfig对象。 ModelManagement.ReloadConfigRequest reloadConfigRequest = ModelManagement.ReloadConfigRequest.newBuilder() .setConfig(modelServerConfig) .build(); // 6、构建modelServiceBlockingStub访问句柄 ManagedChannel channel = ManagedChannelBuilder.forAddress(host, port).usePlaintext().build(); ModelServiceGrpc.ModelServiceBlockingStub modelServiceBlockingStub = ModelServiceGrpc.newBlockingStub(channel); ModelManagement.ReloadConfigResponse reloadConfigResponse = modelServiceBlockingStub.handleReloadConfigRequest(reloadConfigRequest); System.out.println(reloadConfigResponse.getStatus().getErrorMessage()); channel.shutdownNow(); }
说明:
- 动态更新模型是一个全量的模型加载,在发布A模型后想动态发布B模型需要同时传递模型A和B的信息。
- 再次强调,需要全量更新,全量更新,全量更新!!!
在线模型预测
public static void doPredict() throws Exception { // 1、构建feature Map<String, Feature> featureMap = new HashMap<>(); featureMap.put("match_type", feature("")); featureMap.put("position", feature(0.0f)); featureMap.put("brand_prefer_1d", feature(0.0f)); featureMap.put("brand_prefer_1m", feature(0.0f)); featureMap.put("brand_prefer_1w", feature(0.0f)); featureMap.put("brand_prefer_2w", feature(0.0f)); featureMap.put("browse_norm_score_1d", feature(0.0f)); featureMap.put("browse_norm_score_1w", feature(0.0f)); featureMap.put("browse_norm_score_2w", feature(0.0f)); featureMap.put("buy_norm_score_1d", feature(0.0f)); featureMap.put("buy_norm_score_1w", feature(0.0f)); featureMap.put("buy_norm_score_2w", feature(0.0f)); featureMap.put("cate1_prefer_1d", feature(0.0f)); featureMap.put("cate1_prefer_2d", feature(0.0f)); featureMap.put("cate1_prefer_1m", feature(0.0f)); featureMap.put("cate1_prefer_1w", feature(0.0f)); featureMap.put("cate1_prefer_2w", feature(0.0f)); featureMap.put("cate2_prefer_1d", feature(0.0f)); featureMap.put("cate2_prefer_1m", feature(0.0f)); featureMap.put("cate2_prefer_1w", feature(0.0f)); featureMap.put("cate2_prefer_2w", feature(0.0f)); featureMap.put("cid_prefer_1d", feature(0.0f)); featureMap.put("cid_prefer_1m", feature(0.0f)); featureMap.put("cid_prefer_1w", feature(0.0f)); featureMap.put("cid_prefer_2w", feature(0.0f)); featureMap.put("user_buy_rate_1d", feature(0.0f)); featureMap.put("user_buy_rate_2w", feature(0.0f)); featureMap.put("user_click_rate_1d", feature(0.0f)); featureMap.put("user_click_rate_1w", feature(0.0f)); Features features = Features.newBuilder().putAllFeature(featureMap).build(); Example example = Example.newBuilder().setFeatures(features).build(); // 2、构建Predict请求 Predict.PredictRequest.Builder predictRequestBuilder = Predict.PredictRequest.newBuilder(); // 3、构建模型请求维度ModelSpec,绑定模型名和预测的签名 Model.ModelSpec.Builder modelSpecBuilder = Model.ModelSpec.newBuilder(); modelSpecBuilder.setName("wdl_model"); modelSpecBuilder.setSignatureName("predict"); predictRequestBuilder.setModelSpec(modelSpecBuilder); // 4、构建预测请求的维度信息DIM对象 TensorShapeProto.Dim dim = TensorShapeProto.Dim.newBuilder().setSize(300).build(); TensorShapeProto shapeProto = TensorShapeProto.newBuilder().addDim(dim).build(); TensorProto.Builder tensor = TensorProto.newBuilder(); tensor.setTensorShape(shapeProto); tensor.setDtype(DataType.DT_STRING); // 5、批量绑定预测请求的数据 for (int i=0; i<300; i++) { tensor.addStringVal(example.toByteString()); } predictRequestBuilder.putInputs("examples", tensor.build()); // 6、构建PredictionServiceBlockingStub对象准备预测 ManagedChannel channel = ManagedChannelBuilder.forAddress(host, port).usePlaintext().build(); PredictionServiceGrpc.PredictionServiceBlockingStub predictionServiceBlockingStub = PredictionServiceGrpc.newBlockingStub(channel); // 7、执行预测 Predict.PredictResponse predictResponse = predictionServiceBlockingStub.predict(predictRequestBuilder.build()); // 8、解析请求结果 List<Float> floatList = predictResponse .getOutputsOrThrow("probabilities") .getFloatValList(); }
说明:
- TFS的RPC请求过程中设置的参数需要考虑TF模型的数据结构。
- TFS的RPC请求有同步和异步两种方式,上述只展示同步方式。
TF模型结构
{ "model_spec": { "name": "wdl_model", "signature_name": "", "version": "4" }, "metadata": { "signature_def": { "signature_def": { "predict": { "inputs": { "examples": { "dtype": "DT_STRING", "tensor_shape": { "dim": [{ "size": "-1", "name": "" }], "unknown_rank": false }, "name": "input_example_tensor:0" } }, "outputs": { "logistic": { "dtype": "DT_FLOAT", "tensor_shape": { "dim": [{ "size": "-1", "name": "" }, { "size": "1", "name": "" } ], "unknown_rank": false }, "name": "head/predictions/logistic:0" }, "class_ids": { "dtype": "DT_INT64", "tensor_shape": { "dim": [{ "size": "-1", "name": "" }, { "size": "1", "name": "" } ], "unknown_rank": false }, "name": "head/predictions/ExpandDims:0" }, "probabilities": { "dtype": "DT_FLOAT", "tensor_shape": { "dim": [{ "size": "-1", "name": "" }, { "size": "2", "name": "" } ], "unknown_rank": false }, "name": "head/predictions/probabilities:0" }, "classes": { "dtype": "DT_STRING", "tensor_shape": { "dim": [{ "size": "-1", "name": "" }, { "size": "1", "name": "" } ], "unknown_rank": false }, "name": "head/predictions/str_classes:0" }, "logits": { "dtype": "DT_FLOAT", "tensor_shape": { "dim": [{ "size": "-1", "name": "" }, { "size": "1", "name": "" } ], "unknown_rank": false }, "name": "add:0" } }, "method_name": "tensorflow/serving/predict" }, "classification": { "inputs": { "inputs": { "dtype": "DT_STRING", "tensor_shape": { "dim": [{ "size": "-1", "name": "" }], "unknown_rank": false }, "name": "input_example_tensor:0" } }, "outputs": { "classes": { "dtype": "DT_STRING", "tensor_shape": { "dim": [{ "size": "-1", "name": "" }, { "size": "2", "name": "" } ], "unknown_rank": false }, "name": "head/Tile:0" }, "scores": { "dtype": "DT_FLOAT", "tensor_shape": { "dim": [{ "size": "-1", "name": "" }, { "size": "2", "name": "" } ], "unknown_rank": false }, "name": "head/predictions/probabilities:0" } }, "method_name": "tensorflow/serving/classify" }, "regression": { "inputs": { "inputs": { "dtype": "DT_STRING", "tensor_shape": { "dim": [{ "size": "-1", "name": "" }], "unknown_rank": false }, "name": "input_example_tensor:0" } }, "outputs": { "outputs": { "dtype": "DT_FLOAT", "tensor_shape": { "dim": [{ "size": "-1", "name": "" }, { "size": "1", "name": "" } ], "unknown_rank": false }, "name": "head/predictions/logistic:0" } }, "method_name": "tensorflow/serving/regress" }, "serving_default": { "inputs": { "inputs": { "dtype": "DT_STRING", "tensor_shape": { "dim": [{ "size": "-1", "name": "" }], "unknown_rank": false }, "name": "input_example_tensor:0" } }, "outputs": { "classes": { "dtype": "DT_STRING", "tensor_shape": { "dim": [{ "size": "-1", "name": "" }, { "size": "2", "name": "" } ], "unknown_rank": false }, "name": "head/Tile:0" }, "scores": { "dtype": "DT_FLOAT", "tensor_shape": { "dim": [{ "size": "-1", "name": "" }, { "size": "2", "name": "" } ], "unknown_rank": false }, "name": "head/predictions/probabilities:0" } }, "method_name": "tensorflow/serving/classify" } } } } }
低调大师中文资讯倾力打造互联网数据资讯、行业资源、电子商务、移动互联网、网络营销平台。
持续更新报道IT业界、互联网、市场资讯、驱动更新,是最及时权威的产业资讯及硬件资讯报道平台。
转载内容版权归作者及来源网站所有,本站原创内容转载请注明来源。
- 上一篇
Spring Cloud Alibaba基础教程:Nacos配置的加载规则详解
上一篇,我们学习了如何在Nacos中创建配置,以及如何使用Spring Cloud Alibaba的Nacos客户端模块来加载配置。在入门例子中,我们只配置了Nacos的地址信息,没有配置任何其他与配置加载相关的其他内容。所以,接下来准备分几篇说说大家问的比较多的一些实际使用的问题或疑问。 加载规则 在《Spring Cloud Alibaba基础教程:使用Nacos作为配置中心》一文中,我们的例子完全采用了默认配置完成。所以,一起来看看Spring Cloud Alibaba Nacos模块默认情况下是如何加载配置信息的。 首先,回顾一下,我们在入门例子中,Nacos中创建的配置内容是这样的: Data ID:alibaba-nacos-config-client.properties Group:DEFAULT_GROUP 拆解一下,主要有三个元素,它们与具体应用的配置内容对应关系如下: Data ID中的alibaba-nacos-config-client:对应客户端的配置spring.cloud.nacos.config.prefix,默认值为${spring.applica...
- 下一篇
tensorflow serving api
背景介绍 tensorflow serving 在客户端和服务端之间的通信采用的是RPC/REST协议。在TFS提供的REST协议接口存在一定的局限性,REST和RPC对比如下: 1、REST在实际应用中不能支持运行过程动态模型发布。 2、REST预测过程中组装报文格式复杂。 3、RPC接口提供TFS的所有核心能力。 基于以上原因且考虑我们主要使用java进行编程,因此我们必须具备能够编译TFS JAVA API的能力,这篇文章主要目的就是提供编译TFS JAVA API的方法。 编译方法 step_1 安装protoc protoc 3 已经有编译好的版本, 直接从protoc官网 下载编译好的安装包protoc-3.6.1-osx-x86_64.zip, 然后将命令复制到/usr/local/bin即可。 cd /tmp mv protoc-3.5.1-osx-x86_64.zip . unzip protoc-3.5.1-osx-x86_64.zip cd bin cp protoc /usr/local/bin/ lebron374$ protoc --version l...
相关文章
文章评论
共有0条评论来说两句吧...
文章二维码
点击排行
推荐阅读
最新文章
- Windows10,CentOS7,CentOS8安装MongoDB4.0.16
- CentOS8编译安装MySQL8.0.19
- MySQL8.0.19开启GTID主从同步CentOS8
- CentOS8安装Docker,最新的服务器搭配容器使用
- CentOS8,CentOS7,CentOS6编译安装Redis5.0.7
- SpringBoot2整合Redis,开启缓存,提高访问速度
- CentOS7,8上快速安装Gitea,搭建Git服务器
- Docker使用Oracle官方镜像安装(12C,18C,19C)
- CentOS关闭SELinux安全模块
- SpringBoot2初体验,简单认识spring boot2并且搭建基础工程