Pytorch迁移学习使用Resnet50进行模型训练预测猫狗二分类

目录

1.ResNet残差网络

1.1 ResNet定义

 1.2 ResNet 几种网络配置

 1.3 ResNet50网络结构

1.3.1 前几层卷积和池化

1.3.2 残差块:构建深度残差网络

1.3.3 ResNet主体:堆叠多个残差块

1.4 迁移学习猫狗二分类实战

1.4.1 迁移学习

1.4.2 模型训练

1.4.3 模型预测


1.ResNet残差网络

1.1 ResNet定义

深度学习在图像分类、目标检测、语音识别等领域取得了重大突破,但是随着网络层数的增加,梯度消失和梯度爆炸问题逐渐凸显。随着层数的增加,梯度信息在反向传播过程中逐渐变小,导致网络难以收敛。同时,梯度爆炸问题也会导致网络的参数更新过大,无法正常收敛。

为了解决这些问题,ResNet提出了一个创新的思路:引入残差块(Residual Block)。残差块的设计允许网络学习残差映射,从而减轻了梯度消失问题,使得网络更容易训练。

下图是一个基本残差块。它的操作是把某层输入跳跃连接到下一层乃至更深层的激活层之前,同本层输出一起经过激活函数输出。
 

24353e89d9c84a17babbbf4ebe90630b.png

 1.2 ResNet 几种网络配置

如下图:

 1.3 ResNet50网络结构

ResNet-50是一个具有50个卷积层的深度残差网络。它的网络结构非常复杂,但我们可以将其分为以下几个模块:

1.3.1 前几层卷积和池化

import torch
import torch.nn as nnclass ResNet50(nn.Module):def __init__(self, num_classes=1000):super(ResNet50, self).__init__()self.conv1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=7, stride=2, padding=3, bias=False)self.bn1 = nn.BatchNorm2d(64)self.relu = nn.ReLU(inplace=True)self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

1.3.2 残差块:构建深度残差网络

class ResidualBlock(nn.Module):def __init__(self, in_channels, out_channels, stride=1, downsample=None):super(ResidualBlock, self).__init__()self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False)self.bn1 = nn.BatchNorm2d(out_channels)self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)self.bn2 = nn.BatchNorm2d(out_channels)self.conv3 = nn.Conv2d(out_channels, out_channels * 4, kernel_size=1, stride=1, bias=False)self.bn3 = nn.BatchNorm2d(out_channels * 4)self.relu = nn.ReLU(inplace=True)self.downsample = downsampledef forward(self, x):identity = xout = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.conv2(out)out = self.bn2(out)out = self.relu(out)out = self.conv3(out)out = self.bn3(out)if self.downsample is not None:identity = self.downsample(x)out += identityout = self.relu(out)return out

1.3.3 ResNet主体:堆叠多个残差块

在ResNet-50中,我们堆叠了多个残差块来构建整个网络。每个残差块会将输入的特征图进行处理,并输出更加丰富的特征图。堆叠多个残差块允许网络在深度方向上进行信息的层层提取,从而获得更高级的语义信息。代码如下:

class ResNet50(nn.Module):def __init__(self, num_classes=1000):# ... 前几层代码 ...# 4个残差块的block1self.layer1 = self._make_layer(ResidualBlock, 64, 3, stride=1)# 4个残差块的block2self.layer2 = self._make_layer(ResidualBlock, 128, 4, stride=2)# 4个残差块的block3self.layer3 = self._make_layer(ResidualBlock, 256, 6, stride=2)# 4个残差块的block4self.layer4 = self._make_layer(ResidualBlock, 512, 3, stride=2)

 利用make_layer函数实现对基本残差块Bottleneck的堆叠。代码如下:

def _make_layer(self, block, channel, block_num, stride=1):"""block: 堆叠的基本模块channel: 每个stage中堆叠模块的第一个卷积的卷积核个数,对resnet50分别是:64,128,256,512block_num: 当期stage堆叠block个数stride: 默认卷积步长"""downsample = None   # 用于控制shorcut路的if stride != 1 or self.in_channel != channel*block.expansion:   # 对resnet50:conv2中特征图尺寸H,W不需要下采样/2,但是通道数x4,因此shortcut通道数也需要x4。对其余conv3,4,5,既要特征图尺寸H,W/2,又要shortcut维度x4downsample = nn.Sequential(nn.Conv2d(in_channels=self.in_channel, out_channels=channel*block.expansion, kernel_size=1, stride=stride, bias=False), # out_channels决定输出通道数x4,stride决定特征图尺寸H,W/2nn.BatchNorm2d(num_features=channel*block.expansion))layers = []  # 每一个convi_x的结构保存在一个layers列表中,i={2,3,4,5}layers.append(block(in_channel=self.in_channel, out_channel=channel, downsample=downsample, stride=stride)) # 定义convi_x中的第一个残差块,只有第一个需要设置downsample和strideself.in_channel = channel*block.expansion   # 在下一次调用_make_layer函数的时候,self.in_channel已经x4for _ in range(1, block_num):  # 通过循环堆叠其余残差块(堆叠了剩余的block_num-1个)layers.append(block(in_channel=self.in_channel, out_channel=channel))return nn.Sequential(*layers)   # '*'的作用是将list转换为非关键字参数传入

1.4 迁移学习猫狗二分类实战

1.4.1 迁移学习

迁移学习(Transfer Learning)是一种机器学习和深度学习技术,它允许我们将一个任务学到的知识或特征迁移到另一个相关的任务中,从而加速模型的训练和提高性能。在迁移学习中,我们通常利用已经在大规模数据集上预训练好的模型(称为源任务模型),将其权重用于新的任务(称为目标任务),而不是从头开始训练一个全新的模型。

迁移学习的核心思想是:在解决一个新任务之前,我们可以先从已经学习过的任务中获取一些通用的特征或知识,并将这些特征或知识迁移到新任务中。这样做的好处在于,源任务模型已经在大规模数据集上进行了充分训练,学到了很多通用的特征,例如边缘检测、纹理等,这些特征对于许多任务都是有用的。

1.4.2 模型训练

首先,我们需要准备用于猫狗二分类的数据集。数据集可以从Kaggle上下载,其中包含了大量的猫和狗的图片。

在下载数据集后,我们需要将数据集划分为训练集和测试集。训练集文件夹命名为train,其中建立两个文件夹分别为cat和dog,每个文件夹里存放相应类别的图片。测试集命名为test,同理。然后我们使用ResNet50网络模型,在我们的计算机上使用GPU进行训练并保存我们的模型,训练完成后在测试集上验证模型预测的正确率。

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
from torchvision.datasets import ImageFolder
from torchvision.models import resnet50# 设置随机种子
torch.manual_seed(42)# 定义超参数
batch_size = 32
learning_rate = 0.001
num_epochs = 10# 定义数据转换
transform = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])# 加载数据集
train_dataset = ImageFolder("train", transform=transform)
test_dataset = ImageFolder("test", transform=transform)train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size)# 加载预训练的ResNet-50模型
model = resnet50(pretrained=True)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 2)  # 替换最后一层全连接层,以适应二分类问题device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9)# 训练模型
total_step = len(train_loader)
for epoch in range(num_epochs):for i, (images, labels) in enumerate(train_loader):images = images.to(device)labels = labels.to(device)# 前向传播outputs = model(images)loss = criterion(outputs, labels)# 反向传播和优化optimizer.zero_grad()loss.backward()optimizer.step()if (i + 1) % 100 == 0:print(f"Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{total_step}], Loss: {loss.item()}")
torch.save(model,'model/c.pth')
# 测试模型
model.eval()
with torch.no_grad():correct = 0total = 0for images, labels in test_loader:images = images.to(device)labels = labels.to(device)outputs = model(images)print(outputs)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()breakprint(f"Accuracy on test images: {(correct / total) * 100}%")

