MindSpore基础教程:使用 MindCV和 Gradio 创建一个图像分类应用

MindSpore基础教程:使用 MindCV和 Gradio 创建一个图像分类应用

官方文档教程使用已经弃用的MindVision模块,本文是对官方文档的更新
在这篇博客中,我们将探索如何使用 MindSpore 框架和 Gradio 库来创建一个基于深度学习的图像分类应用。我们将使用预训练的 ResNet50 模型,以 CIFAR-10 数据集为例进行训练,并通过 Gradio 接口进行图像分类预测。下面是一个简单、直观的指南,适用于希望将深度学习模型转换为交互式应用的开发者。

训练模型

环境设置

首先,我们需要设置 GPU 作为训练的目标设备。MindSpore 提供了一个便捷的方式来配置环境。

from mindspore import context
context.set_context(device_target="GPU")

解析参数

我们使用 argparse 来解析命令行参数。这样可以方便地在训练时调整参数,例如数据集路径、学习率和训练周期数。

import argparse
def parse_args():"""解析命令行参数。返回:argparse.Namespace: 包含命令行参数的命名空间。"""parser = argparse.ArgumentParser(description="训练 ResNet 模型",formatter_class=argparse.ArgumentDefaultsHelpFormatter)parser.add_argument('--pretrain_path', type=str, default='',help='预训练文件的路径')parser.add_argument('--data_path', type=str, default='datasets/drizzlezyk/cifar10/',help='训练数据的路径')parser.add_argument('--output_path', default='train/resnet/', type=str,help='模型保存路径')parser.add_argument('--epochs', default=10, type=int, help='训练周期数')parser.add_argument('--lr', default=0.0001, type=int, help='学习率')return parser.parse_args()

创建数据集

使用 MindSpore 的 create_dataset 方法,我们可以轻松创建和预处理 CIFAR-10 训练数据集。

from mindcv.data import create_dataset, create_transforms, create_loaderdef create_training_dataset(data_path, batch_size):"""创建训练数据集。参数:data_path (str): 数据集的路径。batch_size (int): 批量大小。返回:Tuple[DataLoader, int]: 数据加载器和每个 epoch 的批次数量。"""dataset_train = create_dataset(name='cifar10', root=data_path, split='train', shuffle=True)transform_train = create_transforms(dataset_name='cifar10', image_resize=224)train_loader = create_loader(dataset=dataset_train, batch_size=batch_size, is_training=True,num_classes=10, transform=transform_train)num_batches = train_loader.get_dataset_size()return train_loader, num_batches

模型训练

接下来,我们定义 train_model 函数来实现模型的训练逻辑。这包括模型的初始化、损失函数、优化器的设置,以及训练过程的启动。

from mindcv import create_model, create_loss, create_scheduler, create_optimizer
from mindspore.train import Model
from mindspore import load_checkpoint, load_param_into_netdef train_model(args):"""训练模型。参数:args (argparse.Namespace): 包含命令行参数的命名空间。"""train_loader, num_batches = create_training_dataset(args.data_path, batch_size=32)net = create_model(model_name='resnet50', num_classes=10)if args.pretrain_path:param_dict = load_checkpoint(args.pretrain_path)load_param_into_net(net, param_dict)loss_fn = create_loss(name='CE', reduction='mean')lr_scheduler = create_scheduler(steps_per_epoch=num_batches, scheduler='constant', lr=args.lr)optimizer = create_optimizer(net.trainable_params(), opt='adam', lr=lr_scheduler)model = Model(net, loss_fn=loss_fn, optimizer=optimizer, metrics={'accuracy'})checkpoint_config = CheckpointConfig(save_checkpoint_steps=num_batches, keep_checkpoint_max=10)checkpoint_callback = ModelCheckpoint(prefix='checkpoint_resnet', directory=args.output_path,config=checkpoint_config)model.train(args.epochs, train_loader,callbacks=[checkpoint_callback, LossMonitor(), TimeMonitor(data_size=num_batches)])

