【深度学习】pytorch,MNIST手写数字分类

efficientnet_b0的迁移学习


import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
from torchvision import models
import matplotlib.pyplot as plt# 定义超参数
batch_size = 240
learning_rate = 0.001
num_epochs = 10# 数据预处理,包括调整图像大小并将单通道图像复制到三个通道
transform = transforms.Compose([transforms.Resize(224),  # 调整图像大小以适应EfficientNetB0transforms.Grayscale(num_output_channels=3),  # 将单通道图像复制到三个通道transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])  # 使用ImageNet的均值和标准差
])# 加载数据集
train_dataset = MNIST(root='./data', train=True, transform=transform, download=True)
test_dataset = MNIST(root='./data', train=False, transform=transform, download=True)# 创建数据加载器
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, num_workers=32)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False, num_workers=32)device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 加载预训练的EfficientNetB0模型并调整最后的分类层
model = models.efficientnet_b0(pretrained=True)
model.classifier[1] = nn.Linear(model.classifier[1].in_features, 10)  # MNIST共10个类别
model.to(device)# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)# 用于绘图的数据
train_losses = []
test_accuracies = []# 训练模型
for epoch in range(num_epochs):model.train()running_loss = 0.0for batch_idx, (data, target) in enumerate(train_loader):optimizer.zero_grad()data, target = data.to(device), target.to(device)output = model(data)loss = criterion(output, target)loss.backward()optimizer.step()running_loss += loss.item()print(f"\rEpoch {epoch + 1}/{num_epochs}, Batch {batch_idx + 1}/{len(train_loader)}, Loss: {loss.item():.4f}")# 计算平均损失avg_loss = running_loss / len(train_loader)train_losses.append(avg_loss)# 测试准确率model.eval()correct = 0total = 0with torch.no_grad():for data, target in test_loader:data, target = data.to(device), target.to(device)  # Move test data to the correct deviceoutput = model(data)_, predicted = torch.max(output.data, 1)total += target.size(0)correct += (predicted == target).sum().item()accuracy = 100 * correct / totaltest_accuracies.append(accuracy)print(f'Epoch {epoch + 1}/{num_epochs}, Loss: {avg_loss:.4f}, Test Accuracy: {accuracy:.2f}%')# save
torch.save(model.state_dict(), 'mnist_efficientnetb0.pth')# 绘制损失函数和准确率图
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Training Loss')
plt.title('Training Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()plt.subplot(1, 2, 2)
plt.plot(test_accuracies, label='Test Accuracy')
plt.title('Test Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.legend()plt.show()

训练10轮,测试准确率很猛:

Epoch 10/10, Loss: 0.0087, Test Accuracy: 99.60%

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/768478.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

IntelliJ IDE 插件开发 | (七)PSI 入门及实战(实现 MyBatis 插件的跳转功能)

系列文章 IntelliJ IDE 插件开发 |(一)快速入门IntelliJ IDE 插件开发 |(二)UI 界面与数据持久化IntelliJ IDE 插件开发 |(三)消息通知与事件监听IntelliJ IDE 插件开发 |(四)来查收…

android Fragment 生命周期 方法调用顺序

文章目录 Introlog 及结论代码 Intro 界面设计:点击左侧按钮,会将右侧 青色的RightFragment 替换成 黄色的AnotherRightFragment,而这两个 Fragment 的生命周期方法都会打印日志。 所以只要看执行结果中的日志,就可以知道 Fragme…

【单例测试】Mockito实战

目录 一、项目介绍二、业务代码2.1 导入依赖2.2 entity2.3 Dao2.4 业务代码 三、单元测试3.1 生成Test方法3.2 引入测试类3. 3 测试前准备3.4 测试3.4.1 name和phone参数校验3.4.2 测试数据库访问 3.4.3 数据库反例 总结 前面我们提到了《【单元测试】一文读懂java单元测试》 简…

IDEA Android新建项目基础

title: IDEA Android基础开发 search: 2024-03-16 tags: “#JavaAndroid开发” 一、构建基本项目 在使用 IDEA 进行基础的Android 开发时,我们可以通过IDEA自带的新建项目功能进行Android应用开发基础架构的搭建,可以直接找到 File --> New --> …

基于nodejs+vue学生作业管理系统python-flask-django-php

他们不仅希望页面简单大方,还希望操作方便,可以快速锁定他们需要的线上管理方式。基于这种情况,我们需要这样一个界面简单大方、功能齐全的系统来解决用户问题,满足用户需求。 课题主要分为三大模块:即管理员模块和学生…

上位机图像处理和嵌入式模块部署(qmacvisual轮廓查找)

【 声明:版权所有,欢迎转载,请勿用于商业用途。 联系信箱:feixiaoxing 163.com】 前面我们说过,图像的处理流程一般都是这样的,即灰度化-》降噪-》边缘检测-》二值化-》开闭运算-》轮廓检测。虽然前面的几个…

LeetCode 面试经典150题 14.最长公共前缀

题目: 编写一个函数来查找字符串数组中的最长公共前缀。 如果不存在公共前缀,返回空字符串 ""。 思路: 代码: class Solution {public String longestCommonPrefix(String[] strs) {if (strs.length 0) {return &…

知攻善防应急靶场-Linux(2)

前言: 堕落了三个月,现在因为被找实习而困扰,着实自己能力不足,从今天开始 每天沉淀一点点 ,准备秋招 加油 注意: 本文章参考qax的网络安全应急响应和知攻善防实验室靶场,记录自己的学习过程&am…

python绘图matplotlib——使用记录1

本博文来自于网络收集,如有侵权请联系删除 使用matplotlib绘图 1 常用函数汇总1.1 plot1.2 legend1.3 scatter1.4 xlim1.5 xlabel1.6 grid1.7 axhline1.7 axvspan1.8 annotate1.9 text1.10 title 2 常见图形绘制2.1 bar——柱状图2.2 barh——条形图2.3 hist——直…

flutter3_douyin:基于flutter3+dart3短视频直播实例|Flutter3.x仿抖音

flutter3-dylive 跨平台仿抖音短视频直播app实战项目。 全新原创基于flutter3.19.2dart3.3.0getx等技术开发仿抖音app实战项目。实现了类似抖音整屏丝滑式上下滑动视频、左右滑动切换页面模块,直播间进场/礼物动效,聊天等模块。 运用技术 编辑器&#x…

Netty剖析 - Why Netty

文章目录 Why NettyI/O 请求的两个阶段I/O 模型Netty 如何实现自己的 I/O 模型线程模型 - 事件分发器(Event Dispather)弥补 Java NIO 的缺陷更低的资源消耗网络框架的选型Netty 发展现状Netty 的使用 Why Netty I/O 模型、线程模型和事件处理机制优化&a…

php搭建websocket

1.项目终端执行命令:composer require topthink/think-worker 2.0.x 2.config多出三个配置文件: 3.当使用php think worker:gateway命令时,提示不支持Windows。 4.打包项目为zip格式 5.打包数据库 6.阿里云创建记录 7.宝塔面板新增站点…

Vue3 上手笔记

1. Vue3简介 2020年9月18日,Vue.js发布版3.0版本,代号:One Piece(n 经历了:4800次提交、40个RFC、600次PR、300贡献者 官方发版地址:Release v3.0.0 One Piece vuejs/core 截止2023年10月,最…

网盘——数据库操作

关于网盘的数据库模块,主要有以下几个内容:定义数据库操作类、将数据库操作类定义成单例模式、数据库操作 数据库是在Qt里面,定义成操作类,专门用这个类产生对象,对数据库实现操作,那么我们在产生对象的时…

BMS设计中的短路保护和MOSFET选型(下)

二、MOSFET参数 1、电气参数 (1)VGS :加在栅源两极之间的最大电压,一般为:-20V-+20V。 VGS额定电压是栅源两极间可以施加的最大电压。设定该额定电压的主要目的是防止电压过高导致的栅氧化层损伤。实际栅氧化层可承受的电压远高于额定电压,但是会随制造工艺的不同而改变…

01-机器学习概述

机器学习的定义 机器学习是一门从数据中研究算法的科学学科。 机器学习直白来讲, 就是根据已有的数据,进行算法选择,并基于算法和数据 构建模型,最终对未来进行预测。 机器学习就是一个模拟人决策过程的一种程序结构。 机器学…

PWM实现电机的正反转和调速以及TIM定时器

pwm.c #include "pwm.h"/* PWM --- PA2 --TIM2_CH3 //将电机信号控制一根接GND,一根接在PA2(TIM2_CH3), 输出PWM控制电机快慢 TIM2挂在APB1 定时器频率:84MHZ*/ void Pwm_Init(void) {GPIO_InitTypeDef GPIO_InitStruct;TIM_TimeBaseInitT…

Django下载使用、文件介绍

【一】下载并使用 【1】下载框架 (1)注意事项 计算机名称不要出现中文python解释器版本不同可能会出现启动报错项目中所有的文件名称不要出现中文多个项目文件尽量不要嵌套,做到一项一夹 (2)下载 Django属于第三方模块&#…

STM32微控制器中,如何处理多个同时触发的中断请求?

在STM32微控制器中,处理多个同时触发的中断请求需要一个明确的中断优先级策略,以确保关键任务能够及时得到响应。STM32的中断控制器(NVIC)支持优先级分组,允许开发者为不同的中断设置抢占优先级和子优先级。本文将详细…

uniapp 打包后缺少maps模块和share模块的解决方案

缺失maps模块 我的应用 | 高德控制台 缺失share模块 QQ互联管理中心 微信开放平台