基于pytoch卷积神经网络水质图像分类实战

具体怎么学习pytorch,看b站刘二大人的视频。

完整代码:

import numpy as np
import os
from PIL import Image
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
# 设置随机种子
torch.manual_seed(42)
np.random.seed(42)
'''https://zhuanlan.zhihu.com/p/156926543'''
# 定义图片目录
image_dir = 'images'# 初始化图片路径列表
img_list = []# 遍历指定目录及其子目录中的所有文件
for parent, _, filenames in os.walk(image_dir):for filename in filenames:# 拼接文件的完整路径filename_path = os.path.join(parent, filename)img_list.append(filename_path)# 初始化图像张量列表和标签列表
image_tensors = []
y_list = []for image_path in img_list:# 提取标签 (假设标签是文件名的第一个字符)label = int(os.path.basename(image_path)[0])y_list.append(label)# 打开图像img = Image.open(image_path)# 获取图像尺寸width, height = img.size# 定义裁剪的区域(假设要保留图像中心的 100x100 区域)left = (width - 100) / 2top = (height - 100) / 2right = (width + 100) / 2bottom = (height + 100) / 2# 裁剪图像img = img.crop((left, top, right, bottom))# 将图像转换为 NumPy 数组img_array = np.asarray(img)# 将 NumPy 数组转换为 PyTorch 张量img_tensor = torch.from_numpy(img_array).float()# 如果图像是 RGB,将其转换为 (C, H, W) 格式if img_tensor.ndimension() == 3 and img_tensor.shape[2] == 3:img_tensor = img_tensor.permute(2, 0, 1)  # 从 (H, W, C) 变为 (C, H, W)# 增加 batch 维度img_tensor = img_tensor.unsqueeze(0)  # 从 (C, H, W) 变为 (1, C, H, W)# 规范化到0-1之间img_tensor = img_tensor / 255.0# 添加到图像张量列表image_tensors.append(img_tensor)# 打印图像张量的形状print(f"当前图像形状: {img_tensor.shape}")# 将图像张量列表转换为四维张量
x_data = torch.cat(image_tensors, dim=0)
# 遍历 y_list 中的每个元素,并将每个数减去 1
for i in range(len(y_list)):y_list[i] -= 1# 将标签列表转换为张量
y_labels = torch.tensor(y_list).long()  # 注意这里使用 .long() 方法将标签转换为长整型print(x_data.shape,y_labels.shape)
print(y_labels)# 定义数据集和数据加载器
class CustomDataset(torch.utils.data.Dataset):def __init__(self, x_data, y_labels):self.x_data = x_dataself.y_labels = y_labelsdef __len__(self):return len(self.x_data)def __getitem__(self, idx):return self.x_data[idx], self.y_labels[idx]# 使用自定义数据集和数据加载器
custom_dataset = CustomDataset(x_data, y_labels)
train_size = int(0.8 * len(custom_dataset))
val_size = len(custom_dataset) - train_size
train_set, val_set = torch.utils.data.random_split(custom_dataset, [train_size, val_size])train_loader = DataLoader(train_set, batch_size=32, shuffle=True)
val_loader = DataLoader(val_set, batch_size=32, shuffle=False)# 定义卷积神经网络模型
class SimpleCNN(nn.Module):def __init__(self):super(SimpleCNN, self).__init__()self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)self.pool = nn.MaxPool2d(2, 2)self.fc1 = nn.Linear(32 * 25 * 25, 128)self.fc2 = nn.Linear(128, 5)  # 假设有5个类别def forward(self, x):x = self.pool(torch.relu(self.conv1(x)))x = self.pool(torch.relu(self.conv2(x)))x = torch.flatten(x, 1)x = torch.relu(self.fc1(x))x = self.fc2(x)return x# 训练模型
model = SimpleCNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)for epoch in range(50):  # 假设训练50个epochrunning_loss = 0.0for inputs, labels in train_loader:optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()print(f"Epoch {epoch + 1}, Loss: {running_loss / len(train_loader)}")# 在每个epoch结束后,计算并打印验证集的准确率model.eval()  # 将模型设置为评估模式correct = 0total = 0with torch.no_grad():for inputs, labels in val_loader:outputs = model(inputs)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()val_accuracy = correct / totalprint(f'Validation Accuracy after Epoch {epoch + 1}: {val_accuracy}')

本数据集中x_data的维度是四维张量(203,3,100,100),y_labels的维度是一维张量 

