python实战(十五)——中文手写体数字图像CNN分类

一、任务背景

        本次python实战,我们使用来自Kaggle的数据集《Chinese MNIST》进行CNN分类建模,不同于经典的MNIST数据集,我们这次使用的数据集是汉字手写体数字。除了常规的汉字“零”到“九”之外还多了“十”、“百”、“千”、“万”、“亿”,共15种汉字数字

二、python建模

1、数据读取

        首先,读取jpg数据文件,可以看到总共有15000张图像数据。

import pandas as pd
import ospath = '/kaggle/input/chinese-mnist/data/data/'
files = os.listdir(path)
print('数据总量:', len(files))

        我们也可以打印一张图片出来看看。

import matplotlib.pyplot as plt
import matplotlib.image as mpimg# 定义图片路径
image_path = path+files[3]# 加载图片
image = mpimg.imread(image_path)# 绘制图片
plt.figure(figsize=(3, 3))
plt.imshow(image)
plt.axis('off')  # 关闭坐标轴
plt.show()

2、数据集构建

        加载必要的库以便后续使用,再定义一些超参数。

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image
from torch.utils.data import Dataset, DataLoader, random_split
from sklearn.metrics import precision_score, recall_score, f1_score# 超参数
batch_size = 64
learning_rate = 0.01
num_epochs = 5# 数据预处理
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,))
])

        这里,我们看一看数据集介绍就会知道图片名称及其含义,需要从chinese_mnist.csv文件中根据图片名称中的几个数字来确定图片对应的标签。

# 获取所有图片文件的路径
all_images = [os.path.join(path, img) for img in os.listdir(path) if img.endswith('.jpg')]# 读取索引-标签对应关系csv文件,并将'suite_id', 'sample_id', 'code'设置为索引列便于查找
index_df = pd.read_csv('/kaggle/input/chinese-mnist/chinese_mnist.csv')
index_df.set_index(['suite_id', 'sample_id', 'code'], inplace=True)# 定义函数,根据各索引取值定位图片对应的数值标签value
def get_label_from_index(filename, index_df):suite_id, sample_id, code = map(int, filename.split('.')[0].split('_')[1:])return index_df.loc[(suite_id, sample_id, code), 'value']# 构建value值对应的标签序号,用于模型训练
label_dic = {0:0, 1:1, 2:2, 3:3, 4:4, 5:5, 6:6, 7:7, 8:8, 9:9, 10:10, 100:11, 1000:12, 10000:13, 100000000:14}
# 获取所有图片的标签并转化为标签序号
all_labels = [get_label_from_index(os.path.basename(img), index_df) for img in all_images]
all_labels = [label_dic[li] for li in all_labels]# 将图片路径和标签分成训练集和测试集
train_images, test_images, train_labels, test_labels = train_test_split(all_images, all_labels, test_size=0.2, random_state=2024)

        下面定义数据集类并完成数据的加载。

# 自定义数据集类
class CustomDataset(Dataset):def __init__(self, image_paths, labels, transform=None):self.image_paths = image_pathsself.labels = labelsself.transform = transformdef __len__(self):return len(self.image_paths)def __getitem__(self, idx):image = Image.open(self.image_paths[idx]).convert('L')  # 转换为灰度图像label = self.labels[idx]if self.transform:image = self.transform(image)return image, label# 创建训练集和测试集数据集
train_dataset = CustomDataset(train_images, train_labels, transform=transform)
test_dataset = CustomDataset(test_images, test_labels, transform=transform)# 创建数据加载器
train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=64, shuffle=False)# 打印一些信息
print(f'训练集样本数: {len(train_dataset)}')
print(f'测试集样本数: {len(test_dataset)}')

3、模型构建

        我们构建一个包含两层卷积层和池化层的CNN并且在池化层中使用最大池化的方式。

# 定义CNN模型
class CNN(nn.Module):def __init__(self):super(CNN, self).__init__()self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)self.fc1 = nn.Linear(64 * 16 * 16, 128)self.fc2 = nn.Linear(128, 15)def forward(self, x):x = self.pool(F.relu(self.conv1(x)))x = self.pool(F.relu(self.conv2(x)))x = x.view(-1, 64 * 16 * 16)x = F.relu(self.fc1(x))x = self.fc2(x)return x

4、模型实例化及训练

        下面我们对模型进行实例化并定义criterion和optimizer。

# 初始化模型、损失函数和优化器
model = CNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9)

        定义训练的代码并调用代码训练模型。

