LeNet

概念

代码

model

import torch.nn as nn
import torch.nn.functional as Fclass LeNet(nn.Module):def __init__(self):super(LeNet, self).__init__()  # super()继承父类的构造函数self.conv1 = nn.Conv2d(3, 16, 5)self.pool1 = nn.MaxPool2d(2, 2)self.conv2 = nn.Conv2d(16, 32, 5)self.pool2 = nn.MaxPool2d(2, 2)self.fc1 = nn.Linear(32*5*5, 120)self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, 10)def forward(self, x): x = F.relu(self.conv1(x))    # input(3, 32, 32) output(16, 28, 28)x = self.pool1(x)            # output(16, 14, 14)x = F.relu(self.conv2(x))    # output(32, 10, 10)x = self.pool2(x)            # output(32, 5, 5)x = x.view(-1, 32*5*5)       # output(32*5*5)x = F.relu(self.fc1(x))      # output(120)x = F.relu(self.fc2(x))      # output(84)x = self.fc3(x)              # output(10)return x

forward:定义正向传播的过程。

ReLU:激活哈数

观察网络中的参数传递:发现传递的都是channel通道数,最后output在softmax函数里展开的也是展开的通道数。

train

import torch
import torchvision
import torch.nn as nn
from model import LeNet
import torch.optim as optim
import torchvision.transforms as transformsdef main():transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])# 50000张训练图片# 第一次使用时要将download设置为True才会自动去下载数据集train_set = torchvision.datasets.CIFAR10(root='./data', train=True,download=True, transform=transform)train_loader = torch.utils.data.DataLoader(train_set, batch_size=36,shuffle=True, num_workers=0)# 10000张验证图片# 第一次使用时要将download设置为True才会自动去下载数据集val_set = torchvision.datasets.CIFAR10(root='./data', train=False,download=False, transform=transform)val_loader = torch.utils.data.DataLoader(val_set, batch_size=5000,shuffle=False, num_workers=0)val_data_iter = iter(val_loader)val_image, val_label = next(val_data_iter)# classes = ('plane', 'car', 'bird', 'cat',#            'deer', 'dog', 'frog', 'horse', 'ship', 'truck')net = LeNet()loss_function = nn.CrossEntropyLoss()optimizer = optim.Adam(net.parameters(), lr=0.001)for epoch in range(5):  # loop over the dataset multiple timesrunning_loss = 0.0for step, data in enumerate(train_loader, start=0):# get the inputs; data is a list of [inputs, labels]inputs, labels = data# zero the parameter gradientsoptimizer.zero_grad()# forward + backward + optimizeoutputs = net(inputs)loss = loss_function(outputs, labels)loss.backward()optimizer.step()# print statisticsrunning_loss += loss.item()if step % 500 == 499:    # print every 500 mini-batcheswith torch.no_grad():outputs = net(val_image)  # [batch, 10]predict_y = torch.max(outputs, dim=1)[1]accuracy = torch.eq(predict_y, val_label).sum().item() / val_label.size(0)print('[%d, %5d] train_loss: %.3f  test_accuracy: %.3f' %(epoch + 1, step + 1, running_loss / 500, accuracy))running_loss = 0.0print('Finished Training')save_path = './Lenet.pth'torch.save(net.state_dict(), save_path)if __name__ == '__main__':main()

predict.py

import torch
import torchvision.transforms as transforms
from PIL import Imagefrom model import LeNetdef main():transform = transforms.Compose([transforms.Resize((32, 32)),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])classes = ('plane', 'car', 'bird', 'cat','deer', 'dog', 'frog', 'horse', 'ship', 'truck')net = LeNet()net.load_state_dict(torch.load('Lenet.pth'))im = Image.open('1.jpg').convert('RGB')im = transform(im)  # [C, H, W]im = torch.unsqueeze(im, dim=0)  # [N, C, H, W]with torch.no_grad():outputs = net(im)predict = torch.max(outputs, dim=1)[1].numpy()# predict = torch.softmax(outputs,dim=1)# print(predict)# tensor([[9.9884e-01, 1.9386e-04, 3.8757e-04, 2.0671e-05, 2.5372e-04, 3.6199e-05,# 3.7643e-05, 1.7624e-04, 2.0138e-05, 3.4801e-05]])print(classes[int(predict)])if __name__ == '__main__':main()

