Kaggle -- Digit Recognizer 98.57%

使用卷积神经网络进行模型构建,代码如下:

import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, random_split# 读取数据,并跳过标题行
df = pd.read_csv('train.csv', header=None, skiprows=1)# 数据预处理
labels = df.iloc[:, 0].values.astype(int)
pixels = df.iloc[:, 1:].values.astype(float)# 重塑数据并归一化像素值
pixels = pixels.reshape(-1, 28, 28).astype('float32') / 255.0# 自定义Dataset类
class HandwrittenDigitsDataset(Dataset):def __init__(self, images, labels=None):self.images = torch.tensor(images, dtype=torch.float32)if labels is not None:self.labels = torch.tensor(labels, dtype=torch.long)else:self.labels = Nonedef __len__(self):return len(self.images)def __getitem__(self, idx):image = self.images[idx].unsqueeze(0)  # 加入通道维度if self.labels is not None:label = self.labels[idx]return image, labelelse:return image# 创建Dataset对象
dataset = HandwrittenDigitsDataset(pixels, labels)# 拆分数据集
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])# 创建DataLoader
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)# 定义CNN模型
class CNNModel(nn.Module):def __init__(self):super(CNNModel, self).__init__()self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)self.pool = nn.MaxPool2d(kernel_size=2, stride=2)self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)self.fc1 = nn.Linear(64 * 7 * 7, 128)self.fc2 = nn.Linear(128, 10)def forward(self, x):x = self.pool(F.relu(self.conv1(x)))x = self.pool(F.relu(self.conv2(x)))x = x.view(-1, 64 * 7 * 7)x = F.relu(self.fc1(x))x = self.fc2(x)return x# 实例化模型
model = CNNModel()# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)# 训练模型
num_epochs = 10
for epoch in range(num_epochs):model.train()running_loss = 0.0for images, labels in train_loader:optimizer.zero_grad()outputs = model(images)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}')# 评估模型
model.eval()
correct = 0
total = 0
with torch.no_grad():for images, labels in test_loader:outputs = model(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print(f'Accuracy of the model on the test set: {100 * correct / total:.2f}%')# 读取预测数据,并跳过标题行
predict_df = pd.read_csv('test.csv', header=None, skiprows=1)# 数据预处理
predict_pixels = predict_df.values.astype(float)# 重塑数据并归一化像素值
predict_pixels = predict_pixels.reshape(-1, 28, 28).astype('float32') / 255.0# 创建预测Dataset对象
predict_dataset = HandwrittenDigitsDataset(predict_pixels)# 创建DataLoader
predict_loader = DataLoader(predict_dataset, batch_size=64, shuffle=False)# 进行预测
predicted_labels = []
model.eval()
with torch.no_grad():for images in predict_loader:outputs = model(images)_, predicted = torch.max(outputs.data, 1)predicted_labels.extend(predicted.numpy())# 将预测结果添加回原始数据
predict_df['label'] = predicted_labels# print(predict_df.loc[0])
predict_df['ImageId'] = [i for i in range(1,len(predict_df) + 1)]
predict_df['Label'] = predict_df['label']predict_df = predict_df[['ImageId','Label']]
print(predict_df.loc[0])predict_df.to_csv("ans.csv",index=False)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/851327.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【MySQL】聊聊唯一索引是如何加锁的

首先我们要明确,加锁的对象是索引,加锁的基本单位是next-key lock,由记录锁和间隙锁组成。next-key是前开后闭区间,间隙锁是前开后开区间。根据不同的查询条件next-key 可能会退化成记录锁或间隙锁。 在能使用记录锁或者间隙锁就…

【CSP】202312-1 仓库规划

2023年 第32次CCF计算机软件能力认证 202312-1 仓库规划 原题链接:CSP32-仓库规划 时间限制: 1.0 秒 空间限制: 512 MiB 目录 题目描述 输入格式 输出格式 样例输入 样例输出 样例解释 子任务 解题思路 AC代码 题目描述 西西艾…

Spring和SpringBoot的特点

1.Spring的特点 1.IOC和AOP是Spring的两大核心特性,即控制反转和依赖注入。 2.松耦合:IOC和AOP两大特性可以尽可能地将对象之间的关系解耦 3.可配置:提供外部化配置的方式,可以灵活地配置容器及容器中的Bean 4.一站式&#xff1a…

电路防护-贴片陶瓷气体放电管

贴片陶瓷气体放电管 GDT工作原理GDT主要特性参数典型电路压敏电阻与 TVS 管的区别 GDT工作原理 陶瓷气体放电管是一种电子器件,其工作原理基于气体放电现象。这种管子的内部填充了一种特定的气体,通常是氖气或氩气。当管子两端施加足够的电压时&#xf…

本地化平台部署运维事项

现阶段越来越多的项目都是有云端SAAS部署,流程技术简单多了,需要服务器,数据库云端购买,各大云厂商也能做好服务器的异常拉起,数据库的集群,备份,主从复制等。需要安全证书,安全产品…

Codeforces Global Round 26 题解分享

A. Strange Splitting 思路 贪心 将题目中的红色元素范围不等于蓝色元素范围改成红色元素范围小于蓝色范围其实是一样的 那么红色元素范围最小是0,要占据一个元素。然后我们只要从数组中找到两个不同的元素就能够使得蓝色元素范围大于0,满足题意。 …

自动化测试进阶之路:从入门到精通

今天,我想和大家分享一些我在自动化测试方面的经验和知识,希望能帮助大家更好地掌握自动化测试技能。 一、自动化测试入门 自动化测试,顾名思义,就是利用自动化工具或脚本来执行测试用例,以减轻测试人员的工作负担&a…

vivado HW_ILA_DATA、HW_PROBE

HW_ILA_DATA 描述 硬件ILA数据对象是ILA调试核心上捕获的数据的存储库 编程到当前硬件设备上。upload_hw_ila_data命令 在从ila调试移动捕获的数据的过程中创建hw_ila_data对象 核心,hw_ila,在物理FPGA上,hw_device。 read_hw_ila_data命令还…

C++中的map容器详解

C中的map容器是一种关联式容器&#xff0c;提供了键-值对&#xff08;key-value pair&#xff09;的存储和快速查找功能。map容器由标准模板库&#xff08;STL&#xff09;提供&#xff0c;包含在<map>头文件中。map使用平衡二叉树&#xff08;通常是红黑树&#xff09;实…

软考 系统架构设计师系列知识点之杂项集萃(31)

接前一篇文章&#xff1a;软考 系统架构设计师系列知识点之杂项集萃&#xff08;30&#xff09; 第49题 软件开发环境是支持软件产品开发的软件系统&#xff0c;它由软件工具集和环境集成机制构成。环境集成机制包括&#xff1a;提供统一的数据模式和数据接口规范的数据集成机…

VB.net调用VC DLL

函数的修饰名&#xff1f;参考文献12 .DEF导出和__declspec(dllexport)的优缺点&#xff1f;参考文献11 1、__declspec(dllexport) 可以使用 __declspec(dllexport) 关键字从 DLL 中导出数据、函数、类或类成员函数。 尝试导出已修饰的 C 函数名称时&#xff0c;这种便利性…

什么是幂等问题?

什么是幂等问题&#xff1f; 先说下什么是幂等&#xff0c;幂等性是数学和计算机科学中的概念&#xff0c;用于描述操作无论执行多少次&#xff0c;都产生相同结果的特性。在软件行业中&#xff0c;广泛应用该概念。当我们说一个接口支持幂等性时&#xff0c;无论调用多少次&a…

nginx优化与防盗链【☆☆☆】

目录 一、用户层面的优化 1、隐藏版本号 方法一&#xff1a;修改配置文件 方法二&#xff1a;修改源码文件&#xff0c;重新编译安装 2、修改nginx用户与组 3、配置nginx网页缓存时间 4、nginx的日志切割 5、配置nginx实现连接超时 6、更改nginx运行进程数 7、开启网…

1 c++多线程创建和传参

什么是进程&#xff1f; 系统资源分配的最小单位。 什么是线程&#xff1f; 操作系统调度的最小单位&#xff0c;即程序执行的最小单位。 为什么需要多线程&#xff1f; &#xff08;1&#xff09;加快程序执行速度和响应速度, 使得程序充分利用CPU资源。 &#xff08;2&…

Python 全栈体系【四阶】(五十八)

第五章 深度学习 十三、自然语言处理&#xff08;NLP&#xff09; 3. 文本表示 3.1 One-hot One-hot&#xff08;独热&#xff09;编码是一种最简单的文本表示方式。如果有一个大小为V的词表&#xff0c;对于第i个词 w i w_i wi​&#xff0c;可以用一个长度为V的向量来表示…

【设计模式】行为型设计模式之 模板方法模式

介绍 GOF 定义 模板方法模式 Template Method Design Pattern &#xff1a;模板方法模式在一个方法中定义一个算法骨架&#xff0c;并将某些步骤推迟到子类中去实现&#xff1b;模板方法在不改变算法整体结构的情况下&#xff0c;可以重新定义算法中的某些步骤。 代码举例 …

npm install 的原理

1. 执行命令发生了什么 &#xff1f; 执行命令后&#xff0c;会将安装相关的依赖&#xff0c;依赖会存放在根目录的node_modules下&#xff0c;默认采用扁平化的方式安装&#xff0c;排序规则为&#xff1a;bin文件夹为第一个&#xff0c;然后是开头系列的文件夹&#xff0c;后…

Linux网络诊断工具mtr命令详解

目录 一、mtr概述 二、mtr的特点 1、动态路由显示 2、数据包类型 3、显示延迟和丢包 4、过滤和日志 5、网络探测 三、基本用法 1、基本语法 2、帮助 3、常用选项 四、输出解释 1、常见mtr命令及其输出 2、输出解释 四、命令实例 1. 最基本的用法 2. 显示报告形式…

SpringBoot 配置事务

SpringBoot 在启动时已经加载了事务管理器&#xff0c;所以只需要在需要添加事务的方法/类上添加Transactional即可生效&#xff0c;无需额外配置。 TransactionAutoConfiguration 事务的自动配置类解析&#xff1a; SpringBoot 启动时加载/META-INF/spring/org.springframewor…

⑤单细胞学习-cellchat组间通讯差异分析

④-1单细胞学习-cellchat单数据代码补充版-CSDN博客 ④-2单细胞学习-cellchat单数据代码补充版&#xff08;通讯网络&#xff09;-CSDN博客 参考&#xff1a; 1&#xff1a;单细胞分析之细胞交互-3&#xff1a;CellChat - 简书 (jianshu.com) 2&#xff1a;CellChat细胞通讯…