构建 Gradio 接口

预测函数

在 Gradio 接口中,我们定义一个 predict_image 函数来处理图像输入并返回预测结果。

import gradio as gr
import numpy as np
from mindspore import Tensor
import cv2def predict_image(img):# 创建模型实例net = create_model(model_name='resnet50', num_classes=NUM_CLASS)param_dict = load_checkpoint('/root/MyCode/pycharm/ResNet50/train/resnet/checkpoint_resnet-5_1563.ckpt')load_param_into_net(net, param_dict)# 封装模型为 Model 类实例model = Model(net)# 调整图像格式和大小img = cv2.resize(img, (224, 224))img = np.array(img, dtype=np.float32) / 255.0  # 归一化并确保数据类型为 Float32# 如果图像是 BGR 格式,转换为 RGB 格式# img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)# 标准化处理img = (img - np.array([0.485, 0.456, 0.406], dtype=np.float32)) / np.array([0.229, 0.224, 0.225], dtype=np.float32)# 转换维度 - 通道优先格式 (C, H, W)img = np.transpose(img, (2, 0, 1))# 添加批次维度 (N, C, H, W)img = np.expand_dims(img, axis=0)# 将图像数据转换为 MindSpore 张量img_tensor = Tensor(img, dtype=mindspore.float32)  # 显式指定数据类型# 预测图像output = model.predict(img_tensor)# 应用 Softmax 获取概率softmax = Softmax(axis=1)predict_probability = softmax(output).asnumpy()predict_probability = predict_probability[0]  # 获取批量中的第一个元素# 将预测概率映射到类别名称return {class_names[i]: float(predict_probability[i]) for i in range(NUM_CLASS)}

Gradio 界面

使用 Gradio,我们可以快速构建一个交互式界面。用户可以上传图片,模型将返回图像分类的预测结果。

image = gr.Image()
label = gr.Label(num_top_classes=NUM_CLASS)gr.Interface(css=".footer {display:none !important}",fn=predict_image,inputs=image,live=False,description="Please upload a image in JPG, JPEG or PNG.",title='Image Classification by ResNet50',outputs=gr.Label(num_top_classes=NUM_CLASS, label="预测类别"),examples=['./example_img/airplane.jpg', './example_img/automobile.jpg', './example_img/bird.jpg','./example_img/cat.jpg', './example_img/deer.jpg', './example_img/dog.jpg','./example_img/frog.jpg', './example_img/horse.JPG', './example_img/ship.jpg','./example_img/truck.jpg']).launch(share=True)

image-20231121192446268

完整代码

