CNN、数据预处理、模型保存

目录

  • CNN
    • 代码
      • 读取数据
      • 搭建CNN
      • 训练网络模型
  • 数据增强
  • 迁移学习
    • 图像识别策略
      • 数据读取
      • 定义数据预处理操作
      • 冻结resnet18的函数
      • 把模型输出层改成自己的
      • 设置哪些层需要训练
      • 设置优化器和损失函数
      • 训练
      • 开始训练
      • 再训练所有层
      • 关机了,再开机,加载训练好的模型

CNN

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

代码

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets,transforms
import matplotlib.pyplot as plt
import numpy as np
%matplotlib inline

读取数据

#定义超参数
input_size=28
num_class=10
num_epochs=3
batch_size=64
#训练集
train_dataset=datasets.MNIST(root='./data',train=True,transform=transforms.ToTensor(),download=True)test_dataset=datasets.MNIST(root='./data',train=False,transform=transforms.ToTensor())
#构建batch数据
train_loader=torch.utils.data.DataLoader(dataset=train_dataset,batch_size=batch_size,shuffle=True) #num_worker=4 使用4个子线程加载数据
test_loader=torch.utils.data.DataLoader(dataset=test_dataset,batch_size=batch_size,shuffle=False)
train_data_iter=iter(train_loader)
#获取训练集的第一个批次数据(第一个快递包)
batch_x,batch_y=next(train_data_iter)
print(batch_x.shape,batch_y.shape)test_data_iter=iter(test_loader)
batch_x_test,batch_y_test=next(test_data_iter)
print(batch_x_test.shape,batch_y_test.shape)

在这里插入图片描述

搭建CNN

class CNN(nn.Module):def __init__(self):super(CNN,self).__init__() #batch_size,1,28,28self.conv1=nn.Sequential(nn.Conv2d(in_channels=1,out_channels=16,kernel_size=5,stride=1,padding=2), #batch_size,16,28,28nn.ReLU(),nn.MaxPool2d(kernel_size=2), #batch_size,16,14,14)self.conv2=nn.Sequential(nn.Conv2d(16,32,5,1,2), #batch_size,32,14,14nn.ReLU(),nn.Conv2d(32,32,5,1,2), #batch_size,32,14,14  #输入输出通道不变,让其在隐藏层里面更进一步提取特征nn.ReLU(),nn.MaxPool2d(2), #batch_size,32,7,7)self.conv3=nn.Sequential(nn.Conv2d(32,64,5,1,2), #batch_size,64,7,7nn.ReLU(),)#batch_size,64*7*7self.out=nn.Linear(64*7*7,10)def forward(self,x):x=self.conv1(x)x=self.conv2(x)x=nn.Flatten(self.conv3(x))output=self.out(x)return output
def accuracy(prediction,labels):pred=torch.argmax(prediction.data,dim=1) #prediction.data中加data是为了防止数据里面单独数据可能会带来梯度信息rights=pred.eq(labels.data,view_as(pred)).sum()return rights,len(labels) #(batch_size,)/(batch_size,1)

训练网络模型

net=CNN()criterion=nn.CrossEntropyLoss() #不需要在CNN中将logistic转换为概率,因为pytorch的交叉熵损失函数会自动进行optimizer=optim.Adam(net.parameters(),lr=0.001)for epoch in range(num_epochs):train_rights=[]for batch_idx,(data,target) in enumerate(train_loader):net.train() #进入训练状态,也就是所有网络参数都处于可更新状态output=net(data) #output只是logits得分loss=criterion(output,target)optimizer.zero_grad()loss.backward()optimizer.step()right=accuracy(output,target)train_rights.append(right)if batch_idx %100 ==0:net.eval() #进入评估模式,自动关闭求导机制和模型中的BN层drop out层val_rights=[]for (data,target) in test_loader:output=net(data)right=accuracy(output,target)val_rights.append(right)train_r=(sum([tup[0] for tup in train_rights]),sum([tup[1] for tup in train_rights]))val_r=(sum([tup[0] for tup in val_rights]),sum([tup[1] for tup in val_rights]))print('当前epoch:{} [{}/{} ({:.0f}%)]\t损失:{:.6f}\t训练集准确率:{:.2f}%\t测试集准确率:{:.2f}%'.format(epoch,batch_idx*batch_size,len(train_loader.dataset),100.*batch_idx/len(train_loader),loss.data,100.*train_r[0].numpy()/train_r[1],100.*val_r[0].numpy()/val_r[1]))

在这里插入图片描述

数据增强