代码中需要注意的点,卷积模型接受的是四维张量,因此要转变为四维张量。

全连接层中输入的特征数,需要自己计算,通过前面卷积层和池化层后,计算总的维度数。一般是最后的通道数*高度*宽度

定义模型中的forward函数中,在经过全连接层计算前,需要将四维的x转为2维

如果 x 的形状是 (64, 32, 28, 28),表示一个批次大小为64的图像张量,其中每个图像有32个通道,高度和宽度都是28像素。现在,我们希望将这个张量展平为一个二维张量,以便输入到全连接层进行进一步处理。

通过 torch.flatten(x, 1) 操作,我们将在指定维度(这里是第一个维度,也就是通道维度)上对张量进行展平。展平后的张量形状将变为 (64, 32*28*28),其中64是批次大小,而 32*28*28 是展平后的特征数量,即每个图像的特征数量。这与前面定义的全连接层的输入特征数要一致。

Dataloader中batch_size就是设置第一个维度,比如这里的batch_size是32,那么

for inputs, labels in train_loader:

 这里的inputs维度是(32,3,100,100)

新学习pytorch中的分割数据集与测试集方法。

# 使用自定义数据集和数据加载器
custom_dataset = CustomDataset(x_data, y_labels)
train_size = int(0.8 * len(custom_dataset))
val_size = len(custom_dataset) - train_size
train_set, val_set = torch.utils.data.random_split(custom_dataset, [train_size, val_size])train_loader = DataLoader(train_set, batch_size=32, shuffle=True)
val_loader = DataLoader(val_set, batch_size=32, shuffle=False)

结果展现,可以看见准确率有0.82:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/849056.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

resultType的类型错误

resultType的类型错误,不能是List而应该是对应的返回Bean对象的类型,VO 这里是引用 org.mybatis.spring.MyBatisSystemException: nested exception is org.apache.ibatis.exceptions.PersistenceException: Error querying database. Cause: java.lang…

opencv进阶 ——(十二)基于三角剖分实现人脸对齐

三角剖分概念 三角剖分(Triangulation)是一种将多边形或曲面分解为一系列互不相交的三角形的技术,它是计算几何、计算机图形学、地理信息系统、工程和科学计算中的一个基本概念。通过三角剖分,复杂的形状可以被简化为基本的三角…

病理级Polymer酶标二抗IHC试剂盒上线!

免疫组织化学 Immunohistochemistry,lHC 是利用抗体与抗原特异性识别原理,对组织样本中的抗原进行定位/定性分析的实验技术。组织切片保留了样品的解剖学结构特征,从而可以高分辨率地显现蛋白在细胞,甚至细胞器中的定位。基于以上特性&…

Apple - Image I/O Programming Guide