import argparsefrom mindcv import create_model, create_loss, create_scheduler, create_optimizer
from mindspore.train import Model
from mindspore import load_checkpoint, load_param_into_net
from mindcv.data import create_dataset, create_transforms, create_loader
from mindspore import LossMonitor, TimeMonitor, CheckpointConfig, ModelCheckpoint# 设置GPU
from mindspore import contextcontext.set_context(device_target="GPU")def parse_args():"""解析命令行参数。返回:argparse.Namespace: 包含命令行参数的命名空间。"""parser = argparse.ArgumentParser(description="训练 ResNet 模型",formatter_class=argparse.ArgumentDefaultsHelpFormatter)parser.add_argument('--pretrain_path', type=str, default='',help='预训练文件的路径')parser.add_argument('--data_path', type=str, default='datasets/drizzlezyk/cifar10/',help='训练数据的路径')parser.add_argument('--output_path', default='train/resnet/', type=str,help='模型保存路径')parser.add_argument('--epochs', default=10, type=int, help='训练周期数')parser.add_argument('--lr', default=0.0001, type=int, help='学习率')return parser.parse_args()def create_training_dataset(data_path, batch_size):"""创建训练数据集。参数:data_path (str): 数据集的路径。batch_size (int): 批量大小。返回:Tuple[DataLoader, int]: 数据加载器和每个 epoch 的批次数量。"""dataset_train = create_dataset(name='cifar10', root=data_path, split='train', shuffle=True)transform_train = create_transforms(dataset_name='cifar10', image_resize=224)train_loader = create_loader(dataset=dataset_train, batch_size=batch_size, is_training=True,num_classes=10, transform=transform_train)num_batches = train_loader.get_dataset_size()return train_loader, num_batchesdef train_model(args):"""训练模型。参数:args (argparse.Namespace): 包含命令行参数的命名空间。"""train_loader, num_batches = create_training_dataset(args.data_path, batch_size=32)net = create_model(model_name='resnet50', num_classes=10)if args.pretrain_path:param_dict = load_checkpoint(args.pretrain_path)load_param_into_net(net, param_dict)loss_fn = create_loss(name='CE', reduction='mean')lr_scheduler = create_scheduler(steps_per_epoch=num_batches, scheduler='constant', lr=args.lr)optimizer = create_optimizer(net.trainable_params(), opt='adam', lr=lr_scheduler)model = Model(net, loss_fn=loss_fn, optimizer=optimizer, metrics={'accuracy'})checkpoint_config = CheckpointConfig(save_checkpoint_steps=num_batches, keep_checkpoint_max=10)checkpoint_callback = ModelCheckpoint(prefix='checkpoint_resnet', directory=args.output_path,config=checkpoint_config)model.train(args.epochs, train_loader,callbacks=[checkpoint_callback, LossMonitor(), TimeMonitor(data_size=num_batches)])if __name__ == '__main__':train_model(parse_args())
import gradio as gr
import numpy as np
from mindspore import Tensor
from mindspore.nn import Softmax
import cv2
from typing import Type, Union, List, Optional
from mindspore import nn
from mindspore import load_checkpoint, load_param_into_net
from mindspore.train import Model
from mindcv.models import create_model
import mindsporeprint(mindspore.__version__)NUM_CLASS = 10
class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']def predict_image(img):# 创建模型实例net = create_model(model_name='resnet50', num_classes=NUM_CLASS)param_dict = load_checkpoint('/root/MyCode/pycharm/ResNet50/train/resnet/checkpoint_resnet-5_1563.ckpt')load_param_into_net(net, param_dict)# 封装模型为 Model 类实例model = Model(net)# 调整图像格式和大小img = cv2.resize(img, (224, 224))img = np.array(img, dtype=np.float32) / 255.0  # 归一化并确保数据类型为 Float32# 如果图像是 BGR 格式,转换为 RGB 格式# img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)# 标准化处理img = (img - np.array([0.485, 0.456, 0.406], dtype=np.float32)) / np.array([0.229, 0.224, 0.225], dtype=np.float32)# 转换维度 - 通道优先格式 (C, H, W)img = np.transpose(img, (2, 0, 1))# 添加批次维度 (N, C, H, W)img = np.expand_dims(img, axis=0)# 将图像数据转换为 MindSpore 张量img_tensor = Tensor(img, dtype=mindspore.float32)  # 显式指定数据类型# 预测图像output = model.predict(img_tensor)# 应用 Softmax 获取概率softmax = Softmax(axis=1)predict_probability = softmax(output).asnumpy()predict_probability = predict_probability[0]  # 获取批量中的第一个元素# 将预测概率映射到类别名称return {class_names[i]: float(predict_probability[i]) for i in range(NUM_CLASS)}image = gr.Image()
label = gr.Label(num_top_classes=NUM_CLASS)gr.Interface(css=".footer {display:none !important}",fn=predict_image,inputs=image,live=False,description="Please upload a image in JPG, JPEG or PNG.",title='Image Classification by ResNet50',outputs=gr.Label(num_top_classes=NUM_CLASS, label="预测类别"),examples=['./example_img/airplane.jpg', './example_img/automobile.jpg', './example_img/bird.jpg','./example_img/cat.jpg', './example_img/deer.jpg', './example_img/dog.jpg','./example_img/frog.jpg', './example_img/horse.JPG', './example_img/ship.jpg','./example_img/truck.jpg']).launch(share=True)

