[Kaggle] Digit Recognizer 手写数字识别(神经网络)

文章目录

    • 1. baseline
    • 2. 改进
      • 2.1 增加训练时间
      • 2.2 更改网络结构

Digit Recognizer 练习地址

相关博文:
[Hands On ML] 3. 分类(MNIST手写数字预测)
[Kaggle] Digit Recognizer 手写数字识别

1. baseline

  • 导入包
import tensorflow as tf
from tensorflow import keras
# import keras
import numpy as np
%matplotlib inline
import matplotlib.pyplot as plt
import pandas as pdtrain = pd.read_csv('train.csv')
y_train_full = train['label']
X_train_full = train.drop(['label'], axis=1)
X_test = pd.read_csv('test.csv')
  • 数据维度
X_train_full.shape
(42000, 784)

42000个训练样本,每个样本 28*28 展平后的像素值 784 个

  • 像素归一化,拆分训练集、验证集
X_valid, X_train = X_train_full[:8000] / 255.0, X_train_full[8000:] / 255.0
y_valid, y_train = y_train_full[:8000], y_train_full[8000:]
  • 数据预览
from PIL import Image
img = Image.fromarray(np.uint8(np.array(X_train_full)[0].reshape(28,28)))
img.show()
print(np.uint8(np.array(X_train_full)[0].reshape(28,28)))


数字 1 的像素矩阵:

[[  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0][  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0][  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0][  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0][  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0 188 255  94   0   0   0   0   0][  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0 191 250 253  93   0   0   0   0   0][  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0  123 248 253 167  10   0   0   0   0   0][  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0  80  247 253 208  13   0   0   0   0   0   0][  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0  29 207  253 235  77   0   0   0   0   0   0   0][  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0  54 209 253  253  88   0   0   0   0   0   0   0   0][  0   0   0   0   0   0   0   0   0   0   0   0   0   0  93 254 253 238  170  17   0   0   0   0   0   0   0   0][  0   0   0   0   0   0   0   0   0   0   0   0   0  23 210 254 253 159   0   0   0   0   0   0   0   0   0   0][  0   0   0   0   0   0   0   0   0   0   0   0  16 209 253 254 240  81   0   0   0   0   0   0   0   0   0   0][  0   0   0   0   0   0   0   0   0   0   0   0  27 253 253 254  13   0   0   0   0   0   0   0   0   0   0   0][  0   0   0   0   0   0   0   0   0   0   0  20 206 254 254 198   7   0   0   0   0   0   0   0   0   0   0   0][  0   0   0   0   0   0   0   0   0   0   0 168 253 253 196   7   0   0   0   0   0   0   0   0   0   0   0   0][  0   0   0   0   0   0   0   0   0   0  20 203 253 248  76   0   0   0   0   0   0   0   0   0   0   0   0   0][  0   0   0   0   0   0   0   0   0  22 188 253 245  93   0   0   0   0   0   0   0   0   0   0   0   0   0   0][  0   0   0   0   0   0   0   0   0 103 253 253 191   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0][  0   0   0   0   0   0   0   0  89 240 253 195  25   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0][  0   0   0   0   0   0   0  15 220 253 253  80   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0][  0   0   0   0   0   0   0  94 253 253 253  94   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0][  0   0   0   0   0   0   0  89 251 253 250 131   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0][  0   0   0   0   0   0   0   0 214 218  95   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0][  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0][  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0][  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0][  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0]]
  • 添加模型
model = keras.models.Sequential()
# model.add(keras.layers.Flatten(input_shape=[784]))
model.add(keras.layers.Dense(300, activation="relu"))
model.add(keras.layers.Dense(100, activation="relu"))
model.add(keras.layers.Dense(10, activation="softmax"))

或者以下写法

model = keras.models.Sequential([
# keras.layers.Flatten(input_shape=[784]),
keras.layers.Dense(300, activation="relu"),
keras.layers.Dense(100, activation="relu"),
keras.layers.Dense(10, activation="softmax")
])
  • 定义优化器,配置模型
opt = keras.optimizers.Adam(learning_rate=0.001, beta_1=0.9, beta_2=0.999, decay=0.01)
model.compile(loss="sparse_categorical_crossentropy",optimizer=opt, metrics=["accuracy"])
  • 训练
history = model.fit(X_train, y_train, epochs=30,validation_data=(X_valid, y_valid))
...
Epoch 30/30
1063/1063 [==============================] - 2s 2ms/step - 
loss: 0.0927 - accuracy: 0.9748 - 
val_loss: 0.1295 - val_accuracy: 0.9643
  • 模型参数