知识点:

增加新的维度: 

im = torch.unsqueeze(im, dim=0)  # [N, C, H, W] 

predict = torch.max(outputs, dim=1)[1].numpy():

这一行代码使用torch.max()函数找到outputs张量在第一个维度上的最大值,并返回最大值和对应的索引。dim=1表示在第一个维度上进行最大值的计算,即对每个样本的输出进行比较。[1]表示返回最大值对应的索引。最后,.numpy()将结果转换为NumPy数组。 

更换:

predict = torch.softmax(outputs,dim=1)

print:tensor([[9.9884e-01, 1.9386e-04, 3.8757e-04, 2.0671e-05, 2.5372e-04, 3.6199e-05,
         3.7643e-05, 1.7624e-04, 2.0138e-05, 3.4801e-05]])

Pytorch使用

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/211018.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Bash脚本处理ogg、flac格式到mp3格式的批量转换

现在下载的许多音乐文件是flac和ogg格式的,QQ音乐上下载的就是这样的,这些文件尺寸比较大,在某些场合使用不便,比如在车机上播放还是mp3格式合适,音质这些在车机上播放足够了,要求不高。比如本人就喜欢下载…

软件接口安全设计规范

《软件项目接口安全设计规范》 1.token授权机制 2.https传输加密 3.接口调用防滥用 4.日志审计里监控 5.开发测试环境隔离,脱敏处理 6.数据库运维监控审计

卷王开启验证码后无法登陆问题解决

问题描述 使用 docker 部署,后台设置开启验证,重启服务器之后,docker重启,再次访问系统,验证码获取失败,导致无法进行验证,也就无法登陆系统。 如果不了解卷王的,可以去官网看下。…

飞天使-linux操作的一些技巧与知识点3

http工作原理 http1.0 协议 使用的是短连接,建立一次tcp连接,发起一次http的请求,结束,tcp断开 http1.1 协议使用的是长连接,建立一次tcp的连接,发起多次http的请求,结束,tcp断开ngi…

ky10 server x86 设置网卡开机自启