from tqdm import tqdm
# 训练模型
def train(model, train_loader, criterion, optimizer, epochs):model.train()running_loss = 0.0for epoch in range(epochs):for data, target in tqdm(train_loader):optimizer.zero_grad()output = model(data)loss = criterion(output, target)loss.backward()optimizer.step()running_loss += loss.item()print(f'Epoch [{epoch + 1}], Loss: {running_loss / len(train_loader):.4f}')running_loss = 0.0train(model, train_loader, criterion, optimizer, num_epochs)

5、测试模型

        定义模型测试代码,调用代码看指标可知我们所构建的CNN模型表现还不错。

# 测试模型
def test(model, test_loader, criterion):model.eval()test_loss = 0correct = 0all_preds = []all_targets = []with torch.no_grad():for data, target in test_loader:output = model(data)test_loss += criterion(output, target).item()pred = output.argmax(dim=1, keepdim=True)correct += pred.eq(target.view_as(pred)).sum().item()all_preds.extend(pred.cpu().numpy())all_targets.extend(target.cpu().numpy())test_loss /= len(test_loader.dataset)accuracy = 100. * correct / len(test_loader.dataset)precision = precision_score(all_targets, all_preds, average='macro')recall = recall_score(all_targets, all_preds, average='macro')f1 = f1_score(all_targets, all_preds, average='macro')print(f'Test Loss: {test_loss:.4f}, Accuracy: {accuracy:.2f}%, Precision: {precision:.4f}, Recall: {recall:.4f}, F1 Score: {f1:.4f}')test(model, test_loader, criterion)

三、完整代码

import pandas as pd
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader, random_split
from sklearn.metrics import precision_score, recall_score, f1_scorepath = '/kaggle/input/chinese-mnist/data/data/'
files = os.listdir(path)
print('数据总量:', len(files))# 超参数
batch_size = 64
learning_rate = 0.01
num_epochs = 5# 数据预处理
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,))
])# 获取所有图片文件的路径
all_images = [os.path.join(path, img) for img in os.listdir(path) if img.endswith('.jpg')]# 读取索引-标签对应关系csv文件,并将'suite_id', 'sample_id', 'code'设置为索引列便于查找
index_df = pd.read_csv('/kaggle/input/chinese-mnist/chinese_mnist.csv')
index_df.set_index(['suite_id', 'sample_id', 'code'], inplace=True)# 定义函数,根据各索引取值定位图片对应的数值标签value
def get_label_from_index(filename, index_df):suite_id, sample_id, code = map(int, filename.split('.')[0].split('_')[1:])return index_df.loc[(suite_id, sample_id, code), 'value']# 构建value值对应的标签序号,用于模型训练
label_dic = {0:0, 1:1, 2:2, 3:3, 4:4, 5:5, 6:6, 7:7, 8:8, 9:9, 10:10, 100:11, 1000:12, 10000:13, 100000000:14}# 获取所有图片的标签并转化为标签序号
all_labels = [get_label_from_index(os.path.basename(img), index_df) for img in all_images]
all_labels = [label_dic[li] for li in all_labels]# 将图片路径和标签分成训练集和测试集
train_images, test_images, train_labels, test_labels = train_test_split(all_images, all_labels, test_size=0.2, random_state=2024)# 自定义数据集类
class CustomDataset(Dataset):def __init__(self, image_paths, labels, transform=None):self.image_paths = image_pathsself.labels = labelsself.transform = transformdef __len__(self):return len(self.image_paths)def __getitem__(self, idx):image = Image.open(self.image_paths[idx]).convert('L')  # 转换为灰度图像label = self.labels[idx]if self.transform:image = self.transform(image)return image, label# 创建训练集和测试集数据集
train_dataset = CustomDataset(train_images, train_labels, transform=transform)
test_dataset = CustomDataset(test_images, test_labels, transform=transform)# 创建数据加载器
train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=64, shuffle=False)# 打印信息
print(f'训练集样本数: {len(train_dataset)}')
print(f'测试集样本数: {len(test_dataset)}')# 定义CNN模型
class CNN(nn.Module):def __init__(self):super(CNN, self).__init__()self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)self.fc1 = nn.Linear(64 * 16 * 16, 128)self.fc2 = nn.Linear(128, 15)def forward(self, x):x = self.pool(F.relu(self.conv1(x)))x = self.pool(F.relu(self.conv2(x)))x = x.view(-1, 64 * 16 * 16)x = F.relu(self.fc1(x))x = self.fc2(x)return x# 初始化模型、损失函数和优化器
model = CNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9)# 训练模型
def train(model, train_loader, criterion, optimizer, epochs):model.train()running_loss = 0.0for epoch in range(epochs):for data, target in tqdm(train_loader):optimizer.zero_grad()output = model(data)loss = criterion(output, target)loss.backward()optimizer.step()running_loss += loss.item()print(f'Epoch [{epoch + 1}], Loss: {running_loss / len(train_loader):.4f}')running_loss = 0.0train(model, train_loader, criterion, optimizer, num_epochs)# 测试模型
def test(model, test_loader, criterion):model.eval()test_loss = 0correct = 0all_preds = []all_targets = []with torch.no_grad():for data, target in test_loader:output = model(data)test_loss += criterion(output, target).item()pred = output.argmax(dim=1, keepdim=True)correct += pred.eq(target.view_as(pred)).sum().item()all_preds.extend(pred.cpu().numpy())all_targets.extend(target.cpu().numpy())test_loss /= len(test_loader.dataset)accuracy = 100. * correct / len(test_loader.dataset)precision = precision_score(all_targets, all_preds, average='macro')recall = recall_score(all_targets, all_preds, average='macro')f1 = f1_score(all_targets, all_preds, average='macro')print(f'Test Loss: {test_loss:.4f}, Accuracy: {accuracy:.2f}%, Precision: {precision:.4f}, Recall: {recall:.4f}, F1 Score: {f1:.4f}')test(model, test_loader, criterion)