model.summary()

输出:

Model: "sequential_5"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense_15 (Dense)             (None, 300)               235500    
_________________________________________________________________
dense_16 (Dense)             (None, 100)               30100     
_________________________________________________________________
dense_17 (Dense)             (None, 10)                1010      
=================================================================
Total params: 266,610
Trainable params: 266,610
Non-trainable params: 0
_________________________________________________________________
  • 绘制模型结构
from tensorflow.keras.utils import plot_model
plot_model(model, './model.png', show_shapes=True)

  • 绘制训练曲线
pd.DataFrame(history.history).plot(figsize=(8, 5))
plt.grid(True)
plt.gca().set_ylim(0, 1) # set the vertical range to [0-1]
plt.show()

  • 对测试集预测
y_pred = model.predict(X_test)
pred = y_pred.argmax(axis=1).reshape(-1)
print(pred.shape)image_id = pd.Series(range(1,len(pred)+1))
output = pd.DataFrame({'ImageId':image_id, 'Label':pred})
output.to_csv("submission_svc.csv",  index=False)

得分 : 0.95989

2. 改进

根据上面的准确率:

...
Epoch 30/30
1063/1063 [==============================] - 2s 2ms/step - 
loss: 0.0927 - accuracy: 0.9748 - 
val_loss: 0.1295 - val_accuracy: 0.9643

人类的准确率几乎是100%,我们的训练集准确率 97.48%,验证集准确率 96.43%,我们的模型存在高偏差

参考, 超参数调试、正则化以及优化:https://michael.blog.csdn.net/article/details/108372707

怎么办?

2.1 增加训练时间

训练次数更改为 epochs=100

...
Epoch 100/100
1063/1063 [==============================] - 2s 2ms/step - 
loss: 0.0751 - accuracy: 0.9798 - 
val_loss: 0.1194 - val_accuracy: 0.9661

得分 : 0.96296,比上面好 0.307%

2.2 更改网络结构

  • 添加隐藏层
model = keras.models.Sequential()
model.add(keras.layers.Dense(300, activation="relu"))
model.add(keras.layers.Dense(100, activation="relu"))
model.add(keras.layers.Dense(100, activation="relu")) # 增加一层
model.add(keras.layers.Dense(10, activation="softmax"))
Epoch 100/100
1063/1063 [==============================] - 2s 2ms/step - 
loss: 0.0585 - accuracy: 0.9847 - 
val_loss: 0.1114 - val_accuracy: 0.9672


得分 : 0.96546,比上面好 0.25%

  • 再添加隐藏层
model = keras.models.Sequential()
model.add(keras.layers.Dense(300, activation="relu"))
model.add(keras.layers.Dense(100, activation="relu"))
model.add(keras.layers.Dense(100, activation="relu")) # 增加一层
model.add(keras.layers.Dense(50, activation="relu")) # 增加一层
model.add(keras.layers.Dense(10, activation="softmax"))
Epoch 100/100
1063/1063 [==============================] - 2s 2ms/step - 
loss: 0.0544 - accuracy: 0.9860 - 
val_loss: 0.1039 - val_accuracy: 0.9700


得分 : 0.96578,比上面好 0.032%

  • 增加隐藏单元数量、使用 batch_size = 128、训练250轮
DROP_OUT = 0.3
model = keras.models.Sequential()
model.add(keras.layers.Dense(500, activation="relu"))
model.add(keras.layers.Dense(500, activation="relu"))
model.add(keras.layers.Dense(500, activation="relu"))
model.add(keras.layers.Dense(500, activation="relu"))
model.add(keras.layers.Dense(10, activation="softmax"))
history = model.fit(X_train, y_train, epochs=250, batch_size=128,validation_data=(X_valid, y_valid))
Epoch 250/250
266/266 [==============================] - 3s 10ms/step - 
loss: 9.7622e-08 - accuracy: 1.0000 - 
val_loss: 0.2358 - val_accuracy: 0.9766


得分 : 0.97442,比上面好 0.864%

  • 使用 dropout 随机使一些神经元失效,是一种正则化方法