总结

通过 MindSpore 和 Gradio,我们可以不仅训练强大的深度学习模型,还可以将这些模型转化为交互式应用,使非专业人士也能轻松体验 AI 的魅力。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/155513.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

股票基础数据(二)

二. 股票基础数据 文章目录 二. 股票基础数据一. 查询股票融资信息数据二. 查询所有的股票信息三. 查询所有的股票类型信息四. 根据类型查询所有的股票数据信息五. 查询股票当前的基本信息六. 查询股票的K线图, 返回对应的 base64 信息七. 展示股票的K线图数据, 对应的是数据信…

大模型的实践应用7-阿里的多版本通义千问Qwen大模型的快速应用与部署

大家好,我是微学AI,今天给大家介绍一下大模型的实践应用7-阿里的多版本通义千问Qwen大模型的快速应用与部署。阿里云开源了Qwen系列模型,即Qwen-7B和Qwen-14B,以及Qwen的聊天模型,即Qwen-7B-Chat和Qwen-14B-Chat。通义千问模型针对多达 3 万亿个 token 的多语言数据进行了…

LLM之Prompt(二):清华提出Prompt 对齐优化技术BPO

论文题目:《Black-Box Prompt Optimization: Aligning Large Language Models without Model Training》 论文链接:https://arxiv.org/abs/2311.04155 github地址:https://github.com/thu-coai/BPO BPO背景介绍 最近,大型语言模…

米哈游大数据云原生实践

云布道师 近年来,容器、微服务、Kubernetes 等各项云原生技术的日渐成熟,越来越多的公司开始选择拥抱云原生,并将企业应用部署运行在云原生之上。随着米哈游业务的高速发展,大数据离线数据存储量和计算任务量增长迅速&#xff0c…

中大型企业网搭建(毕设类型)

毕业设计类别 某大学网络规划与部署 目录 某大学网络规划与部署 第一章项目概述 1.1 项目背景 1.2 网络需求分析 第二章网络总体设计方案 2.1 网络整体架构 2.2 网络设计思路 第三章 网络技术应用 3.1 DHCP 3.2 MSTP 3.3 VRRP 3.4 OSPF 3.5 VLAN 3.6 NAT 3.7 WLAN 3…

78基于matlab的BiLSTM分类算法,输出迭代曲线,测试集和训练集分类结果和混淆矩阵

基于matlab的BiLSTM分类算法,输出迭代曲线,测试集和训练集分类结果和混淆矩阵,程序有详细注释,数据可更换自己的,程序已调通,可直接运行。

全面解析IEC 60364三种接地系统的概念、特点及应用

根据IEC 60364规定的各种保护方式、术语概念,低压配电系统按接地方式的不同分为三类,即 TT 、 TN 和 IT 系统。 1.TT系统TT grounding system TT供电系统:是指将电气设备的金属外壳直接接地的保护系统,称为保护接地系统&#xff…

C#WPF用户控件及自定义控件实例

本文演示C#WPF自定义控件实例 用户控件(UserControl)和自定义控件(CustomControl)都是对UI控件的一种封装方式,目的都是实现封装后控件的重用。 只不过各自封装的实现方式和使用的场景上存在差异。 1 基于UserControl 创建 创建控件最简单一个方法就是基于UserControl …

如何使用API接口对接淘宝获取店铺销量排序,店铺名称等参数

要接入淘宝官方开放平台API接口获取店铺销量排序,店铺名称等参数,需要按照以下步骤进行操作: 找到可用的API接口:首先,需要找到支持查询店铺信息的API接口。可以在电商数据平台的开放平台上查找相应的API接口。注册并…

YOLOv8更换骨干网络HorNet:递归门控卷积的高效高阶空间交互——涨点神器!

