PyTorch翻译官网教程-DEPLOYING PYTORCH IN PYTHON VIA A REST API WITH FLASK

官网链接

Deploying PyTorch in Python via a REST API with Flask — PyTorch Tutorials 2.0.1+cu117 documentation

通过flask的rest API在python中部署pytorch

在本教程中,我们将使用Flask部署PyTorch模型,并开放用于模型推断的REST API。特别是,我们将部署一个预训练的DenseNet 121模型来检测图像。

这是关于在生产环境中部署PyTorch模型的系列教程中的第一篇。使用Flask这种方式是迄今为止部署PyTorch模型的最简单方法,但它不适用于具有高性能要求的用例。

  • 如果你已经熟悉了TorchScript,你可以直接跳到我们的加载一个TorchScript模型在c++教程。(Loading a TorchScript Model in C++ )
  • 如果你需要对TorchScript进行复习,请查看我们的TorchScript入门教程。(Intro a TorchScript )

API定义

我们将首先定义API 路径、请求和响应类型。我们的API路径是 /predict它接受带有包含图像的文件参数的HTTP POST请求。响应将是JSON响应,其中包含预测结果。

{"class_id": "n02124075", "class_name": "Egyptian_cat"}{"class_id": "n02124075", "class_name": "Egyptian_cat"}


依赖项

运行以下命令安装所需的依赖项:

$ pip install Flask==2.0.1 torchvision==0.10.0

简单Web服务器

下面是一个简单的web服务器,摘自Flask的文档

from flask import Flask
app = Flask(__name__)@app.route('/')
def hello():return 'Hello World!'

将上面的代码片段保存在一个名为app.py的文件中,现在你可以通过输入以下命令来运行Flask开发服务器:

$ FLASK_ENV=development FLASK_APP=app.py flask run

当您在web浏览器中访问http://localhost:5000/时,您将看到Hello World!文本

我们将对上面的代码片段做一些修改,使其适合我们的API定义。首先,我们将把方法重命名为predict。我们将把请求路径更新为/predict。由于图像文件将通过HTTP POST请求发送,我们将更新它,使其也只接受POST请求。

@app.route('/predict', methods=['POST'])
def predict():return 'Hello World!'

我们还将更改响应类型,以便它返回一个包含ImageNet类id和名称的JSON响应。更新后的app.py文件现在将是:

from flask import Flask, jsonify
app = Flask(__name__)@app.route('/predict', methods=['POST'])
def predict():return jsonify({'class_id': 'IMAGE_NET_XXX', 'class_name': 'Cat'})

推理

在下一节中,我们将重点讨论如何编写推理代码。这将涉及两个部分,一个是我们准备图像,以便它可以馈送到DenseNet,接下来,我们将编写代码以从模型中获得实际预测。

准备图像

DenseNet模型要求图像为3通道RGB图像,大小为224 x 224。我们还将用所需的均值和标准差值对图像张量进行归一化。你可以在这里读到更多(here)。

我们将使用torchvision 库中的 transforms ,并构建一个变换管道,它可以根据要求变换我们的图像。你可以在这里关于变换的内容(here)。

import ioimport torchvision.transforms as transforms
from PIL import Imagedef transform_image(image_bytes):my_transforms = transforms.Compose([transforms.Resize(255),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225])])image = Image.open(io.BytesIO(image_bytes))return my_transforms(image).unsqueeze(0)

上述方法接受字节数据的图像,应用一些列的transforms 并返回一个张量。为了测试上述方法,以字节模式读取图像文件(首先将../_static/img/sample_file.jpeg替换为计算机上文件的实际路径)并查看是否返回一个张量:

with open("../_static/img/sample_file.jpeg", 'rb') as f:image_bytes = f.read()tensor = transform_image(image_bytes=image_bytes)print(tensor)


预测

现在将使用预训练的DenseNet 121模型来预测图像类别。我们将使用torchvision库,加载模型并获得推理结果。虽然我们将在本例中使用预训练模型,但您可以对自己的模型使用相同的方法。了解更多关于加载模型的信息(tutorial)。

from torchvision import models# Make sure to set `weights` as `'IMAGENET1K_V1'` to use the pretrained weights:
model = models.densenet121(weights='IMAGENET1K_V1')
# Since we are using our model only for inference, switch to `eval` mode:
model.eval()def get_prediction(image_bytes):tensor = transform_image(image_bytes=image_bytes)outputs = model.forward(tensor)_, y_hat = outputs.max(1)return y_hat


