python实战(十五)——中文手写体数字图像CNN分类

一、任务背景

        本次python实战,我们使用来自Kaggle的数据集《Chinese MNIST》进行CNN分类建模,不同于经典的MNIST数据集,我们这次使用的数据集是汉字手写体数字。除了常规的汉字“零”到“九”之外还多了“十”、“百”、“千”、“万”、“亿”,共15种汉字数字

二、python建模

1、数据读取

        首先,读取jpg数据文件,可以看到总共有15000张图像数据。

import pandas as pd
import ospath = '/kaggle/input/chinese-mnist/data/data/'
files = os.listdir(path)
print('数据总量:', len(files))

        我们也可以打印一张图片出来看看。

import matplotlib.pyplot as plt
import matplotlib.image as mpimg# 定义图片路径
image_path = path+files[3]# 加载图片
image = mpimg.imread(image_path)# 绘制图片
plt.figure(figsize=(3, 3))
plt.imshow(image)
plt.axis('off')  # 关闭坐标轴
plt.show()

2、数据集构建

        加载必要的库以便后续使用,再定义一些超参数。

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image
from torch.utils.data import Dataset, DataLoader, random_split
from sklearn.metrics import precision_score, recall_score, f1_score# 超参数
batch_size = 64
learning_rate = 0.01
num_epochs = 5# 数据预处理
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,))
])

        这里,我们看一看数据集介绍就会知道图片名称及其含义,需要从chinese_mnist.csv文件中根据图片名称中的几个数字来确定图片对应的标签。

# 获取所有图片文件的路径
all_images = [os.path.join(path, img) for img in os.listdir(path) if img.endswith('.jpg')]# 读取索引-标签对应关系csv文件,并将'suite_id', 'sample_id', 'code'设置为索引列便于查找
index_df = pd.read_csv('/kaggle/input/chinese-mnist/chinese_mnist.csv')
index_df.set_index(['suite_id', 'sample_id', 'code'], inplace=True)# 定义函数,根据各索引取值定位图片对应的数值标签value
def get_label_from_index(filename, index_df):suite_id, sample_id, code = map(int, filename.split('.')[0].split('_')[1:])return index_df.loc[(suite_id, sample_id, code), 'value']# 构建value值对应的标签序号,用于模型训练
label_dic = {0:0, 1:1, 2:2, 3:3, 4:4, 5:5, 6:6, 7:7, 8:8, 9:9, 10:10, 100:11, 1000:12, 10000:13, 100000000:14}
# 获取所有图片的标签并转化为标签序号
all_labels = [get_label_from_index(os.path.basename(img), index_df) for img in all_images]
all_labels = [label_dic[li] for li in all_labels]# 将图片路径和标签分成训练集和测试集
train_images, test_images, train_labels, test_labels = train_test_split(all_images, all_labels, test_size=0.2, random_state=2024)

        下面定义数据集类并完成数据的加载。

# 自定义数据集类
class CustomDataset(Dataset):def __init__(self, image_paths, labels, transform=None):self.image_paths = image_pathsself.labels = labelsself.transform = transformdef __len__(self):return len(self.image_paths)def __getitem__(self, idx):image = Image.open(self.image_paths[idx]).convert('L')  # 转换为灰度图像label = self.labels[idx]if self.transform:image = self.transform(image)return image, label# 创建训练集和测试集数据集
train_dataset = CustomDataset(train_images, train_labels, transform=transform)
test_dataset = CustomDataset(test_images, test_labels, transform=transform)# 创建数据加载器
train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=64, shuffle=False)# 打印一些信息
print(f'训练集样本数: {len(train_dataset)}')
print(f'测试集样本数: {len(test_dataset)}')

3、模型构建

        我们构建一个包含两层卷积层和池化层的CNN并且在池化层中使用最大池化的方式。

# 定义CNN模型
class CNN(nn.Module):def __init__(self):super(CNN, self).__init__()self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)self.fc1 = nn.Linear(64 * 16 * 16, 128)self.fc2 = nn.Linear(128, 15)def forward(self, x):x = self.pool(F.relu(self.conv1(x)))x = self.pool(F.relu(self.conv2(x)))x = x.view(-1, 64 * 16 * 16)x = F.relu(self.fc1(x))x = self.fc2(x)return x

4、模型实例化及训练

        下面我们对模型进行实例化并定义criterion和optimizer。

# 初始化模型、损失函数和优化器
model = CNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9)

        定义训练的代码并调用代码训练模型。

