使用pytorch构建ResNet50模型训练猫狗数据集

数据集

1.导包

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, models
import numpy as np
import matplotlib.pyplot as plt
import os
from tqdm.auto import tqdm  # 引入tqdm库以显示进度条

2.数据预处理

ResNet50模型适合的图片大小为224x244

# 定义数据转换
data_transforms = {'train': transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),'test': transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
}

3.加载数据集和模型构建

# 加载数据集
data_dir = 'catdog_data'
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),data_transforms[x])for x in ['train', 'test']}
dataloaders = {x: DataLoader(image_datasets[x], batch_size=4,shuffle=True, num_workers=4)for x in ['train', 'test']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'test']}
class_names = image_datasets['train'].classes# 加载ResNet-50模型
model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)# 替换最后的全连接层以适配我们的分类问题
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, len(class_names))# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

4.训练

# 训练次数
num_epochs = 10# 初始化训练次数计数器
train_count = 0
for epoch in range(num_epochs):  # num_epochs 是你希望训练的轮数for phase in ['train', 'test']:if phase == 'train':model.train()else:model.eval()running_loss = 0.0running_corrects = 0# 使用tqdm显示进度条with tqdm(total=len(dataloaders[phase]), desc=f'Epoch {epoch+1}/{num_epochs}', leave=False) as progress_bar:for inputs, labels in dataloaders[phase]:optimizer.zero_grad()with torch.set_grad_enabled(phase == 'train'):outputs = model(inputs)_, preds = torch.max(outputs, 1)loss = criterion(outputs, labels)if phase == 'train':loss.backward()optimizer.step()running_loss += loss.item() * inputs.size(0)running_corrects += torch.sum(preds == labels.data)epoch_loss = running_loss / dataset_sizes[phase]epoch_acc = running_corrects.double() / dataset_sizes[phase]progress_bar.set_postfix(loss=epoch_loss, acc=epoch_acc)progress_bar.update(1)print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')# 更新训练次数计数器train_count += 1print(f'Training Count: {train_count}')

训练过程

5.预测

import torch
import torchvision
import torchvision.transforms as transforms
from PIL import Image
import matplotlib.pyplot as plt# 定义模型的类别数量
num_classes = 2# 加载模型
model = torchvision.models.resnet50(pretrained=False)
# 修改模型的fc层以匹配训练时的结构
model.fc = torch.nn.Linear(model.fc.in_features, num_classes)
# 加载保存的权重
model.load_state_dict(torch.load('mg_ResNet50model.pth'))
model.eval()# 图像预处理
preprocess = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])# 测试图片
img_path = 'mao_1.jpg'  # 替换为你的图片路径
img = Image.open(img_path)
img_t = preprocess(img)# 扩展维度,因为模型需要4维输入(Batch, Channels, Height, Width)
batch_t = torch.unsqueeze(img_t, 0)# 预测
with torch.no_grad():out = model(batch_t)# 获取最高分数的类别
_, index = torch.max(out, 1)# 可视化结果
plt.imshow(img)
plt.title(f'Predicted: {index.item()}')
plt.show()

预测效果

0就是猫咪,1就是小狗

