Pytorch 猫狗识别案例

猫狗识别数据集icon-default.png?t=N7T8https://download.csdn.net/download/Victor_Li_/88483483?spm=1001.2014.3001.5501

训练集图片路径

测试集图片路径

训练代码如下

import torch
import torchvision
import matplotlib.pyplot as plt
import torchvision.models as models
import torch.nn as nn
import torch.optim as optim
import torch.multiprocessing as mp
import time
from torch.optim.lr_scheduler import StepLRif __name__ == '__main__':torch.autograd.set_detect_anomaly(True)mp.freeze_support()train_on_gpu = torch.cuda.is_available()if not train_on_gpu:print('CUDA is not available. Training on CPU...')else:print('CUDA is available! Training on GPU...')device = torch.device("cuda" if torch.cuda.is_available() else "cpu")batch_size = 32# 设置数据预处理的转换transform = torchvision.transforms.Compose([torchvision.transforms.Resize((224, 224)),  # 调整图像大小为 224x224torchvision.transforms.RandomHorizontalFlip(),torchvision.transforms.RandomRotation(45),torchvision.transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),torchvision.transforms.ToTensor(),  # 转换为张量torchvision.transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  # 归一化])dataset = torchvision.datasets.ImageFolder('./cats_and_dogs_train',transform=transform)val_ratio = 0.2val_size = int(len(dataset) * val_ratio)train_size = len(dataset) - val_sizetrain_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])train_dataset = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4,pin_memory=True)val_dataset = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, num_workers=4, pin_memory=True)# x,y = next(iter(val_dataset))# x = x.permute(1, 2, 0)  # 将通道维度调整到最后# x = (x - x.min()) / (x.max() - x.min())  # 反归一化操作# plt.imshow(x)  # 将通道维度调整到最后# plt.axis('off')  # 关闭坐标轴# plt.show()model = models.resnet34(weights=None)num_classes = 2model.fc = nn.Sequential(nn.Dropout(p=0.2),# nn.BatchNorm4d(model.fc.in_features),nn.Linear(model.fc.in_features, num_classes),nn.Sigmoid(),)lambda_L1 = 0.001lambda_L2 = 0.0001regularization_loss_L1 = 0regularization_loss_L2 = 0for name,param in model.named_parameters():param.requires_grad = Trueif 'bias' not in name:regularization_loss_L1 += torch.norm(param, p=1).detach()regularization_loss_L2 += torch.norm(param, p=2).detach()optimizer = optim.Adam(model.parameters(), lr=0.01)scheduler = StepLR(optimizer, step_size=5, gamma=0.9)criterion = nn.BCELoss().to(device)model.to(device)# print(model)loadfilename = "recognize_cats_and_dogs.pt"savefilename = "recognize_cats_and_dogs3.pt"checkpoint = torch.load(loadfilename)model.load_state_dict(checkpoint['model_state_dict'])def save_checkpoint(epoch, model, optimizer, filename, train_loss=0., val_loss=0.):checkpoint = {'epoch': epoch,'model_state_dict': model.state_dict(),'optimizer_state_dict': optimizer.state_dict(),'train_loss': train_loss,'val_loss': val_loss,}torch.save(checkpoint, filename)num_epochs = 100train_loss = []for epoch in range(num_epochs):running_loss = 0correct = 0total = 0epoch_start_time = time.time()for i, (inputs, labels) in enumerate(train_dataset):# 将数据放到设备上inputs, labels = inputs.to(device), labels.to(device)# 前向计算outputs = model(inputs)one_hot = nn.functional.one_hot(labels, num_classes).float()# 计算损失和梯度loss = criterion(outputs, one_hot) + lambda_L1 * regularization_loss_L1 + lambda_L2 * regularization_loss_L2loss.backward()if ((i + 1) % 2 == 0) or (i + 1 == len(train_dataset)):# 更新模型参数optimizer.step()optimizer.zero_grad()# 记录损失和准确率running_loss += loss.item()train_loss.append(loss.item())_, predicted = torch.max(outputs.data, 1)correct += (predicted == labels).sum().item()total += labels.size(0)accuracy_train = 100 * correct / total# 在测试集上计算准确率with torch.no_grad():running_loss_test = 0correct_test = 0total_test = 0for inputs, labels in val_dataset:inputs, labels = inputs.to(device), labels.to(device)outputs = model(inputs)one_hot = nn.functional.one_hot(labels, num_classes).float()loss = criterion(outputs, one_hot)running_loss_test += loss.item()_, predicted = torch.max(outputs.data, 1)correct_test += (predicted == labels).sum().item()total_test += labels.size(0)accuracy_test = 100 * correct_test / total_test# 输出每个 epoch 的损失和准确率epoch_end_time = time.time()epoch_time = epoch_end_time - epoch_start_timetain_loss = running_loss / len(train_dataset)val_loss = running_loss_test / len(val_dataset)print("Epoch [{}/{}], Time: {:.4f}s, Loss: {:.4f}, Train Accuracy: {:.2f}%, Loss: {:.4f}, Test Accuracy: {:.2f}%".format(epoch + 1, num_epochs, epoch_time, tain_loss,accuracy_train, val_loss, accuracy_test))save_checkpoint(epoch, model, optimizer, savefilename, tain_loss, val_loss)scheduler.step()# plt.plot(train_loss, label='Train Loss')# # 添加图例和标签# plt.legend()# plt.xlabel('Epochs')# plt.ylabel('Loss')# plt.title('Training Loss')## # 显示图形# plt.show()

