pytorch实现水果2分类(蓝莓,苹果)

1.数据集的路径,结构

dataset.py

目的:

        输入:没有输入,路径是写死了的。

        输出:返回的是一个对象,里面有self.data。self.data是一个列表,里面是(图片路径.jpg,标签)

        -data[item]返回的是(img_tensor , one-hot编码)。one-hot编码是[0,1]或者[1,0]

import glob
import os.pathimport cv2
import torch
from torch.utils.data import Dataset
from torchvision import transformsclass DtataAndLabel(Dataset):def __init__(self,path='fruits',is_train=True):self.tran=transforms.Compose([transforms.ToTensor(),transforms.Resize(size=(88,88))])is_train='train' if True else 'test'self.data=[]path=os.path.join(path,is_train)print('path=',path)print(os.path.join(path, '*', '*'))img_paths=glob.glob(os.path.join(path,'*','*'))for img_path in img_paths:label=0 if img_path.split('\\')[-2]=='blueberry' else 1self.data.append((img_path,label))def __getitem__(self, idx):#每一张图片返回一个img_tensor,one_hotimg_path,label =self.data[idx]img=cv2.imread(img_path)# img_gray=cv2.cvtColor(img,cv2.COLOR_BGR2GRAY)img_tensor=self.tran(img)img_tensor=img_tensor/255img_tensor=torch.flatten(img_tensor)one_hot=torch.zeros(2)one_hot[label]=1return img_tensor,one_hotdef __len__(self):return len(self.data)if __name__ == '__main__':# 测试data=DtataAndLabel()print(data[1][0].shape)print(data[1][1])

net.py

目的:将输入维度(k(k是加载进去的图片数),88,88,3)三通道的宽高是88,88,通过网络变化为(k,2)。

import torch.nn
import torch.nn as nnclass Net(nn.Module):def __init__(self):super().__init__()self.model = nn.Sequential(nn.Linear(88*88*3, 800),nn.ReLU(),nn.Linear(800, 500),nn.ReLU(),nn.Linear(500, 800),nn.ReLU(),nn.Linear(800, 200),nn.ReLU(),nn.Linear(200, 2),)self.softmax=nn.Softmax(dim=1)def forward(self,x):x=self.model(x)x=self.softmax(x)return x
if __name__ == '__main__':net=Net()#测试一下x=torch.randn(1,100*100)out=net(x)print(out.shape)

test_train.py

目的:将图像丢进模型,然后训练出最优模型

步骤:

       1.定义初始化

                -定义拿到data对象

                -定义加载器分批加载,这里可以变换维度

                -定义初始化网络

                -定义损失函数,这里采用了均方差函数

                -定义优化器

        2.实现训练

                -将每一批数据丢给网络,此时维度发生了变化,产生了升维

                -使用优化器        

                        ---自动梯度清0

                        ---自动求导更新参数

                -计算损失值和准确度

        ·~自己建一个文件夹

import torch.optim
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdmfrom net import Net
from dataset import DtataAndLabel
import torch.nn as nn
class TrainAndTest():def __init__(self):self.writer = SummaryWriter("logs")self.train_data=DtataAndLabel(is_train=True)self.test_data=DtataAndLabel(is_train=False)#使用加载器分批加载self.train_loader=DataLoader(self.train_data,batch_size=10,shuffle=True)self.test_loader=DataLoader(self.test_data,batch_size=10,shuffle=True)#初始化网络#损失函数#优化器net=Net()self.net=netself.loss=nn.MSELoss()self.opt=torch.optim.Adam(net.parameters(),lr=0.001)self.min_loss=100.0self.weight_path='weight/best.pt'def train(self,epoch):sum_loss = 0sum_acc = 0for img_tensors, targets in tqdm(self.train_loader, desc="train...", total=len(self.train_loader)):out = self.net(img_tensors)loss = self.loss(out, targets)self.opt.zero_grad()loss.backward()self.opt.step()sum_loss += loss.item()pred_cls = torch.argmax(out, dim=1)target_cls = torch.argmax(targets, dim=1)accuracy = torch.mean(torch.eq(pred_cls, target_cls).to(torch.float32))sum_acc += accuracy.item()avg_loss = sum_loss / len(self.train_loader)avg_acc = sum_acc / len(self.train_loader)print(f'train:loss{round(avg_loss, 3)} acc:{round(avg_acc, 3)}')self.writer.add_scalars("loss", {"train_avg_loss": avg_loss}, epoch)self.writer.add_scalars("acc", {"train_avg_acc": avg_acc}, epoch)def test(self,epoch):sum_loss = 0sum_acc = 0for img_tensors, targets in tqdm(self.test_loader, desc="test...", total=len(self.test_loader)):out = self.net(img_tensors)loss = self.loss(out, targets)sum_loss += loss.item()pred_cls = torch.argmax(out, dim=1)target_cls = torch.argmax(targets, dim=1)accuracy = torch.mean(torch.eq(pred_cls, target_cls).to(torch.float32))sum_acc += accuracy.item()avg_loss = sum_loss / len(self.test_loader)avg_acc = sum_acc / len(self.test_loader)print(f'test:loss{round(avg_loss, 3)} acc:{round(avg_acc, 3)}')self.writer.add_scalars("loss", {"test_avg_loss": avg_loss}, epoch)self.writer.add_scalars("acc", {"test_avg_acc": avg_acc}, epoch)if avg_loss<self.min_loss:self.min_loss=min(self.min_loss,avg_loss)torch.save(self.net.state_dict(), self.weight_path)def run(self):for epo in range(100):self.train(epo)self.test(epo)if __name__ == '__main__':trainer=TrainAndTest()trainer.run()