比如数据不够,可以对数据进行旋转,翻转等操作来添加数据
在这里插入图片描述

迁移学习

例如使用预训练模型
在这里插入图片描述

图像识别策略

输出为102

数据读取

data_dir = './汪学长的随堂资料/2/flower_data/'
train_dir = data_dir + '/train' # 训练数据的文件路径
valid_dir = data_dir + '/valid' # 验证数据的文件路径

定义数据预处理操作

data_transforms = {'train':transforms.Compose([transforms.Resize([96, 96]),transforms.RandomRotation(45), # 随机旋转, -45~45度之间transforms.CenterCrop(64), #对中心进行裁剪,变成64*64transforms.RandomHorizontalFlip(p=0.5),transforms.RandomVerticalFlip(p=0.5),transforms.ColorJitter(brightness=0.2, contrast=0.1, saturation=0.1, hue=0.1), # 亮度、对比度、饱和度、色调transforms.RandomGrayscale(p=0.025), #彩色图变成灰度图transforms.ToTensor(), # 0-255 ——> 0-1transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) #这组均值和标准差是最适合图片进行使用的,因为是3通道所以有3组]),'valid':transforms.Compose([transforms.Resize([64, 64]),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]),}
image_datasets

在这里插入图片描述

dataloaders

在这里插入图片描述

dataset_sizes

在这里插入图片描述

model_name = "resnet18" # resnet34, resnet50, feature_extract = True #使用训练好的参数

冻结resnet18的函数

def set_parameter_requires_gard(model ,feature_extracting):if feature_extracting:for param in model.parameters():param.requires_grad = False
model_ft = models.resnet18() #内置的resnet18
model_ft

改最后一层的,因为默认的是1000输出
在这里插入图片描述

把模型输出层改成自己的

def initialize_model(feature_extract, use_pretrained=True):model_ft = models.resnet18(pretrained = use_pretrained)set_parameter_requires_gard(model_ft, feature_extract)model_ft.fc = nn.Linear(512, 102)input_size = 64return model_ft, input_size

设置哪些层需要训练

model_ft, input_size = initialize_model(feature_extract, use_pretrained=True)device = torch.device("mps") # cuda/cpumodel_ft = model_ft.to(device)filename = 'best.pt' # .pt .pthparams_to_update = model_ft.parameters()if feature_extract:params_to_update = []for name, parm in model_ft.named_parameters():if parm.requires_grad == True:params_to_update.append(parm)print(name)

在这里插入图片描述

model_ft

在这里插入图片描述

设置优化器和损失函数

optimizer_ft = optim.Adam(params_to_update, lr=1e-3)# 定义学习率调度器
scheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size=10, gamma=0.1)criterion = nn.CrossEntropyLoss()
optimizer_ft.param_groups[0]

训练

def train_model(model, dataloaders, criterion, optimizer, num_epochs=50, filename="best.pt"):# 初始化一些变量since = time.time() # 记录初始时间best_acc = 0 # 记录验证集上的最佳精度model.to(device)train_acc_history = []val_acc_history = []train_losses = []valid_losses = []LRS = [optimizer.param_groups[0]['lr']]best_model_wts = copy.deepcopy(model.state_dict())for epoch in range(num_epochs):print('Epoch {}/{}'.format(epoch + 1, num_epochs))print('-' * 10)# 在每个epoch内,遍历训练和验证两个阶段for phase in ['train', 'valid']:if phase == 'train':model.train()else:model.eval()running_loss = 0.0 # 累积训练过程中的损失running_corrects = 0 # 累积训练过程中的正确预测的样本数量for inputs, labels in dataloaders[phase]:inputs = inputs.to(device)labels = labels.to(device)outputs = model(inputs)loss = criterion(outputs, labels)preds = torch.argmax(outputs, dim=1)optimizer.zero_grad()if phase == 'train':loss.backward()optimizer.step()running_loss += loss.item() * inputs.size(0)running_corrects += torch.sum(preds == labels.data)epoch_loss = running_loss / len(dataloaders[phase].dataset)# 整个epoch的平均损失epoch_acc = running_corrects.float() / len(dataloaders[phase].dataset) # 整个epoch的准确率time_elapsed = time.time() - sinceprint('Time elapsed {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))print('{} Loss: {:.4f}; ACC: {:.4f}'.format(phase, epoch_loss, epoch_acc))if phase == "valid" and epoch_acc > best_acc:best_acc = epoch_accbest_model_wts = copy.deepcopy(model.state_dict())state = {'state_dict': model.state_dict(),'best_acc': best_acc,'optimizer': optimizer.state_dict()}torch.save(state, filename)if phase == 'valid':val_acc_history.append(epoch_acc)valid_losses.append(epoch_loss)if phase == 'train':train_acc_history.append(epoch_acc)train_losses.append(epoch_loss)print('Optimizer learning rate : {:.7f}'.format(optimizer.param_groups[0]['lr']))LRS.append(optimizer.param_groups[0]['lr'])print()scheduler.step() # 调用学习率调度器来进行学习率更新操作# 已经全部训练完了time_elapsed = time.time() - sinceprint('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))print('Best val Acc: {:.4f}'.format(best_acc))model.load_state_dict(best_model_wts)return model, val_acc_history, train_acc_history, valid_losses, train_losses ,LRS

开始训练


# def train_model(model, dataloaders, criterion, optimizer, num_epochs=50, filename="best.pt"):
model_ft, val_acc_history, train_acc_history, valid_losses, train_losses ,LRS = train_model(model_ft, dataloaders, criterion, optimizer_ft, num_epochs=5)

在这里插入图片描述
在这里插入图片描述

再训练所有层

# 解冻
for param in model_ft.parameters():parm.requires_grad = Trueoptimizer = optim.Adam(model_ft.parameters(), lr=1e-3)
scheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1) # 每7个epoch, 学习率衰减1/10
criterion = nn.CrossEntropyLoss()
# 加载之间训练好的权重参数
checkpoint = torch.load(filename)
best_acc = checkpoint['best_acc']
model_ft.load_state_dict(checkpoint['state_dict'])

在这里插入图片描述

model_ft, val_acc_history, train_acc_history, valid_losses, train_losses ,LRS = train_model(model_ft, dataloaders, criterion, optimizer_ft, num_epochs=3)

在这里插入图片描述

关机了,再开机,加载训练好的模型

model_ft, input_size = initialize_model(feature_extract, use_pretrained=True)filename = 'best.pt'# 加载模型
checkpoint = torch.load(filename)
model_ft.load_state_dict(checkpoint['state_dict'])

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/16908.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

如何快速用Go获取短信验证码

要用Go获取短信验证码,通常需要连接到一个短信服务提供商的API,并通过该API发送请求来获取验证码。由于不同的短信服务提供商可能具有不同的API和授权方式,我将以一个简单的示例介绍如何使用Go语言来获取短信验证码。 在这个示例中&#xff0…

【ARM Coresight 系列文章 2.4 - Coresight 寄存器:DEVARCH,DEVID, DEVTYPE】

文章目录 1.1 DEVARCH(device architecture register)1.2 DEVID(Device configuration Register)1.3 DEVTYPE(Device Type Identifier Register) 1.1 DEVARCH(device architecture register) DEVARCH 寄存器标识了coresight 组件的架构信息。 bits[31:21] 定义了组件架构&…

力扣简单1道_两数之和

两数之和 给定一个整数数组 nums 和一个整数目标值 target,请你在该数组中找出 和为目标值 target 的那 两个 整数,并返回它们的数组下标。你可以假设每种输入只会对应一个答案。但是,数组中同一个元素在答案里不能重复出现。你可以按任意顺序…

微信小程序开发学习之--地图绘制行政区域图

不知道大家有没有感觉就是在做微信小程序地图功能时刚刚接触时候真的感觉好迷茫呀,文档看不懂,资料找不到,就很难受呀,比如我现在的功能就想想绘制出一个区域的轮廓图,主要是为了显眼,效果图如下&#xff1…

【入门SpringCloud(一)】什么是SpringCloud?

一、概述 集群(Cluster):同一种软件服务的多个服务节点共同为系统提供服务过程,称之为该软件服务集群。 分布式(Distribute):分布式是一种系统架构,是将系统中的不同组件分布在不同…

Mac 安装配置adb命令环境(详细步骤)

一、注意:前提要安装java环境。 因为android sdk里边开发的一些包都是依赖java语言的,所以,首先要确保已经配置了java环境。 二、在Mac下配置android adb命令环境,配置方式如下: 1、下载并安装IDE (andr…

LLaMA系列 | LLaMA和LLaMA-2精简总结

文章目录 1、LLaMA1.1、模型结构1.2、训练方式1.3、结论 2、LLaMA-22.1、相比LLaMA1的升级2.3、模型结构2.3.1、MHA, MQA, GQA区别与联系 2.4、训练方式 1、LLaMA 🔥 纯基座语言模型 《LLaMA: Open and Efficient Foundation Language Models》:https:/…

Unity3d C#快速打开萤石云监控视频流(ezopen)支持WebGL平台,替代UMP播放视频流的方案(含源码)

前言 Universal Media Player算是视频流播放功能常用的插件了,用到现在已经不知道躺了多少坑了,这个插件虽然是白嫖的,不过被甲方和领导吐槽的就是播放视频流的速度特别慢,可能需要几十秒来打开监控画面,等待的时间较…

Spring学习笔记之spring概述

文章目录 Spring介绍Spring8大模块Spring特点 Spring介绍 Spring是一个轻量级的控制反转和面向切面的容器框架 Spring最初的出现是为了解决EJB臃肿的设计,以及难以测试等问题。 Spring为了简化开发而生,让程序员只需关注核心业务的实现,尽…

HTML+CSS+JavaScript:轮播图的自动播放、手动播放、鼠标悬停暂停播放

一、需求 昨天我们做了轮播图的自动播放,即每隔一秒自动切换一次 今天我们增加两个需求: 1、鼠标点击向右按钮,轮播图往后切换一次;鼠标点击向左按钮,轮播图往前切换一次 2、鼠标悬停在轮播图区域中时,…

Verilog语法学习——LV5_位拆分与运算

LV5_位拆分与运算 题目来源于牛客网 [牛客网在线编程_Verilog篇_Verilog快速入门 (nowcoder.com)](https://www.nowcoder.com/exam/oj?page1&tabVerilog篇&topicId301) 题目 题目描述: 现在输入了一个压缩的16位数据,其实际上包含了四个数据…

从互联网到云时代,Apache RocketMQ 是如何演进的?

作者:隆基 2022 年,RocketMQ 5.0 的正式版发布。相对于 4.0 版本而言,架构走向云原生化,并且覆盖了更多业务场景。 消息队列演进史 操作系统、数据库、中间件是基础软件的三驾马车,而消息队列属于最经典的中间件之一…

用python需要下载软件吗,python需要安装哪些软件

大家好,本文将围绕安装python需要什么样的电脑配置展开说明,python需要安装哪些软件是一个很多人都想弄明白的事情,想搞清楚用python需要下载软件吗需要先了解以下几个事情。 编程这东西很神奇。对于那些知道如何有用和有趣的这个工具,对于Xi…

Windows 实例如何开放端口

矩池云 Windows 实例相比于 Linux 实例,除了在租用机器的时候自定义端口外,还需要在 Windows防火墙中添加入口规则。接下来将教大家如何设置 Windows 防火墙,启用端口。 租用成功后通过 RDP 链接连接服务器,然后搜索防火墙&#x…

React的UmiJS搭建的项目集成海康威视h5player播放插件H5视频播放器开发包 V2.1.2

最近前端的一个项目,大屏需要摄像头播放,摄像头厂家是海康威视的,网上找了一圈都没有React集成的,特别是没有使用UmiJS搭脚手架搭建的,所以记录一下。 海康威视的开放平台的API地址,相关插件和文档都可以下…

简单的python有趣小程序,有趣的代码大全python

这篇文章主要介绍了python简单有趣的程序源代码,具有一定借鉴价值,需要的朋友可以参考下。希望大家阅读完这篇文章后大有收获,下面让小编带着大家一起了解一下。

Hadoop学习日记-YARN组件

YARN(Yet Another Resource Negotiator)作为一种新的Hadoop资源管理器,是另一种资源协调者。 YARN是一个通用的资源管理系统和调度平台,可为上层应用提供统一的资源管理和调度 YARN架构图 YARN3大组件: (物理层面&#xff09…

ICML 2023 | 拓展机器学习的边界

编者按:如今,机器学习已成为人类未来发展的焦点领域,如何进一步拓展机器学习技术和理论的边界,是一个极富挑战性的重要话题。7月23日至29日,第四十届国际机器学习大会 ICML 2023 在美国夏威夷举行。该大会是由国际机器…

HarmonyOS/OpenHarmony元服务开发-配置卡片的配置文件

卡片相关的配置文件主要包含FormExtensionAbility的配置和卡片的配置两部分: 1.卡片需要在module.json5配置文件中的extensionAbilities标签下,配置FormExtensionAbility相关信息。FormExtensionAbility需要填写metadata元信息标签,其中键名称…

CentOS 7安装Docker

文章目录 🌞版本选择☀️1.CentOS安装Docker🌱1.1.卸载(可选)🌱1.2.安装docker🌱1.3.启动docker🌱1.4.配置镜像加速 ☀️2.CentOS7安装DockerCompose🌱2.1.下载🌱2.2.修改…