【深度学习】快速制作图像标签数据集以及训练

快速制作图像标签数据集以及训练

制作DataSet

  • 先从网络收集十张图片 每种十张
    在这里插入图片描述

  • 定义dataSet和dataloader

import glob
import torch
from torch.utils import data
from PIL import Image
import numpy as np
from torchvision import transforms
import matplotlib.pyplot as plt# 通过创建data.Dataset子类Mydataset来创建输入
class Mydataset(data.Dataset):# init() 初始化方法,传入数据文件夹路径def __init__(self, root):self.imgs_path = root# getitem() 切片方法,根据索引下标,获得相应的图片def __getitem__(self, index):img_path = self.imgs_path[index]# len() 计算长度方法,返回整个数据文件夹下所有文件的个数def __len__(self):return len(self.imgs_path)# 使用glob方法来获取数据图片的所有路径
all_imgs_path = glob.glob(r"./Data/*/*.jpg")  # 数据文件夹路径# 利用自定义类Mydataset创建对象brake_dataset
# 将所有的路径塞进dataset  使用每张图片的路径进行索引图片
brake_dataset = Mydataset(all_imgs_path)
# print("图片总数:{}".format(len(brake_dataset)))  # 返回文件夹中图片总个数# 制作dataloader
brake_dataloader = torch.utils.data.DataLoader(brake_dataset, batch_size=2)  # 每次迭代时返回4个数据
# print(next(iter(break_dataloader)))

制作标签


# 为每张图片制作对应标签
species = ['sun', 'rain', 'cloud']
species_to_id = dict((c, i) for i, c in enumerate(species))
# print(species_to_id)id_to_species = dict((v, k) for k, v in species_to_id.items())
# print(id_to_species)# 对所有图片路径进行迭代
all_labels = []
for img in all_imgs_path:# 区分出每个img,应该属于什么类别for i, c in enumerate(species):if c in img:all_labels.append(i)
# print(all_labels)

制作数据和标签一起的dataset和dataloader

  • 上面的dataset不够完善
# 将数据转换为张量数据
# 对数据进行转换处理
transform = transforms.Compose([transforms.Resize((256, 256)),  # 做的第一步转换transforms.ToTensor()  # 第二步转换,作用:第一转换成Tensor,第二将图片取值范围转换成0-1之间,第三会将channel置前
])class Mydatasetpro(data.Dataset):def __init__(self, img_paths, labels, transform):self.imgs = img_pathsself.labels = labelsself.transforms = transform# 进行切片def __getitem__(self, index):img = self.imgs[index]label = self.labels[index]pil_img = Image.open(img)  # pip install pillowpil_img = pil_img.convert('RGB')data = self.transforms(pil_img)return data, label# 返回长度def __len__(self):return len(self.imgs)BATCH_SIZE = 4
brake_dataset = Mydatasetpro(all_imgs_path, all_labels, transform)
brake_dataloader = data.DataLoader(brake_dataset,batch_size=BATCH_SIZE,shuffle=True
)
imgs_batch, labels_batch = next(iter(brake_dataloader))# 4 X 3 X 256 X 256
print(imgs_batch.shape)plt.figure(figsize=(12, 8))
for i, (img, label) in enumerate(zip(imgs_batch[:10], labels_batch[:10])):img = img.permute(1, 2, 0).numpy()plt.subplot(2, 3, i + 1)plt.title(id_to_species.get(label.item()))plt.imshow(img)
plt.show()  # 展示图片

制作训练集和测试集

# 划分数据集和测试集
index = np.random.permutation(len(all_imgs_path))#  打乱所有图片的索引
print(index)# 根据索引获取所有图片的路径
all_imgs_path = np.array(all_imgs_path)[index]
all_labels = np.array(all_labels)[index]print("打乱顺序之后的所有图片路径{}".format(all_imgs_path))
print("打乱顺序之后的所有图片索引{}".format(all_labels))# 80%做训练集
s = int(len(all_imgs_path) * 0.8)
# print(s)train_imgs = all_imgs_path[:s]
# print(train_imgs)
train_labels = all_labels[:s]
test_imgs = all_imgs_path[s:]
test_labels = all_labels[s:]# 将训练集和标签 制作dataset 需要转换为张量
train_ds = Mydatasetpro(train_imgs, train_labels, transform)  # TrainSet TensorData
test_ds = Mydatasetpro(test_imgs, test_labels, transform)  # TestSet TensorData
# print(train_ds)
# print(test_ds)
print("**********")
# 制作trainLoader
train_dl = data.DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)  # TrainSet Labels
test_dl = data.DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)  # TestSet Labels

