pytorch神经网络训练(AlexNet)

  • 导包
import osimport torchimport torch.nn as nnimport torch.optim as optimfrom torch.utils.data import Dataset, DataLoaderfrom PIL import Imagefrom torchvision import models, transforms
  • 定义自定义图像数据集
class CustomImageDataset(Dataset): 

定义一个自定义的图像数据集类,继承自Dataset

def __init__(self, main_dir, transform=None): 

初始化方法,接收主目录和转换方法

        self.main_dir = main_dir 

主目录,包含多个子目录,每个子目录包含同一类别的图像

        self.transform = transform

 图像转换方法,用于对图像进行预处理

        self.files = [] 

存储所有图像文件的路径

        self.labels = [] 

存储所有图像的标签

        self.label_to_index = {} 

创建一个字典,用于将标签映射到索引

        for index, label in enumerate(os.listdir(main_dir)):

 遍历主目录中的所有子目录

 

          self.label_to_index[label] = index label_dir = os.path.join(main_dir, label) 

将标签映射到索引,构建标签子目录的路径

           if os.path.isdir(label_dir): for file in os.listdir(label_dir): self.files.append(os.path.join(label_dir, file))self.labels.append(label) 

如果是目录,遍历目录中的所有文件,将文件路径添加到列表,将标签添加到列表

def __len__(self):

定义数据集的长度

        return len(self.files) 

返回文件列表的长度

def __getitem__(self, idx): 

定义获取数据集中单个样本的方法

        image = Image.open(self.files[idx]) label = self.labels[idx] if self.transform: image = self.transform(image) return image, self.label_to_index[label] 

打开图像文件,获取图像的标签,如果有转换方法,对图像进行转换,返回图像和对应的标签索引

  • 定义数据转换