精度的计算:

                比如通过网络出现的维度是(1,2),其数值是[[0.9 , 0.1]](0.9与0.1表示预测的两个类别的概率)。我们通过maxarg取到其中最大的索引0,与之前真实的标签0或者1做比较。从而可以得出结果

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/869914.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

JMH325【剑侠情缘3】第2版80级橙武网游单机更稳定亲测视频安装教学更新整合收集各类修改教学补丁兴趣可以慢慢探索

资源介绍&#xff1a; 是否需要虚拟机&#xff1a;是 文件大小&#xff1a;压缩包约14G 支持系统&#xff1a;win10、win11 硬件需求&#xff1a;运行内存8G 4核及以上CPU独立显卡 下载方式&#xff1a;百度网盘 任务修复&#xff1a; 1&#xff0c;掌门任务&#xff08…

【Android组件】封装加载弹框

&#x1f4d6;封装加载弹框 ✅1. 构造LoadingDialog✅2. 调用LoadingDialog 效果&#xff1a; ✅1. 构造LoadingDialog 构造LoadingDialog类涉及到设计模式中的建造者模式&#xff0c;进行链式调用&#xff0c;注重的是构建的过程&#xff0c;设置需要的属性。 步骤一&#x…

[数据结构] 归并排序快速排序 及非递归实现

&#xff08;&#xff09;标题&#xff1a;[数据结构] 归并排序&&快速排序 及非递归实现 水墨不写bug &#xff08;图片来源于网络&#xff09; 目录 (一)快速排序 类比递归谋划非递归 快速排序的非递归实现&#xff1a; &#xff08;二&#xff09;归并排序 归…

Elasticsearch文档_id以数组方式返回

背景需求是只需要文档的_id字段&#xff0c;并且_id组装成一个数组。 在搜索请求中使用 script_fields 来整理 _id 为数组输出&#xff1a; POST goods_info/_search?size0 {"query": {"term": {"brand": {"value": "MGC"…

明白这两大关键点,轻松脱单不再是难题!

很多未婚男女都渴望找到心仪的伴侣&#xff0c;建立稳定的情感关系&#xff0c;但往往在脱单的过程中跌跌撞撞。平时与同学、同事之间相处得很融洽&#xff0c;一旦遇到心仪的异性&#xff0c;情商直接掉线&#xff0c;难道情商也会选择性地发挥作用吗&#xff1f;其实&#xf…

什么牌子的开放式耳机好用?南卡、Cleer、小米、开石超值机型力荐!

​开放式耳机在如今社会中已经迅速成为大家购买耳机的新趋势&#xff0c;深受喜欢听歌和热爱运动的人群欢迎。当大家谈到佩戴的稳固性时&#xff0c;开放式耳机都会收到一致好评。对于热爱运动的人士而言&#xff0c;高品质的开放式耳机无疑是理想之选。特别是在近年来的一些骑…

AnimateLCM:高效生成连贯真实的视频

视频扩散模型因其能够生成连贯且高保真的视频而日益受到关注。然而&#xff0c;迭代去噪过程使得这类模型计算密集且耗时&#xff0c;限制了其应用范围。香港中文大学 MMLab、Avolution AI、上海人工智能实验室和商汤科技公司的研究团队提出了AnimateLCM&#xff0c;这是一种允…

电子电气架构 --- 关于DoIP的一些闲思 上

我是穿拖鞋的汉子,魔都中坚持长期主义的汽车电子工程师。 老规矩,分享一段喜欢的文字,避免自己成为高知识低文化的工程师: 屏蔽力是信息过载时代一个人的特殊竞争力,任何消耗你的人和事,多看一眼都是你的不对。非必要不费力证明自己,无利益不试图说服别人,是精神上的节…

JavaDS —— 单链表 与 LinkedList

顺序表和链表区别 ArrayList &#xff1a; 底层使用连续的空间&#xff0c;可以随机访问某下标的元素&#xff0c;时间复杂度为O&#xff08;1&#xff09; 但是在插入和删除操作的时候&#xff0c;需要将该位置的后序元素整体往前或者向后移动&#xff0c;时间复杂度为O&…

