0 简单的图像分类

本文主要针对交通标识图片进行分类,包含62类,这个就是当前科大讯飞比赛,目前准确率在0.94左右,难点如下:

1 类别不均衡,有得种类图片2百多,有个只有10个不到;

2 像素大小不同,导致有的图片很清晰,有的很模糊;

直接上代码:

import os
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import random_splitfrom torchvision import models, datasets, transforms
import torch.utils.data as tud
import numpy as np
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler
from PIL import Image
import matplotlib.pyplot as plt
import warnings
import pandas as pd
from torch.utils.data import random_splitwarnings.filterwarnings("ignore")# 检测能否使用GPU
print(#labels
torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')
)device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')
n_classes = 62  # 几种分类的
preteain = False  # 是否下载使用训练参数 有网true 没网false
epoches = 10  # 训练的轮次
traindataset = datasets.ImageFolder(root='../all/data/train_set/', transform=transforms.Compose([transforms.Resize((224,224)),#transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]))# 分割比例:比如80%的数据用于训练,20%用于验证
train_val_ratio = 0.8
train_size = int(len(traindataset) * train_val_ratio)
val_size = len(traindataset) - train_size
train_dataset, val_dataset = random_split(traindataset, [train_size, val_size])classes = traindataset.classes
print(classes)model = models.resnext50_32x4d(pretrained=preteain)
#model = models.resnet34(pretrained=preteain)if preteain == True:for param in model.parameters():param.requires_grad = Falsemodel.fc = nn.Linear(in_features=2048, out_features=n_classes, bias=True)
model = model.to(device)def train_model(model, train_loader, loss_fn, optimizer, epoch):model.train()total_loss = 0.total_corrects = 0.total = 0.for idx, (inputs, labels) in enumerate(train_loader):inputs = inputs.to(device)labels = labels.to(device)outputs = model(inputs)loss = loss_fn(outputs, labels)optimizer.zero_grad()loss.backward()optimizer.step()preds = outputs.argmax(dim=1)total_corrects += torch.sum(preds.eq(labels))total_loss += loss.item() * inputs.size(0)total += labels.size(0)total_loss = total_loss / totalacc = 100 * total_corrects / totalprint("轮次:%4d|训练集损失:%.5f|训练集准确率:%6.2f%%" % (epoch + 1, total_loss, acc))return total_loss, accdef test_model(model, test_loader, loss_fn, optimizer, epoch):model.train()total_loss = 0.total_corrects = 0.total = 0.with torch.no_grad():for idx, (inputs, labels) in enumerate(test_loader):inputs = inputs.to(device)labels = labels.to(device)outputs = model(inputs)loss = loss_fn(outputs, labels)preds = outputs.argmax(dim=1)total += labels.size(0)total_loss += loss.item() * inputs.size(0)total_corrects += torch.sum(preds.eq(labels))loss = total_loss / totalaccuracy = 100 * total_corrects / totalprint("轮次:%4d|测试集损失:%.5f|测试集准确率:%6.2f%%" % (epoch + 1, loss, accuracy))return loss, accuracyloss_fn = nn.CrossEntropyLoss().to(device)optimizer = optim.Adam(model.parameters(), lr=0.0001)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(val_dataset, batch_size=32, shuffle=True)
for epoch in range(0, epoches):loss1, acc1 = train_model(model, train_loader, loss_fn, optimizer, epoch)loss2, acc2 = test_model(model, test_loader, loss_fn, optimizer, epoch)

模型预测:

sub = pd.read_csv("../all/data/example.csv")
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

model.eval()
for path in os.listdir("../all/data/test_set/"):
    try:
        img = Image.open("../all/data/test_set/"+path)
        img_p = transform(img).unsqueeze(0).to(device)
        output = model(img_p)
        pred = output.argmax(dim=1).item()
        if img.size[0] * img.size[1]<2000:
            plt.imshow(img)
            plt.show()
        p = 100 * nn.Softmax(dim=1)(output).detach().cpu().numpy()[0]
        sub.loc[sub['ImageID'] == path,'label'] = classes[pred]
        print(f'{path} size = {img.size}, 该图像预测类别为:', classes[pred])
    except:
        print(f'error {path}')
sub.loc[sub['ImageID']=='e57471de-6527-4b9b-90a8-4f1d93909216.png','label'] = 'Under Construction'
sub.loc[sub['ImageID']=='ff38d59e-9a11-41e4-901b-67097bb0e960.png','label'] = 'Keep Left'
sub.columns = ['ImageID','Sign Name']
label_map = pd.read_excel("../all/data/label_map.xlsx")
sub_all = pd.merge(left=sub,right=label_map,on='Sign Name',how='left')
#sub_all[['ImageID','label']].to_csv('./sub_resnet34_add_img_ratio_drop_dire.csv',index=False)

个人的心得:

1 如何进行图片增强,图片增强应该注意什么(方向问题);

2 模型大小如何进行选择;

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/854427.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

滑动窗口(LeeCode209题,以JS为例)

什么是滑动窗口&#xff1f; 滑动窗口是算法中一种非常有用的技术&#xff0c;特别是在处理数据序列或数组时。它的核心思想是维护一个固定大小的窗口&#xff0c;这个窗口在数据序列上滑动&#xff0c;以便于在窗口内的元素上进行操作或计算。滑动窗口技术通常用于解决与数据…

对 2024 年美赛选题的建议

对2024年美赛选题的建议包括&#xff1a; 1. 深入探讨当下全球面临的重大问题和挑战&#xff1a;鉴于美赛通常聚焦于全球性议题&#xff0c;如气候变化、可持续发展、数据分析等&#xff0c;参赛学生应关注这些议题&#xff0c;并深入研究相关数据与背景信息&#xff0c;以提出…

单片机建立自己的库文件(4)

文章目录 前言一、新建自己的外设文件夹1.新建外设文件夹&#xff0c;做项目好项目文件管理2.将之前写的.c .h 文件添加到文件夹中 二、在软件中添加项目 .c文件2.1 编译工程保证没问题2. 修改项目列表下的名称 三、在软件项目中添加 .h文件路径四、实际使用测试总结 前言 提示…

sheng的学习笔记-AI-集成学习(adaboost,bagging,随机森林)

ai目录&#xff1a;sheng的学习笔记-AI目录-CSDN博客 目录​​​​​​​ 集成学习 什么是集成学习 集成学习一般结构&#xff1a; 示意图 弱学习器 经典算法 Boosting 什么是boosting 方法图 AdaBoost 算法 AdaBoost示意图 流程解析&#xff1a; 错误分类率error…

太速科技-FMC213V3-基于FMC兼容1.8V IO的Full Camera Link 输入子卡

FMC213V3-基于FMC兼容1.8V IO的Full Camera Link 输入子卡 一、板卡概述 该板卡为了考虑兼容1.8V电平IO&#xff0c;适配Virtex7&#xff0c;Kintex Ultrascale&#xff0c;Virtex ultrasacle FPGA而特制&#xff0c;如果要兼容原来的3.3V 也可以修改硬件参数。板卡支持1路…

快速欧氏聚类与普通欧氏聚类比较

1、前言 文献《FEC: Fast Euclidean Clustering for Point Cloud Segmentation》介绍了一种快速欧氏聚类方法,大概原理可以参考如下图,具体原理可以参考参考文献。 2、时间效率比较:快速欧氏聚类VS普通欧氏聚类 网上搜集的快速欧式聚类,与自己手写的普通欧式聚类进行对比,…

SLG火并6月:多强鼎立,增量用户发展成行业新题

SLG赛道进入到6月&#xff0c;《三国&#xff1a;谋定天下》、《野兽领主&#xff1a;新世界》、《无尽冬日》大量新品袭来搅动市场。 在这样的关口&#xff0c;占据SLG半壁江山的灵犀互娱《三国志战略版》先一步刊登出战报&#xff0c;宣布1亿SLG玩家已收归麾下。 但新的挑战…

Linux时间子系统6:NTP原理和Linux NTP校时机制

一、前言 上篇介绍了时间同步的基本概念和常见的时间同步协议NTP、PTP&#xff0c;本篇将详细介绍NTP的原理以及NTP在Linux上如何实现校时。 二、NTP原理介绍 1. 什么是NTP 网络时间协议&#xff08;英语&#xff1a;Network Time Protocol&#xff0c;缩写&#xff1a;NTP&a…

华为数通企业面试笔试实验题

1. 笔试题 1.1 实验拓扑 1.2 实验要求 公司A为小型销售公司,需要实现基本上网功能,蓝色部分为外网线,提供DHCP服务 DnsServer:114.114.114.114 帮助网管排查某一台计算机在某一台交换机的某个端口 2. 操作步骤 配置路由器相关的LAN侧接口IP地址 配置DHCP项,要求有PC1与PC2…

大模型KV Cache节省神器MLA学习笔记(包含推理时的矩阵吸收分析)

首先&#xff0c;本文回顾了MHA的计算方式以及KV Cache的原理&#xff0c;然后深入到了DeepSeek V2的MLA的原理介绍&#xff0c;同时对MLA节省的KV Cache比例做了详细的计算解读。接着&#xff0c;带着对原理的理解理清了HuggingFace MLA的全部实现&#xff0c;每行代码都去对应…

软件改为开机自启动

1.按键 win R,输入“shell:startup”命令, 然后就可以打开启动目录了&#xff0c;如下&#xff1a; 2.然后&#xff0c;把要开机启动的程序的图标拖进去即可。 参考&#xff1a;开机启动项如何设置

JAVA面试(六)

缓存 MemcachedredisRedis常见数据类型和使用Redis缓存持久化RDB-快照AOF-追加文件 Redis数据过期机制惰性删除定期删除 Redis缓存淘汰策略&#xff08;8种&#xff09;算法LRU &#xff08;Least Recently Used&#xff09;&#xff1a;最近最少使用LFU&#xff08;Least Freq…

day12--150. 逆波兰表达式求值+239. 滑动窗口最大值+ 347. 前 K 个高频元素

一、150. 逆波兰表达式求值 题目链接&#xff1a;https://leetcode.cn/problems/evaluate-reverse-polish-notation/description/ 文章讲解&#xff1a;https://programmercarl.com/0150.%E9%80%86%E6%B3%A2%E5%85%B0%E8%A1%A8%E8%BE%BE%E5%BC%8F%E6%B1%82%E5%80%BC.html 视频…

R可视化:微生物相对丰度或富集热图可视化

欢迎大家关注全网生信学习者系列: WX公zhong号:生信学习者Xiao hong书:生信学习者知hu:生信学习者CDSN:生信学习者2介绍 热图(Heatmap)是一种数据可视化方法,它通过颜色的深浅或色调的变化来展示数据的分布和密度。在微生物学领域,热图常用于表示微生物在不同分组(如…

【leetcode刷题】面试经典150题 , 27. 移除元素

leetcode刷题 面试经典150 27. 移除元素 难度&#xff1a;简单 文章目录 一、题目内容二、自己实现代码2.1 方法一&#xff1a;直接硬找2.1.1 实现思路2.1.2 实现代码2.1.3 结果分析 2.2 方法二&#xff1a;排序整体删除再补充2.1.1 实现思路2.1.2 实现代码2.1.3 结果分析 三、…

字符串专题详解

目录 字符串hash进阶 KMP算法 next数组 KMP算法 KMP算法优化 字符串hash进阶 字符串hash是指将一个字符串S映射为一个整数&#xff0c;使得该整数可以尽可能唯一地代表字符串S。那么在一定程度上&#xff0c;如果两个字符串转换成的整数相等&#xff0c;就可以认为这两个…

麻了,5年Java竟然不知道幂等......

在分布式系统中&#xff0c;接口幂等性是确保操作一致性的关键特性。 啥是幂等性 幂等性 指的是在给定的条件下&#xff0c;无论操作执行多少次&#xff0c;其结果都保持不变。在接口设计中&#xff0c;幂等性意味着使用相同的参数多次调用接口&#xff0c;应产生与单次调用相…

STM32学习笔记(五)--TIM输出比较PWM详解

&#xff08;1&#xff09;配置步骤1.配置RCC外设时钟 开启GPIO以及TIM外设2.配置时基单元的时钟 包含时钟源选择配置初始化时基单元3.配置输出比较单元 包含CCR的值 输出比较模式 极性选择 输出使能等4.配置GPIO口 初始化为复用式推挽输出的配置5.运行控制 启动计数器 输出PWM…

Windows CSC服务特权提升漏洞(CVE-2024-26229)

文章目录 前言声明一、漏洞描述二、漏洞成因三、影响版本四、漏洞复现五、CVE-2024-26229 BOF六、修复方案 前言 Windows CSC服务特权提升漏洞。 当程序向缓冲区写入的数据超出其处理能力时&#xff0c;就会发生基于堆的缓冲区溢出&#xff0c;从而导致多余的数据溢出到相邻的…

QT 的文件

QT 和C、linux 一样&#xff0c;也有自带的文件系统. 它的操作和C、c差不多&#xff0c;不过也需要我们来了解一下。 输入输出设备类 QObject 有一个子类&#xff0c;名为 QIODevice 类&#xff0c;如其名字&#xff0c;该类是管理所有输入输出设备的类。 比如文件、网络套…