输入命令查看网卡名称 ip a 输入命令编辑网卡信息 vi /etc/sysconfig/network-scripts/*33改成yes 按ESC键,输入:wq,保存

Aloha 机械臂的学习记录2——AWE:AWE + ACT

继续下一个阶段: Train policy python act/imitate_episodes.py \ --task_name [TASK] \ --ckpt_dir data/outputs/act_ckpt/[TASK]_waypoint \ --policy_class ACT --kl_weight 10 --chunk_size 50 --hidden_dim 512 --batch_size 8 --dim_feedforward 3200 \ --n…

F : A DS二分查找_寻找比目标字母大的最小字母

Description 给你一个字符串str,字符串中的字母都已按照升序排序,且只包含小写字母。另外给出一个目标字母target,请你寻找在这一有序字符串里比目标字母大的最小字母。 在比较时,字母是依序循环出现的。例如,str“ab…

Python中锁的常见用法

在 Python 中,可以使用线程锁来控制多个线程对共享资源的访问。以下是一些常见的 Python 中锁的用法: 创建线程锁 在 Python 中,可以使用 threading 模块中的 Lock 类来创建线程锁。例如: import threading# 创建线程锁 lock …

Python网络爬虫环境的安装指南

网络爬虫是一种自动化的网页数据抓取技术,广泛用于数据挖掘、信息搜集和互联网研究等领域。Python作为一种强大的编程语言,拥有丰富的库支持网络爬虫的开发。本文将为你详细介绍如何在你的计算机上安装Python网络爬虫环境。 一、安装python开发环境 进…

什么是电压纹波,造成不良,如何测量、如何抑制设计

1 引言 电源给电子产品提供能量同时也附带了一些不好的影响成分,如纹波、噪声等,这些对本振、、滤波、放大器、混频器、检波、A/D 转换等电路都会产生影响,会直接影响电子产品正常工作,所以项目设计要合理、要有实测数据、要尽量减小系统电压的纹波。 1.1 电压纹波(volta…

bc-linux-欧拉重制root密码

最近需要重新安装虚拟机的系统 安装之后发现对方提供的root密码不对,无法进入系统。 上网搜了下发现可以进入单用户模式进行密码修改从而重置root用户密码。 在这个界面下按e键 找到图中部分,把标红的部分删除掉,然后写上rw init/bin/…

strftime(“%-m/%-d/%Y“) 报错 ValueError: Invalid format string

问题 运行测试用例时,出现ValueError: Invalid format string的错误,代码大致如下: from datetime import date .... current date.today() return current.strftime("%-m/%-d/%Y")原因 开发此代码的时候是在mac上开发的&#…

24、文件上传漏洞——Apache文件解析漏洞

文章目录 一、环境简介一、Apache与php三种结合方法二、Apache解析文件的方法三、Apache解析php的方法四、漏洞原理五、修复方法 一、环境简介 Apache文件解析漏洞与用户配置有密切关系。严格来说,属于用户配置问题,这里使用ubantu的docker来复现漏洞&am…

IOday7作业

1> 使用无名管道完成父子进程间的通信 #include<myhead.h>int main(int argc, const char *argv[]) {//创建存放两个文件描述符的数组int fd[2];int pid -1;//打开无名管道if(pipe(fd) -1){perror("pipe");return -1;}//创建子进程pid fork();if(pid &g…

wordpress小记

1.插件市场搜索redis&#xff0c;并按照 Redis Object cache插件 2.开启php的redis扩展 执行php -m|grep redis&#xff0c;没有显示就执行 yum -y install php-redis3.再次修改wp配置文件&#xff0c;增加redis的配置 define( WP_REDIS_HOST, 114.80.36.124 );define( WP_…

非标设计之电磁阀

电磁阀&#xff1a; 分类&#xff1a; 动画演示两位三通电磁阀&#xff1a; 两位三通电磁阀动画演示&#xff1a; 111&#xff1a; 气缸回路的介绍&#xff1a; 失电状态&#xff1a; 电磁阀得电状态&#xff1a; 两位五通电磁阀的回路&#xff1a;&#xff08;常用&#xf…

算数运算符和算数表达式

基本算数运算符 算数运算符&#xff1a; &#xff08;加法运算符或正值运算符&#xff09;、-&#xff08;减法运算符或负值运算符&#xff09;、*&#xff08;乘&#xff09;、/&#xff08;除&#xff09;、%&#xff08;求余数&#xff09; 双目运算符&#xff1a; 双目…

四则运算 .

输入一个表达式&#xff08;用字符串表示&#xff09;&#xff0c;求这个表达式的值。 保证字符串中的有效字符包括[‘0’-‘9’],‘’,‘-’, ‘*’,‘/’ ,‘(’&#xff0c; ‘)’,‘[’, ‘]’,‘{’ ,‘}’。且表达式一定合法。字符串长度满足1≤n≤1000 输入描述&#x…

CGAL的2D符合规定的三角剖分和网格

1、符合规定的三角剖分 1.1、定义 如果三角形的任何面的外接圆在其内部不包含顶点&#xff0c;则该三角形是 Delaunay 三角形。 约束 Delaunay 三角形是一种尽可能接近 Delaunay 的约束三角形。 约束 Delaunay 三角形的任何面的外接圆在其内部不包含从该面可见的数据点。 如果…

陀螺仪LSM6DSV16X与AI集成(3)----读取融合算法输出的四元数

陀螺仪LSM6DSV16X与AI集成.2--姿态解算 概述视频教学样品申请完整代码下载使用demo板生成STM32CUBEMX串口配置IIC配置CS和SA0设置串口重定向参考程序初始化SFLP步骤初始化SFLP读取四元数数据演示 概述 LSM6DSV16X 特性涉及到的是一种低功耗的传感器融合算法&#xff08;Sensor…