DROP_OUT = 0.3
model = keras.models.Sequential()
model.add(keras.layers.Dense(500, activation="relu"))
model.add(keras.layers.Dropout(DROP_OUT)) # dropout 正则化
model.add(keras.layers.Dense(500, activation="relu"))
model.add(keras.layers.Dropout(DROP_OUT))
model.add(keras.layers.Dense(500, activation="relu"))
model.add(keras.layers.Dropout(DROP_OUT))
model.add(keras.layers.Dense(500, activation="relu"))
model.add(keras.layers.Dropout(DROP_OUT))
model.add(keras.layers.Dense(10, activation="softmax"))
history = model.fit(X_train, y_train, epochs=250, batch_size=128,validation_data=(X_valid, y_valid))
Epoch 250/250
266/266 [==============================] - 4s 16ms/step - 
loss: 0.0171 - accuracy: 0.9940 - 
val_loss: 0.0928 - val_accuracy: 0.9779


得分 : 0.97546,比上面好 0.104%

  • 实验对比汇总:
模型/准确率(%)训练集验证集测试集
简单模型97.4896.4395.989
增加训练次数97.9896.6196.296(+0.307%)
增加隐藏层98.4796.7296.546(+0.25%)
再增加隐藏层98.6097.0096.578(+0.032%)
增加隐藏单元数量、batch_size = 128、训练250轮10097.6697.442(+0.864%)
使用 dropout 随机失活(正则化)99.4097.7997.546(+0.104%)

目前最好得分,可以在 kaggle 排到1597名。


我的CSDN博客地址 https://michael.blog.csdn.net/

长按或扫码关注我的公众号(Michael阿明),一起加油、一起学习进步!
Michael阿明

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/473770.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

逻辑回归原理

逻辑回归原理 1 逻辑回归简介 logistic回归(LR),是一种广义的线性回归分析模型,常用于数据挖掘,疾病预测,经济预测等方面。 优点:计算代价低,思路清晰易于理解和实现;…

Django中示例验证码的实现总结

验证码 在用户注册、登录页面,为了防止暴力请求,可以加入验证码功能,如果验证码错误,则不需要继续处理,可以减轻业务服务器、数据库服务器的压力。 1)安装包Pillow3.4.1。 1pip install Pillow3.4.1点击查看…

java.lang.IllegalStateException: Not connected to server

在开发人际银行的时候 客户端smack老是出现如下错误: 12-09 13:00:37.115: E/AndroidRuntime(5221): FATAL EXCEPTION: Thread-1812-09 13:00:37.115: E/AndroidRuntime(5221): java.lang.IllegalStateException: Not connected to server.12-09 13:00:37.115: E/AndroidRuntim…

LeetCode 956. 最高的广告牌(DP)

文章目录1. 题目2. 解题1. 题目 你正在安装一个广告牌,并希望它高度最大。 这块广告牌将有两个钢制支架,两边各一个。每个钢支架的高度必须相等。 你有一堆可以焊接在一起的钢筋 rods。 举个例子,如果钢筋的长度为 1、2 和 3,则…

python面试题总结(一)字符串反转,写取指定数函数

1.请至少用一种方法下面字符串的反转? # 1.请至少用一种方法下面字符串的反转? s hello print() print(-a1-切片,简单的步长为-1, 即字符串的翻转(常用)-) #方法一:切片,简单的步长为-1, 即字符串的翻转(常用); a1s[::-1] print(a1)print() …

新闻发布系统登陆页

主要为前台设计&#xff0c;这真是一个细致活。另外用到圆角矩形制作&#xff0c;其实学会了也蛮简单的。 要学好后台对前台一定要有一定的了解并掌握一些相关知识。 以下为登录页代码&#xff1a; <% Page Language"C#" AutoEventWireup"true" CodeFil…

Tensorflow线程队列与IO操作

目录 Tensorflow线程队列与IO操作 1 线程和队列 1.1 前言 1.2 队列 1.3 队列管理器 1.4 线程协调器 2 文件读取 2.1 流程 2.2 文件读取API&#xff1a; 3 图像读取 3.1 图像读取基本知识 3.2 图像基本操作 3.3 图像读取API 3.4 图片批处理流程 3.5 读取图片案例 …