翻译自:Image I/O Programming Guide(更新时间:2016-09-13 https://developer.apple.com/library/archive/documentation/GraphicsImaging/Conceptual/ImageIOGuide/imageio_intro/ikpg_intro.html#//apple_ref/doc/uid/TP40005462 文章目录 …

orbslam2代码解读(1):数据预处理过程

写orbslam2代码解读文章的初衷 首先最近陆陆续续花了一两周时间学习视觉slam,因为之前主要是做激光slam,有一定基础所以学的也比较快,也是看完了视觉14讲的后端后直接看orbslam2的课,看的cvlife的课(课里大部分是代码…

jenkins的简单使用

2.1.简介 Jenkins是一个开源软件项目,是基于Java开发的一种持续集成工具,用于监控持续重复的工作,旨在提供一个开放易用的软件平台,使软件的持续集成变成可能。 2.4.Jenkins安装 1.下载安装包jenkins.war; 2.在安装…

笔记 | 软件工程04:软件项目管理

1 软件项目及其特点 1.1 什么是项目 1.2 项目特点 1.3 影响项目成功的因素 1.4 什么是软件项目 针对软件这一特定产品和服务的项目努力开展“软件开发活动",(理解:软件项目是一种活动) 1.5 软件项目的特点 1.6 军用软件项目的特点 2 …

怎么用电脑把图片转换二维码?图片在线生成二维码的步骤内容

现在很多人会通过二维码来存储物品的信息图片,其他人可以通过扫描二维码的方式来查看对应的图片内容,那么当我们需要将一批图片每个单独生成二维码,该如何操作能够快速将图片转换二维码呢? 今天,小编来分享给大家一个…

CNN卷积神经网络

一、概述 卷积神经网络(CNN)是深度学习领域的重要算法,特别适用于处理具有网格结构的数据,比如说图像和音频。它起源于二十世纪80至90年代,但真正得到快速发展和应用是在二十一世纪,随着深度学习理论的兴起…

【ai】phc:安装issac环境且fix libstdc++.so 版本报错

Pycharm远程连接服务器(2023-11-9) 大神分享了pycharm远程连接ubuntu工作站的方法。 https://github.com/ZhengyiLuo/PHC 给出的操作同样适用: 参考 Pycharm远程连接服务器(2023-11-9) :前提是一样的 PHC的要求:isaac 创建 conda activate isaac

【Vue】scoped解决样式冲突

默认情况下写在组件中的样式会 全局生效 → 因此很容易造成多个组件之间的样式冲突问题。 全局样式: 默认组件中的样式会作用到全局,任何一个组件中都会受到此样式的影响 局部样式: 可以给组件加上scoped 属性,可以让样式只作用于当前组件 一、代码示例 BaseOne…

RocketMQ可视化界面安装

RocketMQ可视化界面安装 **起因:**访问rocketmq-externals项目的git地址,下载了源码,在目录中并没有找到rocketmq-console文件夹。 git下面文档提示rocketMQ的仪表板转移到了新的项目中,点击仪表板到新项目地址; 下载…

搜索与图论:宽度优先搜索

搜索与图论&#xff1a;宽度优先搜索 题目描述参考代码 题目描述 输入样例 5 5 0 1 0 0 0 0 1 0 1 0 0 0 0 0 0 0 1 1 1 0 0 0 0 1 0输出样例 8参考代码 #include <iostream> #include <algorithm> #include <cstring> using namespace std;const int N …

VsQt单元测试目录的管理方式

正常项目的文件管理方式 正常项目的目录&#xff0c;是由文件系统中实际的文件夹进行分类管理的。 但是如果单元测试用实际文件夹管理的话&#xff0c;会出现问题&#xff0c;就是被测类太多了&#xff0c;用文件系统管理的话&#xff0c;不太方面查看&#xff0c;如下图所示。…

contentType 与 dataType

contentType 与 dataType contentType contentType&#xff1a;发送的数据格式&#xff08;请求方发送给服务器的数据格式&#xff09;&#xff0c;这个内容会放在请求方的 请求头中 application/x-www-form-urlencoded 这个是默认的请求格式。 提交给后台的数据会按照 KV&am…

创新实训2024.06.06日志:部署web服务

1. 运行web项目前后端服务 首先我们要先在服务器上运行客户端以及服务端的应用程序。随后再考虑如何通过公网/局域网访问的问题。 如何启动服务在仓库对应分支下的Readme文件中已经有详细描述了。 1.1. 启动服务端 对于服务端&#xff0c;即&#xff08;要求你在服务端子项…

SCARA机器人中旋转花键的维护和保养方法!

作为精密传动元件的一种&#xff0c;旋转花键在工作过程中承受了较大的负荷。在自动化设备上运用广泛&#xff0c;如&#xff1a;水平多关节机械手臂&#xff08;SCARA&#xff09;、产业用机器人、自动装载机、雷射加工机、搬运装置、机械加工中心的ATC装置等&#xff0c;最适…

如何在Windows 10和11上修复DISM错误87?这里提供办法

​在电脑上运行DISM命令时&#xff0c;是否收到“错误代码87”消息&#xff1f;这是一个非常常见的错误&#xff0c;你可以轻松地修复它。我们将向你展示在Windows 11或Windows 10计算机上解决此问题的多种方法。 确保键入正确的命令 运行DISM命令时出现错误代码87的最常见原…

优雅谈大模型10:MoE

大模型技术论文不断&#xff0c;每个月总会新增上千篇。本专栏精选论文重点解读&#xff0c;主题还是围绕着行业实践和工程量产。若在某个环节出现卡点&#xff0c;可以回到大模型必备腔调或者LLM背后的基础模型新阅读。而最新科技&#xff08;Mamba,xLSTM,KAN&#xff09;则提…

应对800G以太网挑战:数据中心迁移

在过去几年中&#xff0c;云基础设施和服务的大规模使用推动了对更多带宽、更快速度和更低延迟性能的需求。交换机和服务器技术的改进要求布线和架构随之调整。因此&#xff0c;800G以太网对数据中心迁移的需求&#xff0c;特别是对速率&#xff08;包括带宽、光纤密度和通道速…