迁移学习实现图片分类任务

导入工具包

import time
import osimport numpy as np
from tqdm import tqdmimport torch
import torchvision
import torch.nn as nn
import torch.nn.functional as Fimport matplotlib.pyplot as plt
%matplotlib inline# 忽略烦人的红色提示
import warnings
warnings.filterwarnings("ignore")

获取计算硬件

# 有 GPU 就用 GPU,没有就用 CPU
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('device', device)

图片预处理

from torchvision import transforms# 训练集图像预处理:缩放裁剪、图像增强、转 Tensor、归一化
train_transform = transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])# 测试集图像预处理-RCTN:缩放、裁剪、转 Tensor、归一化
test_transform = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])

这里对train训练集和text集的处理不同,几个transforms的操作通过compose进行整合。

载入图片分类数据集

# 数据集文件夹路径
dataset_dir = 'fruit30_split'train_path = os.path.join(dataset_dir, 'train')
test_path = os.path.join(dataset_dir, 'val')
print('训练集路径', train_path)
print('测试集路径', test_path)from torchvision import datasets# 载入训练集
train_dataset = datasets.ImageFolder(train_path, train_transform)# 载入测试集
test_dataset = datasets.ImageFolder(test_path, test_transform)print('训练集图像数量', len(train_dataset))
print('类别个数', len(train_dataset.classes))
print('各类别名称', train_dataset.classes)print('测试集图像数量', len(test_dataset))
print('类别个数', len(test_dataset.classes))
print('各类别名称', test_dataset.classes)

datasets下的ImageFolder,可以直接构建数据集。

类别与索引号一一对应

class_names = train_dataset.classes
n_class = len(class_names)# 映射关系:类别 到 索引号
train_dataset.class_to_idx

定义数据加载器Dataloader,dataloader用于给模型喂数据。

from torch.utils.data import DataLoaderBATCH_SIZE = 32# 训练集的数据加载器
train_loader = DataLoader(train_dataset,batch_size=BATCH_SIZE,shuffle=True,num_workers=4)# 测试集的数据加载器
test_loader = DataLoader(test_dataset,batch_size=BATCH_SIZE,shuffle=False,num_workers=4)

查看一个batch的图像与标注

# DataLoader 是 python生成器,每次调用返回一个 batch 的数据
images, labels = next(iter(train_loader))images. Shape
#torch.Size([32, 3, 224, 224])
labels
#tensor([11, 19,  3, 25, 29, 13, 21, 18, 11,  1, 13, 15, 13,  0, 15, 25,  0,  7,11, 10,  9,  6, 26,  2, 11, 10, 29, 29, 15,  8, 19,  8])

迁移学习范式

导入训练所用的工具包

from torchvision import models
import torch.optim as optim
model = models.resnet18(pretrained=True) # 载入预训练模型
# 修改全连接层,使得全连接层的输出与当前数据集类别数对应
# 新建的层默认 requires_grad=True
model.fc = nn.Linear(model.fc.in_features, n_class)
model.fc
Linear(in_features=512, out_features=30, bias=True)
# 只微调训练最后一层全连接层的参数,其它层冻结
optimizer = optim.Adam(model.fc.parameters())

采用第一种迁移学习的方式,优化器采用的是Adam的优化器。

训练配置

model = model.to(device)# 交叉熵损失函数
criterion = nn.CrossEntropyLoss() # 训练轮次 Epoch
EPOCHS = 20

模拟一个batch的训练

这里着重注意反向传播三部曲

# 反向传播“三部曲”
optimizer.zero_grad() # 清除梯度
loss.backward() # 反向传播
optimizer.step() # 优化更新

 运行完整训练

# 遍历每个 EPOCH
for epoch in tqdm(range(EPOCHS)):model. Train() #每次开始前将模型设置为训练模式for images, labels in train_loader:  # 获取训练集的一个 batch,包含数据和标注images = images.to(device)labels = labels.to(device)outputs = model(images)           # 前向预测,获得当前 batch 的预测结果loss = criterion(outputs, labels) # 比较预测结果和标注,计算当前 batch 的交叉熵损失函数optimizer.zero_grad()loss.backward()                   # 损失函数对神经网络权重反向传播求梯度optimizer.step()                  # 优化更新神经网络权重

在测试集上进行初步测试

