Kaggle -- Digit Recognizer 98.57%

使用卷积神经网络进行模型构建,代码如下:

import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, random_split# 读取数据,并跳过标题行
df = pd.read_csv('train.csv', header=None, skiprows=1)# 数据预处理
labels = df.iloc[:, 0].values.astype(int)
pixels = df.iloc[:, 1:].values.astype(float)# 重塑数据并归一化像素值
pixels = pixels.reshape(-1, 28, 28).astype('float32') / 255.0# 自定义Dataset类
class HandwrittenDigitsDataset(Dataset):def __init__(self, images, labels=None):self.images = torch.tensor(images, dtype=torch.float32)if labels is not None:self.labels = torch.tensor(labels, dtype=torch.long)else:self.labels = Nonedef __len__(self):return len(self.images)def __getitem__(self, idx):image = self.images[idx].unsqueeze(0)  # 加入通道维度if self.labels is not None:label = self.labels[idx]return image, labelelse:return image# 创建Dataset对象
dataset = HandwrittenDigitsDataset(pixels, labels)# 拆分数据集
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])# 创建DataLoader
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)# 定义CNN模型
class CNNModel(nn.Module):def __init__(self):super(CNNModel, self).__init__()self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)self.pool = nn.MaxPool2d(kernel_size=2, stride=2)self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)self.fc1 = nn.Linear(64 * 7 * 7, 128)self.fc2 = nn.Linear(128, 10)def forward(self, x):x = self.pool(F.relu(self.conv1(x)))x = self.pool(F.relu(self.conv2(x)))x = x.view(-1, 64 * 7 * 7)x = F.relu(self.fc1(x))x = self.fc2(x)return x# 实例化模型
model = CNNModel()# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)# 训练模型
num_epochs = 10
for epoch in range(num_epochs):model.train()running_loss = 0.0for images, labels in train_loader:optimizer.zero_grad()outputs = model(images)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}')# 评估模型
model.eval()
correct = 0
total = 0
with torch.no_grad():for images, labels in test_loader:outputs = model(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print(f'Accuracy of the model on the test set: {100 * correct / total:.2f}%')# 读取预测数据,并跳过标题行
predict_df = pd.read_csv('test.csv', header=None, skiprows=1)# 数据预处理
predict_pixels = predict_df.values.astype(float)# 重塑数据并归一化像素值
predict_pixels = predict_pixels.reshape(-1, 28, 28).astype('float32') / 255.0# 创建预测Dataset对象
predict_dataset = HandwrittenDigitsDataset(predict_pixels)# 创建DataLoader
predict_loader = DataLoader(predict_dataset, batch_size=64, shuffle=False)# 进行预测
predicted_labels = []
model.eval()
with torch.no_grad():for images in predict_loader:outputs = model(images)_, predicted = torch.max(outputs.data, 1)predicted_labels.extend(predicted.numpy())# 将预测结果添加回原始数据
predict_df['label'] = predicted_labels# print(predict_df.loc[0])
predict_df['ImageId'] = [i for i in range(1,len(predict_df) + 1)]
predict_df['Label'] = predict_df['label']predict_df = predict_df[['ImageId','Label']]
print(predict_df.loc[0])predict_df.to_csv("ans.csv",index=False)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/851327.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【MySQL】聊聊唯一索引是如何加锁的

首先我们要明确,加锁的对象是索引,加锁的基本单位是next-key lock,由记录锁和间隙锁组成。next-key是前开后闭区间,间隙锁是前开后开区间。根据不同的查询条件next-key 可能会退化成记录锁或间隙锁。 在能使用记录锁或者间隙锁就…

电路防护-贴片陶瓷气体放电管

贴片陶瓷气体放电管 GDT工作原理GDT主要特性参数典型电路压敏电阻与 TVS 管的区别 GDT工作原理 陶瓷气体放电管是一种电子器件,其工作原理基于气体放电现象。这种管子的内部填充了一种特定的气体,通常是氖气或氩气。当管子两端施加足够的电压时&#xf…

vivado HW_ILA_DATA、HW_PROBE

HW_ILA_DATA 描述 硬件ILA数据对象是ILA调试核心上捕获的数据的存储库 编程到当前硬件设备上。upload_hw_ila_data命令 在从ila调试移动捕获的数据的过程中创建hw_ila_data对象 核心,hw_ila,在物理FPGA上,hw_device。 read_hw_ila_data命令还…

nginx优化与防盗链【☆☆☆】

目录 一、用户层面的优化 1、隐藏版本号 方法一:修改配置文件 方法二:修改源码文件,重新编译安装 2、修改nginx用户与组 3、配置nginx网页缓存时间 4、nginx的日志切割 5、配置nginx实现连接超时 6、更改nginx运行进程数 7、开启网…

1 c++多线程创建和传参

什么是进程? 系统资源分配的最小单位。 什么是线程? 操作系统调度的最小单位,即程序执行的最小单位。 为什么需要多线程? (1)加快程序执行速度和响应速度, 使得程序充分利用CPU资源。 (2&…

Python 全栈体系【四阶】(五十八)

第五章 深度学习 十三、自然语言处理(NLP) 3. 文本表示 3.1 One-hot One-hot(独热)编码是一种最简单的文本表示方式。如果有一个大小为V的词表,对于第i个词 w i w_i wi​,可以用一个长度为V的向量来表示…

npm install 的原理

1. 执行命令发生了什么 ? 执行命令后,会将安装相关的依赖,依赖会存放在根目录的node_modules下,默认采用扁平化的方式安装,排序规则为:bin文件夹为第一个,然后是开头系列的文件夹,后…

Linux网络诊断工具mtr命令详解

目录 一、mtr概述 二、mtr的特点 1、动态路由显示 2、数据包类型 3、显示延迟和丢包 4、过滤和日志 5、网络探测 三、基本用法 1、基本语法 2、帮助 3、常用选项 四、输出解释 1、常见mtr命令及其输出 2、输出解释 四、命令实例 1. 最基本的用法 2. 显示报告形式…

SpringBoot 配置事务

SpringBoot 在启动时已经加载了事务管理器,所以只需要在需要添加事务的方法/类上添加Transactional即可生效,无需额外配置。 TransactionAutoConfiguration 事务的自动配置类解析: SpringBoot 启动时加载/META-INF/spring/org.springframewor…

帕友的小贴士,锻炼

帕金森病作为一种慢性神经系统疾病,对患者的生活质量产生了深远的影响。虽然医学界对于帕金森病的治疗仍在不断探索,但合理的锻炼已经被证实是改善患者症状、提高生活质量的有效途径之一。本文旨在为帕金森病患者推荐一些适合的锻炼方法,帮助…

c#未能加载基类System错误 这台计算机上缺少此项目引用的 NuGet 程序包

拷贝代码到另一台计算机运行,打开Form1.cs报错 首先确认package的框架 如果是472,则更换472的框架 打开项目->xx属性,进行修改 如果框架正确,就是未识别到程序包 可以参考: https://www.cnblogs.com/txwtech/p/1…

WPF真入门教程32--WPF数字大屏项目实干

1、项目背景 WPF (Windows Presentation Foundation) 是微软的一个框架,用于构建桌面客户端应用程序,它支持富互联网应用程序(RIA)的开发。在数字大屏应用中,WPF可以用来构建复杂的用户界面,展示庞大的数据…

day31贪心算法part01| 理论基础 455.分发饼干 376. 摆动序列 53. 最大子序和

**455.分发饼干 ** 视频讲解 | 力扣链接刚开始想到的&#xff0c;但是这样太暴力了&#xff0c;太笨了 class Solution { public:int findContentChildren(vector<int>& g, vector<int>& s) {// 胃口g 饼干尺寸sint result 0;sort(s.begin(), s.end());…

后方碰撞预警系统技术规范(简化版)

后方碰撞预警系统技术规范(简化版) 1 系统概述2 预警区域3 预警目标4 功能需求功能条件5 显示需求6 指标需求1 系统概述 后方碰撞预警系统RCW(Rear Collision Warning)是在后方车辆即将与自车发生碰撞之前,激活危险警告灯以较高频率闪烁,从而吸引后方驾驶员的注意力,避免…

mysql 数据库datetime 类型,转换为DO里面的long类型后,只剩下年了,没有了月和日

解决方法也简单&#xff1a; 自定义个一个 Date2LongTypeHandler <resultMap id"BeanResult" type"XXXX.XXXXDO"><result column"gmt_create" property"gmtCreate" jdbcType"DATE" javaType"java.lang.Long&…

44【Aseprite 作图】樱花丸子——拆解

1 枝干 2 花朵&#xff1a;其实只要形状差不多都行&#xff0c;有三个颜色&#xff0c;中间花蕊颜色深一点&#xff0c;中间花蕊外的颜色偏白&#xff1b;不透明度也可以改一下&#xff0c;就变成不同颜色 3 丸子 最外层的颜色最深&#xff0c;中间稍浅&#xff0c;加一些高光…

2024全国大学生数学建模竞赛优秀参考资料分享

0、竞赛资料 优秀的资料必不可少&#xff0c;优秀论文是学习的关键&#xff0c;视频学习也非常重要&#xff0c;如有需要请点击下方名片获取。 一、赛事介绍 全国大学生数学建模竞赛(以下简称竞赛)是中国工业与应用数学学会主办的面向全国大学生的群众性科技活动&#xff0c;旨…

修复损坏的Excel文件比你想象的要简单,这里提供几种常见的修复方法

打开重要的Excel文件时遇到问题吗?Microsoft Excel是否要求你验证文件是否已损坏?Excel文件可能由于各种原因而损坏,从而无法打开。但不要失去希望;你可以轻松修复损坏的Excel文件。 更改Excel信任中心设置 Microsoft Excel有一个内置的安全功能,可以在受限模式下打开有…

MySQL数据库的基础:逻辑集合数据库与表的基础操作

本篇会加入个人的所谓鱼式疯言 ❤️❤️❤️鱼式疯言:❤️❤️❤️此疯言非彼疯言 而是理解过并总结出来通俗易懂的大白话, 小编会尽可能的在每个概念后插入鱼式疯言,帮助大家理解的. &#x1f92d;&#x1f92d;&#x1f92d;可能说的不是那么严谨.但小编初心是能让更多人能接…

字符串常量简单介绍

C/C内存四区介绍 如前文所示&#xff0c;字符串常量存储在静态存储区的字符串常量区&#xff0c;这样做的好处是 当程序使用到多个相同的字符串常量时&#xff0c;实际上都是使用的同一份&#xff0c;这样就可以减小程序的体积。注意字符串常量是只读的不能被修改。 如图所示&…