在本教程中,我们将介绍一个简单的方法来获取Keras模型并将其部署为REST API。本文所介绍的示例将作为你构建自己的深度学习API的模板/起点——你可以扩展代码,根据API端点的可伸缩性和稳定性对其进行定制。
具体而言,我们将了解:
· 如何(以及如何不)将Keras模型加载到内存中,以便有效地进行推理
· 如何使用Flask web框架为我们的API创建端点
· 如何使用我们的模型进行预测,用JSON-ify转换它们,并将结果反馈到客户端
· 如何使用cURL和Python来调用我们的Keras REST API
在本教程结束时,你将能很好地理解创建Keras REST API所需的组件(以最简单的形式)。
请随意使用本指南中提供的代码作为你自己的深度学习REST API起点。
配置开发环境
假设Keras已经配置并安装在你的机器上。如果没有,请确保使用官方安装说明安装Keras(https://keras.io/#installation)。
然后,需要安装Flask (http://flask.pocoo.org/)(及其相关的依赖项),一个Python web框架,这样就可以构建API端点了。还需要请求(http://docs.python-requests.org/en/master/),这样就可以使用API了。
有关的pip安装命令如下:
$ pip install flask gevent requests pillow
构建你的Keras REST API
Keras REST API独立于一个名为run_keras_server.py的文件中。为了简单起见,我们将安装保存在一个文件中——安装启用也可以很容易地模块化。
在 run_keras_server.py中,你会发现三个函数,即:
· load_model:用于加载训练好的Keras模型,并为推理做准备。
· prepare_image:这个函数在通过我们的网络进行预测之前对输入图像进行预处理。如果你没有使用图像数据,则可能需要考虑将名称更改为更通用的prepare_datapoint,并应用一些可能需要的缩放/标准化。
· predict:API的实际端点可以将请求中的输入数据分类,并将结果反馈给客户端。
# import the necessary packagesfrom keras.applications import ResNet50from keras.preprocessing.image import img_to_arrayfrom keras.applications import imagenet_utilsfrom PIL import Imageimport numpy as npimport flaskimport io
# initialize our Flask application and the Keras modelapp = flask.Flask(__name__)model = None
第一个代码片段处理导入了所需的程序包,并且对Flask应用程序和模型进行了初始化。
在此,我们定义load_model函数:
def load_model():
# load the pre-trained Keras model (here we are using a model
# pre-trained on ImageNet and provided by Keras, but you can
# substitute in your own networks just as easily)
global model
model = ResNet50(weights="imagenet")
顾名思义,这个方法负责将我们的架构实例化,并从磁盘加载权重。
为了简单起见,将使用在ImageNet数据集上预先训练过的ResNet50架构。
如果你正在使用自定义模型,则需要修改此函数以从磁盘加载架构+权重。
在对任何来自客户端的数据进行预测之前,首先需要准备并预处理数据:
def prepare_image(image, target):
# if the image mode is not RGB, convert it
if image.mode != "RGB":
image = image.convert("RGB")
# resize the input image and preprocess it
image = image.resize(target)
image = img_to_array(image)
image = np.expand_dims(image, axis=0)
image = imagenet_utils.preprocess_input(image)
# return the processed image
return image
这个函数:
· 接受输入图像
· 将模式转换为RGB(如果需要)
· 将大小调整为224x224像素(ResNet的输入空间维度)
· 通过平均减法数组和缩放对阵列进行预处理
此外,在通过模型传递输入数据之前,应该根据某一预处理、缩放或标准化来修改这个函数。
现在可以定义predict函数了——该方法会处理对/predict端点的任何请求:
@app.route("/predict