Django其他(站点、列表、上传

1.静态文件&#xff1a; 项目中的CSS、图片、js都是静态文件 一般会将静态文件放到一个单独的目录中&#xff0c;以方便管理 在html页面中调用时&#xff0c;也需要指定静态文件的路径&#xff0c;Django中提供了一种解析的方式配置静态文件路径 静态文件可以放在项目根目录下…

LeetCode 1298. 你能从盒子里获得的最大糖果数(BFS)

文章目录1. 题目2. 解题1. 题目 给你 n 个盒子&#xff0c;每个盒子的格式为 [status, candies, keys, containedBoxes] &#xff0c;其中&#xff1a; - 状态字 status[i]&#xff1a;整数&#xff0c;如果 box[i] 是开的&#xff0c;那么是 1 &#xff0c;否则是 0 。 - 糖…

给javascript初学者的24条最佳实践

1.使用 代替 JavaScript 使用2种不同的等值运算符&#xff1a;|! 和 |!&#xff0c;在比较操作中使用前者是最佳实践。 “如果两边的操作数具有相同的类型和值&#xff0c;返回true&#xff0c;!返回false。”——JavaScript&#xff1a;语言精粹 然而&#xff0c;当使用和&a…

Python面试题(二)列表去重,单例

1.Python里面如何实现tuple和list的转换python中&#xff0c;tuple和list均为内置类型&#xff0c; 以list作为参数将tuple类初始化&#xff0c;将返回tuple类型tuple([1,2,3]) #list转换为tuple以tuple作为参数将list类初始化&#xff0c;将返回list类型list((1,2,3)) #tuple转…

LeetCode 1614. 括号的最大嵌套深度

文章目录1. 题目2. 解题1. 题目 如果字符串满足一下条件之一&#xff0c;则可以称之为 有效括号字符串&#xff08;valid parentheses string&#xff0c;可以简写为 VPS&#xff09;&#xff1a; 字符串是一个空字符串 ""&#xff0c;或者是一个不为 "("…

[AngularJS]Chapter 1 AnjularJS简介

创建一个完美的Web应用程序是很令人激动的&#xff0c;但是构建这样应用的复杂度也是不可思议的。我们Angular团队的目标就是去减轻构建这样AJAX应用的复杂度。在谷歌我们经历过各种复杂的应用创建工作比如&#xff1a;GMail、Map和日历。我们认为我们有必要把这些经验总结下来…

Log4j框架配置文件

Log4j框架配置文件 1 Log4j的配置文件分类 Log4j支持两种配置文件格式&#xff1a;一中是以log4j.properties &#xff0c;另一种是 log4j.xml 2 Log4j的配置文件例子 ##自定义日志的输出级别log4j.rootLoggerWARN, stdout##自定义日志 log4j.logger.accessWARN, accesslog…

python面试总结(三)拷贝与通信

1.请写出下列结果&#xff1f;&#xff08;深拷贝与浅拷贝&#xff09; import copy a [1, 2, 3, 4, [a, b]] b a c copy.copy(a) d copy.deepcopy(a) a.append(5) a[4].append(c) print(a) print(b) print(c) print(d)# 答案如下&#xff1a; [1, 2, 3, 4, [a, b, c], 5] …

LeetCode 1615. 最大网络秩(出入度)

文章目录1. 题目2. 解题1. 题目 n 座城市和一些连接这些城市的道路 roads 共同组成一个基础设施网络。 每个 roads[i] [ai, bi] 都表示在城市 ai 和 bi 之间有一条双向道路。 两座不同城市构成的 城市对 的 网络秩 定义为&#xff1a;与这两座城市 直接 相连的道路总数。如果…

使用JSLint提高JS代码质量

随着富 Web 前端应用的出现&#xff0c;开发人员不得不重新审视并重视 JavaScript 语言的能力和使用&#xff0c;抛弃过去那种只靠“复制 / 粘贴”常用脚本完成简单前端任务的模式。JavaScript 语言本身是一种弱类型脚本语言&#xff0c;具有相对于 C 或 Java 语言更为松散的限…

Django工具:Git简介与基本操作

1.Git简介&#xff1a; 1.Git是目前世界上最先进的分布式版本控制系统 网址&#xff1a;http://github.com 2.总结git的两大特点&#xff1a; 版本控制&#xff1a;可以解决多人同时开发的代码问题&#xff0c;也可以解决找回历史代码的问题 分布式&#xff1a;Git是分布式…

用户画像系统应用

用户画像系统应用 1 用户信用等级分级 比如在银行根据分级决定给用户贷款的额度&#xff0c;以及贷款的时长&#xff0c;那么怎么对用户分级呢&#xff1f;首先收集大量用户的数据&#xff0c;包括基本属性信息以及用户在使用银行的借记卡&#xff0c;信用卡等等。如果是运营…

SVN或其他网盘类软件同步图标不显示的异常

因为Windows Explorer只支持15个ShellIcon显示 所以有些软件为了正常显示其同步状态&#xff0c;就会通过修改自己的ShellIcon名称来抢占这15个名额 只需在注册表中修改下他们的名称&#xff0c;并将所需要展示的Icon的名称顺序提前 重启Explorer进程即可&#xff1a; HKEY_L…