model.eval() #模型设置为测试模式
with torch.no_grad(): #不再回传梯度correct = 0total = 0for images, labels in tqdm(test_loader): # 获取测试集的一个 batch,包含数据和标注images = images.to(device)labels = labels.to(device)outputs = model(images)              # 前向预测,获得当前 batch 的预测置信度_, preds = torch.max(outputs, 1)     # 获得最大置信度对应的类别,作为预测结果total += labels.size(0)correct += (preds == labels).sum()   # 预测正确样本个数,如果预测类别等于标注类别print('测试集上的准确率为 {:.3f} %'.format(100 * correct / total))

保存模型

torch.save(model, 'checkpoint/fruit30_pytorch_C1.pth')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/667165.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

okhttp 的 拦截器

拦截器有很多作用,实现就是责任链模式,细节,等我有时间补上。 后面有时间更新一下。 OkHttp最核心的工作是在 getResponseWithInterceptorChain() 中进行,在进入这个方法分析之前,我们先来了 解什么是责任链模式&…

Java split 分割字符串避坑

使用split进行字符串分割时需要注意2点 1、特殊字符作为分隔符时需要使用\\进行转义(如\\ -> \\\\; | -> \\| ) 特殊字符 .$|()[{^?*\\ 例如对"|"分隔 未转义 String str "01|02|03"; String[] strArr str.split("|");System.out.…

点击按钮打开自定义iframe弹窗

1、效果 点击按钮打开弹窗&#xff1a; 打开弹窗后&#xff1a; 2、代码 <!DOCTYPE html> <html><head><title>iframe弹窗</title><style>/* 使用媒体查询来实现响应式设计 */media (min-width: 768px) {.popup {width: 80%; /* 设置…

【c/python】GtkBox

一、GtkBox及C语言示例 GtkBox是一个容器部件&#xff0c;用于在GTK&#xff08;GIMP Toolkit&#xff09;应用程序中水平或垂直地排列多个子部件。以下是一个简单的例子&#xff0c;展示了如何在一个基本的GTK应用程序中使用GtkBox来垂直排列两个按钮&#xff1a; 首先&#…

用Python Tkinter打造的精彩连连看小游戏【附源码】

文章目录 连连看小游戏&#xff1a;用Python Tkinter打造的精彩游戏体验游戏简介技术背景MainWindow类:职责:方法:Point类: 主执行部分:完整代码&#xff1a;总结&#xff1a; 连连看小游戏&#xff1a;用Python Tkinter打造的精彩游戏体验 在丰富多彩的游戏世界中&#xff0c…

左旋字符串的三种方法,并判断一个字符串是否为另外一个字符串旋转之后的字符串。(strcpy,strncat,strcmp,strstr函数的介绍)

一. 实现一个函数&#xff0c;可以左旋字符串中的k个字符。 例如&#xff1a; ABCD左旋一个字符得到BCDA ABCD左旋两个字符得到CDAB 通过分析&#xff0c;可以知道实际的旋转次数&#xff0c;其实是k%&#xff08;字符串长度&#xff09;。假设一个字…

西瓜书学习笔记——流形学习(公式推导+举例应用)

文章目录 等度量映射&#xff08;仅保留点与其邻近点的距离&#xff09;算法介绍实验分析 局部线性嵌入&#xff08;不仅保留点与其邻近点的距离还要保留邻近关系&#xff09;算法介绍实验分析 等度量映射&#xff08;仅保留点与其邻近点的距离&#xff09; 算法介绍 等度量映…

树莓派5一键安装C++版本OpenCV

安装环境 本人当前的安装环境&#xff1a; 树莓派5Raspberry Pi os (64-bit) Debian12 Bookworm 镜像下载地址 我这里是将镜像安装好后直接安装opencv&#xff0c;如果不是刚安装好的镜像需要注意是否有openCV的python之类的安装过&#xff0c;不然可能出现编译错误 一、扩展内…

SpringBoot中数据库的连接及Mybatis的配置和使用

目录 1 在pom.xml中引入相关依赖 2 对数据库进行配置 2.1 配置application.yml 2.2 idea连接数据库 (3.2.1有用到) 3 Mybatis的使用 3.1 测试文件的引入 3.2 使用 3.2.1 使用注解(有小技巧(✪ω✪)) 3.2.2 使用动态sql 1 在pom.xml中引入相关依赖 <dependencies&g…

海外多语言盲盒开发:打破语言障碍,连接全球消费者

随着全球化的加速和互联网的普及&#xff0c;语言障碍成为了影响跨国交流和商业活动的重要因素。为了满足跨国市场的需求&#xff0c;海外多语言盲盒开发成为了一个新兴的领域。本文将探讨海外多语言盲盒开发的意义、现状和未来发展。 一、海外多语言盲盒开发的意义 在全球化…

RedHat8.4安装邮件服务器

一、配置发件服务器 1.1 根据现场IP&#xff0c;配置主机名 vim /etc/hosts 192.168.8.120 mail.test.com 将主机名更改为邮件服务器域名mail.test.com 1.2 关闭防火墙&#xff0c;禁止开机启动 systemctl stop firewalld systemctl disable firewalld 1.3 关闭selinux v…

基于springboot就业信息管理系统源码和论文

随着信息化时代的到来&#xff0c;管理系统都趋向于智能化、系统化&#xff0c;就业信息管理系统也不例外&#xff0c;但目前国内仍都使用人工管理&#xff0c;市场规模越来越大&#xff0c;同时信息量也越来越庞大&#xff0c;人工管理显然已无法应对时代的变化&#xff0c;而…

InnoDB 锁系统(小白入门)

1995年 &#xff0c;MySQL 1.0发布&#xff0c;仅供内部使用&#xff01; 开发多用户、数据库驱动的应用时&#xff0c;最大的一个难点是&#xff1a;一方面要最大程度地利用数据库的并发访问&#xff0c;另一方面还要确保每个用户能以一致性的方式读取和修改数据。 MVCC 并发…

基于python+控制台的员工信息管理系统

基于python控制台的员工信息管理系统 一、系统介绍二、效果展示三、其他系统实现四、获取源码 一、系统介绍 1.添加职工数据 2.显示职工数据 3.查询职工数据 4.修改职工数据 5.删除职工数据 6.保存职工数据 7.排序职工数据 8.统计职工工资数据 9.退出 二、效果展示 三、其他系…

从搜索引擎到答案引擎:LLM驱动的变革

在过去的几周里&#xff0c;我一直在思考和起草这篇文章&#xff0c;认为谷歌搜索正处于被颠覆的边缘&#xff0c;它实际上可能会影响 SEO 作为业务牵引渠道的可行性。 考虑到谷歌二十多年来的完全统治地位&#xff0c;以及任何竞争对手都完全无力削弱它&#xff0c;坦率地说&…

CSS transition(过渡效果)详解并附带示例

CSS过渡效果&#xff08;CSS transitions&#xff09;是一种在元素属性值发生变化时&#xff0c;通过指定过渡效果来实现平滑的动画效果的方法。通过定义起始状态和结束状态之间的过渡属性&#xff0c;可以使元素的变化更加流畅和可视化。 过渡效果的基本语法如下&#xff1a;…

乐意购项目前端开发 #6

一、商品详情页面 代码模版 创建Detail文件夹, 然后创建index.vue文件 <script setup> import { getDetail } from "/api/goods/index"; import { ref, onMounted } from "vue"; import { useRoute } from "vue-router"; import { useCar…

SpringBoot 登录检验JWT令牌 生成与校验

JWT官网 https://jwt.io/ 引入依赖 <dependency><groupId>io.jsonwebtoken</groupId><artifactId>jjwt</artifactId><version>0.9.1</version> </dependency>设置过期时间 LocalDateTime localDateTime LocalDateTime.now().…

STM32--SPI通信协议(1)SPI基础知识总结

前言 I2C (Inter-Integrated Circuit)和SPI (Serial Peripheral Interface)是两种常见的串行通信协议&#xff0c;用于连接集成电路芯片之间的通信&#xff0c;选择I2C或SPI取决于具体的应用需求。如果需要较高的传输速度和简单的接口&#xff0c;可以选择SPI。如果需要连接多…

css1字体属性

一.font-family(字体系列&#xff09; 不同字体系统用&#xff0c;隔开&#xff1b; 多个字母的字体系统用“”&#xff1b; 二.font-size&#xff08;字体大小&#xff09;&#xff08;有单位px&#xff09;&#xff08;默认字体16px&#xff09; 三.font-weight&#xff08…