from tqdm import tqdm
# 训练模型
def train(model, train_loader, criterion, optimizer, epochs):model.train()running_loss = 0.0for epoch in range(epochs):for data, target in tqdm(train_loader):optimizer.zero_grad()output = model(data)loss = criterion(output, target)loss.backward()optimizer.step()running_loss += loss.item()print(f'Epoch [{epoch + 1}], Loss: {running_loss / len(train_loader):.4f}')running_loss = 0.0train(model, train_loader, criterion, optimizer, num_epochs)

5、测试模型

        定义模型测试代码,调用代码看指标可知我们所构建的CNN模型表现还不错。

# 测试模型
def test(model, test_loader, criterion):model.eval()test_loss = 0correct = 0all_preds = []all_targets = []with torch.no_grad():for data, target in test_loader:output = model(data)test_loss += criterion(output, target).item()pred = output.argmax(dim=1, keepdim=True)correct += pred.eq(target.view_as(pred)).sum().item()all_preds.extend(pred.cpu().numpy())all_targets.extend(target.cpu().numpy())test_loss /= len(test_loader.dataset)accuracy = 100. * correct / len(test_loader.dataset)precision = precision_score(all_targets, all_preds, average='macro')recall = recall_score(all_targets, all_preds, average='macro')f1 = f1_score(all_targets, all_preds, average='macro')print(f'Test Loss: {test_loss:.4f}, Accuracy: {accuracy:.2f}%, Precision: {precision:.4f}, Recall: {recall:.4f}, F1 Score: {f1:.4f}')test(model, test_loader, criterion)

三、完整代码

import pandas as pd
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader, random_split
from sklearn.metrics import precision_score, recall_score, f1_scorepath = '/kaggle/input/chinese-mnist/data/data/'
files = os.listdir(path)
print('数据总量:', len(files))# 超参数
batch_size = 64
learning_rate = 0.01
num_epochs = 5# 数据预处理
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,))
])# 获取所有图片文件的路径
all_images = [os.path.join(path, img) for img in os.listdir(path) if img.endswith('.jpg')]# 读取索引-标签对应关系csv文件,并将'suite_id', 'sample_id', 'code'设置为索引列便于查找
index_df = pd.read_csv('/kaggle/input/chinese-mnist/chinese_mnist.csv')
index_df.set_index(['suite_id', 'sample_id', 'code'], inplace=True)# 定义函数,根据各索引取值定位图片对应的数值标签value
def get_label_from_index(filename, index_df):suite_id, sample_id, code = map(int, filename.split('.')[0].split('_')[1:])return index_df.loc[(suite_id, sample_id, code), 'value']# 构建value值对应的标签序号,用于模型训练
label_dic = {0:0, 1:1, 2:2, 3:3, 4:4, 5:5, 6:6, 7:7, 8:8, 9:9, 10:10, 100:11, 1000:12, 10000:13, 100000000:14}# 获取所有图片的标签并转化为标签序号
all_labels = [get_label_from_index(os.path.basename(img), index_df) for img in all_images]
all_labels = [label_dic[li] for li in all_labels]# 将图片路径和标签分成训练集和测试集
train_images, test_images, train_labels, test_labels = train_test_split(all_images, all_labels, test_size=0.2, random_state=2024)# 自定义数据集类
class CustomDataset(Dataset):def __init__(self, image_paths, labels, transform=None):self.image_paths = image_pathsself.labels = labelsself.transform = transformdef __len__(self):return len(self.image_paths)def __getitem__(self, idx):image = Image.open(self.image_paths[idx]).convert('L')  # 转换为灰度图像label = self.labels[idx]if self.transform:image = self.transform(image)return image, label# 创建训练集和测试集数据集
train_dataset = CustomDataset(train_images, train_labels, transform=transform)
test_dataset = CustomDataset(test_images, test_labels, transform=transform)# 创建数据加载器
train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=64, shuffle=False)# 打印信息
print(f'训练集样本数: {len(train_dataset)}')
print(f'测试集样本数: {len(test_dataset)}')# 定义CNN模型
class CNN(nn.Module):def __init__(self):super(CNN, self).__init__()self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)self.fc1 = nn.Linear(64 * 16 * 16, 128)self.fc2 = nn.Linear(128, 15)def forward(self, x):x = self.pool(F.relu(self.conv1(x)))x = self.pool(F.relu(self.conv2(x)))x = x.view(-1, 64 * 16 * 16)x = F.relu(self.fc1(x))x = self.fc2(x)return x# 初始化模型、损失函数和优化器
model = CNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9)# 训练模型
def train(model, train_loader, criterion, optimizer, epochs):model.train()running_loss = 0.0for epoch in range(epochs):for data, target in tqdm(train_loader):optimizer.zero_grad()output = model(data)loss = criterion(output, target)loss.backward()optimizer.step()running_loss += loss.item()print(f'Epoch [{epoch + 1}], Loss: {running_loss / len(train_loader):.4f}')running_loss = 0.0train(model, train_loader, criterion, optimizer, num_epochs)# 测试模型
def test(model, test_loader, criterion):model.eval()test_loss = 0correct = 0all_preds = []all_targets = []with torch.no_grad():for data, target in test_loader:output = model(data)test_loss += criterion(output, target).item()pred = output.argmax(dim=1, keepdim=True)correct += pred.eq(target.view_as(pred)).sum().item()all_preds.extend(pred.cpu().numpy())all_targets.extend(target.cpu().numpy())test_loss /= len(test_loader.dataset)accuracy = 100. * correct / len(test_loader.dataset)precision = precision_score(all_targets, all_preds, average='macro')recall = recall_score(all_targets, all_preds, average='macro')f1 = f1_score(all_targets, all_preds, average='macro')print(f'Test Loss: {test_loss:.4f}, Accuracy: {accuracy:.2f}%, Precision: {precision:.4f}, Recall: {recall:.4f}, F1 Score: {f1:.4f}')test(model, test_loader, criterion)