全部代码

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, models
import numpy as np
import matplotlib.pyplot as plt
import os
from tqdm.auto import tqdm  # 引入tqdm库以显示进度条# 定义数据转换
data_transforms = {'train': transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),'test': transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
}# 加载数据集
data_dir = 'catdog_data'
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),data_transforms[x])for x in ['train', 'test']}
dataloaders = {x: DataLoader(image_datasets[x], batch_size=4,shuffle=True, num_workers=4)for x in ['train', 'test']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'test']}
class_names = image_datasets['train'].classes# 加载ResNet-50模型
model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)# 替换最后的全连接层以适配我们的分类问题
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, len(class_names))# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)# 训练次数
num_epochs = 10# 初始化训练次数计数器
train_count = 0
for epoch in range(num_epochs):  # num_epochs 是你希望训练的轮数for phase in ['train', 'test']:if phase == 'train':model.train()else:model.eval()running_loss = 0.0running_corrects = 0# 使用tqdm显示进度条with tqdm(total=len(dataloaders[phase]), desc=f'Epoch {epoch+1}/{num_epochs}', leave=False) as progress_bar:for inputs, labels in dataloaders[phase]:optimizer.zero_grad()with torch.set_grad_enabled(phase == 'train'):outputs = model(inputs)_, preds = torch.max(outputs, 1)loss = criterion(outputs, labels)if phase == 'train':loss.backward()optimizer.step()running_loss += loss.item() * inputs.size(0)running_corrects += torch.sum(preds == labels.data)epoch_loss = running_loss / dataset_sizes[phase]epoch_acc = running_corrects.double() / dataset_sizes[phase]progress_bar.set_postfix(loss=epoch_loss, acc=epoch_acc)progress_bar.update(1)print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')# 更新训练次数计数器train_count += 1print(f'Training Count: {train_count}')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/19394.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

流媒体服务器SMS-语音对讲(一)

1.简介 在国标语音对讲对接中,会发现不同的厂商或不同型号的设备,对讲流程都不一样,本文主要介绍流媒体与设备之间的交互情况。 SMS流媒体服务代码库地址:https://gitee.com/inyeme/simple-media-server 2.流媒体与设备交互的可能…

JS中延迟加载的方式有哪些

延迟加载(Lazy loading)是一种性能优化策略,它通过将资源的加载推迟到真正需要使用的时候,来减少页面初始加载的时间和资源消耗。以下是几种常见的延迟加载方式: 1. 图片延迟加载:将页面中的图片的src属性…

Maven pom文件profile的properties在yaml配置文件替换失效问题

Maven profile的properties在yaml配置文件替换失效问题 Maven profile的properties在yaml配置文件替换失效问题原来错误的配置修改后的配置 Maven profile的properties在yaml配置文件替换失效问题 原因:spring-boot项目需要使用进行分割,如yaml配置文件…

Golang:使用embed引入静态文件

Go 语言从 1.16 版本开始引入了一个新的标准库 embed,可以在二进制文件中引入静态文件 指令:/go:embed 通过一个简单的小实例,来演示将静态文件引入到golang的二进制打包产物中 项目结构 $ tree . ├── main.go └── static└── he…

max6675热电偶温度采集

思路来源 参考价格 概述 MAX6675具有冷端补偿和将来自K型热电偶的信号数字化。数据以12位分辨率输出,SPI™兼容, 只读格式。该转换器将温度分解为0.25C,允许读数高达1024C,并显示热电偶8LSB在0C至 700C 引脚连接 温度采样电路 …

中间件复习之-消息队列

消息队列在分布式架构的作用 消息队列:在消息的传输过程中保存消息的容器,生产者和消费者不直接通讯,依靠队列保证消息的可靠性,避免了系统间的相互影响。 主要作用: 业务解耦异步调用流量削峰 业务解耦 将模块间的…

python中正则表达式学习

文章目录 介绍基本语法常用函数捕获组和命名组非捕获组贪婪匹配和非贪婪匹配多行模式和点匹配所有模式示例总结 介绍 Python 中的正则表达式(regular expressions, 简称 regex)由 re 模块提供。正则表达式是一种用于匹配字符串的强大工具,常…

MySQL之创建高性能的索引(八)

创建高性能的索引 覆盖索引 通常大家都会根据查询的WHERE条件来创建合适的索引,不过这只是索引优化的一个方面。设计优秀的索引应该考虑到整个查询,而不单单是WHERE条件部分。索引确实是一种查找数据的高效方式,但是MySQL也可以使用索引来直…

向量数据库引领 AI 创新——Zilliz 亮相 2024 亚马逊云科技中国峰会

2024年5月29日,亚马逊云科技中国峰会在上海召开,此次峰会聚集了来自全球各地的科技领袖、行业专家和创新企业,探讨云计算、大数据、人工智能等前沿技术的发展趋势和应用场景。作为领先的向量数据库技术公司,Zilliz 在本次峰会上展…

