图像分类应用

先留一段图像分类代码,空闲时间再做分析:

创建神经网络:

import torch
from torch import nn
import torch.nn.functional as F
class MyAlexNet(nn.Module):def __init__(self):super(MyAlexNet, self).__init__()self.c1=nn.Conv2d(in_channels=3,out_channels=48,kernel_size=11,stride=4,padding=2)self.ReLU=nn.ReLU()self.c2=nn.Conv2d(in_channels=48,out_channels=128,kernel_size=5,stride=1,padding=2)self.s2=nn.MaxPool2d(2)self.c3=nn.Conv2d(in_channels=128,out_channels=192,kernel_size=3,stride=1,padding=1)self.s3=nn.MaxPool2d(2)self.c4=nn.Conv2d(in_channels=192,out_channels=192,kernel_size=3,stride=1,padding=1)self.c5=nn.Conv2d(in_channels=192,out_channels=128,kernel_size=3,stride=1,padding=1)self.s5=nn.MaxPool2d(kernel_size=3,stride=2)self.flatten=nn.Flatten()self.f6=nn.Linear(128*6*6,2048)self.f7=nn.Linear(2048,2048)self.f8=nn.Linear(2048,1000)self.f9=nn.Linear(1000,2)def forward(self,x):x=self.ReLU(self.c1(x))x=self.ReLU(self.c2(x))x=self.s2(x)x=self.ReLU(self.c3(x))x=self.s3(x)x=self.ReLU(self.c4(x))x=self.ReLU(self.c5(x))x=self.s5(x)x=self.flatten(x)x=self.f6(x)x=F.dropout(x,p=0.5)x=self.f7(x)x=F.dropout(x,p=0.5)x=self.f8(x)x=F.dropout(x,p=0.5)x=self.f9(x)return x
if __name__ == '__main__':x=torch.rand([1,3,224,224])model=MyAlexNet()y=model(x)

数据预处理:

import os
from shutil import copy
import random
def mkdir(file):if not os.path.exists(file):os.makedirs(file)
#获取data文件夹下所有文件夹名(即需要分类的类名)
file_path='E:/BaiduNetdiskDownload/Kaggle猫狗大战/train'
flower_class= [cla for cla in os.listdir(file_path)]
#创建训练集train文件夹,并由类名在其目录下创建子目录
mkdir('data/train')
mkdir('data/train/cat')
mkdir('data/train/dog')
mkdir('data/val')
mkdir('data/val/cat')
mkdir('data/val/dog')
split_rate=0.1
for cla in flower_class:cla_path=file_path+'/'+cla#"E:\BaiduNetdiskDownload\Kaggle猫狗大战\train\train\cat.0.jpg"images=os.listdir(cla_path)print(cla_path)num=len(images)eval_index=random.sample(images,k=int(num*split_rate))for index,image in enumerate(images):if image in eval_index:image_path = cla_path+'/'+imageif "cat" in image_path:new_path = 'data/val/cat/'else:new_path = 'data/val/dog/'copy(image_path,new_path)else:image_path=cla_path+'/'+imageif "cat" in image_path:new_path='data/train/cat/'else:new_path='data/train/dog/'copy(image_path,new_path)print("\r[{}]processing[{}/{}]".format(cla,index+1,num),end="")print()
print("processing done!")

训练集用于 训练权重:

import torch
from torch import nn
from net import MyAlexNet
import numpy as np
from torch.optim import lr_scheduler
import os
from torchvision import transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus']=False
ROOT_TRAIN = 'C:/Users/86156/PycharmProjects/pythonProject1/cat-dog/data/train'
ROOT_TEST='C:/Users/86156/PycharmProjects/pythonProject1/cat-dog/data/val'
normalize=transforms.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])
train_transform=transforms.Compose([transforms.Resize((224,224)),transforms.RandomVerticalFlip(),transforms.ToTensor(),normalize
])
val_transform=transforms.Compose([transforms.Resize((224,224)),transforms.ToTensor(),normalize
])
train_dataset=ImageFolder(ROOT_TRAIN,transform=train_transform)
val_dataset=ImageFolder(ROOT_TEST,transform=val_transform)
train_dataloader=DataLoader(train_dataset,batch_size=32,shuffle=True)
val_dataloader=DataLoader(val_dataset,batch_size=32,shuffle=True)
model=MyAlexNet()
loss_fn=nn.CrossEntropyLoss()
optimizer=torch.optim.SGD(model.parameters(),lr=0.01,momentum=0.9)
lr_scheduler=lr_scheduler.StepLR(optimizer,step_size=10,gamma=0.5)
def train(dataloader,model,loss_fn,optimizer):loss,current,n =0.0,0.0,0.0for batch,(x,y) in enumerate(dataloader):image,y =x,youtput=model(image)cur_loss=loss_fn(output,y)_,pred=torch.max(output,axis=1)cur_acc=torch.sum(y==pred)/output.shape[0]optimizer.zero_grad()cur_loss.backward()optimizer.step()loss+=cur_loss.item()current+=cur_acc.item()n+=1train_loss=loss/ntrain_acc=current/nprint('train_loss'+str(train_loss))print('train_acc'+str(train_acc))return train_loss,train_acc
def val(dataloader,model,loss_fn):model.eval()loss,current,n=0.0,0.0,0.0with torch.no_grad():for batch,(x,y) in enumerate(dataloader):image,y =x,youtput=model(image)cur_loss=loss_fn(output,y)_,pred=torch.max(output,axis=1)cur_acc=torch.sum(y==pred)/output.shape[0]loss+=cur_loss.item()current+=cur_acc.item()n+=1val_loss=loss/nval_acc=current/nprint('val_loss'+str(val_loss))print('val_acc'+str(val_acc))return val_loss,val_acc
def matplot_loss(train_loss,val_loss):plt.plot(train_loss,label='train_loss')plt.plot(val_loss,label='val_loss')plt.legend(loc='best')plt.ylabel('loss')plt.xlabel('epoch')plt.title("训练集和验证集loss值对比图")plt.show()
def matplot_acc(train_loss,val_loss):plt.plot(train_acc,label='train_acc')plt.plot(val_acc,label='val_acc')plt.legend(loc='best')plt.ylabel('acc')plt.xlabel('epoch')plt.title("训练集和验证集acc值对比图")plt.show()
loss_train=[]
acc_train=[]
loss_val=[]
acc_val=[]
epoch=20
min_acc=0
for t in range(epoch):lr_scheduler.step()print(f"epoch{t+1}\n----------------")train_loss,train_acc=train(train_dataloader,model,loss_fn,optimizer)val_loss,val_acc=val(val_dataloader,model,loss_fn)loss_train.append(train_loss)acc_train.append(train_acc)loss_val.append(val_loss)acc_val.append(val_acc)if val_acc>min_acc:folder='save_model'if not os.path.exists(folder):os.mkdir('save_model')min_acc=val_accprint(f"save best model,第{t+1}轮")torch.save(model.state_dict(),'save_model/best.model.pth')if t==epoch-1:torch.save(model.state_dict(),'save_model/last_model.pth')
print('Done')

测试集用于测试模型:

import torch
from net import MyAlexNet
from torch.autograd import variable
from torchvision import datasets,transforms
from torchvision.transforms import ToPILImage
from torchvision.datasets import ImageFolder
from torchvision.transforms import ToTensor
from torch.utils.data import DataLoader
ROOT_TRAIN = 'C:/Users/86156/PycharmProjects/pythonProject1/cat-dog/data/train'
ROOT_TEST='C:/Users/86156/PycharmProjects/pythonProject1/cat-dog/data/val'
normalize=transforms.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])
train_transform=transforms.Compose([transforms.Resize((224,224)),transforms.RandomVerticalFlip(),transforms.ToTensor(),normalize
])
val_transform=transforms.Compose([transforms.Resize((224,224)),transforms.ToTensor(),normalize
])
train_dataset=ImageFolder(ROOT_TRAIN,transform=train_transform)
val_dataset=ImageFolder(ROOT_TEST,transform=val_transform)
train_dataloader=DataLoader(train_dataset,batch_size=32,shuffle=True)
val_dataloader=DataLoader(val_dataset,batch_size=32,shuffle=True)
model=MyAlexNet()
model.load_state_dict(torch.load("C:/Users/86156/PycharmProjects/pythonProject1/cat-dog/save_model/best.model.pth"))
classes=["cat","dog",
]
show=ToPILImage()
model.eval()
for i in range(50):x,y = val_dataset[i][0],val_dataset[i][1]show(x).show()x=torch.tensor(torch.unsqueeze(x,dim=0).float(),requires_grad=True)x=torch.tensor(x)with torch.no_grad():pred=model(x)print(pred)predicted,actual=classes[torch.argmax(pred[0])],classes[y]print(f'predicted:"{predicted}",Actual:"{actual}"')