1.4.3 模型预测

首先加载我们保存的模型,这里我们进行单张图片的预测,并把预测结果打印日志。

import cv2 as cv
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import torchvision.transforms as transforms
import  torch
from PIL import Image
import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model=torch.load('model/c.pth')
print(model)
model.to(device)test_image_path = 'test/dogs/dog.4001.jpg'  # Replace with your test image path
image = Image.open(test_image_path)
transform = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
input_tensor = transform(image).unsqueeze(0).to(device)  # Add a batch dimension and move to GPU# Set the model to evaluation mode
model.eval()with torch.no_grad():outputs = model(input_tensor)_, predicted = torch.max(outputs, 1)predicted_label = predicted.item()label=['猫','狗']
print(label[predicted_label])
plt.axis('off')
plt.imshow(image)
plt.show()

运行截图

至此这篇文章到此结束。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/6572.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

华为数通HCIP-ISIS基础

IS-IS的基本概念 isis(中间系统到中间路由协议) 链路状态路由协议、IGP、无类路由协议; IS-IS是一种链路状态路由协议,IS-IS与OSPF在许多方面非常相似:运行IS-IS协议的直连设备之间通过发送Hello报文发现彼此,然后建…

从零开始搭建vue3 + ts + pinia + vite +element-plus项目

前言:据说vue2将于 2023 年 12 月 31 日停止维护,最近打算搭建一个vue3项目来学习一下,以防忘记,记录一下搭建过程。 一、使用npm创建项目 前提条件:已安装 16.0 或更高版本的 Node.js 执行 “npm init vuelatest”…

【Java基础教程】(四十三)多线程篇 · 下:深入剖析Java多线程编程:同步、死锁及经典案例——生产者与消费者,探究sleep()与wait()的差异

Java基础教程之多线程 下 🔹本节学习目标1️⃣ 线程的同步与死锁1.1 同步问题的引出2.2 synchronized 同步操作2.3 死锁 2️⃣ 多线程经典案例——生产者与消费者🔍分析sleep()和wait()的区别? 🌾 总结 🔹本节学习目标…

谷歌插件(Chrome扩展) “Service Worker (无效)” 解决方法

问题描述: 写 background 文件的时候报错了,说 Service Worker 设置的 background 无效。 解决(检查)方法: 检查配置文件(manifest.json) 中的 manifest_version 是否为 3。 background 中的…

如何动态修改 spring aop 切面信息?让自动日志输出框架更好用

业务背景 很久以前开源了一款 auto-log 自动日志打印框架。 其中对于 spring 项目,默认实现了基于 aop 切面的日志输出。 但是发现一个问题,如果切面定义为全切范围过大,于是 v0.2 版本就是基于注解 AutoLog 实现的。 只有指定注解的类或…

DataWhale AI夏令营——机器学习

DataWhale AI夏令营——机器学习 学习记录一1. 异常值分析2. 单变量箱线图可视化3. 特征重要性分析 学习记录一 锂电池电池生产参数调控及生产温度预测挑战赛 已配置环境,跑通baseline,并在此基础上对数据进行了简单的分析。 1. 异常值分析 对训练集…

K8S初级入门系列之八-网络

一、前言 本章节我们将了解K8S的相关网络概念,包括K8S的网络通讯原理,以及Service以及相关的概念,包括Endpoint,EndpointSlice,Headless service,Ingress等。 二、网络通讯原理和实现 同一K8S集群&…

PMP 数据收集工具与技术

数据收集工具与技术 (9个) 标杆对照 标杆对照是指将实际或计划的产品、流程和实践与其他可比组织的 做法进行比较,以便识别最佳实践、形成改进意见,并为绩效考核 提供依据。 头脑风暴 头脑风暴是一种数据收集和创意技术,主要用于在短时间…

三维点云中的坐标变换(只讲关键部分)