四、总结

        本文基于汉字手写体数字图像进行了CNN分类实战,CNN作为图像处理的经典模型,展现出了它强大的图像特征提取能力,结合更加复杂的模型框架CNN还可用于高精度人脸识别、物体识别等任务中。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/69283.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【深入理解FFMPEG】命令行阅读笔记

这里写自定义目录标题 第三章 FFmpeg工具使用基础3.1 ffmpeg常用命令3.1.13.1.3 转码流程 3.2 ffprobe 常用命令3.2.1 ffprobe常用参数3.2.2 ffprobe 使用示例 3.3 ffplay常用命令3.3.1 ffplay常用参数3.3.2 ffplay高级参数3.3.4 ffplay快捷键 第4章 封装与解封装4.1 视频文件转…

ORACLE-主备备-Failover

背景 随着业务的不断增涨,至使现有的单节点DG环境的连接已经无法满足当前业务需求,并且随着业务的重要性,同时也要求数据库的高可用性,减少数据库故障对业务的影响。于是规划迁移方案。 迁移方案如下: 因PRIMARY库本地磁盘空间已达到80%决定弃用,搭建高可用2个节点的RAC做…

OpenEuler学习笔记(十):用OpenEuler搭建web服务器

以下是在OpenEuler系统上搭建Web服务器的详细步骤,这里以常见的Nginx为例。 1. 系统更新 在进行任何操作之前,最好先更新系统的软件包,确保系统是最新的状态。 sudo dnf update -y2. 安装Nginx 可以使用OpenEuler的软件包管理器dnf来安装…

【C语言系列】深入理解指针(4)

深入理解指针(4) 一、回调函数是什么?二、qsort使用举例2.1使用qsort函数排序整型数据2.2使用qsort排序结构数据 三、qsort函数的模拟实现四、总结 一、回调函数是什么? 回调函数就是一个通过函数指针调用的函数。 如果你把函数的…

vim的多文件操作

[rootxxx ~]# vim aa.txt bb.txt cc.txt #多文件操作 next #下一个文件 prev #上一个文件 first #第一个文件 last #最后一个文件 快捷键: ctrlshift^ #当前和上个之间切换 说明:快捷键ctrlshift^&#xff0c…

解决CentOS9系统下Zabbix 7.2图形中文字符乱码问题

操作系统:CentOS 9 Zabbix版本:Zabbix7.2 问题描述:主机图形中文字符乱码 解决方案: # 安装字体配置和中文语言包 sudo yum install -y fontconfig langpacks-zh_CN.noarch # 检查是否已有中文字体: fc-list :lan…

[SUCTF 2018]MultiSQL1

进去题目页面如下 发现可能注入点只有登录和注册,那么我们先注册一个用户,发现跳转到了/user/user.php, 查看用户信息,发现有传参/user/user.php?id1 用?id1 and 11,和?id1 and 12,判断为数字型注入 原本以为是简单的数字型注入,看到大…

计算机视觉-卷积

