0 简单的图像分类

本文主要针对交通标识图片进行分类,包含62类,这个就是当前科大讯飞比赛,目前准确率在0.94左右,难点如下:

1 类别不均衡,有得种类图片2百多,有个只有10个不到;

2 像素大小不同,导致有的图片很清晰,有的很模糊;

直接上代码:

import os
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import random_splitfrom torchvision import models, datasets, transforms
import torch.utils.data as tud
import numpy as np
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler
from PIL import Image
import matplotlib.pyplot as plt
import warnings
import pandas as pd
from torch.utils.data import random_splitwarnings.filterwarnings("ignore")# 检测能否使用GPU
print(#labels
torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')
)device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')
n_classes = 62  # 几种分类的
preteain = False  # 是否下载使用训练参数 有网true 没网false
epoches = 10  # 训练的轮次
traindataset = datasets.ImageFolder(root='../all/data/train_set/', transform=transforms.Compose([transforms.Resize((224,224)),#transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]))# 分割比例:比如80%的数据用于训练,20%用于验证
train_val_ratio = 0.8
train_size = int(len(traindataset) * train_val_ratio)
val_size = len(traindataset) - train_size
train_dataset, val_dataset = random_split(traindataset, [train_size, val_size])classes = traindataset.classes
print(classes)model = models.resnext50_32x4d(pretrained=preteain)
#model = models.resnet34(pretrained=preteain)if preteain == True:for param in model.parameters():param.requires_grad = Falsemodel.fc = nn.Linear(in_features=2048, out_features=n_classes, bias=True)
model = model.to(device)def train_model(model, train_loader, loss_fn, optimizer, epoch):model.train()total_loss = 0.total_corrects = 0.total = 0.for idx, (inputs, labels) in enumerate(train_loader):inputs = inputs.to(device)labels = labels.to(device)outputs = model(inputs)loss = loss_fn(outputs, labels)optimizer.zero_grad()loss.backward()optimizer.step()preds = outputs.argmax(dim=1)total_corrects += torch.sum(preds.eq(labels))total_loss += loss.item() * inputs.size(0)total += labels.size(0)total_loss = total_loss / totalacc = 100 * total_corrects / totalprint("轮次:%4d|训练集损失:%.5f|训练集准确率:%6.2f%%" % (epoch + 1, total_loss, acc))return total_loss, accdef test_model(model, test_loader, loss_fn, optimizer, epoch):model.train()total_loss = 0.total_corrects = 0.total = 0.with torch.no_grad():for idx, (inputs, labels) in enumerate(test_loader):inputs = inputs.to(device)labels = labels.to(device)outputs = model(inputs)loss = loss_fn(outputs, labels)preds = outputs.argmax(dim=1)total += labels.size(0)total_loss += loss.item() * inputs.size(0)total_corrects += torch.sum(preds.eq(labels))loss = total_loss / totalaccuracy = 100 * total_corrects / totalprint("轮次:%4d|测试集损失:%.5f|测试集准确率:%6.2f%%" % (epoch + 1, loss, accuracy))return loss, accuracyloss_fn = nn.CrossEntropyLoss().to(device)optimizer = optim.Adam(model.parameters(), lr=0.0001)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(val_dataset, batch_size=32, shuffle=True)
for epoch in range(0, epoches):loss1, acc1 = train_model(model, train_loader, loss_fn, optimizer, epoch)loss2, acc2 = test_model(model, test_loader, loss_fn, optimizer, epoch)

模型预测:

sub = pd.read_csv("../all/data/example.csv")
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

model.eval()
for path in os.listdir("../all/data/test_set/"):
    try:
        img = Image.open("../all/data/test_set/"+path)
        img_p = transform(img).unsqueeze(0).to(device)
        output = model(img_p)
        pred = output.argmax(dim=1).item()
        if img.size[0] * img.size[1]<2000:
            plt.imshow(img)
            plt.show()
        p = 100 * nn.Softmax(dim=1)(output).detach().cpu().numpy()[0]
        sub.loc[sub['ImageID'] == path,'label'] = classes[pred]
        print(f'{path} size = {img.size}, 该图像预测类别为:', classes[pred])
    except:
        print(f'error {path}')
sub.loc[sub['ImageID']=='e57471de-6527-4b9b-90a8-4f1d93909216.png','label'] = 'Under Construction'
sub.loc[sub['ImageID']=='ff38d59e-9a11-41e4-901b-67097bb0e960.png','label'] = 'Keep Left'
sub.columns = ['ImageID','Sign Name']
label_map = pd.read_excel("../all/data/label_map.xlsx")
sub_all = pd.merge(left=sub,right=label_map,on='Sign Name',how='left')
#sub_all[['ImageID','label']].to_csv('./sub_resnet34_add_img_ratio_drop_dire.csv',index=False)

个人的心得:

1 如何进行图片增强,图片增强应该注意什么(方向问题);

2 模型大小如何进行选择;

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/854427.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

滑动窗口(LeeCode209题,以JS为例)

什么是滑动窗口&#xff1f; 滑动窗口是算法中一种非常有用的技术&#xff0c;特别是在处理数据序列或数组时。它的核心思想是维护一个固定大小的窗口&#xff0c;这个窗口在数据序列上滑动&#xff0c;以便于在窗口内的元素上进行操作或计算。滑动窗口技术通常用于解决与数据…

对 2024 年美赛选题的建议

对2024年美赛选题的建议包括&#xff1a; 1. 深入探讨当下全球面临的重大问题和挑战&#xff1a;鉴于美赛通常聚焦于全球性议题&#xff0c;如气候变化、可持续发展、数据分析等&#xff0c;参赛学生应关注这些议题&#xff0c;并深入研究相关数据与背景信息&#xff0c;以提出…

趋势Deep Security(Trend Micro Deep Security)安装

趋势Deep Security安装 Deep Security下载地址&#xff1a;https://help.deepsecurity.trendmicro.com/software.html?regsen-hk&prodid1716&_ga2.165737150.1637045249.1717402661-819692893.1716530462 前言 Trend Micro Deep Security是一个提供深度包检查、入侵…

单片机建立自己的库文件(4)

文章目录 前言一、新建自己的外设文件夹1.新建外设文件夹&#xff0c;做项目好项目文件管理2.将之前写的.c .h 文件添加到文件夹中 二、在软件中添加项目 .c文件2.1 编译工程保证没问题2. 修改项目列表下的名称 三、在软件项目中添加 .h文件路径四、实际使用测试总结 前言 提示…

sheng的学习笔记-AI-集成学习(adaboost,bagging,随机森林)

ai目录&#xff1a;sheng的学习笔记-AI目录-CSDN博客 目录​​​​​​​ 集成学习 什么是集成学习 集成学习一般结构&#xff1a; 示意图 弱学习器 经典算法 Boosting 什么是boosting 方法图 AdaBoost 算法 AdaBoost示意图 流程解析&#xff1a; 错误分类率error…

太速科技-FMC213V3-基于FMC兼容1.8V IO的Full Camera Link 输入子卡

FMC213V3-基于FMC兼容1.8V IO的Full Camera Link 输入子卡 一、板卡概述 该板卡为了考虑兼容1.8V电平IO&#xff0c;适配Virtex7&#xff0c;Kintex Ultrascale&#xff0c;Virtex ultrasacle FPGA而特制&#xff0c;如果要兼容原来的3.3V 也可以修改硬件参数。板卡支持1路…

快速欧氏聚类与普通欧氏聚类比较

1、前言 文献《FEC: Fast Euclidean Clustering for Point Cloud Segmentation》介绍了一种快速欧氏聚类方法,大概原理可以参考如下图,具体原理可以参考参考文献。 2、时间效率比较:快速欧氏聚类VS普通欧氏聚类 网上搜集的快速欧式聚类,与自己手写的普通欧式聚类进行对比,…

HTTP协议简单介绍

一、HTTP协议是什么 1、HTTP协议是以TCP协议为基础的文本协议。 2、HTTP协议采用请求和响应的模式。 3、HTTP协议可以传输二进制文件、文本文件、图片等资源。 4、HTTP协议支持表单上传&#xff0c;文件上传&#xff0c;文件下载等功能。 二、HTTP协议的格式 (一)请求格式…

SLG火并6月:多强鼎立,增量用户发展成行业新题

SLG赛道进入到6月&#xff0c;《三国&#xff1a;谋定天下》、《野兽领主&#xff1a;新世界》、《无尽冬日》大量新品袭来搅动市场。 在这样的关口&#xff0c;占据SLG半壁江山的灵犀互娱《三国志战略版》先一步刊登出战报&#xff0c;宣布1亿SLG玩家已收归麾下。 但新的挑战…

Linux时间子系统6:NTP原理和Linux NTP校时机制

一、前言 上篇介绍了时间同步的基本概念和常见的时间同步协议NTP、PTP&#xff0c;本篇将详细介绍NTP的原理以及NTP在Linux上如何实现校时。 二、NTP原理介绍 1. 什么是NTP 网络时间协议&#xff08;英语&#xff1a;Network Time Protocol&#xff0c;缩写&#xff1a;NTP&a…

COVINS-G编译注意事项

install_files.sh 修改source devel/setup.bash 为 source devel/setup.zsh cv_bridge 为了防止和本机的noetic的cv_bridge冲突&#xff0c;需要放入一个旧版本的cv_bridge。 先编译好opencv3_catkin&#xff0c;然后添加cv_bridge,也就是下载vision_opencv的melodic分支到cov…

华为数通企业面试笔试实验题

1. 笔试题 1.1 实验拓扑 1.2 实验要求 公司A为小型销售公司,需要实现基本上网功能,蓝色部分为外网线,提供DHCP服务 DnsServer:114.114.114.114 帮助网管排查某一台计算机在某一台交换机的某个端口 2. 操作步骤 配置路由器相关的LAN侧接口IP地址 配置DHCP项,要求有PC1与PC2…

Java StringBuffer 和 StringBuilder 类的比较与应用

Java 中的 StringBuffer 和 StringBuilder 类都用于处理字符串&#xff0c;但它们在性能和线程安全性方面有所不同。StringBuffer 是线程安全的&#xff0c;适合多线程环境下的字符串操作&#xff1b;而 StringBuilder 则是非线程安全的&#xff0c;提供了更高的性能。本文将从…

大模型KV Cache节省神器MLA学习笔记(包含推理时的矩阵吸收分析)

首先&#xff0c;本文回顾了MHA的计算方式以及KV Cache的原理&#xff0c;然后深入到了DeepSeek V2的MLA的原理介绍&#xff0c;同时对MLA节省的KV Cache比例做了详细的计算解读。接着&#xff0c;带着对原理的理解理清了HuggingFace MLA的全部实现&#xff0c;每行代码都去对应…

软件改为开机自启动

1.按键 win R,输入“shell:startup”命令, 然后就可以打开启动目录了&#xff0c;如下&#xff1a; 2.然后&#xff0c;把要开机启动的程序的图标拖进去即可。 参考&#xff1a;开机启动项如何设置

JAVA面试(六)

缓存 MemcachedredisRedis常见数据类型和使用Redis缓存持久化RDB-快照AOF-追加文件 Redis数据过期机制惰性删除定期删除 Redis缓存淘汰策略&#xff08;8种&#xff09;算法LRU &#xff08;Least Recently Used&#xff09;&#xff1a;最近最少使用LFU&#xff08;Least Freq…

java类型转换(强制类型转换)底层转换原理,此篇带你理解清楚

介绍 Java 中的类型强制转换&#xff08;Type Casting&#xff09;可以分为基本类型&#xff08;primitive types&#xff09;的强制转换和引用类型&#xff08;reference types&#xff09;的强制转换。它们在底层的原理和实现有所不同。以下是对这两种类型强制转换的详细解释…

ElasticSearch聚合排序

聚合排序 根据之前的博客可知,ES对于聚合结果的默认排序规则有时并非是我们希望的。可以使用ES提供的sort子句进行自定义排序,有多种排序方式可供选择: 按照聚合后的文档计数的大小进行排序按照聚合后的某个指标进行排序按照每个组的名称进行排序1.1 按文档计数排序 在聚合排…

day12--150. 逆波兰表达式求值+239. 滑动窗口最大值+ 347. 前 K 个高频元素

一、150. 逆波兰表达式求值 题目链接&#xff1a;https://leetcode.cn/problems/evaluate-reverse-polish-notation/description/ 文章讲解&#xff1a;https://programmercarl.com/0150.%E9%80%86%E6%B3%A2%E5%85%B0%E8%A1%A8%E8%BE%BE%E5%BC%8F%E6%B1%82%E5%80%BC.html 视频…

R可视化:微生物相对丰度或富集热图可视化

欢迎大家关注全网生信学习者系列: WX公zhong号:生信学习者Xiao hong书:生信学习者知hu:生信学习者CDSN:生信学习者2介绍 热图(Heatmap)是一种数据可视化方法,它通过颜色的深浅或色调的变化来展示数据的分布和密度。在微生物学领域,热图常用于表示微生物在不同分组(如…