034、test

之——全纪录

目录

之——全纪录

杂谈

正文

1.下载处理数据

2.数据集概览

3.构建自定义dataset

4.初始化网络

5.训练


杂谈

        综合方法试一下。


leaves

1.下载处理数据

        从官网下载数据集:Classify Leaves | Kaggle

        解压后有一个图片集,一个提交示例,一个测试集,一个训练集。

        images,27153个树叶图片:

        test.csv,8800个:

        train.csv,18353个:


2.数据集概览

        训练集、测试集、类别:

#导包
import random
import torch
from torch import nn
from torch.nn import functional as F
from torchvision import datasets, transforms
import torchvision
import pandas as pd
import matplotlib.pyplot as plt
from d2l import torch as d2l
from PIL import Imagetrain_data=pd.read_csv(r"D:\apycharmblackhorse\leaves\train.csv")
test_data=pd.read_csv(r"D:\apycharmblackhorse/leaves/test.csv")train_images=train_data.iloc[:,0].values #把所有的训练集图片路径读进来成list
print("训练集数量:",len(train_images))
n_train=len(train_images)
test_images=test_data.iloc[:,0].values
print("测试集数量:",len(test_images))
n_test=len(test_images)train_labels = pd.get_dummies(train_data.iloc[:, 1]).values.astype(int).argmax(1)
#独热编码后找到每行最大的索引记下来就是类别号,而顺序与独热编码colums,也就是与下方排序一致
# print(len(train_labels),train_labels)#记录并排序所有的类别名
train_labels_header = pd.get_dummies(train_data.iloc[:, 1]).columns.values
print("总类别:",len(train_labels_header))
classes=len(train_labels_header)


3.构建自定义dataset

       继承 torch.utils.Dataset 类,自定义树叶分类数据集:

#继承 torch.utils.Dataset 类,自定义树叶分类数据集
class leaves_dataset(torch.utils.data.Dataset):#root数据目录, images图片路径, labels图片标签, transform数据增强def __init__(self, root, images, labels, transform):super(leaves_dataset, self).__init__()self.root = rootself.images = imagesif labels is None:self.labels = Noneelse:self.labels = labelsself.transform = transform#获得指定样本def __getitem__(self, index):image_path = self.root + self.images[index]image = Image.open(image_path)#预处理image = self.transform(image)if self.labels is None:return imagelabel = torch.tensor(self.labels[index])return image, label#获得数据集长度def __len__(self):return self.images.shape[0]

        构建读取数据与预处理:

def load_data(images, labels, batch_size, train):aug = []normalize = torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])if (train):aug = [torchvision.transforms.CenterCrop(224),transforms.RandomHorizontalFlip(),transforms.RandomVerticalFlip(),transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5),transforms.ToTensor(),normalize]else:aug = [torchvision.transforms.Resize([256, 256]),torchvision.transforms.CenterCrop(224),transforms.ToTensor(),normalize]transform = transforms.Compose(aug)dataset = leaves_dataset(r"D:\apycharmblackhorse\leaves\\", images, labels, transform=transform)if train==True:type="训练"else:type="测试"print("载入:",dataset.__len__(),type)return torch.utils.data.DataLoader(dataset=dataset, batch_size=batch_size, num_workers=0, shuffle=train)train_iter = load_data(train_images, train_labels, 512, train=True)

4.初始化网络

        使用官方预训练模型初始化网络,并修改输出类别数:

#初始化网络
net = torchvision.models.resnet18(pretrained=True)net.fc = nn.Linear(net.fc.in_features, classes)
nn.init.xavier_uniform_(net.fc.weight)
net.fc


5.训练

         定义迭代器、优化器以及其他超参数,进行训练:

# 如果param_group=True,输出层中的模型参数将使用十倍的学习率
def train_fine_tuning(net, learning_rate, batch_size=64, num_epochs=20,param_group=True):train_slices = random.sample(list(range(n_train)), 15000)test_slices = list(set(range(n_train)) - set(train_slices))train_iter = load_data(train_images[train_slices], train_labels[train_slices], batch_size, train=True)test_iter = load_data(train_images[test_slices], train_labels[test_slices], batch_size, train=False)devices = d2l.try_all_gpus()loss = nn.CrossEntropyLoss(reduction="none")if param_group:params_1x = [param for name, param in net.named_parameters()if name not in ["fc.weight", "fc.bias"]]#别的层不变,最后一层10倍学习率trainer = torch.optim.Adam([{'params': params_1x},{'params': net.fc.parameters(),'lr': learning_rate * 10}],lr=learning_rate, weight_decay=0.001)else:trainer = torch.optim.Adam(net.parameters(), lr=learning_rate,weight_decay=0.001)print(111)try:d2l.train_ch13(net, train_iter, test_iter, loss, trainer, num_epochs,devices)except Exception as e:print(e)#%%#较小的学习率,通过微调预训练获得的模型参数
train_fine_tuning(net, 1e-3)

        小破脑跑得慢,之前不用预训练5个epoch后acc大概只能到0.3  ,使用预训练后到了0.6,但实际上感觉对于树叶的针对性分类还是需要从头开始才是最好的选择,资源不够这里就不做尝试了,大概尝试情况:


CIFAR-10

1.数据集


2.未完待续

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/153979.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Codeforces Round 910 (Div. 2) --- B-E 补题记录

B - Milena and Admirer Problem - B - Codeforces 题目大意: 现在给出一个无序序列,你可以使用任意次操作将这个无序序列修改为不递减序列,操作为你可以使用两个数a和b来替换ai,序列就变为了 ai-1, a,…

【C++ Primer Plus学习记录】for循环

很多情况下都需要程序执行重复的任务&#xff0c;C中的for循环可以轻松地完成这种任务。 我们来从程序清单5.1了解for循环所做的工作&#xff0c;然后讨论它是如何工作的。 //forloop.cpp #if 1 #include<iostream> using namespace std;int main() {int i;for (i 0; …

Ubuntu文件系统损坏:The root filesystem on /dev/sda5 requires a manual fsck

前言 Ubuntu在启动过程中&#xff0c;经常会遇到一些开故障&#xff0c;导致设备无法正常开机&#xff0c;例如文件系统损坏等。 故障描述 Ubuntu系统启动过程中&#xff0c;出现以下文件系统损坏错误&#xff1a; 产生原因 该故障是由磁盘检测不能通过导致&#xff0c;可能是因…

代码随想录 11.21 || 单调栈 LeetCode 84.柱状图中最大的矩形

84.柱状图中最大的矩形 给定 n 个非负整数&#xff0c;用来表示柱状图中各个柱子的高度。每个柱子彼此相邻&#xff0c;且宽度为 1。求在柱状图中&#xff0c;能够勾勒出来的矩形的最大面积。和 42.接雨水 类似&#xff0c;在由数组组成的柱状图中&#xff0c;根据条件求解。 图…

NLP:使用 SciKit Learn 的文本矢量化方法

一、说明 本文是使用所有 SciKit Learns 预处理方法生成文本数字表示的深入解释和教程。对于以下每个矢量化器&#xff0c;将给出一个简短的定义和实际示例&#xff1a;one-hot、count、dict、TfIdf 和哈希矢量化器。 SciKit Learn 是一个用于机器学习项目的广泛库&#xff0c;…

官宣!Sam Altman加入微软,OpenAI临时CEO曝光,回顾董事会‘’政变‘’始末

11月20日下午&#xff0c;微软首席执行官Satya Nadella在社交平台宣布&#xff0c;“微软仍然致力于与 OpenAI的合作伙伴关系。同时欢迎Sam Altman 和 Greg Brockman 及其团队加入微软&#xff0c;领导一个全新的AI研究团队”。 Sam第一时间对这个消息进行了确认。 此外&…

Dart笔记:glob 文件系统遍历

Dart笔记 文件系统遍历工具&#xff1a;glob 模块 作者&#xff1a;李俊才 &#xff08;jcLee95&#xff09;&#xff1a;https://blog.csdn.net/qq_28550263 邮箱 &#xff1a;291148484163.com 本文地址&#xff1a;https://blog.csdn.net/qq_28550263/article/details/13442…

2023 羊城杯 final

前言 笔者并未参加此次比赛, 仅仅做刷题记录. 题目难度中等偏下吧, 看你记不记得一些利用手法了. arrary_index_bank 考点: 数组越界 保护: 除了 Canary, 其他保护全开, 题目给了后门 漏洞点: idx/one 为 int64, 是带符号数, 所以这里存在向上越界, 并且 buf 为局部变量,…

ROS1余ROS2共存的一键安装(全)

