利用gradio快速搭建AI应用

引言

Gradio 是一个用于快速创建交互式界面的Python库,这些界面可以用于演示和测试机器学习模型。使用Gradio,开发者可以非常轻松地为他们的模型构建一个前端界面,而不需要任何Web开发经验。
在这里插入图片描述

与类似产品的对比

  • TensorBoard:主要用于TensorFlow的训练可视化。而Gradio则更注重模型的交互式演示。
  • Streamlit:也是一个快速创建交互式应用的工具,但Gradio更注重于机器学习模型的界面。

Gradio三大特点

  • 设置快速、简单
    Gradio 可以通过 pip 安装。创建一个Gradio界面只需要添加几行代码 到您的项目。无缝使用计算机上的任何 Python 库。如果你能写一个 python函数,gradio可以运行它。

  • 呈现并分享
    Gradio 可以嵌入到Python Notebookks 或呈现为 网页。Gradio 界面可以自动生成您可以共享的公共链接 与同事一起,让他们与您计算机上的模型进行交互 从他们自己的设备远程。

  • 永久托管
    创建界面后,您可以将其永久托管在 Hugging Face 上。Hugging Face Spaces 将在其服务器上托管该界面,并为您提供一个链接,您可以分享。

快速应用示例

安装

Gradio 最显著的优势之一是它支持各种机器学习框架,包括 PyTorch。这意味着无论您使用哪个框架来训练模型,您都可以使用 Gradio 轻松部署它。要开始使用 Gradio,您需要安装所需的库。您可以使用 pip 安装它们:

pip install gradio torch torchvision

针对 MNIST 数据训练深度学习模型

在本节中,我们将建立模型。我们使用 CNN 模型并使用 MNIST 数据集对其进行训练。让我们一起建设吧!

第一步,我们需要为模型设置 MNIST 数据集。幸运的是,我们可以使用torchvision库来帮助轻松下载和准备我们的数据集。

from torchvision import datasets, transforms
import torchtransform=transforms.Compose([transforms.ToTensor(),transforms.Resize(28),transforms.Normalize((0.1307,), (0.3081,)),])dataset1 = datasets.MNIST('../data', train=True, download=True,transform=transform)dataset2 = datasets.MNIST('../data', train=False,transform=transform)train_loader = torch.utils.data.DataLoader(dataset1, batch_size=10000, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset2, batch_size=1000)

现在,我们将使用卷积层设计模型,如下所示:

import torch.nn as nn
import torch.nn.functional as Fclass CNNModel(nn.Module):def __init__(self):super(CNNModel, self).__init__()self.conv1 = nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1)self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)self.fc1 = nn.Linear(32 * 7 * 7, 128)self.fc2 = nn.Linear(128, 10)def forward(self, x):x = F.relu(self.conv1(x))x = F.max_pool2d(x, 2)x = F.relu(self.conv2(x))x = F.max_pool2d(x, 2)x = x.view(-1, 32 * 7 * 7)x = F.relu(self.fc1(x))x = self.fc2(x)return x# Make device agnostic code
device = "cuda" if torch.cuda.is_available() else "cpu"
# Create an instance of the CNN model
model = CNNModel().to(device)

该 CNN 模型由两个具有 ReLU 激活的卷积层组成,后面是用于下采样的最大池化层。在卷积层之后,有两个带有 ReLU 激活的全连接层,用于最终分类。

下面是训练部分:

# Setup loss function and optimizer
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(params=model.parameters(), lr=0.1)
# Set the number of epochs
epochs = 15for epoch in tqdm(range(epochs)):for batch, (X, y) in enumerate(train_loader):model.train()# Put data on the target deviceX, y = X.to(device), y.to(device)#1. Forward passy_pred = model(X)# 2. Calculate loss (per batch)loss = loss_fn(y_pred, y)# 3. Optimizer zero gradoptimizer.zero_grad()# 4. Loss backwardloss.backward()# 5. Optimizer stepoptimizer.step()

然后,我们评估模型在测试数据集上的性能。

# Define the accuracy function
def accuracy_fn(y_true, y_pred):correct = torch.eq(y_true, y_pred).sum().item()acc = (correct / len(y_pred)) * 100return acc# Load and test the model on test data
for X, y in test_loader:X, y = X.to(device), y.to(device)y_pred = model(X)print(accuracy_fn(y_true = y, y_pred = y_pred.argmax(dim = 1)))
94.54

这看起来准确度分数非常高(但不要相信它们;我将在部署部分进行解释)。最后,我们保存模型以便使用 Gradio 进行部署。

torch.save(obj=model.state_dict(), f='./mnist_model.pt')

开始使用 Gradio

现在,确保您已准备好经过训练的 PyTorch 模型。对于此示例,我们假设您有一个在自定义数据集上训练的草图识别模型。

