034、test

之——全纪录

目录

之——全纪录

杂谈

正文

1.下载处理数据

2.数据集概览

3.构建自定义dataset

4.初始化网络

5.训练


杂谈

        综合方法试一下。


leaves

1.下载处理数据

        从官网下载数据集:Classify Leaves | Kaggle

        解压后有一个图片集,一个提交示例,一个测试集,一个训练集。

        images,27153个树叶图片:

        test.csv,8800个:

        train.csv,18353个:


2.数据集概览

        训练集、测试集、类别:

#导包
import random
import torch
from torch import nn
from torch.nn import functional as F
from torchvision import datasets, transforms
import torchvision
import pandas as pd
import matplotlib.pyplot as plt
from d2l import torch as d2l
from PIL import Imagetrain_data=pd.read_csv(r"D:\apycharmblackhorse\leaves\train.csv")
test_data=pd.read_csv(r"D:\apycharmblackhorse/leaves/test.csv")train_images=train_data.iloc[:,0].values #把所有的训练集图片路径读进来成list
print("训练集数量:",len(train_images))
n_train=len(train_images)
test_images=test_data.iloc[:,0].values
print("测试集数量:",len(test_images))
n_test=len(test_images)train_labels = pd.get_dummies(train_data.iloc[:, 1]).values.astype(int).argmax(1)
#独热编码后找到每行最大的索引记下来就是类别号,而顺序与独热编码colums,也就是与下方排序一致
# print(len(train_labels),train_labels)#记录并排序所有的类别名
train_labels_header = pd.get_dummies(train_data.iloc[:, 1]).columns.values
print("总类别:",len(train_labels_header))
classes=len(train_labels_header)


3.构建自定义dataset

       继承 torch.utils.Dataset 类,自定义树叶分类数据集:

#继承 torch.utils.Dataset 类,自定义树叶分类数据集
class leaves_dataset(torch.utils.data.Dataset):#root数据目录, images图片路径, labels图片标签, transform数据增强def __init__(self, root, images, labels, transform):super(leaves_dataset, self).__init__()self.root = rootself.images = imagesif labels is None:self.labels = Noneelse:self.labels = labelsself.transform = transform#获得指定样本def __getitem__(self, index):image_path = self.root + self.images[index]image = Image.open(image_path)#预处理image = self.transform(image)if self.labels is None:return imagelabel = torch.tensor(self.labels[index])return image, label#获得数据集长度def __len__(self):return self.images.shape[0]

        构建读取数据与预处理:

def load_data(images, labels, batch_size, train):aug = []normalize = torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])if (train):aug = [torchvision.transforms.CenterCrop(224),transforms.RandomHorizontalFlip(),transforms.RandomVerticalFlip(),transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5),transforms.ToTensor(),normalize]else:aug = [torchvision.transforms.Resize([256, 256]),torchvision.transforms.CenterCrop(224),transforms.ToTensor(),normalize]transform = transforms.Compose(aug)dataset = leaves_dataset(r"D:\apycharmblackhorse\leaves\\", images, labels, transform=transform)if train==True:type="训练"else:type="测试"print("载入:",dataset.__len__(),type)return torch.utils.data.DataLoader(dataset=dataset, batch_size=batch_size, num_workers=0, shuffle=train)train_iter = load_data(train_images, train_labels, 512, train=True)

4.初始化网络

        使用官方预训练模型初始化网络,并修改输出类别数:

#初始化网络
net = torchvision.models.resnet18(pretrained=True)net.fc = nn.Linear(net.fc.in_features, classes)
nn.init.xavier_uniform_(net.fc.weight)
net.fc


5.训练

         定义迭代器、优化器以及其他超参数,进行训练:

# 如果param_group=True,输出层中的模型参数将使用十倍的学习率
def train_fine_tuning(net, learning_rate, batch_size=64, num_epochs=20,param_group=True):train_slices = random.sample(list(range(n_train)), 15000)test_slices = list(set(range(n_train)) - set(train_slices))train_iter = load_data(train_images[train_slices], train_labels[train_slices], batch_size, train=True)test_iter = load_data(train_images[test_slices], train_labels[test_slices], batch_size, train=False)devices = d2l.try_all_gpus()loss = nn.CrossEntropyLoss(reduction="none")if param_group:params_1x = [param for name, param in net.named_parameters()if name not in ["fc.weight", "fc.bias"]]#别的层不变,最后一层10倍学习率trainer = torch.optim.Adam([{'params': params_1x},{'params': net.fc.parameters(),'lr': learning_rate * 10}],lr=learning_rate, weight_decay=0.001)else:trainer = torch.optim.Adam(net.parameters(), lr=learning_rate,weight_decay=0.001)print(111)try:d2l.train_ch13(net, train_iter, test_iter, loss, trainer, num_epochs,devices)except Exception as e:print(e)#%%#较小的学习率,通过微调预训练获得的模型参数
train_fine_tuning(net, 1e-3)

        小破脑跑得慢,之前不用预训练5个epoch后acc大概只能到0.3  ,使用预训练后到了0.6,但实际上感觉对于树叶的针对性分类还是需要从头开始才是最好的选择,资源不够这里就不做尝试了,大概尝试情况:


CIFAR-10

1.数据集


2.未完待续

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/153979.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

信息系统项目管理师-范围管理论文提纲

快速导航 1.信息系统项目管理师-项目整合管理 2.信息系统项目管理师-项目范围管理 3.信息系统项目管理师-项目进度管理 4.信息系统项目管理师-项目成本管理 5.信息系统项目管理师-项目质量管理 6.信息系统项目管理师-项目资源管理 7.信息系统项目管理师-项目沟通管理 8.信息系…

只有cpu的时候加载模型

只有cpu的时候加载模型 checkpoint torch.load(model_path, map_locationtorch.device(‘cpu’))

Codeforces Round 910 (Div. 2) --- B-E 补题记录

B - Milena and Admirer Problem - B - Codeforces 题目大意: 现在给出一个无序序列,你可以使用任意次操作将这个无序序列修改为不递减序列,操作为你可以使用两个数a和b来替换ai,序列就变为了 ai-1, a,…

【C++ Primer Plus学习记录】for循环

很多情况下都需要程序执行重复的任务&#xff0c;C中的for循环可以轻松地完成这种任务。 我们来从程序清单5.1了解for循环所做的工作&#xff0c;然后讨论它是如何工作的。 //forloop.cpp #if 1 #include<iostream> using namespace std;int main() {int i;for (i 0; …

Ubuntu文件系统损坏:The root filesystem on /dev/sda5 requires a manual fsck

前言 Ubuntu在启动过程中&#xff0c;经常会遇到一些开故障&#xff0c;导致设备无法正常开机&#xff0c;例如文件系统损坏等。 故障描述 Ubuntu系统启动过程中&#xff0c;出现以下文件系统损坏错误&#xff1a; 产生原因 该故障是由磁盘检测不能通过导致&#xff0c;可能是因…

el-table中的文本居中

el-table中的文本居中 整个表格和内容居中的方式&#xff1a; header-cell-style设置头部居中&#xff1b; cell-style设置单元格内容居中<el-table:data"tableData":header-cell-style"{text-align:center}":cell-style"{text-align:center}&quo…

代码随想录 11.21 || 单调栈 LeetCode 84.柱状图中最大的矩形

84.柱状图中最大的矩形 给定 n 个非负整数&#xff0c;用来表示柱状图中各个柱子的高度。每个柱子彼此相邻&#xff0c;且宽度为 1。求在柱状图中&#xff0c;能够勾勒出来的矩形的最大面积。和 42.接雨水 类似&#xff0c;在由数组组成的柱状图中&#xff0c;根据条件求解。 图…

C++标准模板(STL)- 类型支持 (类型关系,检查是否能转换一个类型为另一类型,std::is_convertible)

类型特性 类型特性定义一个编译时基于模板的结构&#xff0c;以查询或修改类型的属性。 试图特化定义于 <type_traits> 头文件的模板导致未定义行为&#xff0c;除了 std::common_type 可依照其所描述特化。 定义于<type_traits>头文件的模板可以用不完整类型实例…

打破思维的玻璃罩

你是否听过这个实验&#xff1a;将一只跳蚤放进杯中&#xff0c;它很轻松就能跳出来。给杯子加上玻璃罩后&#xff0c;跳蚤一开始会不断尝试跳出来&#xff0c;但发现每次的努力都是徒劳的&#xff0c;慢慢就不再尝试。即便有一天玻璃罩被拿掉&#xff0c;跳蚤也不会认为自己可…