卷积-图像去噪 一、图像 二进制 灰度 彩色 1.1二进制图像 0 1 一个点可以用一个bit(0/1)来表示 1.2灰度图像 0-255 一个点可以用一个byte来表示 1.3彩色图像 RGB 表达一个彩色图像先说它的分辨率p/w(宽)和q/h(高…

mybatis(78/134)

前天学了很多&#xff0c;关于java的反射机制&#xff0c;其实跳过了new对象&#xff0c;然后底层生成了字节码&#xff0c;创建了对应的编码。手搓了一遍源码&#xff0c;还是比较复杂的。 <?xml version"1.0" encoding"UTF-8" ?> <!DOCTYPE …

1.23 补题 寒假训练营

E 一起走很长的路&#xff01; 输入描述 第一行输入两个整数 n,q&#xff08;1≤n,q≤210^5&#xff09;&#xff0c;代表多米诺骨牌的个数和询问次数。 第二行输入 n 个整数 a1,a2,…,an​&#xff08;1≤ai≤10^9&#xff09;&#xff0c;表示多米诺骨牌的重量。 此后输入…

【中间件快速入门】什么是Redis

现在后端开发会用到各种中间件&#xff0c;一不留神项目可能在哪天就要用到一个我们之前可能听过但是从来没接触过的中间件&#xff0c;这个时候对于开发人员来说&#xff0c;如果你不知道这个中间件的设计逻辑和使用方法&#xff0c;那在后面的开发和维护工作中可能就会比较吃…

金晟新能源由盈转亏:毛利率下滑产能利用率不佳,关联交易持续增加

《港湾商业观察》黄懿 近期&#xff0c;广东金晟新能源股份有限公司&#xff08;下称“金晟新能源”&#xff09;递交了招股书&#xff0c;拟冲刺港交所IPO&#xff0c;中金公司、招银国际为联席保荐人。 金晟新能源处于电池回收的新兴大势行业&#xff0c;但是&#xff0c;受…

RTMP|RTSP播放器只解码视频关键帧功能探讨

技术背景 我们在做RTMP|RTSP直播播放器的时候&#xff0c;遇到过这样的技术诉求&#xff0c;在一些特定的应用场景中&#xff0c;可能只需要关键帧的信息&#xff0c;例如视频内容分析系统&#xff0c;可能只对关键帧进行分析&#xff0c;以提取特征、检测对象或场景变化。鉴于…

2K高刷电竞显示器怎么选?

2K高刷电竞显示器怎么选&#xff1f;哪个价格适合你&#xff1f;哪个配置适合你呢&#xff1f; 1.HKC G27H2Pro - 2K高刷电竞显示器怎么选 外观设计 - HKC G27H2Pro 2K高刷电竞显示器 电竞风拉满&#xff1a;作为猎鹰系列的一员&#xff0c;背部 “鹰翼图腾” 切割线搭配炎红…

STM32-时钟树

STM32-时钟树 时钟 时钟

基于SpringBoot+WebSocket的前后端连接,并接入文心一言大模型API

前言&#xff1a; 本片博客只讲述了操作的大致流程&#xff0c;具体实现步骤并不标准&#xff0c;请以参考为准。 本文前提&#xff1a;熟悉使用webSocket 如果大家还不了解什么是WebSocket&#xff0c;可以参考我的这篇博客&#xff1a; rWebSocket 详解&#xff1a;全双工…

StarRocks BE源码编译、CLion高亮跳转方法

阅读SR BE源码时&#xff0c;很多类的引用位置爆红找不到&#xff0c;或无法跳转过去&#xff0c;而自己的Linux机器往往缺乏各种C依赖库&#xff0c;配置安装比较麻烦&#xff0c;因此总体的思路是通过CLion远程连接SR社区已经安装完各种依赖库的Docker容器&#xff0c;进行编…

STM32 按键密码系统的实现

本次基于STM32F407开发板&#xff0c;来实现密码系统&#xff0c;输入四位密码&#xff0c;密码正确时LED1亮&#xff0c;密码错误时四个LED灯双闪。 LED双闪代码 简单的逻辑&#xff0c;让四个LED灯先亮然后再延时一会LED灯灭&#xff0c;循环4此实现双闪的效果。 按键密码的…

linux常用加固方式

目录 一.系统加固 二.ssh加固 三.换个隐蔽的端口 四.防火墙配置 五.用户权限管理 六.暴力破解防护 七.病毒防护 八.磁盘加密 九.双因素认证2FA 十.日志监控 十一.精简服务 一.系统加固 第一步&#xff1a;打好系统补丁 sudo apt update && sudo apt upgra…

hadoop==docker desktop搭建hadoop

hdfs map readuce yarn https://medium.com/guillermovc/setting-up-hadoop-with-docker-and-using-mapreduce-framework-c1cd125d4f7b 清理资源 docker-compose down docker system prune -f