没有显卡慢的跟狗屎一样。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/719530.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

二刷代码随想录算法训练营第十天 | 232.用栈实现队列、 225. 用队列实现栈

目录 一、232. 用栈实现队列 二、225. 用队列实现栈 一、232. 用栈实现队列 题目链接:力扣 文章讲解:代码随想录 视频讲解: 栈的基本操作! | LeetCode:232.用栈实现队列 题目: 请你仅使用两个栈实现先…

Vision Pro开发者学习路线

官方给到的Vision Pro开发者学习路线: 1. 学习基础知识: - 学习 Xcode、Swift 和 SwiftUI 的基础知识,包括语法、UI 设计等。 - 掌握 ARKit 和 SwiftUI 的使用,了解如何创建沉浸式增强现实体验。 2. 学习 3D 建模&#xf…

『Linux从入门到精通』第 ㉕ 期 - System V 共享内存

文章目录 💐专栏导读💐文章导读🐧共享内存原理🐧共享内存相关函数🐦key 与 shmid 区别 🐧代码实例 💐专栏导读 🌸作者简介:花想云 ,在读本科生一枚&#xff0…

CentOS7安装DockerCompose和Docker镜像仓库的配置

CentOS7安装DockerCompose 1.下载 Linux下需要通过命令下载: # 安装 curl -L https://github.com/docker/compose/releases/download/1.23.1/docker-compose-uname -s-uname -m > /usr/local/bin/docker-compose2.修改文件权限 修改文件权限: # …

YOLOv9独家原创改进|加入幽灵卷积Ghost Convolution模块,轻量化!

专栏介绍:YOLOv9改进系列 | 包含深度学习最新创新,主力高效涨点!!! 一、论文摘要 由于内存和计算资源有限,在嵌入式设备上部署卷积神经网络是困难的。特征图中的冗余是那些成功的细胞神经网络的一个重要特征…

【网站项目】158企业人事管理系统

🙊作者简介:拥有多年开发工作经验,分享技术代码帮助学生学习,独立完成自己的项目或者毕业设计。 代码可以私聊博主获取。🌹赠送计算机毕业设计600个选题excel文件,帮助大学选题。赠送开题报告模板&#xff…

突破编程_C++_字符串算法(判断字符串是否包含)