NLP:使用 SciKit Learn 的文本矢量化方法

一、说明 本文是使用所有 SciKit Learns 预处理方法生成文本数字表示的深入解释和教程。对于以下每个矢量化器&#xff0c;将给出一个简短的定义和实际示例&#xff1a;one-hot、count、dict、TfIdf 和哈希矢量化器。 SciKit Learn 是一个用于机器学习项目的广泛库&#xff0c;…

new Vue() 发生了什么

前言: 在Vue.js中&#xff0c;当你创建一个新的Vue实例时&#xff0c;通过 new Vue() 发生了一系列重要的操作&#xff0c;包括Vue实例的初始化、数据绑定、模板编译等。这个过程是Vue应用的核心&#xff0c;本文将深入探讨new Vue()发生了什么以及其原理&#xff0c;提供示例…

官宣!Sam Altman加入微软,OpenAI临时CEO曝光,回顾董事会‘’政变‘’始末

11月20日下午&#xff0c;微软首席执行官Satya Nadella在社交平台宣布&#xff0c;“微软仍然致力于与 OpenAI的合作伙伴关系。同时欢迎Sam Altman 和 Greg Brockman 及其团队加入微软&#xff0c;领导一个全新的AI研究团队”。 Sam第一时间对这个消息进行了确认。 此外&…

Dart笔记:glob 文件系统遍历

Dart笔记 文件系统遍历工具&#xff1a;glob 模块 作者&#xff1a;李俊才 &#xff08;jcLee95&#xff09;&#xff1a;https://blog.csdn.net/qq_28550263 邮箱 &#xff1a;291148484163.com 本文地址&#xff1a;https://blog.csdn.net/qq_28550263/article/details/13442…

2023 羊城杯 final

前言 笔者并未参加此次比赛, 仅仅做刷题记录. 题目难度中等偏下吧, 看你记不记得一些利用手法了. arrary_index_bank 考点: 数组越界 保护: 除了 Canary, 其他保护全开, 题目给了后门 漏洞点: idx/one 为 int64, 是带符号数, 所以这里存在向上越界, 并且 buf 为局部变量,…

ROS1余ROS2共存的一键安装(全)

ROS1的安装&#xff1a; ROS的一键安装&#xff08;全&#xff09;_ros一键安装_牙刷与鞋垫的博客-CSDN博客 ROS2的安装 在开始这一部分的ROS2安装之前&#xff0c;是可以安装ROS1的&#xff0c;当然如果你只需要安装ROS2的话就执行从此处开始的代码即可 我是ubuntu20.4的版…

Ansible的when语句做条件判断

环境 控制节点&#xff1a;Ubuntu 22.04Ansible 2.10.8管理节点&#xff1a;CentOS 8 使用 when 语句做条件判断 创建文件 test1.yml 如下&#xff1a; --- - hosts: alltasks:- name: task1debug:msg: "hello"when: 1 > 0- name: task2debug:msg: "OK&q…

电力感知边缘计算网关产品设计方案-业务流程设计

1.工业数据通信流程 工业数据是由仪器仪表、PLC、DCS等工业生产加工设备提供的,通过以太网连接工业边缘计算网关实现实时数据采集。按照现有的通信组网方案,在理想通信状态下可以保证有效获取工业数据的真实性和有效性。 边缘计算数据通信框架图: 2.边缘计算数据处理方案 …

makefile备忘

结构描述 目标 … : 依赖 … 命令1 命令2 . . . 标记符 CFLAGS $^ 表示所有的依赖文件 $ 表示生成的目标文件 $< 代表第一个依赖文件 调试信息选项&#xff1a;-g优化选项&#xff1a;-O编译警告选项&#xff1a;-Wall指定包含目录选项&#xff1a;-I指定库目录选项&am…

Linux驱动开发——块设备驱动

目录 一、 学习目标 二、 磁盘结构 三、块设备内核组件 四、块设备驱动核心数据结构和函数 五、块设备驱动实例 六、 习题 一、 学习目标 块设备驱动是 Linux 的第二大类驱动&#xff0c;和前面的字符设备驱动有较大的差异。要想充分理解块设备驱动&#xff0c;需要对系统…