什么是智能制造?

科技的每一次飞跃都深刻改变着我们的生产生活方式。其中&#xff0c;智能制造作为工业4.0的核心概念&#xff0c;正引领着全球制造业向更加高效、灵活、智能的方向迈进。那么&#xff0c;究竟什么是智能制造&#xff1f;它如何重塑我们的工业版图&#xff0c;又将对未来社会产生…

TTT架构超越Transformer,ML模型替代RNN隐藏状态!

目录 01 算法原理 02 骨干架构 03 实验结果 一种崭新的大语言模型&#xff08;LLM&#xff09;架构有望取代当前主导 AI 领域的 Transformer&#xff0c;并在性能上超越 Mamba。 论文地址&#xff1a;https://arxiv.org/abs/2407.04620 本周一&#xff0c;关于 Test-Time Tr…

修复 Ubuntu 24.04 Dock 丢失应用程序图标

找出应用程序窗口的类名 首先&#xff0c;您需要启动应用程序窗口。然后&#xff0c;按 Alt F2 启动“运行 Command”对话框。当对话框打开时&#xff0c;输入 lg 并按 Enter 键。 在该窗口中&#xff0c;单击Windows按钮&#xff0c;然后找出目标应用程序窗口的类名称。 在/…

Flutter——最详细(Table)网格、表格组件使用教程

背景 用于展示表格组件&#xff0c;可指定线宽、列宽、文字方向等属性 属性作用columnWidths列的宽度defaultVerticalAlignment网格内部组件摆放方向border网格样式修改children表格里面的组件textDirection文本排序方向 import package:flutter/material.dart;class CustomTa…

公众号运营秘籍:8 大策略让你的粉丝翻倍!

在当今信息爆炸的时代&#xff0c;微信公众号的运营者们面临着前所未有的挑战&#xff1a;如何在这个充满竞争的红海中脱颖而出&#xff0c;吸引并留住粉丝&#xff1f;事实上&#xff0c;微信公众号的红利期并未完全过去&#xff0c;关键在于我们如何策略性地运营&#xff0c;…

使用PEFT库进行ChatGLM3-6B模型的QLORA高效微调

PEFT库进行ChatGLM3-6B模型QLORA高效微调 QLORA微调ChatGLM3-6B模型安装相关库使用ChatGLM3-6B模型GPU显存占用准备数据集加载数据集数据处理数据集处理加载量化模型-4bit预处理量化模型配置LoRA适配器训练超参数配置开始训练保存LoRA模型模型推理合并模型使用微调后的模型 QLO…

【Pytorch实用教程】transformer中创建嵌入层的模块nn.Embedding的用法

文章目录 1. nn.Embedding的简单介绍1.1 基本用法1.2 示例代码1.3 注意事项2. 通俗的理解num_embeddings和embedding_dim2.1 num_embeddings2.2 embedding_dim2.3 使用场景举例结合示例1. nn.Embedding的简单介绍 nn.Embedding 是 PyTorch 中的一个模块,用于创建一个嵌入层。…

准大一新生开学千万要带证件照用途大揭秘

1、提前关注好都有哪些考场&#xff0c;以及这些考场大致在网页的哪个位置。比如我选对外经贸大学&#xff0c;我就直接找到第二个点进去。 2、电脑上同时开了谷歌浏览器和IE浏览器&#xff0c;以及手机也登陆了。亲测下来&#xff0c;同一时间刷新&#xff0c;谷歌浏览器能显示…

​cesium、three.js,三维GIS为啥那么热?到底怎么学呢?

​cesium、three.js&#xff0c;三维GIS为啥那么热&#xff1f;他们的应用场景都是什么呢&#xff1f;接下来我们可以一起来看看~ 三维GIS的应用 GIS和3D的应用是趋势&#xff0c;目前已经有很多应用案例&#xff0c;例如BIM&#xff0c;智慧城市&#xff0c;数字孪生等。如下…

汇聚荣拼多多电商实力强吗?

汇聚荣拼多多电商实力强吗?汇聚荣拼多多&#xff0c;作为中国电商领域的后起之秀&#xff0c;已经在市场上占据了一席之地。那么&#xff0c;它的实力究竟如何呢?在回答这个问题之前&#xff0c;我们需要先了解一下拼多多的基本情况。拼多多是一家以社交电商为主要模式的购物…

3个方法教你如果快速绕过Excel工作表保护密码

在日常生活中&#xff0c;我们可能会遇到一些特殊情况&#xff0c;比如不小心忘记了Excel文件中设置的打开密码。别担心&#xff01;这里为您带来一份详细的Excel文件密码移除教程&#xff0c;助您轻松绕过Excel工作表保护。 方法一&#xff1a;使用备份文件 如果您有文件的备…