四、总结

        本文基于汉字手写体数字图像进行了CNN分类实战,CNN作为图像处理的经典模型,展现出了它强大的图像特征提取能力,结合更加复杂的模型框架CNN还可用于高精度人脸识别、物体识别等任务中。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/69283.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

解决双系统引导问题:Ubuntu 启动时不显示 Windows 选项的处理方法

方法 1:检查 GRUB 引导菜单是否隐藏 启动进入 Ubuntu 系统。打开终端,输入以下命令编辑 GRUB 配置文件:sudo nano /etc/default/grub检查以下配置项: GRUB_TIMEOUT0:如果是 0,将其改为一个较大的值&#x…

Django网站搭建流程

使用Django搭建网站是一个系统的过程,涉及从环境搭建到部署上线的多个步骤。以下是详细的流程: 1. 环境搭建 (1)安装Python Django是基于Python的Web框架,因此需要先安装Python。建议安装Python 3.8及以上版本。 下载地…

【深入理解FFMPEG】命令行阅读笔记

这里写自定义目录标题 第三章 FFmpeg工具使用基础3.1 ffmpeg常用命令3.1.13.1.3 转码流程 3.2 ffprobe 常用命令3.2.1 ffprobe常用参数3.2.2 ffprobe 使用示例 3.3 ffplay常用命令3.3.1 ffplay常用参数3.3.2 ffplay高级参数3.3.4 ffplay快捷键 第4章 封装与解封装4.1 视频文件转…

为AI聊天工具添加一个知识系统 之72 详细设计之13 图灵机