接下来,让我们深入研究一些代码并使用 Gradio 部署模型:

首先,我们需要一段代码来存储我们在上一节中制作的模型结构和transformer函数。

### loaded_model.py
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transformsclass CNNModel(nn.Module):def __init__(self):super(CNNModel, self).__init__()self.conv1 = nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1)self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)self.fc1 = nn.Linear(32 * 7 * 7, 128)self.fc2 = nn.Linear(128, 10)def forward(self, x):x = F.relu(self.conv1(x))x = F.max_pool2d(x, 2)x = F.relu(self.conv2(x))x = F.max_pool2d(x, 2)x = x.view(-1, 32 * 7 * 7)x = F.relu(self.fc1(x))x = self.fc2(x)return xdef create_transformer():transform=transforms.Compose([transforms.ToTensor(),transforms.Resize(28),transforms.Normalize((0.1307,), (0.3081,)),])return transform

然后,下面的代码向您展示了如何使用 Gradio 部署经过训练的模型:

### app.py
import gradio as gr
import torch
from loaded_model import CNNModel
import loaded_modelmodel = CNNModel()
model.load_state_dict(torch.load(f="./mnist_model.pt"))# Define a function to make predictions with your model
def classify_image(image):# Preprocess the imagepreprocess = loaded_model.create_transformer()image_tensor = preprocess(image)image_tensor = image_tensor.unsqueeze(0)# Make predictionwith torch.no_grad():output = model(image_tensor)_, predicted_class = torch.max(output, 1)return f"Predicted class: {predicted_class.item()}"gr.Interface(fn=classify_image, inputs="sketchpad", outputs="label").launch()

最后,在您的网络浏览器上运行 Gradio:

python app.py

恭喜!你可以部署它:
在这里插入图片描述
然而,有时这个模型会做出错误的预测。此错误可能是由于 MNIST 数据不适合从 Gradio 输入接口输入造成的。因此,如果您在现实世界中需要一个良好的性能模型,您应该使用原始数据(来自 Graido 输入接口的数字图像)来训练模型。

结论

部署机器学习模型可能是一项艰巨的任务,但 Gradio 通过提供用户友好且直观的界面简化了该过程。只需几行代码,您就可以将 PyTorch 模型转变为交互式 Web 应用程序并与世界分享。无论您是经验丰富的数据科学家还是刚刚开始学习机器学习,Gradio 都是一个有价值的工具,可以帮助您展示模型并与更广泛的受众互动。

那么,为什么还要等呢?尝试一下 Gradio 和 PyTorch,立即开始分享您的机器学习模型。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/231880.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【Python从入门到进阶】44、Scrapy的基本介绍和安装

接上篇《43.验证码识别工具结合requests的使用》 上一篇我们学习了如何使用验证码识别工具进行登录验证的自动识别。本篇我们开启一个新的章节,来学习一下快速、高层次的屏幕抓取和web抓取框架Scrapy。 一、Scrapy框架的背景和特点 Scrapy框架是一个为了爬取网站数…

C++内存布局

