PyTorch下的5种不同神经网络-一.AlexNet

1.导入模块

导入所需的Python库,包括图像处理、深度学习模型和数据加载

import osimport torchimport torch.nn as nnimport torch.optim as optimfrom torch.utils.data import Dataset, DataLoaderfrom PIL import Imagefrom torchvision import models, transforms

2.定义自定义图像数据集

创建一个自定义的图像数据集类,用于加载和处理图像数据。

class CustomImageDataset(Dataset):def __init__(self, main_dir, transform=None):self.main_dir = main_dirself.transform = transformself.files = []self.labels = []self.label_to_index = {}for index, label in enumerate(os.listdir(main_dir)):self.label_to_index[label] = indexlabel_dir = os.path.join(main_dir, label)if os.path.isdir(label_dir):for file in os.listdir(label_dir):self.files.append(os.path.join(label_dir, file))self.labels.append(label)def __len__(self):return len(self.files)def __getitem__(self, idx):image = Image.open(self.files[idx])label = self.labels[idx]if self.transform:image = self.transform(image)return image, self.label_to_index[label]

3.定义数据转换

定义一个数据转换过程,包括图像大小调整、随机翻转、旋转、转换为张量以及标准化

transform = transforms.Compose([transforms.Resize((227, 227)),  # AlexNet的输入图像大小transforms.RandomHorizontalFlip(),  # 随机水平翻转transforms.RandomRotation(10),  # 随机旋转transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),  # AlexNet的标准化])

4.创建数据集

使用自定义数据集类和定义的数据转换来创建数据集

dataset = CustomImageDataset(main_dir="F:\\A-GX\\A-SJJ\\flower_photos\\flower_photos", transform=transform)

5.创建数据加载器

使用数据集创建一个数据加载器,用于批量加载和处理数据

data_loader = DataLoader(dataset, batch_size=32, shuffle=True)

6.加载预训练的AlexNet模型

从PyTorch库中加载预训练的AlexNet模型

alexnet_model = models.alexnet(pretrained=True)

7.修改最后几层以适应新的分类任务

修改AlexNet模型的最后几层,以便它能够处理新的分类任务

num_ftrs = alexnet_model.classifier[6].in_featuresalexnet_model.classifier[6] = nn.Linear(num_ftrs, len(dataset.label_to_index))

8.定义损失函数和优化器

定义用于训练模型的损失函数和优化器。

criterion = nn.CrossEntropyLoss()optimizer = optim.Adam(alexnet_model.parameters(), lr=0.0001)

9.模型并行化

如果有多GPU,则使用nn.DataParallel来并行化模型

if torch.cuda.device_count() > 1:alexnet_model = nn.DataParallel(alexnet_model)

10.将模型发送到GPU

将模型发送到GPU进行训练。

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")alexnet_model.to(device)

11.训练模型

数据加载器和定义的参数训练模型

num_epochs = 10for epoch in range(num_epochs):alexnet_model.train()running_loss = 0.0for images, labels in data_loader:images, labels = images.to(device), labels.to(device)# 前向传播outputs = alexnet_model(images)loss = criterion(outputs, labels)# 反向传播和优化optimizer.zero_grad()loss.backward()optimizer.step()running_loss += loss.item()# 在每个epoch结束后评估模型train_accuracy = evaluate_model(alexnet_model, data_loader, device)print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {running_loss / len(data_loader):.4f}, Train Accuracy: {train_accuracy:.2f}%')

12.评估模型

定义一个评估函数,用于评估模型的性能

def evaluate_model(model, data_loader, device):model.eval()  # 将模型设置为评估模式correct = 0total = 0with torch.no_grad():  # 在这个块中,所有计算都不会计算梯度for images, labels in data_loader:images, labels = images.to(device), labels.to(device)outputs = model(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()accuracy = 100 * correct / totalreturn accuracy

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/31299.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

怎么添加网页到桌面快捷方式?

推荐用过最棒的学习网站!https://offernow.cn 添加网页到桌面快捷方式? 很简单,仅需要两步,接下来以chrome浏览器为例。 第一步 在想要保存的网页右上角点击设置。 第二步 保存并分享-创建快捷方式,保存到桌面即可…

Docker定位具体占用大量存储的容器

监控告警生产环境的服务器磁盘分区使用率大于90%,进入服务器查看Docker 的 overlay2 存储驱动目录中占用很大,很可能是某个容器一直在打印日志,所以需要定位到是哪个容器,然后进行进一步排查。 然后进入到overlay2中查看是哪个目录…

【第13章】进阶调试思路:如何安装复杂节点IP-Adapter?(安装/复杂报错/节点详情页/精读)ComfyUI基础入门教程

🎈背景 IP-Adapter这个名字,大家可能听说过,可以让生成的结果从参考图中学习人物、画风的一致性,在目前是比较实用的一个节点,广泛的用于照片绘制、电商作图等方面。 但同时,这个节点也是比较难安装的一个节点。 所以,这节课,我们就通过一个案例,来学习如何在Comf…

MySQL----彻底卸载(附带每一步截图)

停止mysql服务 打开任务管理器,点击服务,找到mysql服务,这里我的是MySQL57,找到mysql服务后选中,点击右键选择停止服务 删除mysql服务 winR打开命令框,输入cmd打开cmd控制台或者电脑左下角输入cmd搜索&…

算法导论 总结索引 | 第四部分 第十五章:数据结构的扩张

1、动态规划(dynamic programming)与分治方法相似,都是通过组合子问题的解 来求解原问题 分治方法 将问题划分为互不相交的子问题,递归地求解子问题,再将它们的解组合起来。求出原问题的解 与之相反,动态规…

HarmonyOS角落里的知识:一杯冰美式的时间 -- 之打字机

一、前言 模拟编辑器或者模拟输入框中文字啪啦啪啦输入的效果,往往能够吸引人们的眼球,让用户的注意力聚焦在输入的内容上,本文将和大家探讨打字机效果的实现方式以及应用。Demo基于API12。 二、思路 拆分开来很简单,将字符串拆…

每天写java到期末考试(6.21)--集合4--练习--6.20

练习1&#xff1a; 正常写集合 bool类 代码&#xff1a; import QM_Fx.Student;import java.util.ArrayList;public class test {public static void main(String[] args) {ArrayList<Student> listnew ArrayList<>();//2.创建学生对象Student s1new Student(&quo…

八大经典排序算法

前言 本片博客主要讲解一下八大排序算法的思想和排序的代码 &#x1f493; 个人主页&#xff1a;普通young man-CSDN博客 ⏩ 文章专栏&#xff1a;排序_普通young man的博客-CSDN博客 若有问题 评论区见&#x1f4dd; &#x1f389;欢迎大家点赞&#x1f44d;收藏⭐文章 目录 …

MySQL 面试突击指南:核心知识点解析1

MySQL中有哪些存储引擎? InnoDB存储引擎 InnoDB是MySQL的默认事务型引擎,也是最重要、使用最广泛的存储引擎,设计用于处理大量短期事务。 MyISAM存储引擎 在MySQL 5.1及之前版本,MyISAM是默认的存储引擎。它提供了全文索引、压缩、空间函数(GIS)等特性,但不支持事务和…

【SCAU数据挖掘】数据挖掘期末总复习题库简答题及解析——中

1. 某学校对入学的新生进行性格问卷调查(没有心理学家的参与)&#xff0c;根据学生对问题的回答&#xff0c;把学生的性格分成了8个类别。请说明该数据挖掘任务是属于分类任务还是聚类任务?为什么?并利用该例说明聚类分析和分类分析的异同点。 解答&#xff1a; (a)该数据…

图解Sieve of Eratosthenes(埃拉托斯特尼筛法)算法求解素数个数

1.素数的定义 素数又称质数。质数是指在大于1的自然数中&#xff0c;除了1和它本身以外不再有其他因数的自然数。一个大于1的自然数&#xff0c;除了1和它自身外&#xff0c;不能被其他自然数整除的数叫做质数&#xff1b;否则称为合数&#xff08;规定1既不是质数也不是合数&…

leetCode热题100——两数之和(python)

题目 给定一个整数数组 nums 和一个整数目标值 target&#xff0c;请你在该数组中找出 和为目标值 target 的那 两个 整数&#xff0c;并返回它们的数组下标。 你可以假设每种输入只会对应一个答案。但是&#xff0c;数组中同一个元素在答案里不能重复出现。 你可以按任意顺…

Node.js 是一个开源的 跨平台的JavaScript运行环境

https://www.npmjs.com/ 中央仓库 Visual Studio Code - Code Editing. Redefined https://openjsf.org/ OpenJS 促进了关键 JavaScript 技术在全球范围内的广泛采用和持续发展。 Apache服务器 Nginx服务器 Tomcat服务器 Node.js服务器 Gunicorn服务器 uW…

低代码平台实践:打造高效动态表单解决方案的探索与思考

&#x1f525;需求背景 我司业务同事在抓取到候选人的简历之后&#xff0c;经常会出现&#xff0c;很多意向候选人简历信息不完整&#xff0c;一个个打电话确认的情况&#xff0c;严重影响了HR的工作效率&#xff0c;于是提出我们可以通过发送邮件、短信、H5链接的方式来提醒候…

.NET C# 操作Neo4j图数据库

.NET C# 操作Neo4j图数据库 目录 .NET C# 操作Neo4j图数据库环境Code 环境 VisualStudio2022 .NET 6 Neo4j.Driver 5.21 Code // 连接设置 var uri "bolt://localhost:7687"; var user "neo4j"; var password "password"; // 请替换为你的…

docker 配置与使用

目录 安装docker 作者遇到的问题1&#xff1a;安装docker 错误说明 解决方法&#xff1a; 作者遇到问题2&#xff1a;GPG密钥问题 问题说明 解决方法&#xff1a; 方法一&#xff1a;使用备用的GPG密钥服务器 方法二&#xff1a;使用国内镜像源 方法3&#xff1a;手动下…

使用lua开发apisix自定义插件并发布

接到老大需求&#xff1a;需要对cookie进行操作&#xff0c;遂查询apisix的自带插件&#xff0c;发现有&#xff0c;但不满足&#xff0c;于是自己开发了一个插件并部署&#xff0c;把开发部署流程写在这里打个日志怕以后忘掉。 一、需求 插件很简单&#xff0c;就是在reques…

什么是嵌入式,单片机又是什么,两者有什么关联又有什么区别?

在开始前刚好我有一些资料&#xff0c;是我根据网友给的问题精心整理了一份「嵌入式的资料从专业入门到高级教程」&#xff0c; 点个关注在评论区回复“888”之后私信回复“888”&#xff0c;全部无偿共享给大家&#xff01;&#xff01;&#xff01;从科普的角度&#xff0c;…

HTTP 抓包工具——Fiddler项目实战

网络爬虫实质上是模拟浏览器向 Web 服务器发送请求。对于一些简单的网络请求&#xff0c;我们 可以通过查看 URL 地址来构造请求&#xff0c;但对于一些稍复杂的网络请求&#xff0c;仍然通过观察 URL 地 址将无法构造正确。因此我们需要对这些复杂的网络请求进行捕获分…

【总线】AXI4第二课时:深入AXI4总线的基础事务

大家好,欢迎来到今天的总线学习时间!如果你对电子设计、特别是FPGA和SoC设计感兴趣&#xff0c;那你绝对不能错过我们今天的主角——AXI4总线。作为ARM公司AMBA总线家族中的佼佼者&#xff0c;AXI4以其高性能和高度可扩展性&#xff0c;成为了现代电子系统中不可或缺的通信桥梁…