【深度学习】pytorch,MNIST手写数字分类

efficientnet_b0的迁移学习


import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
from torchvision import models
import matplotlib.pyplot as plt# 定义超参数
batch_size = 240
learning_rate = 0.001
num_epochs = 10# 数据预处理,包括调整图像大小并将单通道图像复制到三个通道
transform = transforms.Compose([transforms.Resize(224),  # 调整图像大小以适应EfficientNetB0transforms.Grayscale(num_output_channels=3),  # 将单通道图像复制到三个通道transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])  # 使用ImageNet的均值和标准差
])# 加载数据集
train_dataset = MNIST(root='./data', train=True, transform=transform, download=True)
test_dataset = MNIST(root='./data', train=False, transform=transform, download=True)# 创建数据加载器
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, num_workers=32)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False, num_workers=32)device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 加载预训练的EfficientNetB0模型并调整最后的分类层
model = models.efficientnet_b0(pretrained=True)
model.classifier[1] = nn.Linear(model.classifier[1].in_features, 10)  # MNIST共10个类别
model.to(device)# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)# 用于绘图的数据
train_losses = []
test_accuracies = []# 训练模型
for epoch in range(num_epochs):model.train()running_loss = 0.0for batch_idx, (data, target) in enumerate(train_loader):optimizer.zero_grad()data, target = data.to(device), target.to(device)output = model(data)loss = criterion(output, target)loss.backward()optimizer.step()running_loss += loss.item()print(f"\rEpoch {epoch + 1}/{num_epochs}, Batch {batch_idx + 1}/{len(train_loader)}, Loss: {loss.item():.4f}")# 计算平均损失avg_loss = running_loss / len(train_loader)train_losses.append(avg_loss)# 测试准确率model.eval()correct = 0total = 0with torch.no_grad():for data, target in test_loader:data, target = data.to(device), target.to(device)  # Move test data to the correct deviceoutput = model(data)_, predicted = torch.max(output.data, 1)total += target.size(0)correct += (predicted == target).sum().item()accuracy = 100 * correct / totaltest_accuracies.append(accuracy)print(f'Epoch {epoch + 1}/{num_epochs}, Loss: {avg_loss:.4f}, Test Accuracy: {accuracy:.2f}%')# save
torch.save(model.state_dict(), 'mnist_efficientnetb0.pth')# 绘制损失函数和准确率图
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Training Loss')
plt.title('Training Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()plt.subplot(1, 2, 2)
plt.plot(test_accuracies, label='Test Accuracy')
plt.title('Test Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.legend()plt.show()

训练10轮,测试准确率很猛:

Epoch 10/10, Loss: 0.0087, Test Accuracy: 99.60%

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/768478.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

IntelliJ IDE 插件开发 | (七)PSI 入门及实战(实现 MyBatis 插件的跳转功能)

系列文章 IntelliJ IDE 插件开发 |(一)快速入门IntelliJ IDE 插件开发 |(二)UI 界面与数据持久化IntelliJ IDE 插件开发 |(三)消息通知与事件监听IntelliJ IDE 插件开发 |(四)来查收…

【mybatis】TypeHandler解读

在谈论MyBatis的源码时,TypeHandler 是其中一个非常关键的组成部分,它负责Java类型和JDBC类型之间的相互转换。理解TypeHandler的工作原理,对于深入理解MyBatis的数据处理流程十分重要。 什么是TypeHandler? 在MyBatis中,TypeH…

android Fragment 生命周期 方法调用顺序

文章目录 Introlog 及结论代码 Intro 界面设计:点击左侧按钮,会将右侧 青色的RightFragment 替换成 黄色的AnotherRightFragment,而这两个 Fragment 的生命周期方法都会打印日志。 所以只要看执行结果中的日志,就可以知道 Fragme…

【单例测试】Mockito实战

目录 一、项目介绍二、业务代码2.1 导入依赖2.2 entity2.3 Dao2.4 业务代码 三、单元测试3.1 生成Test方法3.2 引入测试类3. 3 测试前准备3.4 测试3.4.1 name和phone参数校验3.4.2 测试数据库访问 3.4.3 数据库反例 总结 前面我们提到了《【单元测试】一文读懂java单元测试》 简…

IDEA Android新建项目基础

title: IDEA Android基础开发 search: 2024-03-16 tags: “#JavaAndroid开发” 一、构建基本项目 在使用 IDEA 进行基础的Android 开发时,我们可以通过IDEA自带的新建项目功能进行Android应用开发基础架构的搭建,可以直接找到 File --> New --> …

vue的history路由实现形式

vue的路由实现形式 SPA single page web application,单页Web应用 简单的说SPA就是一个WEB项目只有一个HTML页面,一旦页面加载完成,SPA不会因为用户的操作而进行页面的重新加载和跳转。取而代之的是利用JS动态的改变HTML的内容&#xff0c…

代码随想录算法训练营day19 | 二叉树阶段性总结

各个部分题目的代码题解都在我往日的二叉树的博客中。 (day14到day22) 目录 二叉树理论基础二叉树的遍历方式深度优先遍历广度优先遍历 求二叉树的属性二叉树的修改与制造求二叉搜索树的属性二叉树公共最先问题二叉搜索树的修改与构造总结 二叉树理论基础 二叉树的理论基础参…

基于nodejs+vue学生作业管理系统python-flask-django-php

他们不仅希望页面简单大方,还希望操作方便,可以快速锁定他们需要的线上管理方式。基于这种情况,我们需要这样一个界面简单大方、功能齐全的系统来解决用户问题,满足用户需求。 课题主要分为三大模块:即管理员模块和学生…

平台介绍-搭建赛事运营平台(1)

平台的一个很重要的市场方向是为企业搭建各类运营平台。运营平台是这类企业的核心系统,例如对银行而言就是柜台系统,对于电商而言就是电子商城。运营平台和内部信息平台的显著区别是要面向外部C端客户。内部信息平台的受众只是企业内部人员。 最近签约开…

HAL STM32G4 +ADC手动触发采集+各种滤波算法实现

HAL STM32G4 ADC手动触发采集各种滤波算法实现 📍相关篇《HAL STM32G4 TIM1 3路PWM互补输出VOFA波形演示》 ✨本篇内容也是继欧拉电子相关无刷电机驱动控制学习的相关基础内容。仅作为个人笔记记录使用。 📍感谢网友提供的相关内容《基于STM32的ADC采样及…

上位机图像处理和嵌入式模块部署(qmacvisual轮廓查找)

【 声明:版权所有,欢迎转载,请勿用于商业用途。 联系信箱:feixiaoxing 163.com】 前面我们说过,图像的处理流程一般都是这样的,即灰度化-》降噪-》边缘检测-》二值化-》开闭运算-》轮廓检测。虽然前面的几个…

LeetCode 面试经典150题 14.最长公共前缀

题目: 编写一个函数来查找字符串数组中的最长公共前缀。 如果不存在公共前缀,返回空字符串 ""。 思路: 代码: class Solution {public String longestCommonPrefix(String[] strs) {if (strs.length 0) {return &…

知攻善防应急靶场-Linux(2)

前言: 堕落了三个月,现在因为被找实习而困扰,着实自己能力不足,从今天开始 每天沉淀一点点 ,准备秋招 加油 注意: 本文章参考qax的网络安全应急响应和知攻善防实验室靶场,记录自己的学习过程&am…

python绘图matplotlib——使用记录1

本博文来自于网络收集,如有侵权请联系删除 使用matplotlib绘图 1 常用函数汇总1.1 plot1.2 legend1.3 scatter1.4 xlim1.5 xlabel1.6 grid1.7 axhline1.7 axvspan1.8 annotate1.9 text1.10 title 2 常见图形绘制2.1 bar——柱状图2.2 barh——条形图2.3 hist——直…

flutter3_douyin:基于flutter3+dart3短视频直播实例|Flutter3.x仿抖音

flutter3-dylive 跨平台仿抖音短视频直播app实战项目。 全新原创基于flutter3.19.2dart3.3.0getx等技术开发仿抖音app实战项目。实现了类似抖音整屏丝滑式上下滑动视频、左右滑动切换页面模块,直播间进场/礼物动效,聊天等模块。 运用技术 编辑器&#x…

git标签的简单操作

创建标签 git tag v1.0 # 对head指向的commit创建标签 git tag v1.1 commit_id # 对指定的commit创建标签 git tag v2.0 -a -m "标签注释" commit_id # 创建注释标签查看标签 git tag -l v1* # 查看标签,匹配v1开头的 git show v2.0 # 查看标签详细信息…

Qt如何重写closeEvent

在 Qt 中,重写 closeEvent 函数是处理窗口关闭事件的一种方式。当你关闭一个 Qt 窗口时,该窗口会接收到一个 QCloseEvent 对象。通过重写窗口类的 closeEvent 函数,你可以自定义窗口关闭时的行为。 下面是一个简单的例子,展示了如…

Netty剖析 - Why Netty

文章目录 Why NettyI/O 请求的两个阶段I/O 模型Netty 如何实现自己的 I/O 模型线程模型 - 事件分发器(Event Dispather)弥补 Java NIO 的缺陷更低的资源消耗网络框架的选型Netty 发展现状Netty 的使用 Why Netty I/O 模型、线程模型和事件处理机制优化&a…

php搭建websocket

1.项目终端执行命令:composer require topthink/think-worker 2.0.x 2.config多出三个配置文件: 3.当使用php think worker:gateway命令时,提示不支持Windows。 4.打包项目为zip格式 5.打包数据库 6.阿里云创建记录 7.宝塔面板新增站点…

Vue3 上手笔记

1. Vue3简介 2020年9月18日,Vue.js发布版3.0版本,代号:One Piece(n 经历了:4800次提交、40个RFC、600次PR、300贡献者 官方发版地址:Release v3.0.0 One Piece vuejs/core 截止2023年10月,最…