【漏洞复现】电信网关配置管理系统 rewrite.php 文件上传漏洞

0x01 产品简介 中国电信集团有限公司(英文名称"China Telecom”、简称“"中国电信”)成立于2000年9月,是中国特大型国有通信企业、上海世博会全球合作伙伴。电信网关配置管理系统是一个用于管理和配置电信网络中网关设备的软件系统。它可以帮助网络管理员…

在线IP检测如何做?代理IP需要检查什么?

当我们的数字足迹无处不在,隐私保护显得愈发重要。而代理IP就像是我们的隐身斗篷,让我们在各项网络业务中更加顺畅。 我们常常看到别人购买了代理IP服务后,通在线检测网站检查IP,相当于一个”售前检验““售后质检”的作用。但是…

2024-5-31 石群电路-19

2024-5-31,星期五,10:53,天气:阴雨,心情:晴。今天就要回学校啦,当大家看到这篇推文的时候我已经要收拾收拾去赶返校的火车啦,和女朋友短暂分别,不过小别胜新婚吗&#xf…

css动画效果(边框流光闪烁阴影效果)

1.整体效果 https://mmbiz.qpic.cn/sz_mmbiz_gif/EGZdlrTDJa7odDQYuaatklJUMc5anU10PWLAt14rNnNUD6oHJG9U63fc0yibiapuDViatVk62ma3K63oqQ3U1VtMQ/640?wx_fmtgif&fromappmsg&wxfrom13 CSS边框流光闪烁阴影动画效果是一种令人印象深刻的技术,它通过动态的光…

笔记-docker基于ubuntu22.04安装Jitsi Meet

背景 利用JitsiMeet打造一个可以在线会议的环境,根据躺的坑,做个记录 参考 JitsMeet部署安装说明 开始操作 环境 docker run -it --name ubuntu22.04 ubuntu:22.04 /bin/bash问题 1、安装 openjdk-11 apt install openjdk-11-jdk配置环境变量&…

es初始化

一.初始化es public void initES() {/*LOGGER.info("host" host);LOGGER.info("port" port);LOGGER.info("scheme" scheme);LOGGER.info("userName" userName);LOGGER.info("password" password);*/// 客户端连接创建…

自媒体必用的50 个最佳 ChatGPT 社交媒体帖子提示prompt通用模板教程

在这个信息爆炸的时代,社交媒体已经成为我们生活中不可或缺的一部分。无论是品牌宣传、个人展示,还是日常交流,我们都离不开它。然而,要在众多信息中脱颖而出,吸引大家的关注并不容易。这时候,ChatGPT这样的…

uniapp的tooltip功能放到表单laber

在uniapp中,tooltip功能通常是通过view组件的hover-class属性来实现的,而不是直接放在form的label上。hover-class属性可以定义当元素处于hover状态时的样式类,通过这个属性,可以实现一个类似tooltip的效果。 以下是一个简单的例…

代码随想录35期Day56-Java

Day56题目 LeetCode647回文子串的数量 核心思想:使用数组dp[i][j]表示s从i到j是否是回文串 class Solution {public int countSubstrings(String s) {// dp[i][j] 表示s从i到j是不是回文串boolean[][] dp new boolean[s.length()][s.length()];for(int i 0 ; i < s.len…

跨境经营的艺术:中资企业海外市场售后服务创新与挑战

出海&#xff0c;已不再是企业的“备胎”&#xff0c;而是必须面对的“大考”&#xff01;在这个全球化的大潮中&#xff0c;有的企业乘风破浪&#xff0c;勇攀高峰&#xff0c;也有的企业在异国他乡遭遇了“水土不服”。 面对“要么出海&#xff0c;要么出局”的抉择&#xff…

JavaScript笔记一-初识JavaScript

1、JavaScript的起源 JavaScript诞生于1995年&#xff0c;它的出现主要是用于处理网页中的前端验证。所谓的前端验证&#xff0c;就是指检查用户输入的内容是否符合一定的规则。比如&#xff1a;用户名的长度&#xff0c;密码的长度&#xff0c;邮箱的格式等。 2、JavaScript…