1 算法题 :判断一个字符串是否包含另一个字符串的所有字符(不一定连续) 1.1 题目含义 判断一个字符串(称为“主字符串”或“大字符串”)是否包含另一个字符串(称为“子字符串”或“小字符串”&#xff09…

代码随想录算法训练营第31天—贪心算法05 | ● 435. 无重叠区间 ● *763.划分字母区间 ● *56. 合并区间

435. 无重叠区间 https://programmercarl.com/0435.%E6%97%A0%E9%87%8D%E5%8F%A0%E5%8C%BA%E9%97%B4.html 考点 贪心算法重叠区间 我的思路 先按照区间左坐标进行排序,方便后续处理进行for循环,循环范围是0到倒数第二个元素如果当前区间和下一区间重叠…

在Linux以命令行方式(静默方式/非图形化方式)安装MATLAB(正版)

1.根据教程,下载windows版本matlab,打开图形化界面,选择linux版本的只下载不安装 2.获取安装文件夹 3.获取许可证 4.安装 (1)跳过引用文章的2.2章节 (2)本文的安装文件夹代替引用文章的解压IS…

Java进阶(锁)——锁的升级,synchronized与lock锁区别

目录 引出Java中锁升级synchronized与lock锁区别 缓存三兄弟:缓存击穿、穿透、雪崩缓存击穿缓存穿透缓存雪崩 总结 引出 Java进阶(锁)——锁的升级,synchronized与lock锁区别 Java中锁升级 看一段代码: public class…

Fastwhisper + Pyannote 实现 ASR + 说话者识别

文章目录 前言一、faster-whisper简单介绍二、pyannote.audio介绍三、faster-whisper pyannote.audio 实现语者识别四、多说几句 前言 最近在研究ASR相关的业务,也是调研了不少模型,踩了不少坑,ASR这块,目前中文普通话效果最好的…

Scrapy与分布式开发(1.1):课程导学

Scrapy与分布式开发:从入门到精通,打造高效爬虫系统 课程大纲 在这个专栏中,我们将一起探索Scrapy框架的魅力,以及如何通过Scrapy-Redis实现分布式爬虫的开发。在本课程导学中,我们将为您简要介绍课程的学习目标、内容…

Verilog Coding Styles For Improved Simulation Efficiency论文学习记录

原文基于Verilog-XL仿真器,测试了以下几种方式对仿真效率的影响。 1. 使用 Case 语句而不是 if / else if 语句 八选一多路选择器 case 实现效率比 if / else if 提升 6% 。 2. 如果可以尽量不使用 begin end 语句 使用 begin end 的 ff 触发器比不使用 begin end …

初学者学习51还是STM32

初学者学习51还是STM32 在嵌入式系统领域,51和STM32是两种常见的单片机架构。对于初学者来说,选择学习哪种架构可能会成为一个难题。本文将对初学者学习51和STM32进行比较,以帮助读者做出明智的选择。 1. 51架构 51架构是指Intel 8051系列…

深度相机xyz点云文件三维坐标和jpg图像文件二维坐标的相互变换函数

深度相机同时拍摄xyz点云文件和jpg图像文件。xyz文件里面包含三维坐标[x,y,z]和jpg图像文件包含二维坐标[x,y],但是不能直接进行变换,需要一定的步骤来推演。 下面函数是通过box二维框[xmin, ymin, xmax, ymax, _, _ ]去截取xyz文件中对应box里面的点云…

MyCAT学习——在openEuler22.03中安装MyCAT2(网盘下载版)

准备工作 因为MyCAT 2基于JDK 1.8开发。也需要在虚拟机中安装JDK(JDK官网就能下载,我这提供一个捷径) jdk-8u401-linux-x64.rpmhttps://pan.baidu.com/s/1ywcDsxYOmfZONpmH9oDjfw?pwdrhel下载对应的tar安装包,以及对应的jar包 安装程序包…

九州金榜|孩子厌学要怎么办?

孩子从小学到初中再到高中,孩子出现厌学情绪很正常,但是孩子出现厌学情绪后,就必然会影响到孩子学习成绩,孩子产生厌学情绪的原因有哪些呢?只有找准孩子厌学原因才能去帮助孩子怎样去克服孩子厌学情绪,下面…

ajax请求servlet成功但接收不到返回数据问题

ajax请求servlet成功但接收不到返回数据问题 javaweb初学者,最近老师布置的课设,所有功能都完成了,唯独ajax与servlet交互出现问题,无论怎么调试都收不到数据 查询两天无果,刚才无意间看到 Crabime前辈的文章才恍然大…

深入解析YOLO:实时目标检测技术的革命者

深入解析YOLO:实时目标检测技术的革命者 目标检测作为计算机视觉领域的一个核心任务,一直以来都是研究的热点。而YOLO(You Only Look Once)技术作为其中的杰出代表,以其独特的处理方式和卓越的性能,成为了…

day34贪心算法 part03

1005. K 次取反后最大化的数组和 简单 给你一个整数数组 nums 和一个整数 k ,按以下方法修改该数组: 选择某个下标 i 并将 nums[i] 替换为 -nums[i] 。 重复这个过程恰好 k 次。可以多次选择同一个下标 i 。 以这种方式修改数组后,返回数…