训练代码

import torch
import torchvision.models as models
from torch import nn
from torch import optim
from DataSetMake import brake_dataloader
from DataSetMake import train_dl, test_dl# 判断是否使用GPU
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")#  使用resnet 训练
model_ft = models.resnet50(pretrained=True)  # 使用迁移学习,加载预训练权in_features = model_ft.fc.in_features
model_ft.fc = nn.Sequential(nn.Linear(in_features, 256),nn.ReLU(),# nn.Dropout(0, 4),nn.Linear(256, 4),nn.LogSoftmax(dim=1))model_ft = model_ft.to(DEVICE)  # 将模型迁移到gpu# 优化器
loss_fn = nn.CrossEntropyLoss()
loss_fn = loss_fn.to(DEVICE)  # 将loss_fn迁移到GPU# Adam损失函数
optimizer = optim.Adam(model_ft.fc.parameters(), lr=0.003)epochs = 50  # 迭代次数
steps = 0
running_loss = 0
print_every = 10
train_losses, test_losses = [], []for epoch in range(epochs):model_ft.train()# 遍历训练集数据for imgs, labels in brake_dataloader:steps += 1# 标签转换为 tensorlabels = torch.tensor(labels, dtype=torch.long)# 将图片和标签 放到设备上imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)optimizer.zero_grad()  # 梯度归零#  前向推理outputs = model_ft(imgs)# 计算lossloss = loss_fn(outputs, labels)loss.backward()  # 反向传播计算梯度optimizer.step()  # 梯度优化# 累加lossrunning_loss += loss.item()if steps % print_every == 0:test_loss = 0accuracy = 0# 验证模式model_ft.eval()# 测试集 不需要计算梯度with torch.no_grad():# 遍历测试集数据for imgs, labels in test_dl:#  转换为tensorlabels = torch.tensor(labels, dtype=torch.long)#  数据标签 部署到gpuimgs, labels = imgs.to(DEVICE), labels.to(DEVICE)#  前向推理outputs = model_ft(imgs)#  计算损失loss = loss_fn(outputs, labels)# 累加测试机的损失test_loss += loss.item()ps = torch.exp(outputs)top_p, top_class = ps.topk(1, dim=1)equals = top_class == labels.view(*top_class.shape)accuracy += torch.mean(equals.type(torch.FloatTensor)).item()train_losses.append(running_loss / len(train_dl))test_losses.append(test_loss / len(test_dl))print(f"Epoch {epoch + 1}/{epochs}.. "f"Train loss: {running_loss / print_every:.3f}.. "f"Test loss: {test_loss / len(test_dl):.3f}.. "f"Test accuracy: {accuracy / len(test_dl):.3f}")#  回到训练模式 训练误差清0running_loss = 0model_ft.train()
torch.save(model_ft, "aerialmodel.pth")

在这里插入图片描述

预测代码

import os
import torch
from PIL import Image
from torch import nn
from torchvision import transforms, modelsi = 0  # 识别图片计数
# 这里最好新建一个test_data文件随机放一些上面整理好的图片进去
root_path = r"D:\CODE\ImageClassify\Test"  # 待测试文件夹
names = os.listdir(root_path)for name in names:print(name)i = i + 1data_class = ['sun', 'rain', 'cloud']  # 按文件索引顺序排列#  找出文件夹中的所有图片image_path = os.path.join(root_path, name)image = Image.open(image_path)print(image)#  张量定义格式transform = transforms.Compose([transforms.Resize((256, 256)),transforms.ToTensor()])# 图片转换为张量image = transform(image)print(image.shape)#  定义resnet模型model_ft = models.resnet50()# 模型结构in_features = model_ft.fc.in_featuresmodel_ft.fc = nn.Sequential(nn.Linear(in_features, 256),nn.ReLU(),# nn.Dropout(0, 4),nn.Linear(256, 4),nn.LogSoftmax(dim=1))# 加载已经训练好的模型参数model = torch.load("aerialmodel.pth", map_location=torch.device("cpu"))# 将每张图片 调整维度image = torch.reshape(image, (1, 3, 256, 256))  # 修改待预测图片尺寸,需要与训练时一致model.eval()#  速出预测结果with torch.no_grad():output = model(image)print(output)  # 输出预测结果# print(int(output.argmax(1)))# 对结果进行处理,使直接显示出预测的种类  根据索引判别是哪一类print("第{}张图片预测为:{}".format(i, data_class[int(output.argmax(1))]))

工程目录结构

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/130761.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

CAD操作技巧学习总结

1&#xff0c;已知一个圆&#xff0c;画该圆切线。 L命令画直线&#xff0c;再tan指令确定第一个点为切点&#xff0c;依次输入&#xff08;长度&#xff09;<&#xff08;角度&#xff09;&#xff0c;如55<-45,负号为顺时针。 2&#xff0c;中心点偏移。 O命令偏移&am…

go语言 | grpc原理介绍(三)

了解 gRPC 通信模式中的消息流 gRPC 支持四种通信模式&#xff0c;分别是简单 RPC、服务端流式 RPC、客户端流式 RPC 和双向流式 RPC。 简单 RPC 在gRPC中&#xff0c;一个简单的RPC调用遵循请求-响应模型&#xff0c;通常涉及以下几个关键步骤和组件&#xff1a; 请求头&a…

鸿蒙LiteOs读源码教程+向LiteOS中添加一个简单的基于线程运行时的短作业优先调度策略

一、鸿蒙Liteos读源码教程 鸿蒙的源码是放在openharmony文件夹下&#xff0c;openharmony下的kernel文件夹存放操作系统内核的相关代码和实现。 内核是操作系统的核心部分&#xff0c;所以像负责&#xff1a;资源管理、任务调度、内存管理、设备驱动、进程通信的源码都可以在…

升级 MacOS 系统后,playCover 内游戏打不开了如何解决

我们有些小伙伴在升级了 macOS 系统后大概率会遇到之前能够正常使用的 playCover 突然游戏打不开了&#xff0c;最近 mac 刚刚正式推出了 MacOS 14.1 ,导致很多用户打开游戏会闪退&#xff0c;我们其实只需要更新一下 playCover 就可以解决 playCover 正式版更新会比较慢所以我…

基于生成对抗网络的照片上色动态算法设计与实现 - 深度学习 opencv python 计算机竞赛

文章目录 1 前言1 课题背景2 GAN(生成对抗网络)2.1 简介2.2 基本原理 3 DeOldify 框架4 First Order Motion Model5 最后 1 前言 &#x1f525; 优质竞赛项目系列&#xff0c;今天要分享的是 &#x1f6a9; 基于生成对抗网络的照片上色动态算法设计与实现 该项目较为新颖&am…

【rust/esp32】初识slint ui框架并在st7789 lcd上显示

文章目录 说在前面关于slint关于no-std关于dma准备工作相关依赖代码结果参考 说在前面 esp32版本&#xff1a;s3运行环境&#xff1a;no-std开发环境&#xff1a;wsl2LCD模块&#xff1a;ST7789V2 240*280 LCDSlint版本&#xff1a;master分支github地址&#xff1a;这里 关于s…

【音视频 | opus】opus编码的Ogg封装文件详解

&#x1f601;博客主页&#x1f601;&#xff1a;&#x1f680;https://blog.csdn.net/wkd_007&#x1f680; &#x1f911;博客内容&#x1f911;&#xff1a;&#x1f36d;嵌入式开发、Linux、C语言、C、数据结构、音视频&#x1f36d; &#x1f923;本文内容&#x1f923;&a…

汇编语言(举个栗子)

汇编语言&#xff08;Assembly Language&#xff09;是任何一种用于电子计算机、微处理器、微控制器或其他可编程器件的低级语言&#xff0c;亦称为符号语言。在汇编语言中&#xff0c;用助记符代替机器指令的操作码&#xff0c;用地址符号或标号代替指令或操作数的地址。在不同…

基于51单片机土壤湿度检测及自动浇花系统仿真(带时间显示)

wx供重浩&#xff1a;创享日记 对话框发送&#xff1a;单片机浇花 获取完整源码源文件仿真源文件原理图源文件论文报告等 单片机土壤湿度检测及自动浇花系统仿真&#xff08;带时间显示&#xff09; 具体功能&#xff1a; &#xff08;1&#xff09;液晶第一行显示实际湿度&am…

信道编码译码及MATLAB仿真

文章目录 前言一、什么是信道编码&#xff1f;二、信道编码的基本逻辑—冗余数据1、奇偶检验码2、重复码 三、编码率四、4G 和 5G 的信道编码1、卷积码2、维特比译码&#xff08;Viterbi&#xff09;—— 概率译码3、LTE 的咬尾卷积码4、LTE 的 turbo 码 五、MATLAB 仿真1、plo…

0008Java安卓程序设计-ssm基于Android平台的健康管理系统

文章目录 **摘要**目录系统实现开发环境 编程技术交流、源码分享、模板分享、网课教程 &#x1f427;裙&#xff1a;776871563 摘要 首先,论文一开始便是清楚的论述了系统的研究内容。其次,剖析系统需求分析,弄明白“做什么”,分析包括业务分析和业务流程的分析以及用例分析,…

Git 的基本操作 ——命令行

Git 的工作流程 详解如下&#xff1a; 本地仓库&#xff1a;是在开发人员自己电脑上的Git仓库,存放我们的代码(.git 隐藏文件夹就是我们的本地仓库) 远程仓库&#xff1a;是在远程服务器上的Git仓库,存放代码(可以是github.com或者gitee.com 上的仓库,或者自己该公司的服务器…

【ElasticSearch系列-05】SpringBoot整合elasticSearch

ElasticSearch系列整体栏目 内容链接地址【一】ElasticSearch下载和安装https://zhenghuisheng.blog.csdn.net/article/details/129260827【二】ElasticSearch概念和基本操作https://blog.csdn.net/zhenghuishengq/article/details/134121631【三】ElasticSearch的高级查询Quer…

PTA:前序序列创建二叉树

前序序列创建二叉树 题目输入格式输出格式输入样例&#xff08;及其对应的二叉树&#xff09;输出样例 代码 题目 编一个程序&#xff0c;读入用户输入的一串先序遍历字符串&#xff0c;根据此字符串建立一个二叉树&#xff08;以二叉链表存储&#xff09;。 例如如下的先序遍…

SpringCloudAlibaba - 项目完整搭建(Nacos + OpenFeign + Getway + Sentinel)

目录 一、SpringCloudAlibaba 项目完整搭建 1.1、初始化项目 1.1.1、创建工程 1.1.2、配置父工程的 pom.xml 1.1.3、创建子模块 1.2、user 微服务 1.2.1、配置 pom.xml 1.2.2、创建 application.yml 配置文件 1.2.3、创建启动类 1.2.4、测试 1.3、product 微服务 1…

如何使用CodeceptJS、Playwright和GitHub Actions构建端到端测试流水线

介绍 端到端测试是软件开发的一个重要方面&#xff0c;因为它确保系统的所有组件都能正确运行。CodeceptJS是一个高效且强大的端到端自动化框架&#xff0c;与Playwright 结合使用时&#xff0c;它成为自动化Web、移动甚至桌面 (Electron.js) 应用程序比较好用的工具。 在本文中…

代码随想录算法训练营第23期day38|动态规划理论基础、509. 斐波那契数、70. 爬楼梯、746. 使用最小花费爬楼梯

目录 一、动态规划理论基础 1.动态规划的解题步骤 2.动态规划应该如何debug 二、&#xff08;leetcode 509&#xff09;斐波那契数 1.递归解法 2.动态规划 1&#xff09;确定dp数组以及下标的含义 2&#xff09;确定递推公式 3&#xff09;dp数组如何初始化 4&#x…

C++虚表与虚表指针详解

类的虚表 每个包含了虚函数的类都包含一个虚表。 当一个类&#xff08;B&#xff09;继承另一个类&#xff08;A&#xff09;时&#xff0c;类B会继承类A的函数的调用权。所以如果一个基类包含了虚函数&#xff0c;那么其继承类也可调用这些虚函数&#xff0c;换句话说&…

基于ASP.NET MVC + Bootstrap的仓库管理系统

基于ASP.NET MVC Bootstrap的仓库管理系统。源码亲测可用&#xff0c;含有简单的说明文档。 适合单仓库&#xff0c;基本的仓库入库管理&#xff0c;出库管理&#xff0c;盘点&#xff0c;报损&#xff0c;移库&#xff0c;库位等管理&#xff0c;有着可视化图表。 系统采用Bo…

MySQL导入数据库报错Error Code: 2006

Error Code: 2006 - MySQL server has gone away 因为导入的某张表数据过大导致导入中途失败 , 修改max_allowed_packet 即可解决。 SET GLOBAL max_allowed_packet 1024*1024*200;