ROS1的安装&#xff1a; ROS的一键安装&#xff08;全&#xff09;_ros一键安装_牙刷与鞋垫的博客-CSDN博客 ROS2的安装 在开始这一部分的ROS2安装之前&#xff0c;是可以安装ROS1的&#xff0c;当然如果你只需要安装ROS2的话就执行从此处开始的代码即可 我是ubuntu20.4的版…

电力感知边缘计算网关产品设计方案-业务流程设计

1.工业数据通信流程 工业数据是由仪器仪表、PLC、DCS等工业生产加工设备提供的,通过以太网连接工业边缘计算网关实现实时数据采集。按照现有的通信组网方案,在理想通信状态下可以保证有效获取工业数据的真实性和有效性。 边缘计算数据通信框架图: 2.边缘计算数据处理方案 …

Linux驱动开发——块设备驱动

目录 一、 学习目标 二、 磁盘结构 三、块设备内核组件 四、块设备驱动核心数据结构和函数 五、块设备驱动实例 六、 习题 一、 学习目标 块设备驱动是 Linux 的第二大类驱动&#xff0c;和前面的字符设备驱动有较大的差异。要想充分理解块设备驱动&#xff0c;需要对系统…

高效开发与设计:提效Spring应用的运行效率和生产力 | 京东云技术团队

引言 现状和背景 Spring框架是广泛使用的Java开发框架之一&#xff0c;它提供了强大的功能和灵活性&#xff0c;但在大型应用中&#xff0c;由于Spring框架的复杂性和依赖关系&#xff0c;应用的启动时间和性能可能会受到影响。这可能导致开发过程中的迟缓和开发效率低下。优…

Golang基础-面向过程篇

文章目录 基本语法变量常量函数import导包匿名导包 指针defer静态数组动态数组(slice)定义方式slice追加元素slice截取 map定义方式map使用方式 基本语法 go语言输出hello world的语法如下 package mainimport ("fmt""time" )func main() {fmt.Println(&…

循环链表2

循环链表的实现 对于数据结构中所有的结构而言&#xff0c;每一次都是用之前初始化&#xff08;处理一开始的随机值&#xff09;一下&#xff0c; 用完销毁&#xff08;不管有没有malloc都能用&#xff0c;用了可以保证没有动态内存泄漏了&#xff09;一下 而在C里面&#x…

Dubbo开发系列

一、概述 以上是 Dubbo 的工作原理图&#xff0c;从抽象架构上分为两层&#xff1a;服务治理抽象控制面 和 Dubbo 数据面 。 服务治理控制面。服务治理控制面不是特指如注册中心类的单个具体组件&#xff0c;而是对 Dubbo 治理体系的抽象表达。控制面包含协调服务发现的注册中…

PLC设备相关常用英文单词(一)

PLC设备相关常用英文单词&#xff08;一&#xff09; Baud rate 波特率Bus 总线Binary 二进制Configuration 组态Consistent data 一致性数据Counter 计数器Cycle time 循环时间Conveyor 传送Device names 设备名称Debug 调试Download 下载Expand 扩展Fix 固定Flow 流量Functio…

【LeetCode:689. 三个无重叠子数组的最大和 | 序列dp+前缀和】

&#x1f680; 算法题 &#x1f680; &#x1f332; 算法刷题专栏 | 面试必备算法 | 面试高频算法 &#x1f340; &#x1f332; 越难的东西,越要努力坚持&#xff0c;因为它具有很高的价值&#xff0c;算法就是这样✨ &#x1f332; 作者简介&#xff1a;硕风和炜&#xff0c;…

WMS系统先验后收策略

在制造业工厂的仓库管理中&#xff0c;确保物料的质量和数量是至关重要的。传统的仓库管理方式往往采用“先收后验”策略&#xff0c;即先接收物料&#xff0c;然后再进行质量检验。然而&#xff0c;这种方式存在一定的风险&#xff0c;例如不良品流入、数量不准确等问题。为了…

腾讯云服务器标准型S5实例CPU性能如何?配置特性说明

腾讯云服务器CVM标准型S5实例具有稳定的计算性能&#xff0c;CVM 2核2G S5活动优惠价格280.8元一年自带1M带宽&#xff0c;15个月313.2元、2核4G配置748.2元15个月&#xff0c;CPU内存配置还可以选择4核8G、8核16G等配置&#xff0c;公网带宽可选1M、3M、5M或10M&#xff0c;腾…