当前位置: 移动技术网 > IT编程>脚本编程>Python > flask搭建Keras服务出现的问题解决办法

flask搭建Keras服务出现的问题解决办法

2020年09月28日  | 移动技术网IT编程  | 我要评论
当使用Keras训练好了一个识别模型后,如果采用线上部署为服务,一般情况下采用flask或者Django进行服务搭建。

当使用Keras训练好了一个识别模型后,如果采用线上部署为服务,一般情况下采用flask或者Django进行服务搭建。在我电脑上使用的环境是Keras 2.3.1、tensorflow 1.15.3这个版本。将手写数字的识别模型使用flask部署为服务。代码如下:

from flask import Flask from flask import request import numpy as np import keras from keras import models import tensorflow as tf import cv2 #将network定义为全局变量 network=models.load_model('./model/mnist_cnn.h5') app=Flask(__name__) @app.route('/') def index(): return 'welcome to visit huanhuncao server' @app.route('/post',methods=['GET','POST']) def lx_post(): if request.method=='POST': test_img=cv2.imread('3.jpg',0) test_img=test_img.reshape(28,28,1) test_img=test_img.reshape((1,)+test_img.shape) test_img=test_img.astype('float32')/255 output=network.predict(test_img) output=output.argmax(axis=1) output=str(output) return output if __name__=='__main__': app.run(host='172.24.103.157',port=6001) 

我们分析如上的代码,首先在全局变量里就将模型文件载入进去,然后一旦接收到post请求,就将结果返回,从程序逻辑上看似乎没什么问题。运行这个服务,进行测试下:
发现服务会报错:
在这里插入图片描述
错误提示:

ValueError: Tensor Tensor("dense_2/Softmax:0", shape=(?, 10), dtype=float32) is not an element of this graph. 

错误的原因在于Keras使用tensorflow作为后端时,tensorflow的操作都是默认加载在一个默认的Graph中,所以如果为了避免出错,自己就要创建Graph以及Session。
针对这个问题进行修改,修改后的代码如下:

from flask import Flask from flask import request import numpy as np import keras from keras import models import tensorflow as tf import cv2 #将network定义为全局变量 global sess,graph #tf2.x中为sess = tf.compat.v1.keras.backend.get_session() sess=keras.backend.get_session() graph=tf.get_default_graph() network=models.load_model('./model/mnist_cnn.h5') app=Flask(__name__) @app.route('/') def index(): return 'welcome to visit huanhuncao server' @app.route('/post',methods=['GET','POST']) def lx_post(): if request.method=='POST': test_img=cv2.imread('3.jpg',0) test_img=test_img.reshape(28,28,1) test_img=test_img.reshape((1,)+test_img.shape) test_img=test_img.astype('float32')/255 #在默认会话与计算图中进行模型的预测 with sess.as_default(): with graph.as_default(): output=network.predict(test_img) output=output.argmax(axis=1) output=str(output) return output if __name__=='__main__': app.run(host='172.24.103.157',port=6001) 

运行这个服务,可以看到已经成功了。
在这里插入图片描述
在使用flask进行Keras模型预测的时候,还有一种错误会出现,比如使用Keras 2.2和tensorflow 1.15会出现一种线程错误,这种解决办法是将flask以单线程进行运行,网上很多说法是将Keras和tensorflow的版本进行降级,在我看来是不需要这种做法的。

app.run(host='172.24.103.157',port=6001,threaded=False) 

其实只要tensorflow和Keras的版本一一对应上,是不会出现这个问题的。Keras与tensorflow的版本对应从这个网页里可以查看。

本文地址:https://blog.csdn.net/qq_37781464/article/details/108842854

版权声明:本文内容由互联网用户贡献,该文观点仅代表作者本人。本站仅提供信息存储服务,不拥有所有权,不承担相关法律责任。 如发现本站有涉嫌抄袭侵权/违法违规的内容, 请发送邮件至 2386932994@qq.com 举报,一经查实将立刻删除。

相关文章:

验证码:
移动技术网