使用pytorch构建ResNet50模型训练猫狗数据集

数据集

1.导包

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, models
import numpy as np
import matplotlib.pyplot as plt
import os
from tqdm.auto import tqdm  # 引入tqdm库以显示进度条

2.数据预处理

ResNet50模型适合的图片大小为224x244

# 定义数据转换
data_transforms = {'train': transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),'test': transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
}

3.加载数据集和模型构建

# 加载数据集
data_dir = 'catdog_data'
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),data_transforms[x])for x in ['train', 'test']}
dataloaders = {x: DataLoader(image_datasets[x], batch_size=4,shuffle=True, num_workers=4)for x in ['train', 'test']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'test']}
class_names = image_datasets['train'].classes# 加载ResNet-50模型
model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)# 替换最后的全连接层以适配我们的分类问题
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, len(class_names))# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

4.训练

# 训练次数
num_epochs = 10# 初始化训练次数计数器
train_count = 0
for epoch in range(num_epochs):  # num_epochs 是你希望训练的轮数for phase in ['train', 'test']:if phase == 'train':model.train()else:model.eval()running_loss = 0.0running_corrects = 0# 使用tqdm显示进度条with tqdm(total=len(dataloaders[phase]), desc=f'Epoch {epoch+1}/{num_epochs}', leave=False) as progress_bar:for inputs, labels in dataloaders[phase]:optimizer.zero_grad()with torch.set_grad_enabled(phase == 'train'):outputs = model(inputs)_, preds = torch.max(outputs, 1)loss = criterion(outputs, labels)if phase == 'train':loss.backward()optimizer.step()running_loss += loss.item() * inputs.size(0)running_corrects += torch.sum(preds == labels.data)epoch_loss = running_loss / dataset_sizes[phase]epoch_acc = running_corrects.double() / dataset_sizes[phase]progress_bar.set_postfix(loss=epoch_loss, acc=epoch_acc)progress_bar.update(1)print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')# 更新训练次数计数器train_count += 1print(f'Training Count: {train_count}')

训练过程

5.预测

import torch
import torchvision
import torchvision.transforms as transforms
from PIL import Image
import matplotlib.pyplot as plt# 定义模型的类别数量
num_classes = 2# 加载模型
model = torchvision.models.resnet50(pretrained=False)
# 修改模型的fc层以匹配训练时的结构
model.fc = torch.nn.Linear(model.fc.in_features, num_classes)
# 加载保存的权重
model.load_state_dict(torch.load('mg_ResNet50model.pth'))
model.eval()# 图像预处理
preprocess = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])# 测试图片
img_path = 'mao_1.jpg'  # 替换为你的图片路径
img = Image.open(img_path)
img_t = preprocess(img)# 扩展维度,因为模型需要4维输入(Batch, Channels, Height, Width)
batch_t = torch.unsqueeze(img_t, 0)# 预测
with torch.no_grad():out = model(batch_t)# 获取最高分数的类别
_, index = torch.max(out, 1)# 可视化结果
plt.imshow(img)
plt.title(f'Predicted: {index.item()}')
plt.show()

预测效果

0就是猫咪,1就是小狗

全部代码

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, models
import numpy as np
import matplotlib.pyplot as plt
import os
from tqdm.auto import tqdm  # 引入tqdm库以显示进度条# 定义数据转换
data_transforms = {'train': transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),'test': transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
}# 加载数据集
data_dir = 'catdog_data'
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),data_transforms[x])for x in ['train', 'test']}
dataloaders = {x: DataLoader(image_datasets[x], batch_size=4,shuffle=True, num_workers=4)for x in ['train', 'test']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'test']}
class_names = image_datasets['train'].classes# 加载ResNet-50模型
model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)# 替换最后的全连接层以适配我们的分类问题
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, len(class_names))# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)# 训练次数
num_epochs = 10# 初始化训练次数计数器
train_count = 0
for epoch in range(num_epochs):  # num_epochs 是你希望训练的轮数for phase in ['train', 'test']:if phase == 'train':model.train()else:model.eval()running_loss = 0.0running_corrects = 0# 使用tqdm显示进度条with tqdm(total=len(dataloaders[phase]), desc=f'Epoch {epoch+1}/{num_epochs}', leave=False) as progress_bar:for inputs, labels in dataloaders[phase]:optimizer.zero_grad()with torch.set_grad_enabled(phase == 'train'):outputs = model(inputs)_, preds = torch.max(outputs, 1)loss = criterion(outputs, labels)if phase == 'train':loss.backward()optimizer.step()running_loss += loss.item() * inputs.size(0)running_corrects += torch.sum(preds == labels.data)epoch_loss = running_loss / dataset_sizes[phase]epoch_acc = running_corrects.double() / dataset_sizes[phase]progress_bar.set_postfix(loss=epoch_loss, acc=epoch_acc)progress_bar.update(1)print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')# 更新训练次数计数器train_count += 1print(f'Training Count: {train_count}')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/19394.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