一、坐标旋转 坐标旋转包含绕x、y、z轴旋转,在右手坐标系中,x-翻滚(roll),y-俯仰(pitch),z-航向(yaw)。如果想详细了解,可以网络搜索 在PCL中,从baseLink到map的转换关系为:先绕x轴旋转,在绕y轴旋转,最后绕…

【软件工程中的各种图】

1、用例图(use case diagrams) 【概念】描述用户需求,从用户的角度描述系统的功能 【描述方式】椭圆表示某个用例;人形符号表示角色 【目的】帮组开发团队以一种可视化的方式理解系统的功能需求 【用例图】 2、静态图(Static …

【数据结构】C--单链表(小白入门基础知识)

前段时间写了一篇关于顺序表的博客,http://t.csdn.cn/0gCRp 顺序表在某些时候存在着一些不可避免的缺点: 问题: 1. 中间 / 头部的插入删除,时间复杂度为 O(N) 2. 增容需要申请新空间,拷贝数据,释放旧空间。会有不…

前端 | ( 十一)CSS3简介及基本语法(上) | 尚硅谷前端html+css零基础教程2023最新

学习来源:尚硅谷前端htmlcss零基础教程,2023最新前端开发html5css3视频 系列笔记: 【HTML4】(一)前端简介【HTML4】(二)各种各样的常用标签【HTML4】(三)表单及HTML4收尾…

2023/07/23

1. 必须等待所有请求结束后才能执行后续操作的处理方式 方式一: async func () {const p1 await api1();const p2 await api2();const p3 await api3();Promise.all([p1, p2, p3]).then(res > {后续操作...}) }方式二:待补充 2. flex 弹性盒子布…

FPGA实现串口回环

文章目录 前言一、串行通信1、分类1、同步串行通信2、异步串行通信 2、UART串口通信1、UART通信原理2、串口通信时序图 二、系统设计1、系统框图2.RTL视图 三、源码1、串口发送模块2、接收模块3、串口回环模块4、顶层模块 四、测试效果五、总结六、参考资料 前言 环境&#xff…

【计算机视觉 | 目标检测】arxiv 计算机视觉关于目标检测的学术速递(7 月 21 日论文合集)

文章目录 一、检测相关(15篇)1.1 Representation Learning in Anomaly Detection: Successes, Limits and a Grand Challenge1.2 AlignDet: Aligning Pre-training and Fine-tuning in Object Detection1.3 Cascade-DETR: Delving into High-Quality Universal Object Detectio…

《Docker与持续集成/持续部署:构建高效交付流程,打造敏捷软件交付链》

🌷🍁 博主 libin9iOak带您 Go to New World.✨🍁 🦄 个人主页——libin9iOak的博客🎐 🐳 《面试题大全》 文章图文并茂🦕生动形象🦖简单易学!欢迎大家来踩踩~&#x1f33…

c语言修炼之指针和数组笔试题解析(1.2)

前言: 书接上回,让我们继续开始今天的学习叭!废话不多说,还是字符数组的内容上代码! char *p是字符指针,*表示p是个指针,char表示p指向的对象类型是char型! char*p"abcdef&q…

使用Plist编辑器——简单入门指南

本指南将介绍如何使用Plist编辑器。您将学习如何打开、编辑和保存plist文件,并了解plist文件的基本结构和用途。跟随这个简单的入门指南,您将掌握如何使用Plist编辑器轻松管理您的plist文件。 plist文件是一种常见的配置文件格式,用于存储应…

7.6Java EE——Bean的生命周期

Bean在不同作用域内的生命周期 Bean的生命周期是指Bean实例被创建、初始化和销毁的过程。在Bean的两种作用域singleton和prototype中,Spring容器对Bean的生命周期的管理是不同的。在singleton作用域中,Spring容器可以管理Bean的生命周期,控制…

vue父组件和子组件数据传递

vue --父组件向子组件传递数据 父组件&#xff1a; <template><div class"parent"><p>父组件&#xff1a;{{ msg }}</p><Child message"Hello, I am parent!"></Child></div> </template><script>…