基于flask的猫狗图像预测案例

📚博客主页:knighthood2001
公众号:认知up吧 (目前正在带领大家一起提升认知,感兴趣可以来围观一下)
🎃知识星球:【认知up吧|成长|副业】介绍
❤️如遇文章付费,可先看看我公众号中是否发布免费文章❤️
🙏笔者水平有限,欢迎各位大佬指点,相互学习进步!

假设,你有模型,有训练好的模型文件,有模型推理代码,就可以把他放到flask上进行展示。

项目架构

在这里插入图片描述

  • index.html是模板文件
  • app.py是项目运行的入口
  • best_model.pth是训练好的模型参数
  • model.py是神经网络模型,这里采用的是GoogleNet网络。
  • model_reasoning.py是模型推理,通过这里面的代码,我们可以在本地进行猫狗图片的预测。

运行图

在这里插入图片描述

点击选择文件
在这里插入图片描述
图片下面就显示预测结果了。
在这里插入图片描述

项目完整代码与讲解

index.html

<!DOCTYPE html>
<html lang="en">
<head><meta charset="UTF-8"><title>图像分类</title><style>body {font-family: Arial, sans-serif;margin: 20px;}#result {margin-top: 10px;}#preview-image {max-width: 400px;margin-top: 20px;}</style>
</head>
<body><h1>图像分类</h1><form id="upload-form" action="/predict" method="post" enctype="multipart/form-data"><input type="file" name="file" accept="image/*" onchange="previewImage(event)"><input type="submit" value="预测"></form><img id="preview-image" src="" alt=""><br><div id="result"></div><script>document.getElementById('upload-form').addEventListener('submit', async (e) => {e.preventDefault();  // 阻止默认的表单提交行为const formData = new FormData(); // 创建一个新的FormData对象,用于封装表单数据formData.append('file', document.querySelector('input[type=file]').files[0]);  // 添加表单数据// 使用fetch API发送POST请求到'/predict'路径,并将formData作为请求体const response = await fetch('/predict', {method: 'POST',body: formData});// 获取响应的JSON数据const result = await response.json();// 将预测结果显示在页面上ID为'result'的元素中document.getElementById('result').innerText = `预测结果: ${result.prediction}`;});function previewImage(event) {const file = event.target.files[0];  // 获取上传的文件对象const reader = new FileReader();  // 创建一个FileReader对象,用于读取文件内容// 清空上一次的预测结果document.getElementById('result').innerText = '';// 当文件读取完成后,将文件内容显示在页面上ID为'preview-image'的元素中reader.onload = function(event) {document.getElementById('preview-image').setAttribute('src', event.target.result);}// 如果用户选择了文件,则开始读取文件内容if (file) {reader.readAsDataURL(file); // 将文件读取为DataURL格式,这样可以直接用作img元素的src属性}}</script>
</body>
</html>

前端我练的不多,很多解释已经在代码中讲了。

model.py

这是GoogleNet的网络架构

import torch
from torch import nn
from torchsummary import summary
# 定义一个Inception模块
class Inception(nn.Module):def __init__(self, in_channels, c1, c2, c3, c4):  # 这些参数,所在的位置都会发送变化,所有需要这个参数super(Inception, self).__init__()self.ReLU = nn.ReLU()# 路线1,单1×1卷积层self.p1_1 = nn.Conv2d(in_channels=in_channels, out_channels=c1, kernel_size=1)# 路线2,1×1卷积层, 3×3的卷积self.p2_1 = nn.Conv2d(in_channels=in_channels, out_channels=c2[0], kernel_size=1)self.p2_2 = nn.Conv2d(in_channels=c2[0], out_channels=c2[1], kernel_size=3, padding=1)# 路线3,1×1卷积层, 5×5的卷积self.p3_1 = nn.Conv2d(in_channels=in_channels, out_channels=c3[0], kernel_size=1)self.p3_2 = nn.Conv2d(in_channels=c3[0], out_channels=c3[1], kernel_size=5, padding=2)# 路线4,3×3的最大池化, 1×1的卷积self.p4_1 = nn.MaxPool2d(kernel_size=3, padding=1, stride=1)self.p4_2 = nn.Conv2d(in_channels=in_channels, out_channels=c4, kernel_size=1)def forward(self, x):p1 = self.ReLU(self.p1_1(x))p2 = self.ReLU(self.p2_2(self.ReLU(self.p2_1(x))))p3 = self.ReLU(self.p3_2(self.ReLU(self.p3_1(x))))p4 = self.ReLU(self.p4_2(self.p4_1(x)))return torch.cat((p1, p2, p3, p4), dim=1)class GoogLeNet(nn.Module):def __init__(self, Inception, in_channels, out_channels):super(GoogLeNet, self).__init__()self.b1 = nn.Sequential(nn.Conv2d(in_channels=in_channels, out_channels=64, kernel_size=7, stride=2, padding=3),nn.ReLU(),nn.MaxPool2d(kernel_size=3, stride=2, padding=1))self.b2 = nn.Sequential(nn.Conv2d(in_channels=64, out_channels=64, kernel_size=1),nn.ReLU(),nn.Conv2d(in_channels=64, out_channels=192, kernel_size=3, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=3, stride=2, padding=1))self.b3 = nn.Sequential(Inception(192, 64, (96, 128), (16, 32), 32),Inception(256, 128, (128, 192), (32, 96), 64),nn.MaxPool2d(kernel_size=3, stride=2, padding=1))self.b4 = nn.Sequential(Inception(480, 192, (96, 208), (16, 48), 64),Inception(512, 160, (112, 224), (24, 64), 64),Inception(512, 128, (128, 256), (24, 64), 64),Inception(512, 112, (128, 288), (32, 64), 64),Inception(528, 256, (160, 320), (32, 128), 128),nn.MaxPool2d(kernel_size=3, stride=2, padding=1))self.b5 = nn.Sequential(Inception(832, 256, (160, 320), (32, 128), 128),Inception(832, 384, (192, 384), (48, 128), 128),nn.AdaptiveAvgPool2d((1, 1)),nn.Flatten(),nn.Linear(1024, out_channels))for m in self.modules():if isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity='relu')if m.bias is not None:nn.init.constant_(m.bias, 0)elif isinstance(m, nn.Linear):nn.init.normal_(m.weight, 0, 0.01)if m.bias is not None:nn.init.constant_(m.bias, 0)def forward(self, x):x = self.b1(x)x = self.b2(x)x = self.b3(x)x = self.b4(x)x = self.b5(x)return xif __name__ == "__main__":device = torch.device("cuda" if torch.cuda.is_available() else "cpu")model = GoogLeNet(Inception, 1, 10).to(device)print(summary(model, (1, 224, 224)))

model_reasoning.py

import torch
from torchvision import transforms
from model import GoogLeNet, Inception
from PIL import Imagedef test_model(model, test_file):# 设定测试所用到的设备,有GPU用GPU没有GPU用CPUdevice = "cuda" if torch.cuda.is_available() else 'cpu'model = model.to(device)classes = ['猫', '狗']print(classes)image = Image.open(test_file)# normalize = transforms.Normalize([0.162, 0.151, 0.138], [0.058, 0.052, 0.048])# # 定义数据集处理方法变量# test_transform = transforms.Compose([transforms.Resize((224, 224)), transforms.ToTensor(), normalize])# 定义数据集处理方法变量test_transform = transforms.Compose([transforms.Resize((224, 224)), transforms.ToTensor()])image = test_transform(image)# 添加批次维度,变成[1,3,224,224]image = image.unsqueeze(0)with torch.no_grad():model.eval()image = image.to(device)  # 图片也要放到设备当中output = model(image)print(output.tolist())pre_lab = torch.argmax(output, dim=1)result = pre_lab.item()print("预测值:", classes[result])return classes[result]def test_special_model(best_model_file, test_file):# 加载模型model = GoogLeNet(Inception, in_channels=3, out_channels=2)model.load_state_dict(torch.load(best_model_file))# 模型的推理判断return test_model(model, test_file)if __name__ == "__main__":# # 加载模型# model = GoogLeNet(Inception, in_channels=3, out_channels=2)# model.load_state_dict(torch.load('best_model.pth'))# # 模型的推理判断# test_model(model, "test_data/images.jfif")test_special_model("best_model.pth", "static/1.jpg")

这段代码与之前的模型推理代码不同的是,我添加了test_special_model函数,方便后续app.py中可以直接调用这个函数进行模型推理。

app.py

import os
from flask import Flask, request, jsonify, render_templatefrom model_reasoning import test_special_model
from model_reasoning import test_model
app = Flask(__name__)# 定义路由
@app.route('/')
def index():return render_template('index.html')@app.route('/predict', methods=['POST'])
def predict():if request.method == 'POST':# 获取上传的文件file = request.files['file']if file:# 调用模型进行预测# # 加载模型# model = GoogLeNet(Inception, in_channels=3, out_channels=2)# basedir = os.path.abspath(os.path.dirname(__file__))## model.load_state_dict(torch.load(basedir + '/best_model.pth'))# result = test_model(model, file)basedir = os.path.abspath(os.path.dirname(__file__))best_model_file = basedir + '/best_model.pth'result = test_special_model(best_model_file, file)return jsonify({'prediction': result})else:return jsonify({'error': 'No file found'})if __name__ == '__main__':app.run(debug=True)

如果没有上文中的test_special_model函数,那么这里你就需要

   # 加载模型model = GoogLeNet(Inception, in_channels=3, out_channels=2)basedir = os.path.abspath(os.path.dirname(__file__))model.load_state_dict(torch.load(basedir + '/best_model.pth'))result = test_model(model, file)

并且还需要导入相应的库。

best_model.pth

最重要的是,你需要训练好的一个模型。

有需要的,可以联系我,我直接把这个项目代码发你。省得你还需要配置项目架构。

小插曲

我为什么会使用绝对路径,因为我在使用相对路径后,代码提示找不到这个路径。

    basedir = os.path.abspath(os.path.dirname(__file__))best_model_file = basedir + '/best_model.pth'

然后,我刚刚又试了一下,发现使用相对路径,又可以运行成功了。

真是不可思议(这个小插曲花了我大半个小时)。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/42036.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

二次元转向SLG,B站游戏的破圈之困

文 | 螳螂观察 作者 | 夏至 2023年是B站游戏的滑铁卢&#xff0c;尽管这年B站的游戏营收还有40多亿&#xff0c;但相比去年大幅下降了20%&#xff0c;整整少了10亿&#xff0c;这是过去5年来的最大跌幅&#xff0c;也是陈睿接管B站游戏业务一年以来&#xff0c;在鼻子上碰的第…

鸿蒙语言基础类库:【@ohos.process (获取进程相关的信息)】

获取进程相关的信息 说明&#xff1a; 本模块首批接口从API version 7开始支持。后续版本的新增接口&#xff0c;采用上角标单独标记接口的起始版本。开发前请熟悉鸿蒙开发指导文档&#xff1a;gitee.com/li-shizhen-skin/harmony-os/blob/master/README.md点击或者复制转到。…

昇思13天

ResNet50迁移学习 ResNet50迁移学习总结 背景介绍 在实际应用场景中&#xff0c;由于训练数据集不足&#xff0c;很少有人会从头开始训练整个网络。普遍做法是使用在大数据集上预训练得到的模型&#xff0c;然后将该模型的权重参数用于特定任务中。本章使用迁移学习方法对Im…

imx6ull/linux应用编程学习(13) CMAKE

什么是cmake&#xff1f; cmake 工具通过解析 CMakeLists.txt 自动帮我们生成 Makefile&#xff0c;可以实现跨平台的编译。cmake 就是用来产生 Makefile 的工具&#xff0c;解析 CMakeLists.txt 自动生成 Makefile&#xff1a; cmake 的使用方法 cmake 就是一个工具命令&am…

怎么将aac文件弄成mp3格式?把aac改成MP3格式的四种方法

怎么将aac文件弄成mp3格式&#xff1f;手头有一些aac格式的音频文件&#xff0c;但由于某些设备或软件不支持这种格式&#xff0c;你希望将它们转换成更为通用的MP3格式。而且音频格式的转换在现在已经是一个常见且必要的操作。aac是一种相对较新的音频编码格式&#xff0c;通常…

大模型增量预训练新技巧-解决灾难性遗忘

大模型增量预训练新技巧-解决灾难性遗忘 机器学习算法与自然语言处理 2024年03月21日 00:02 吉林 以下文章来源于NLP工作站 &#xff0c;作者刘聪NLP NLP工作站. AIGC前沿知识分享&落地经验总结 转载自 | NLP工作站 作者 | 刘聪NLP 目前不少开源模型在通用领域具有不错…

el-scrollbar实现自动滚动到底部(AI聊天)

目录 项目背景 实现步骤 实现代码 完整示例代码 项目背景 chatGPT聊天消息展示滚动面板&#xff0c;每次用户输入提问内容或者ai进行流式回答时需要不断的滚动到底部确保展示最新的消息。 实现步骤 采用element ui 的el-scrollbar作为聊天消息展示组件。 通过操作dom来实…

理解算法复杂度:空间复杂度详解

引言 在计算机科学中&#xff0c;算法复杂度是衡量算法效率的重要指标。时间复杂度和空间复杂度是算法复杂度的两个主要方面。在这篇博客中&#xff0c;我们将深入探讨空间复杂度&#xff0c;了解其定义、常见类型以及如何进行分析。空间复杂度是衡量算法在执行过程中所需内存…

昇思25天学习打卡营第19天|Diffusion扩散模型

学AI还能赢奖品&#xff1f;每天30分钟&#xff0c;25天打通AI任督二脉 (qq.com) Diffusion扩散模型 本文基于Hugging Face&#xff1a;The Annotated Diffusion Model一文翻译迁移而来&#xff0c;同时参考了由浅入深了解Diffusion Model一文。 本教程在Jupyter Notebook上成…

昇思MindSpore学习笔记5-02生成式--RNN实现情感分类

摘要&#xff1a; 记录MindSpore AI框架使用RNN网络对自然语言进行情感分类的过程、步骤和方法。 包括环境准备、下载数据集、数据集加载和预处理、构建模型、模型训练、模型测试等。 一、概念 情感分类。 RNN网络模型 实现效果&#xff1a; 输入: This film is terrible 正…

放大镜案例

放大镜 <!DOCTYPE html> <html lang"zh-cn"><head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scale1.0"><title>商品放大镜</title><link rel&qu…

如何使用allure生成测试报告

第一步下载安装JDK1.8&#xff0c;参考链接JDK1.8下载、安装和环境配置教程-CSDN博客 第二步配置allure环境&#xff0c;参考链接allure的安装和使用(windows环境)_allure windows-CSDN博客 第三步&#xff1a; 第四步&#xff1a; pytest 查看目前运行的测试用例有无错误 …

如何使用 pytorch 创建一个神经网络

我已发布在&#xff1a;如何使用 pytorch 创建一个神经网络 SapientialM.Github.io 构建神经网络 1 导入所需包 import os import torch from torch import nn from torch.utils.data import DataLoader from torchvision import datasets, transforms2 检查GPU是否可用 dev…

Yolov10训练,转化onnx,推理

yolov10对于大目标的效果好&#xff0c;小目标不好 一、如果你训练过yolov5&#xff0c;yolov8&#xff0c;的话那么你可以直接用之前的环境就行 目录 一、如果你训练过yolov5&#xff0c;yolov8&#xff0c;的话那么你可以直接用之前的环境就行 二、配置好后就可以配置文件…

前端JS特效第24集:jquery css3实现瀑布流照片墙特效

jquery css3实现瀑布流照片墙特效&#xff0c;先来看看效果&#xff1a; 部分核心的代码如下(全部代码在文章末尾)&#xff1a; <!DOCTYPE html> <html lang"en"> <head> <meta charset"UTF-8" /> <title>jquerycss3实现瀑…

Nginx:负载均衡小专题

运维专题 Nginx&#xff1a;负载均衡小专题 - 文章信息 - Author: 李俊才 (jcLee95) Visit me at CSDN: https://jclee95.blog.csdn.netMy WebSite&#xff1a;http://thispage.tech/Email: 291148484163.com. Shenzhen ChinaAddress of this article:https://blog.csdn.net/…

【专项刷题】— 位运算

常见类型介绍&#xff1a; & &#xff1a;有 0 就是 0 | &#xff1a;有 1 就是 1 ^ &#xff1a;相同为 0 &#xff0c;相异为 1 或者 无进位相加给定一个数确定它的二进制位的第x个数是0还是1&#xff1a;将一个数的二进制的第x位改成1&#xff1a;将一个数的二进制的第x…

Windows10/11家庭版开启Hyper-V虚拟机功能详解

Hyper-V是微软的一款虚拟机软件&#xff0c;可以使我们在一台Windows PC上&#xff0c;在虚拟环境下同时运行多个互相之间完全隔离的操作系统&#xff0c;这就实现了在Windows环境下运行Linux以及其他OS的可能性。和第三方虚拟机软件&#xff0c;如VMware等相比&#xff0c;Hyp…

大模型知识问答: 文本分块要点总结

节前&#xff0c;我们组织了一场算法岗技术&面试讨论会&#xff0c;邀请了一些互联网大厂朋友、今年参加社招和校招面试的同学。 针对大模型技术趋势、算法项目落地经验分享、新手如何入门算法岗、该如何准备面试攻略、面试常考点等热门话题进行了深入的讨论。 总结链接如…

C++ 信号量和锁的区别

网上关于信号量和锁的区别&#xff0c;写的比较官方晦涩难懂&#xff0c;对于这个知识点吸收难&#xff0c;通过示例&#xff0c;我们看到信号量&#xff0c;可以控制同一时刻的线程数量&#xff0c;就算同时开启很多线程&#xff0c;依然可以的达到线程数可控 #include <i…