流媒体服务器SMS-语音对讲(一)

1.简介 在国标语音对讲对接中,会发现不同的厂商或不同型号的设备,对讲流程都不一样,本文主要介绍流媒体与设备之间的交互情况。 SMS流媒体服务代码库地址:https://gitee.com/inyeme/simple-media-server 2.流媒体与设备交互的可能…

max6675热电偶温度采集

思路来源 参考价格 概述 MAX6675具有冷端补偿和将来自K型热电偶的信号数字化。数据以12位分辨率输出,SPI™兼容, 只读格式。该转换器将温度分解为0.25C,允许读数高达1024C,并显示热电偶8LSB在0C至 700C 引脚连接 温度采样电路 …

中间件复习之-消息队列

消息队列在分布式架构的作用 消息队列:在消息的传输过程中保存消息的容器,生产者和消费者不直接通讯,依靠队列保证消息的可靠性,避免了系统间的相互影响。 主要作用: 业务解耦异步调用流量削峰 业务解耦 将模块间的…

MySQL之创建高性能的索引(八)

创建高性能的索引 覆盖索引 通常大家都会根据查询的WHERE条件来创建合适的索引,不过这只是索引优化的一个方面。设计优秀的索引应该考虑到整个查询,而不单单是WHERE条件部分。索引确实是一种查找数据的高效方式,但是MySQL也可以使用索引来直…

向量数据库引领 AI 创新——Zilliz 亮相 2024 亚马逊云科技中国峰会

2024年5月29日,亚马逊云科技中国峰会在上海召开,此次峰会聚集了来自全球各地的科技领袖、行业专家和创新企业,探讨云计算、大数据、人工智能等前沿技术的发展趋势和应用场景。作为领先的向量数据库技术公司,Zilliz 在本次峰会上展…

【漏洞复现】电信网关配置管理系统 rewrite.php 文件上传漏洞

0x01 产品简介 中国电信集团有限公司(英文名称"China Telecom”、简称“"中国电信”)成立于2000年9月,是中国特大型国有通信企业、上海世博会全球合作伙伴。电信网关配置管理系统是一个用于管理和配置电信网络中网关设备的软件系统。它可以帮助网络管理员…

在线IP检测如何做?代理IP需要检查什么?

当我们的数字足迹无处不在,隐私保护显得愈发重要。而代理IP就像是我们的隐身斗篷,让我们在各项网络业务中更加顺畅。 我们常常看到别人购买了代理IP服务后,通在线检测网站检查IP,相当于一个”售前检验““售后质检”的作用。但是…

2024-5-31 石群电路-19

2024-5-31,星期五,10:53,天气:阴雨,心情:晴。今天就要回学校啦,当大家看到这篇推文的时候我已经要收拾收拾去赶返校的火车啦,和女朋友短暂分别,不过小别胜新婚吗&#xf…

笔记-docker基于ubuntu22.04安装Jitsi Meet

背景 利用JitsiMeet打造一个可以在线会议的环境,根据躺的坑,做个记录 参考 JitsMeet部署安装说明 开始操作 环境 docker run -it --name ubuntu22.04 ubuntu:22.04 /bin/bash问题 1、安装 openjdk-11 apt install openjdk-11-jdk配置环境变量&…