transform = transforms.Compose([transforms.Resize((227, 227)),  # AlexNet的输入图像大小transforms.RandomHorizontalFlip(),  # 随机水平翻转transforms.RandomRotation(10),  # 随机旋转transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),  # AlexNet的标准化])

  • 创建数据集
dataset = CustomImageDataset(main_dir="D:\\图像处理、深度学习\\flowers", transform=transform)
  • 创建数据加载器
data_loader = DataLoader(dataset, batch_size=32, shuffle=True)
  • 加载预训练的AlexNet模型
alexnet_model = models.alexnet(pretrained=True)
  • 修改最后几层以适应新的分类任务
num_ftrs = alexnet_model.classifier[6].in_featuresalexnet_model.classifier[6] = nn.Linear(num_ftrs, len(dataset.label_to_index))
  • 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()optimizer = optim.Adam(alexnet_model.parameters(), lr=0.0001)
  • 如果有多个GPU,可以使用nn.DataParallel来并行化模型
if torch.cuda.device_count() > 1:alexnet_model = nn.DataParallel(alexnet_model)
  • 将模型发送到GPU(如果可用)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")alexnet_model.to(device)                                                               

  • 模型评估
def evaluate_model(model, data_loader, device):model.eval()  # 将模型设置为评估模式correct = 0total = 0with torch.no_grad():  # 在这个块中,所有计算都不会计算梯度for images, labels in data_loader:images, labels = images.to(device), labels.to(device)outputs = model(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()accuracy = 100 * correct / totalreturn accuracy
  • 训练模型
num_epochs = 10for epoch in range(num_epochs):alexnet_model.train()running_loss = 0.0for images, labels in data_loader:images, labels = images.to(device), labels.to(device)

前向传播

        outputs = alexnet_model(images)loss = criterion(outputs, labels)

反向传播和优化

        optimizer.zero_grad()loss.backward()optimizer.step()running_loss += loss.item()

在每个epoch结束后评估模型

    train_accuracy = evaluate_model(alexnet_model, data_loader, device)print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {running_loss / len(data_loader):.4f}, Train Accuracy: {train_accuracy:.2f}%')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/26571.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

美丽的拉萨,神奇的布达拉宫

原文链接:美丽的拉萨,神奇的布达拉宫 2022年11月30日,可能将成为一个改变人类历史的日子——美国人工智能开发机构OpenAI推出了聊天机器人ChatGPT-3.5,将人工智能的发展推向了一个新的高度。2023年11月7日,OpenAI首届…

TcpClient 服务器、客户端连接

TcpClient 服务器 TcpListener 搭建tcp服务器的类,基于socket套接字通信的 1 创建服务器对象 TcpListener server new TcpListener(IPAddress.Parse("127.0.0.1"), 3000); 2 开启服务器 设置最大连接数 server.Start(1000); 3 接收客户端的链接,只能…

Android帧绘制流程深度解析 (二)

书接上回:Android帧绘制流程深度解析 (一) 5、 dispatchVsync: 在请求Vsync以后,choreographer会等待Vsync的到来,在Vsync信号到来后,会触发dispatchVsync函数,从而调用onVsync方法…

手机和模拟器的 Frida 环境配置

目录 一、配置 JDK 和 android 环境 二、连接设备和查看权限 1、连接设备 2、查看手机权限 三、手机配置 Frida 1、frida-server下载 2、验证 四、模拟器配置 Frida 1、下载模拟器并调节成手机版: 2、连接并查看架构 3、配置并开启 x86 的 frida-serve…

Phybers:脑纤维束分析软件包

摘要 本研究提供了一个用于分析脑纤维束数据的Python库(Phybers)。纤维束数据集包含由表示主要白质通路的3D点组成的流线(也称为纤维束)。目前已经提出了一些算法来分析这些数据,包括聚类、分割和可视化方法。由于流线的几何复杂性、文件格式和数据集的大小(可能包…

HTML静态网页成品作业(HTML+CSS)—— 环保主题介绍网页(5个页面)

🎉不定期分享源码,关注不丢失哦 文章目录 一、作品介绍二、作品演示三、代码目录四、网站代码HTML部分代码 五、源码获取 一、作品介绍 🏷️本套采用HTMLCSS,未使用Javacsript代码,共有5个页面。 二、作品演示 三、代…

多层tablayout+ViewPager,NestedScrollView+ViewPager+RecyclerView,嵌套吸顶滑动冲突

先看实现的UI效果 其实就是仿BOSS的页面效果,第二层tab下的viewpager滑到最右边再右滑,就操作第一层viewpager滑动。页面上滑时把第一层tab和vp里的banner都推出界面,让第二层tab吸顶。 滑上去第二个tab块卡在顶部,如图 我混乱…

React 渲染函数render、初始化函数、更新函数运行了两次,原因为何,如何解决? React.StrictMode

文章目录 Intro官网解释解决另一篇官网文章——初始化函数或更新函数运行了两次 Intro 我在用 react 写一个 demo ,当我在某个自定义组件的 return 语句之前加上一句log之后,发现:每次页面重新渲染,该行日志都打印了两次&#xf…

vue-loader

Vue Loader 是一个 webpack 的 loader,它允许你以一种名为单文件组件 (SFCs)的格式撰写 Vue 组件 起步 安装 npm install vue --save npm install webpack webpack-cli style-loader css-loader html-webpack-plugin vue-loader vue-template-compiler webpack…

论文阅读Rolling-Unet,卷积结合MLP的图像分割模型

这篇论文提出了一种新的医学图像分割网络Rolling-Unet,目的是在不用Transformer的前提下,能同时有效提取局部特征和长距离依赖性,从而在性能和计算成本之间找到良好的平衡点。 论文地址:https://ojs.aaai.org/index.php/AAAI/article/view/2…

618购物狂欢节有哪些数码好物值得抢购?年终必备神器清单大揭秘!

一年一度的“618年中大促”即将拉开帷幕,大家是否已经挑选好了心仪的宝贝呢?那些平时心仪已久的商品,是否总期待着在价格最优惠时收入囊中?毫无疑问,618就是这样一个绝佳的时机,因为各大电商平台都会纷纷推…

Lecture2——最优化问题建模

一,建模 1,重要性 实际上,我们并没有得到一个数学公式——通常问题是由某个领域的专家口头描述的。能够将问题转换成数学公式非常重要。建模并不是一件容易的事:有时,我们不仅想找到一个公式,还想找到一个…

群晖NAS安装配置Joplin Server用来存储同步Joplin笔记内容

一、Joplin Server简介 1.1、Joplin Server介绍 Joplin支持多种方式进行同步用户的笔记数据(如:Joplin自己提供的收费的云服务Joplin Cloud,还有第三方的云盘如Dropbox、OneDrive,还有自建的云盘Nextcloud、或者通过WebDAV协议来…

长沙干洗服务,打造您的专属衣橱

长沙干洗服务,用心呵护您的每一件衣物!致力于为您打造专属的衣橱,让您的每一件衣物都焕发出独特的魅力。 我们深知每一件衣物都承载着您的故事和情感,因此我们会以更加细心的态度对待每一件衣物。无论是您心爱的牛仔裤&#xff0c…

sizeof和strlen

1.sizeof和strlen的对比 1.1sizeof sizeof是计算变量所占内存空间大小的,单位是:字节 如果操作数是类型的话,计算的是使用类型创建的变量所占内存空间的大小。 sizeof只关注占用内存空间的大小,不在乎内存中存放的是什么数据 …

QML Canvas 代码演示

一、文字阴影 / 发光 Canvas{id: root; width: 400; height: 400onPaint: //所有的绘制都在onPaint中{var ctx getContext("2d") //获取上下文// 绘制带阴影的文本ctx.fillStyle "#333" //设置填充颜色ctx.fillRect(0, 0, root.width, root.height…

Stability AI发布新版文生图模型:依然开源

Stability AI最近发布了Stable Diffusion 3 Medium(简称SD3 Medium),这是其最新的文生图模型,被官方称为“迄今为止最先进的开源模型”。SD3 Medium的性能甚至超过了Midjourney 6,特别是在生成手部和脸部图像方面表现出…

一杯咖啡的艺术 | 如何利用数字孪生技术做出完美的意式浓缩咖啡?

若您对数据分析以及人工智能感兴趣,欢迎与我们一起站在全球视野关注人工智能的发展,与Forrester 、德勤、麦肯锡等全球知名企业共探AI如何加速制造进程, 共同参与6月20日由Altair主办的面向工程师的全球线上人工智能会议“AI for Engineers”…

可以自定义的文字识别OCR

可以自定义的文字识别OCR 什么是OCR文档自学习自定义模板单证票据信息抽取操作体验 这里提到的可以自定义的文字识别OCR ,其实就是OCR文档自学习。 什么是OCR文档自学习 什么是OCR文档自学习呢?OCR文档自学习,是面向“无算法基础”的企业与个…

C#——字典diction详情

字典 字典: 包含一个key(键)和这个key所以对应的value&#xff08;值&#xff09;&#xff0c;字典是是无序的&#xff0c;key是唯一的&#xff0c;可以根据key获取值。 定义字典: new Diction<key的类型&#xff0c;value的类型>() 方法 添加 var dic new Dictionar…