张量y_hat 将包含预测类别的索引id, 然而,我们需要一个人类可读的类名。为此,我们需要一个类别id和命名的映射。下载 imagenet_class_index.json 这个文件( this file),并记住保存它的位置(或者,如果您遵循本教程中的确切步骤,将其保存在教程/_static中)。这个文件包含ImageNet类别id到ImageNet类名的映射。我们将加载这个JSON文件并获取预测类别索引的类名。

import jsonimagenet_class_index = json.load(open('../_static/imagenet_class_index.json'))def get_prediction(image_bytes):tensor = transform_image(image_bytes=image_bytes)outputs = model.forward(tensor)_, y_hat = outputs.max(1)predicted_idx = str(y_hat.item())return imagenet_class_index[predicted_idx]


在使用imagenet_class_index字典之前,首先我们将把张量值转换为字符串值,因为imagenet_class_index字典中的键是字符串。我们将测试上述方法:

with open("../_static/img/sample_file.jpeg", 'rb') as f:image_bytes = f.read()print(get_prediction(image_bytes=image_bytes))

你应该得到这样的返回:

['n02124075', 'Egyptian_cat']

数组中的第一项是ImageNet类别id,第二项是人类可读的名称。

注意

您是否注意到model变量不是get_prediction方法的局部变量,或者说为什么model是一个全局变量?就内存和计算而言,加载模型可能是一项昂贵的操作。如果我们在get_prediction方法中加载模型,那么每次调用该方法时都会不必要地加载模型。因为我们正在构建一个web服务器,每秒可能有数千个请求,我们不应该浪费时间为每个推理加载模型。因此,我们只将模型加载到内存中一次。在生产系统中,为了能够大规模地处理请求,必须高效地使用计算,因此通常应该在处理请求之前加载模型。


在我们的API服务器中集成模型

在最后一部分中,我们将把模型添加到Flask API服务器中。由于我们的API服务器应该接受一个图像文件,我们将更新我们的预测方法来从请求中读取文件:

from flask import request@app.route('/predict', methods=['POST'])
def predict():if request.method == 'POST':# we will get the file from the requestfile = request.files['file']# convert that to bytesimg_bytes = file.read()class_id, class_name = get_prediction(image_bytes=img_bytes)return jsonify({'class_id': class_id, 'class_name': class_name})

app.py文件现在已经完成。以下是完整版本;将路径替换为您保存文件的路径,它应该运行:

import io
import jsonfrom torchvision import models
import torchvision.transforms as transforms
from PIL import Image
from flask import Flask, jsonify, requestapp = Flask(__name__)
imagenet_class_index = json.load(open('<PATH/TO/.json/FILE>/imagenet_class_index.json'))
model = models.densenet121(weights='IMAGENET1K_V1')
model.eval()def transform_image(image_bytes):my_transforms = transforms.Compose([transforms.Resize(255),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225])])image = Image.open(io.BytesIO(image_bytes))return my_transforms(image).unsqueeze(0)def get_prediction(image_bytes):tensor = transform_image(image_bytes=image_bytes)outputs = model.forward(tensor)_, y_hat = outputs.max(1)predicted_idx = str(y_hat.item())return imagenet_class_index[predicted_idx]@app.route('/predict', methods=['POST'])
def predict():if request.method == 'POST':file = request.files['file']img_bytes = file.read()class_id, class_name = get_prediction(image_bytes=img_bytes)return jsonify({'class_id': class_id, 'class_name': class_name})if __name__ == '__main__':app.run()

让我们测试一下我们的web服务器!运行:

$ FLASK_ENV=development FLASK_APP=app.py flask run


我们可以使用requests库向我们的应用发送POST请求:

import requestsresp = requests.post("http://localhost:5000/predict",files={"file": open('<PATH/TO/.jpg/FILE>/cat.jpg','rb')})

打印rep .json()将显示以下内容:

{"class_id": "n02124075", "class_name": "Egyptian_cat"}

下一个步骤