自媒体必用的50 个最佳 ChatGPT 社交媒体帖子提示prompt通用模板教程

在这个信息爆炸的时代,社交媒体已经成为我们生活中不可或缺的一部分。无论是品牌宣传、个人展示,还是日常交流,我们都离不开它。然而,要在众多信息中脱颖而出,吸引大家的关注并不容易。这时候,ChatGPT这样的…

uniapp的tooltip功能放到表单laber

在uniapp中,tooltip功能通常是通过view组件的hover-class属性来实现的,而不是直接放在form的label上。hover-class属性可以定义当元素处于hover状态时的样式类,通过这个属性,可以实现一个类似tooltip的效果。 以下是一个简单的例…

跨境经营的艺术:中资企业海外市场售后服务创新与挑战

出海,已不再是企业的“备胎”,而是必须面对的“大考”!在这个全球化的大潮中,有的企业乘风破浪,勇攀高峰,也有的企业在异国他乡遭遇了“水土不服”。 面对“要么出海,要么出局”的抉择&#xff…

一分钟学习数据安全——自主管理身份SSI基本概念

之前我们已经介绍过数字身份的几种模式。其中,分布式数字身份模式逐渐普及演进的结果就是自主管理身份(SSI,Self-Sovereign Identity)。当一个人能够完全拥有和控制其数字身份,而无需依赖中心化机构,这就是…

FreeRTOS实时系统 在任务中增加数组等相关操作 导致单片机起不来或者挂掉

在调试串口任务中增加如下代码,发现可以用keil进行仿真,但是烧录程序后,调试串口没有打印,状态灯也不闪烁,单片机完全起不来 博主就纳了闷了,究竟是什么原因,这段代码可是公司永流传的老代码了&…

香橙派OrangePi AIpro上手笔记——之USB摄像头目标检测方案测试(三)

整期笔记索引 香橙派OrangePi AIpro上手笔记——之USB摄像头目标检测方案测试(一) 香橙派OrangePi AIpro上手笔记——之USB摄像头目标检测方案测试(二) 香橙派OrangePi AIpro上手笔记——之USB摄像头目标检测方案测试(…

【MySQL数据库】:MySQL复合查询

目录 基本查询回顾 多表查询 自连接 子查询 单行子查询 多行子查询 多列子查询 在from子句中使用子查询 合并查询 前面我们讲解的mysql表的查询都是对一张表进行查询,在实际开发中这远远不够。 基本查询回顾 【MySQL数据库】:MySQL基本查…

【测试】linux快捷指令工具cxtool

简介 登录linux时,我们经常需要重复输入一些指令. 这个工具可以把这些指令预置,需要的时候鼠标一点,会自动按预置的字符敲击键盘,敲击出指令. 下载地址 https://download.csdn.net/download/bandaoyu/89379371 使用方法 1,编辑配置文件,自定义自己的快捷指令。 2…

运算符重载(下)

目录 前置和后置重载前置的实现Date& Date::operator()代码 后置的实现Date Date::operator(int )代码 前置--和后置--重载前置--的实现Date& Date::operator--( )代码 后置--的实现Date Date::operator--(int )代码 流插入运算符重载流插入运算符重载的实现流提取运算…

任何图≌自己这一几何最起码常识推翻直线公理让R外标准实数一下子浮出水面

黄小宁 h定理:点集AB≌B的必要条件是A≌B。 证:若AB则A必可恒等变换地变为BA≌A,而恒等变换是保距变换。证毕。 如图所示R轴即x轴各元点x沿x轴正向不保距平移变为点y2x就使x轴沿本身拉伸(放大)变换为y2x轴不≌x轴&…

校园疫情防控|基于SprinBoot+vue的校园疫情防控系统(源码+数据库+文档)

校园疫情防控系统 目录 基于SprinBootvue的校园疫情防控系统 一、前言 二、系统设计 三、系统功能设计 1系统功能模块 2后台功能模块 5.2.1管理员功能 5.2.2学生功能 四、数据库设计 五、核心代码 六、论文参考 七、最新计算机毕设选题推荐 八、源码获取&#x…