测试代码如下

import torch
import torchvision
import torch.nn as nn
import torchvision.models as models
import matplotlib.pyplot as plt
import torch.multiprocessing as mpif __name__ == '__main__':mp.freeze_support()train_on_gpu = torch.cuda.is_available()if not train_on_gpu:print('CUDA is not available. Training on CPU...')else:print('CUDA is available! Training on GPU...')device = torch.device("cuda" if torch.cuda.is_available() else "cpu")batch_size = 32transform = torchvision.transforms.Compose([torchvision.transforms.Resize((224,224)),  # 调整图像大小为 224x224torchvision.transforms.ToTensor(),  # 转换为张量torchvision.transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  # 归一化])dataset = torchvision.datasets.ImageFolder('./cats_and_dogs_test',transform=transform)test_dataset = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True,num_workers=4, pin_memory=True)model = models.resnet34()num_classes = 2for param in model.parameters():param.requires_grad = Falsemodel.fc = nn.Sequential(nn.Dropout(),nn.Linear(model.fc.in_features,num_classes),nn.LogSoftmax(dim=1))model.to(device)# print(model)filename = "recognize_cats_and_dogs.pt"checkpoint = torch.load(filename)model.load_state_dict(checkpoint['model_state_dict'])class_name = ['cat','dog']# 在测试集上计算准确率with torch.no_grad():for inputs, labels in test_dataset:inputs, labels = inputs.to(device), labels.to(device)output = model(inputs)_, predicted = torch.max(output.data, 1)for x,y,z in zip(inputs,labels,predicted):x = (x - x.min()) / (x.max() - x.min())plt.imshow(x.cpu().permute(1,2,0))plt.axis('off')plt.title('predicted: {0}'.format(class_name[z]))plt.show()

部分测试结果如下

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/125625.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

csapp datalab

知识点总结 1. 逻辑运算符关系 and(与)、or(或)和xor(异或)是逻辑运算符,用于对布尔值进行操作。它们可以在不同的逻辑表达式之间进行转换。下面是and、or和xor之间的转换规则: a…

SpringCloud 微服务全栈体系(九)

第九章 Docker 三、Dockerfile 自定义镜像 常见的镜像在 DockerHub 就能找到,但是我们自己写的项目就必须自己构建镜像了。 而要自定义镜像,就必须先了解镜像的结构才行。 1. 镜像结构 镜像是将应用程序及其需要的系统函数库、环境、配置、依赖打包而…

OpenCV官方教程中文版 —— 分水岭算法图像分割

OpenCV官方教程中文版 —— 分水岭算法图像分割 前言一、原理二、示例三、完整代码 前言 本节我们将要学习 • 使用分水岭算法基于掩模的图像分割 • 函数:cv2.watershed() 一、原理 任何一副灰度图像都可以被看成拓扑平面,灰度值高的区域可以被看成…

深入探索 C++ 多态 ② - 继承关系

前言 上一章 简述了虚函数的调用链路,本章主要探索 C 各种继承关系的类对象的多态特性。 深入探索 C 多态 ① - 虚函数调用链路深入探索 C 多态 ② - 继承关系深入探索 C 多态 ③ - 虚析构 1. 概述 封装,继承,多态是 C 的三大特性&#xf…

驱动day10作业

基于platform驱动模型完成LED驱动的编写 驱动程序 #include <linux/init.h> #include <linux/module.h> #include<linux/platform_device.h> #include<linux/mod_devicetable.h> #include<linux/of.h> #include<linux/of_gpio.h> #inclu…

基于深度学习的安全帽识别检测系统(python OpenCV yolov5)

收藏和点赞&#xff0c;您的关注是我创作的动力 文章目录 概要 一、研究的内容与方法二、基于深度学习的安全帽识别算法2.1 深度学习2.2 算法流程2.3 目标检测算法2.3.1 Faster R-CNN2.3.2 SSD2.3.3 YOLO v3 三 实验与结果分析3.1 实验数据集3.1.1 实验数据集的构建3.1.2 数据…

iOS的应用生命周期以及应用界面

在iOS的原生开发中&#xff0c;我们需要特别关注两个东西&#xff1a;AppDelegate和ViewController。我们主要的编码工作就是在AppDelegate和ViewControlle这两个类中进行的。它们的类图如下图所示&#xff1a; AppDelegate是应用程序委托对象&#xff0c;它继承了UIResponder类…

均值、方差、标准差

1 中间值和均值 表现&#xff02;中间值&#xff02;的统计名词&#xff1a; a.均值:   mean&#xff0c;数列的算术平均值&#xff0c;反应了数列的集中趋势,等于有效数值的合除以有效数值的个数&#xff0e;b.中位值:  median&#xff0c;等于排序后中间位置的值&#x…

c++多线程

目录 一、进程与线程 二、多线程的实现 2.1 C中创建多线程的方法 2.2 join() 、 detach() 和 joinable() 2.2.1 join() 2.2.2 detach() 2.2.3 joinable() 2.3 this_thread 三、同步机制&#xff08;同步原语&#xff09; 3.1 同步与互斥 3.2 互斥锁&#xff08;mu…

在安装和配置DVWA渗透测试环境遇到的报错问题

安装环境 前面的安装我参考的这个博主&#xff1a;渗透测试漏洞平台DVWA环境安装搭建及初级SQL注入-CSDN博客 修改bug 1.首先十分感谢提供帮助的博主&#xff0c;搭建DVWA Web渗透测试靶场_dvwa 白屏-CSDN博客&#xff0c;解决了我大多数问题&#xff0c;报错如下&#xff1…

leetCode 137. 只出现一次的数字 II(拓展篇) + 模5加法器 + 真值表(数字电路)

leetCode 137. 只出现一次的数字 II 题解可看我的往期文章 leetCode 137. 只出现一次的数字 II 位运算 模3加法器 真值表&#xff08;数字电路&#xff09; 有限状态机-CSDN博客https://blog.csdn.net/weixin_41987016/article/details/134138112?spm1001.2014.3001.5501…

N-131基于jsp,ssm物流快递管理系统

开发工具&#xff1a;eclipse&#xff0c;jdk1.8 服务器&#xff1a;tomcat7.0 数据库&#xff1a;mysql5.7 技术&#xff1a; springspringMVCmybaitsEasyUI 项目包括用户前台和管理后台两部分&#xff0c;功能介绍如下&#xff1a; 一、用户(前台)功能&#xff1a; 用…

040-第三代软件开发-全新波形抓取算法

第三代软件开发-全新波形抓取算法 文章目录 第三代软件开发-全新波形抓取算法项目介绍全新波形抓取算法代码小解 关键字&#xff1a; Qt、 Qml、 抓波、 截获、 波形 项目介绍 欢迎来到我们的 QML & C 项目&#xff01;这个项目结合了 QML&#xff08;Qt Meta-Object …

【Linux】jdk Tomcat MySql的安装及Linux后端接口部署

一&#xff0c;jdk安装 1.1 上传安装包到服务器 打开MobaXterm通过Linux地址连接到Linux并登入Linux&#xff0c;再将主机中的配置文件复制到MobaXterm 使用命令查看&#xff1a;ll 1.2 解压对应的安装包 解压jdk 解压命令&#xff1a;tar -xvf jdk 加键盘中Tab键即可…

「Dr. Bomkus 的试炼」排行榜说明

简要概括 七大区域&#xff0c;一个任务&#xff1a;六场扣人心弦的试炼&#xff0c;有一个休闲大厅作为每场试炼的起点。 试炼 排行榜&#xff1a;掌握每场试炼&#xff0c;攀登排行榜。 以 Ethos Point 来记分&#xff1a;每个试炼中的任务都会获得一个EP。 两种任务类型&am…

idea提交代码一直提示 log into gitee

解决idea提交代码一直提示 log into gitee问题 文章目录 打开setting->Version control->gitee,删除旧账号&#xff0c;重新配置账号&#xff0c;删除重新登录就好 打开setting->Version control->gitee,删除旧账号&#xff0c;重新配置账号&#xff0c;删除重新登…

局域网内远程控制电脑的软件

局域网内远程控制电脑的软件在日常办公中&#xff0c;非常常见了。它可以帮助用户在局域网内远程控制其他电脑&#xff0c;实现文件传输、桌面展示、软键盘输入等功能。 局域网内远程控制电脑的软件有很多种&#xff0c;其中比较实用的有域之盾软件、安企神软件、网管家软件等等…

专业135总400+合工大合肥工业大学833信号分析与处理信息通信上岸经验分享

专业135总400合工大合肥工业大学833信号分析与处理信息通信上岸经验分享 基础课经验很多&#xff0c;大同小异&#xff0c;我分享一下自己的833专业课复习经验。 一&#xff1a;用到的书本 1.《信号与系统》&#xff08;第三版&#xff09;郑君里&#xff0c;高等教育出版社…

最新Microsoft Edge浏览器如何使用圆角

引入 最近我看了edge官方的文档&#xff0c;里面宣传了edge的最新UI设计&#xff0c;也就是圆角&#xff0c;但是我发现我的浏览器在升级至最新版本之后&#xff0c;却没有圆角 网上有很多人说靠实验性功能即可解锁&#xff0c;但是指令我都试过了&#xff0c;每次都是搜索无结…

记一次红队打的逻辑漏洞(验证码绕过任意用户密码重置)

八月初参加某市演练时遇到一个典型的逻辑漏洞&#xff0c;可以绕过验证码并且重置任意用户的密码。 首先访问页面&#xff0c;用户名处输入账号会回显用户名称&#xff0c;输入admin会回显系统管理员。&#xff08;hvv的时候蓝队响应太快了&#xff0c;刚把admin的权限拿到了&a…