🗝️YOLOv8实战宝典--星级指南:从入门到精通,您不可错过的技巧   -- 聚焦于YOLO的 最新版本, 对颈部网络改进、添加局部注意力、增加检测头部,实测涨点 💡 深入浅出YOLOv8:我的专业笔记与技术总结   -- YOLOv8轻松上手, 适用技术小白,文章代码齐全,仅需 …

【深度学习实验】注意力机制(四):点积注意力与缩放点积注意力之比较

文章目录 一、实验介绍二、实验环境1. 配置虚拟环境2. 库版本介绍 三、实验内容0. 理论介绍a. 认知神经学中的注意力b. 注意力机制 1. 注意力权重矩阵可视化(矩阵热图)2. 掩码Softmax 操作3. 打分函数——加性注意力模型3. 打分函数——点积注意力与缩放…

用户增长常见分析模型

一、用户增长是什么 用户增长基本上会涉及生意场上的各行各业,你开个店面希望有更多的客户光顾,你做了个APP希望有更多的用户经常使用,你搭建了个电商平台希望有更多的人下单买东西。 用户增长,即以提升用户LTV为目的&#xff08…

Self-Supervised Exploration via Disagreement论文笔记

通过分歧进行自我监督探索 0、问题 使用可微的ri直接去更新动作策略的参数的,那是不是就不需要去计算价值函数或者critic网络了? 1、Motivation 高效的探索是RL中长期存在的问题。以前的大多数方式要么陷入具有随机动力学的环境,要么效率…

应用在金银精炼控制系统中的Modbus转Profinet网关案例

应用在金银精炼控制系统中的Modbus转Profinet网关案例 Modbus转Profinet网关(XD-MDPN100)能够支持多种通信协议和接口,满足不同设备和系统的需求。在金银精炼控制系统中使用,通过控制PID阀门的大小,将1200plc与PID控制…

Git 远程仓库(Github)

目录 添加远程库 查看当前的远程库 提取远程仓库 推送到远程仓库 删除远程仓库 Git 并不像 SVN 那样有个中心服务器。 目前我们使用到的 Git 命令都是在本地执行,如果你想通过 Git 分享你的代码或者与其他开发人员合作。 你就需要将数据放到一台其他开发人员…

【paddlepaddle】

安装paddlepaddle 报错 ImportError: /home/ubuntu/miniconda3/envs/paddle_gan/bin/../lib/libstdc.so.6: version GLIBCXX_3.4.30 not found (required by /home/ubuntu/miniconda3/envs/paddle_gan/lib/python3.8/site-packages/paddle/fluid/libpaddle.so) 替换 /home/ubu…

【python】Python生成GIF动图,多张图片转动态图,pillow

pip install pillow 示例代码: from PIL import Image, ImageSequence# 图片文件名列表 image_files [car.png, detected_map.png, base64_image_out.png]# 打开图片 images [Image.open(filename) for filename in image_files]# 设置输出 GIF 文件名 output_g…

GAMES101—Lec 05~06:光栅化

目录 概念回顾(个人理解)光栅化1.采样2.采样出现的问题:走样 反走样 概念回顾(个人理解) 屏幕:在图形学中,我们认为屏幕是一个二维数组,数组里的每一个元素为一个二维像素。 光栅化…

【Operating Systems:Three Easy Pieces 操作系统导论 】第28章 插叙:线程 API

【Operating Systems:Three Easy Pieces 操作系统导论 】 第28章 插叙&#xff1a;线程 API pthread 库介绍 线程创建 #include <pthread.h> // 头文件 int pthread_create(pthread_t * thread,const pthread_attr_t * attr,void * (*start_routine)(void*),void *…

【数据结构(四)】栈(1)

文章目录 1. 关于栈的一个实际应用2. 栈的介绍3. 栈的应用场景4. 栈的简单应用4.1. 思路分析4.2. 代码实现 5. 栈的进阶应用(实现综合计算器)5.1. 栈实现一位数计算(中缀表达式)5.1.1. 思路分析5.1.2. 代码实现 5.2. 栈实现多位数计算(中缀表达式)5.2.1. 解决思路5.2.2. 代码实…