本文要点 要点 实际上是要设计一个图灵机,利用λ转换规则和λ演算 来定义StringProcessor的发生产规则的转换功能。三种文法型运行图灵机来处理 不同的串---符号串, 数字串和文字串 一个 StrIngProcessor,图灵机(利用λ转换规则…

BARN_dataset的生成代码jackal-map-creation-master的使用说明:

主要代码是gen_world_ca.py,其中有各个参数来调节,来生成适合自己机器人的gazebo环境,顺带着还会生成路径等等(没有具体研究),具体参数如下: jackal takes up 2 extra grid squares on each side in addit…

基于新年视角下的城市人流数据分析

2025年新年~~~ 旅游消费似乎又成为城市活力的动力话题。 透过话题看数据,透过数据看结果,无非是从--人流量--到--人留量,能不能留下人,能否因人而产生消费。 基于这个角度,地方政府经营城市的商业模式本质则是为城市…

ORACLE-主备备-Failover

背景 随着业务的不断增涨,至使现有的单节点DG环境的连接已经无法满足当前业务需求,并且随着业务的重要性,同时也要求数据库的高可用性,减少数据库故障对业务的影响。于是规划迁移方案。 迁移方案如下: 因PRIMARY库本地磁盘空间已达到80%决定弃用,搭建高可用2个节点的RAC做…

OpenEuler学习笔记(十):用OpenEuler搭建web服务器

以下是在OpenEuler系统上搭建Web服务器的详细步骤,这里以常见的Nginx为例。 1. 系统更新 在进行任何操作之前,最好先更新系统的软件包,确保系统是最新的状态。 sudo dnf update -y2. 安装Nginx 可以使用OpenEuler的软件包管理器dnf来安装…

【C语言系列】深入理解指针(4)

深入理解指针(4) 一、回调函数是什么?二、qsort使用举例2.1使用qsort函数排序整型数据2.2使用qsort排序结构数据 三、qsort函数的模拟实现四、总结 一、回调函数是什么? 回调函数就是一个通过函数指针调用的函数。 如果你把函数的…

vim的多文件操作

[rootxxx ~]# vim aa.txt bb.txt cc.txt #多文件操作 next #下一个文件 prev #上一个文件 first #第一个文件 last #最后一个文件 快捷键: ctrlshift^ #当前和上个之间切换 说明:快捷键ctrlshift^&#xff0c…

Salesforce Too Many Email Invocations: 11

在 Salesforce 中,“Too Many Email Invocations: 11” 错误通常表示您的组织在单个事务中超过了 Apex 电子邮件调用的限制。Salesforce 设置这些限制是为了防止滥用并确保公平使用。以下是解决该问题的方法: 理解限制 Salesforce 允许每个事务中最多进…

力扣【347. 前 K 个高频元素】Java题解(堆)

TopK问题,我们直接上堆。 首先遍历一次然后把各个数字的出现频率存放在哈希表中便于后面堆的操作。 因为是出现频率前 k 高,所以用小顶堆,当我们遍历的频率值大于堆顶值时就可以替换堆顶。 代码: class Solution {public int[] …

[NOIP2007]矩阵取数游戏

点我写题 题目描述 帅帅经常跟同学玩一个矩阵取数游戏:对于一个给定的n*m的矩阵,矩阵中的每个元素aij均为非负整数。游戏规则如下: 1.每次取数时须从每行各取走一个元素,共n个。m次后取完矩阵所有元素; 2.每次取走的…

解决CentOS9系统下Zabbix 7.2图形中文字符乱码问题

操作系统:CentOS 9 Zabbix版本:Zabbix7.2 问题描述:主机图形中文字符乱码 解决方案: # 安装字体配置和中文语言包 sudo yum install -y fontconfig langpacks-zh_CN.noarch # 检查是否已有中文字体: fc-list :lan…

[SUCTF 2018]MultiSQL1

进去题目页面如下 发现可能注入点只有登录和注册,那么我们先注册一个用户,发现跳转到了/user/user.php, 查看用户信息,发现有传参/user/user.php?id1 用?id1 and 11,和?id1 and 12,判断为数字型注入 原本以为是简单的数字型注入,看到大…

计算机视觉-卷积

卷积-图像去噪 一、图像 二进制 灰度 彩色 1.1二进制图像 0 1 一个点可以用一个bit(0/1)来表示 1.2灰度图像 0-255 一个点可以用一个byte来表示 1.3彩色图像 RGB 表达一个彩色图像先说它的分辨率p/w(宽)和q/h(高…

mybatis(78/134)

前天学了很多&#xff0c;关于java的反射机制&#xff0c;其实跳过了new对象&#xff0c;然后底层生成了字节码&#xff0c;创建了对应的编码。手搓了一遍源码&#xff0c;还是比较复杂的。 <?xml version"1.0" encoding"UTF-8" ?> <!DOCTYPE …

Vuex 的核心概念:State, Mutations, Actions, Getters

Vuex 的核心概念&#xff1a;State, Mutations, Actions, Getters Vuex 是 Vue.js 的官方状态管理库&#xff0c;提供了集中式的状态管理机制。它的核心概念包括 State&#xff08;状态&#xff09;、Mutations&#xff08;变更&#xff09;、Actions&#xff08;动作&#xf…

1.23 补题 寒假训练营

E 一起走很长的路&#xff01; 输入描述 第一行输入两个整数 n,q&#xff08;1≤n,q≤210^5&#xff09;&#xff0c;代表多米诺骨牌的个数和询问次数。 第二行输入 n 个整数 a1,a2,…,an​&#xff08;1≤ai≤10^9&#xff09;&#xff0c;表示多米诺骨牌的重量。 此后输入…

自定义数据集使用框架的线性回归方法对其进行拟合

代码 import torch import numpy as np import torch.nn as nncriterion nn.MSELoss()data np.array([[-0.5, 7.7],[1.8, 98.5],[0.9, 57.8],[0.4, 39.2],[-1.4, -15.7],[-1.4, -37.3],[-1.8, -49.1],[1.5, 75.6],[0.4, 34.0],[0.8, 62.3]])x_data data[:, 0] y_data data…