我们编写的服务器非常简单,可能无法完成生产应用程序所需的所有功能。所以,这里有一些你可以做的事情来让它变得更好:

  • 请求路径 /predict假定请求中总是有一个图像文件。这可能并不适用于所有请求。我们的用户可以发送带有不同参数的图像或根本不发送图像。
  • 用户也可以发送非图像类型的文件。由于我们不处理错误,这将破坏我们的服务器。显式添加异常的错误处理路径,将使我们能够更好地处理错误输入
  • 尽管该模型可以识别大量的图像类别,但它可能无法识别所有的图像。优化实现以处理模型无法识别图像中的任何内容的情况。
  • 我们以开发模式运行Flask服务器,这种模式不适合部署到生产环境中。您可以查看本教程,了解如何在生产环境中部署Flask服务器。(this tutorial
  • 您还可以通过创建带有表单的页面来添加UI,该表单接受图像并显示预测结果。请查看类似项目的演示及其源代码。(source code.
  • 在本教程中,我们只展示了如何构建一个每次可以返回单个图像预测的服务。我们可以修改我们的服务,使其能够一次返回多个图像的预测结果。此外,service-streamer库会自动将请求排队到您的服务中,并将它们采样到可以馈送到模型中的小批量中。您可以查看本教程(this tutorial.)。
  • 最后,我们鼓励您查看页面顶部链接的关于部署PyTorch模型的其他教程.

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/2809.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

06-C++ 基本算法 - 二分法

&#x1f4d6; 前言 在这个笔记中&#xff0c;我们将介绍二分法这种基本的算法思想&#xff0c;以及它在 C 中的应用。我们将从一个小游戏猜数字开始&#xff0c;通过这个案例来引出二分法的概念。然后我们将详细讲解什么是二分法以及它的套路和应用。最后&#xff0c;我们还会…

在 3ds Max 中创建逼真的玻璃材质

推荐&#xff1a; NSDT场景编辑器助你快速搭建可二次开发的3D应用场景 尽管本教程基于 3ds Max&#xff0c;但相同的设置适用于许多其他 3D 产品。 注意&#xff1a;单击每个步骤中的缩略图可查看更大的屏幕截图&#xff0c;其中包括视口和用户界面的相关部分。 步骤 1由于本教…

广西学子复读15年,不服从分配。网友:完全是浪费时间

广西学子复读15年&#xff0c;不服从分配。网友&#xff1a;完全是浪费时间 唐尚珺的复读行为引起了网友们的不同解读。有人认为他是一个执念深重的人&#xff0c;目标是考上清华北大&#xff0c;但这个说法是否真实&#xff0c;我们无法确定。无论如何&#xff0c;我们必须认识…

electron+vue3全家桶+vite项目搭建【24】设置应用图标,打包文件的图标

文章目录 引入实现步骤测试结果 引入 demo项目地址 在electron中&#xff0c;我们可以通过electron-builder的配置文件来设置打包后的应用图标 实现步骤 因为mac环境下的图标需要特殊格式&#xff0c;这里我们可以利用electron-icon-builder进行配置 1.引入相关依赖 # 安…

GPT 如此强大,我们可以利用它实现什么?

GPT&#xff08;Generative Pre-trained Transformer&#xff09;是一种基于Transformer结构的预训练语言生成模型&#xff0c;由OpenAI研发。它可以生成高质量的自然语言文本&#xff0c;取得了很好的效果&#xff0c;被广泛应用于各个领域。以下是一些利用GPT实现的应用。 一…

Linux 处理僵尸进程

要查看僵尸进程的来源&#xff0c;可以使用ps命令或top命令。 使用ps命令&#xff1a; 打开终端或命令行界面&#xff0c;输入以下命令并按Enter执行&#xff1a;ps -e -o pid,ppid,state,cmd | grep -w Z该命令将显示所有僵尸进程的进程ID&#xff08;PID&#xff09;、父进…

ts中setState的类型

两种方法: 例子: 父组件 const [value, setValue] useState(); <ChildsetValue{setValue} />子组件 interface Ipros {setValue: (value: string) > void } const Child: React.FC<Ipros> (props) > {}

less常用用法简略总结

1、嵌套&#xff08;与sass相同&#xff09; ul{width:100px;li{width:99px;} } 2、变量&#xff08;变量名&#xff1a;值&#xff09;&#xff0c;sass&#xff1a;&#xff08;$color:green&#xff09; ColorA:green; ColorB:red; .box1{background-color: ColorA; } .b…

SpringMvc配置静态资源访问路径

文章目录 1. 整体流程2. registry.addResourceHandler()2.1 函数分析2.2 结果演示 3. ResourceHandlerRegistration.addResourceLocations()3.1 函数分析3.2 结果演示 1. 整体流程 1. 写一个配置类继承WebMvcConfigurationSupport 2. 利用 registry.addResourceHandler("…

chatgpt 讯飞星火 对比

"ChatGPT"和"讯飞星火"是两个不同的自然语言处理&#xff08;NLP&#xff09;模型&#xff0c;由不同的公司开发和提供。以下是它们之间的一些对比&#xff1a;1.开发公司&#xff1a; ChatGPT&#xff1a;由OpenAI开发&#xff0c;是OpenAI旗下的GPT-3模型…

Vue成绩案例实现添加、删除、显示无数据、添加日期、总分均分以及数据本地化等功能

一、成绩案例 ✅✅✅通过本次案例实现添加、删除、显示无数据、添加日期、总分均分以及数据本地化等功能。 准备成绩案例模板&#xff0c;我们需要在这些模板上面进行功能操作。 <template><div class"score-case"><div class"table">…

nginx基础3——配置文件详解(实用功能篇)

文章目录 一、平滑升级二、修饰符2.1 无修饰符效果2.2 精准匹配&#xff08;&#xff09;2.3 区分大小写匹配&#xff08;~&#xff09;2.4 不区分大小写匹配&#xff08;~*&#xff09;2.5 匹配优先级 三、访问控制四、用户认证五、配置https六、开启状态界面七、rewrite重写u…

matplotlib 3D

import matplotlib.pyplot as plt from mpl_toolkits.mplot3d import Axes3D import numpy as np# 创建一个三维坐标轴 fig plt.figure() ax fig.add_subplot(221, projection3d) xx fig.add_subplot(222) yy fig.add_subplot(223) xy fig.add_subplot(224)# 生成示例数据…

关于你欠缺的NoSQL中的redis和mongoDB

文章目录 前言一、在string list hash结构中&#xff0c;每个至少完成5个命令&#xff0c;包含插入 修改 删除 查询&#xff0c;list 和hash还需要增加遍历的操作命令1、STRING类型2、List类型数据的命令操作&#xff1a;3、举例说明list和hash的应用场景&#xff0c;每个至少一…

echarts图例对齐

富文本不生效&#xff0c;是没有设置lineHeight

企业内部FAQ系统的搭建重要性是什么?

企业内部FAQ系统&#xff08;Frequently Asked Questions&#xff0c;常见问题解答系统&#xff09;的搭建对于企业来说具有重要的意义。它可以帮助企业有效地管理和解决员工和客户的常见问题&#xff0c;提高工作效率和服务质量。 企业内部FAQ系统搭建的重要性&#xff1a; …

Python批量实现Word、EXCLE、PPT转PDF文件

一、绪论背景 在日常办公和文档处理中&#xff0c;有时我们需要将多个Word文档、Excel表格或PPT演示文稿转换为PDF文件。将文档转换为PDF格式的好处是它可以保留文档的布局和格式&#xff0c;并且可以在不同平台上进行方便的查看和共享。 本篇博文将介绍如何使用Python编程语言…

lua脚本语言学习笔记

Lua 是一种轻量小巧的脚本语言&#xff0c;用标准C语言编写并以源代码形式开放&#xff0c; 其设计目的是为了嵌入应用程序中&#xff0c;从而为应用程序提供灵活的扩展和定制功能。 因为我们使用redis的时候一般要写lua脚本&#xff0c;这篇文章就介绍一下lua脚本语言的基础用…

旅行社优惠卡app软件开发

旅游行业的不断发展&#xff0c;越来越多的旅行社开始推出各种优惠卡来吸引游客。而随着智能手机的普及&#xff0c;开发一款旅行社优惠卡APP软件成为了一种必然的趋势。 该软件的主要功能是提供旅行社的各种优惠卡信息&#xff0c;包括优惠卡的种类、价格、使用范围、有效…

C#开发winformwpf后台捕获鼠标移动事件

做 WPF和winform的时候&#xff0c;可以在界面上设置鼠标移动事件来检测鼠标移动&#xff0c;如果项目为后期改造这样做的话改动量很大&#xff0c;今天通过另外一种后台调用windows api的方式进行快速捕获和触发&#xff0c;提高开发效率分享给大家。 /// <summary>/// …