PyTorch入门之【AlexNet】

参考文献:https://www.bilibili.com/video/BV1DP411C7Bw/?spm_id_from=333.999.0.0&vd_source=98d31d5c9db8c0021988f2c2c25a9620
AlexNet 是一个经典的卷积神经网络模型,用于图像分类任务。

目录

  • 大纲
  • dataloader
  • model
  • train
  • test

大纲

在这里插入图片描述
各个文件的作用:

  • data就是数据集
  • dataloader.py就是数据集的加载以及实例初始化
  • model.py就是AlexNet模块的定义
  • train.py就是模型的训练
  • test.py就是模型的测试

dataloader

import torch
import torchvision
import torchvision.transforms as transformsimport matplotlib.pyplot as plt
import numpy as np# define the dataloader
transform = transforms.Compose([transforms.Resize(224),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])batch_size = 16trainset = torchvision.datasets.CIFAR10(root='./data', train=True,download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,shuffle=True)testset = torchvision.datasets.CIFAR10(root='./data', train=False,download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(testset, batch_size=batch_size,shuffle=False)classes = ('plane', 'car', 'bird', 'cat','deer', 'dog', 'frog', 'horse', 'ship', 'truck')if __name__ == '__main__':# get some random training imagesdataiter = iter(train_loader)images, labels = next(dataiter)# print labelsprint(' '.join('%5s' % classes[labels[j]] for j in range(batch_size)))# show imagesimg_grid = torchvision.utils.make_grid(images)img_grid = img_grid / 2 + 0.5npimg = img_grid.numpy()plt.imshow(np.transpose(npimg, (1, 2, 0)))plt.show()

model

import torch.nn as nn
import torchclass AlexNet(nn.Module):def __init__(self, num_classes=10):super(AlexNet, self).__init__()self.conv_1 = nn.Sequential(nn.Conv2d(3, 96, kernel_size=11, stride=4, padding=2),nn.BatchNorm2d(96),nn.ReLU(),nn.MaxPool2d(kernel_size = 3, stride = 2))self.conv_2 = nn.Sequential(nn.Conv2d(96, 256, kernel_size=5, stride=1, padding=2),nn.BatchNorm2d(256),nn.ReLU(),nn.MaxPool2d(kernel_size = 3, stride = 2))self.conv_3 = nn.Sequential(nn.Conv2d(256, 384, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(384),nn.ReLU())self.conv_4 = nn.Sequential(nn.Conv2d(384, 384, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(384),nn.ReLU())self.conv_5 = nn.Sequential(nn.Conv2d(384, 256, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(256),nn.ReLU(),nn.MaxPool2d(kernel_size = 3, stride = 2))self.fc_1 = nn.Sequential(nn.Dropout(0.5),nn.Linear(9216, 4096),nn.ReLU())self.fc_2 = nn.Sequential(nn.Dropout(0.5),nn.Linear(4096, 4096),nn.ReLU())self.fc_3= nn.Sequential(nn.Linear(4096, num_classes))def forward(self, x):out = self.conv_1(x)out = self.conv_2(out)out = self.conv_3(out)out = self.conv_4(out)out = self.conv_5(out)out = out.reshape(out.size(0), -1)out = self.fc_1(out)out = self.fc_2(out)out = self.fc_3(out)return outif __name__ == '__main__':model = AlexNet()print(model)x = torch.randn(1, 3, 224, 224)y = model(x)print(y.size())

train

import torch
import torch.nn as nnfrom dataloader import train_loader, test_loader
from model import AlexNet# define the hyperparameters
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
num_classes = 10
num_epochs = 20
learning_rate = 1e-3# load the model
model = AlexNet(num_classes).to(device)# loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)  # train the model
total_len = len(train_loader)for epoch in range(num_epochs):for i, (images, labels) in enumerate(train_loader):# move tensors to the configured deviceimages = images.to(device)labels = labels.to(device)# forward passoutputs = model(images)loss = criterion(outputs, labels)# backward and optimizeoptimizer.zero_grad()loss.backward()optimizer.step()if (i+1) % 100 == 0:print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format(epoch+1, num_epochs, i+1, total_len, loss.item()))# Validationwith torch.no_grad():model.eval()correct = 0total = 0for images, labels in test_loader:images = images.to(device)labels = labels.to(device)outputs = model(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()model.train()print('Accuracy of the network on the {} validation images: {} %'.format(10000, 100 * correct / total))# save the model checkpoint
torch.save(model.state_dict(), 'alexnet.pth')

test

import torchfrom dataloader import test_loader, classes
from model import AlexNet# load the pretrained model
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = AlexNet().to(device)
model.load_state_dict(torch.load('alexnet.pth', map_location=device))# test the pretrained model on CIFAR-10 test data
with torch.no_grad():model.eval()correct = 0total = 0for images, labels in test_loader:images = images.to(device)labels = labels.to(device)outputs = model(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print('Accuracy of the network on the {} validation images: {} %'.format(10000, 100 * correct / total))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/96089.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

通过ElementUi在Vue搭建的项目中实现CRUD

🏅我是默,一个在CSDN分享笔记的博主。📚📚 🌟在这里,我要推荐给大家我的专栏《Vue》。🎯🎯 🚀无论你是编程小白,还是有一定基础的程序员,这个专栏…

摄影后期图像编辑软件Lightroom Classic 2023 mac中文特点介绍

Lightroom Classic 2023 mac是一款图像处理软件,是数字摄影后期制作的重要工具之一,lrc2023 mac适合数字摄影后期制作、摄影师、设计师等专业人士使用。 Lightroom Classic 2023 mac软件特点 高效的图像管理:Lightroom Classic提供了强大的图…

WPF 实现点击按钮跳转页面功能

方法1. 配置环境 首先添加prism依赖项&#xff0c;配置好所有文件。需要配置的有两个文件&#xff1a;App.xaml.cs和App.xaml App.xaml.cs using System.Data; using System.Linq; using System.Threading.Tasks; using System.Windows;namespace PrismDemo {/// <summa…

输入电压转化为电流性 5~20mA方案

输入电压转化为电流性 5~20mA方案 方案一方案二方案三 方案一 XTR111是一款精密的电压-电流转换器是最广泛应用之一。原因有二&#xff1a;一是线性度非常好、二是价格便宜。总结成一点&#xff0c;就是性价比高。 典型电路 最终电路 Z1二极管处输出电流表达式&#xff1a;…

【Python】读取显示pgm图像文件

文章目录 零. 前言一. pgm基本概念二. pgm基本信息读取三. pgm图像渲染四. 代码优化 零. 前言 这学期要学多媒体信息隐藏对抗&#xff0c;发现其中的图像数据集文件都是pgm文件形式的。虽然是图像文件&#xff0c;但是却不能直接通过图像查看器来打开&#xff0c;上网一搜&…

1、内核加载模块

一、静态加载 1、新功能源码与内核源码一起编译进uImage文件内 新功能源码与Linux内核源码在同一目录结构下在linux-3.14/drivers/char/目录下编写hello.c文件&#xff0c;内容如下 #include <linux/module.h> #include <linux/kernel.h>int __init myhello_ini…

英语四六级高频核心词(故事版)

第一组&#xff1a;" A Century of Community Effort to Improve Quality of Life and Climate" In the early years of the 20th century, a small community found itself facing a decade of challenges. The most pressing issue was the mental quality of life…

理解C++强制类型转换

理解C强制类型转换 文章目录 理解C强制类型转换理解C强制转换运算符1 static_cast1.1. static_cast用于内置数据类型之间的转换1.2 用于指针之间的转换 1.3 用于基类与派生类之间的转换2. const_cast2.1示例12.2 示例2——this指针 3.reinterpret_cast4.dynamic_cast C认为C风格…

多普勒频率相关内容介绍

图1 多普勒效应 1、径向速度 径向速度是作用于雷达或远离雷达的速度的一部分。 图2 不同的速度 2、喷气发动机调制 JEM是涡轮机的压缩机叶片的旋转的多普勒频率。 3、多普勒困境 最大无模糊范围需要尽可能低的PRF&#xff1b; 最大无模糊速度需要尽可能高的PRF&#xff1b…

国庆看坚如磐石

坚如磐石上映了&#xff0c;可以在爱奇艺观看。 而博主在使用蓝牙耳机连接电脑的过程中&#xff0c;发现没有蓝牙开启选项&#xff0c;并且在服务的设备管理器中也没有找到&#xff0c;很明显这是缺少驱动导致的&#xff0c;因此便去联想官方网站下载对应的驱动。 这里可以输入…

【LLM】主流大模型体验(文心一言 科大讯飞 字节豆包 百川 阿里通义千问 商汤商量)

note 智谱AI体验百度文心一言体验科大讯飞大模型体验字节豆包百川智能大模型阿里通义千问商汤商量简要分析&#xff1a;仅从测试“老婆饼为啥没有老婆”这个问题的结果来看&#xff0c;chatglm分点作答有条理&#xff08;但第三点略有逻辑问题&#xff09;&#xff1b;字节豆包…

数据结构与算法(四):哈希表

参考引用 Hello 算法 Github&#xff1a;hello-algo 1. 哈希表 1.1 哈希表概述 哈希表&#xff08;hash table&#xff09;&#xff0c;又称散列表&#xff0c;其通过建立键 key 与值 value 之间的映射&#xff0c;实现高效的元素查询 具体而言&#xff0c;向哈希表输入一个键…

STM32复习笔记(四):看门狗

目录 &#xff08;一&#xff09;简介 &#xff08;二&#xff09;IWDG IWDG的CUBEMX工程配置 IWDG相关函数&#xff08;非常少&#xff0c;所以直接贴上来&#xff09;&#xff1a; &#xff08;三&#xff09;WWDG &#xff08;一&#xff09;简介 看门狗分为独立看门…

几种开源协议的区别(Apache、MIT、BSD、MPL、GPL、LGPL)

作为一名软件开发人员&#xff0c;你一定也是经常接触到开源软件&#xff0c;但你真的就了解这些开源软件使用的开源许可协议吗&#xff1f; 你不会真的认为&#xff0c;开源就是完全免费吧&#xff1f;那么让我们通过本文来寻找答案。 一、开源许可协议简述 开源许可协议是指开…

karmada v1.7.0安装指导

前言 安装心得 经过多种方式操作&#xff0c;发现二进制方法安装太复杂&#xff0c;证书生成及其手工操作太多了&#xff0c;没有安装成功&#xff1b;helm方式的安装&#xff0c;v1.7.0的chart包执行安装会报错&#xff0c;手工修复了报错并修改了镜像地址&#xff0c;还是各…

家居家纺经营配送小程序商城的作用是什么

家居家纺产品是每个家庭都必备的&#xff0c;无论商场还是小摊贩&#xff0c;市场中经营商家数量都比较多&#xff0c;而随着互联网电商发展&#xff0c;在实际经营中&#xff0c;传统线下商家也面临多个难题&#xff1a; 首先就是获客问题&#xff0c;线下渠道推广宣传方式单…

冒泡排序和选择排序

目录 一、冒泡排序 1.冒泡排序的原理 2.实现冒泡排序 1.交换函数 2.单躺排序 3.冒泡排序实现 4.测试 5.升级冒泡排序 6.升级版代码 7.升级版测试 二、选择排序 1.选择排序的原理 2.实现选择排序 1.单躺排序 2.选择排序实现 3.测试 ​4.修改 5.测试 一、冒泡排序…

MacBook 录制电脑内部声音

MacBook 录制电脑内部声音 老妈喜欢跳广场舞&#xff0c;现在广场舞音频下载都收费了&#xff01;没办法&#xff0c;只能自己录歌了&#xff0c;外录有杂音大家也都知道&#xff0c;所以就只能采用内录的方式然后再用 Audition 调整一下音量大小。 一、&#xff08;前置条件&a…

HVDC-MMC互连(1000MW,±320KV)使用聚合MMC模型进行优化的SPS模拟

微❤关注“电气仔推送”获得资料&#xff08;专享优惠&#xff09; 模型概述&#xff1a; 本示例展示了一个SimPowerSystems&#xff08;SPS&#xff09;模型&#xff0c;使用基于模块化多电平变换器&#xff08;MMC&#xff09;技术的电压源换流器&#xff08;VSC&#xff09…

【AI视野·今日Robot 机器人论文速览 第四十九期】Fri, 6 Oct 2023

AI视野今日CS.Robotics 机器人学论文速览 Fri, 6 Oct 2023 Totally 29 papers &#x1f449;上期速览✈更多精彩请移步主页 Interesting: &#x1f4da;ContactGen, 基于生成模型的抓取手势生成&#xff0c;类人五指手。(from 伊利诺伊大学 香槟) 数据集&#xff1a;GRAB da…