温故而知新,本文浅聊和回顾下C内存布局的知识。 一、c内存布局 C的内存布局主要包括以下几个部分: 代码段:存储程序的机器代码。.数据段:存储全局变量和静态变量。数据段又分为初始化数据段(存储初始化的全局变量和…

python与机器学习2,激活函数

目录 1 什么是激活函数? activation function 1.1 阈值 1.2 激活函数a(x) ,包含偏置值θ 1.3 激活函数a(x) ,包含偏置值b 2 激活函数1: 单位阶跃函数 2.1 函数形式 2.2 函数图形 2.3 函数特点 2.4 代码实现这个 单位阶跃函数 3 激活…

Convolutional Neural Network(CNN)——卷积神经网络

1.NN的局限性 拓展性差 NN的计算量大性能差,不利于在不同规模的数据集上有效运行若输入维度发生变化,需要修改并重新训练网络容易过拟合 全连接导致参数量特别多,容易过拟合如果增加更多层,参数量会翻倍无法有效利用局部特征 输入…

结构型设计模式(三)享元模式 代理模式 桥接模式

享元模式 Flyweight 1、什么是享元模式 享元模式的核心思想是共享对象,即通过尽可能多地共享相似对象来减少内存占用或计算开销。这意味着相同或相似的对象在内存中只存在一个共享实例。 2、为什么使用享元模式 减少内存使用:通过共享相似对象&#…

汽车UDS诊断——SecureDataTransmission 加密数据传输(0x84)

诊断协议那些事儿 诊断协议那些事儿专栏系列文章,本文介绍诊断和通讯管理功能单元下的84服务SecureDataTransmission,在常规诊断通信中,数据极易被第三方获取,所以在一些特殊的数据传输时,标准定义了加密数据传输的服务。 简而言之,就是在发送诊断数据时,发送方先把数…

fragstats:景观指数的时间序列分析框架

作者:CSDN _养乐多_ 本文将介绍景观指数的时间序列分析计算的软件使用方法和 python 代码,该框架可用于分析景观指数时间序列图像的趋势分析、突变分析、机器学习(分类/聚类/回归)、相关性分析、周期分析等方面。 文章目录 一、…

智能优化算法应用:基于人工电场算法3D无线传感器网络(WSN)覆盖优化 - 附代码

智能优化算法应用:基于人工电场算法3D无线传感器网络(WSN)覆盖优化 - 附代码 文章目录 智能优化算法应用:基于人工电场算法3D无线传感器网络(WSN)覆盖优化 - 附代码1.无线传感网络节点模型2.覆盖数学模型及分析3.人工电场算法4.实验参数设定5.算法结果6.…

04-Revision和流量管理

1 Revision 关于Revision 应用程序代码及相关容器配置某个版本的不可变快照KService上的spec.template的每次变动,都会自动生成一个新的Revision通常不需要手动创建及维护 Revision的使用场景 将流量切分至不同版本的应用程序间(Canary Deployment、Blu…

静态路由及动态路由

文章目录 静态路由及动态路由一、静态路由基础1. 静态路由配置2. 负载分担3. 路由备份4. 缺省路由5. 静态路由实操 二、RIP 动态路由协议1. RIP 协议概述2. RIP 协议版本对比2.1 有类路由及无类路由 3. RIP 路由协议原理4. RIP 计时器5. 度量值6. 收敛7. 示例 静态路由及动态路…

Kafka基本原理及使用

目录 基本概念 单机版 环境准备 基本命令使用 集群版 消息模型 成员组成 1. Topic(主题): 2. Partition(分区): 3. Producer(生产者): 4. Consumer(…

使用TensorRT对Yolov5进行部署【基于Python】

如果还未配置TensorRT,请看这篇博文:Win11下TensorRT环境部署 这里使用TensorRT对Yolov5进行部署流程比较固定:先将pt模型转换为onnx,再将onnx模型转为engine,所以在执行export.py时要将onnx、engine给到include。 P…

Linear Regression线性回归(一元、多元)

目录 介绍: 一、一元线性回归 1.1数据处理 1.2建模 二、多元线性回归 2.1数据处理 2.2数据分为训练集和测试集 2.3建模 介绍: 线性回归是一种用于预测数值输出的统计分析方法。它通过建立自变量(也称为特征变量)和因变…

【Redis】五、Redis持久化、RDB和AOF

文章目录 Redis持久化一、RDB(Redis DataBase)触发机制如何恢复rdb文件 二、AOF(Append Only File)三、扩展 Redis持久化 面试和工作,持久化都是重点! Redis 是内存数据库,如果不将内存中的数据…

微服务实战系列之ZooKeeper(实践篇)

前言 关于ZooKeeper,博主已完整的通过庖丁解牛式的“解法”,完成了概述。我想掌握了这些基础原理和概念后,工作的问题自然迎刃而解,甚至offer也可能手到擒来,真实一举两得,美极了。 为了更有直观的体验&a…

uniapp 预览图片

preImg(index){let urls []this.images.map((item,i) > {if(indexi){urls.unshift(item.file_path)}else{urls.push(item.file_path)}})uni.previewImage({urls})}

linux之Samba服务器

环境:虚拟机CENTOS 7和 测试机相通 一、Samba服务器_光盘共享(匿名访问) 1.在虚拟机CENTOS 7安装smb服务,并在防火墙上允许samba流量通过 2. 挂载光盘 3.修改smb.conf配置文件,实现光盘匿名共享 4. 启动smb服务 5.在…

JVM基础扫盲

什么是JVM JVM是Java设计者用于屏蔽多平台差异,基于操作系统之上的一个"小型虚拟机",正是因为JVM的存在,使得Java应用程序运行时不需要关注底层操作系统的差异。使得Java程序编译只需编译一次,在任何操作系统都可以以相…

英码科技受邀参加2023计算产业生态大会,分享智慧轨道交通创新解决方案

12月13-14日,“凝心聚力,共赢计算新时代”——2023计算产业生态大会在北京香格里拉饭店成功举办。英码科技受邀参加行业数字化分论坛活动,市场总监李甘来先生现场发表了题为《AI哨兵,为铁路安全运营站好第一道岗》的精彩主题演讲&…

1951 年以来的美国ACIS 气候地图数据集(5 公里空间分辨率)

应用气候信息系统 (ACIS) NRCC NN ACIS是Applied Climate Information System的缩写,是由美国国家气象局(NOAA)开发的一种气候信息系统。ACIS气候地图是通过收集和整理全球的气象数据,利用计算机技术和数据分析方法生成的气候图表…