二、训练fashion_mnist数据集

一、加载fashion_mnist数据集

fashion_mnist数据集中数据为28*28大小的10分类衣物数据集
其中训练集60000张,测试集10000张

from tensorflow import keras
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as npfashion_mnist = keras.datasets.fashion_mnist
(train_images,train_labels),(test_images,test_labels) = fashion_mnist.load_data()print(train_images.shape)
"""
(60000, 28, 28)
"""
print(test_images.shape)
"""
(10000, 28, 28)
"""
print(train_labels.shape)
"""
(60000,)
"""
print(test_labels.shape)
"""
(60000,)
"""

光看像素值是不是能猜到这个图片是啥了?

print(train_images[0])#看一下训练集第一张图片28*28像素点的值
"""
[[  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0][  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0][  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0][  0   0   0   0   0   0   0   0   0   0   0   0   1   0   0  13  73   0   0   1   4   0   0   0   0   1   1   0][  0   0   0   0   0   0   0   0   0   0   0   0   3   0  36 136 127  62  54   0   0   0   1   3   4   0   0   3][  0   0   0   0   0   0   0   0   0   0   0   0   6   0 102 204 176 134 144 123  23   0   0   0   0  12  10   0][  0   0   0   0   0   0   0   0   0   0   0   0   0   0 155 236 207 178 107 156 161 109  64  23  77 130  72  15][  0   0   0   0   0   0   0   0   0   0   0   1   0  69 207 223 218 216 216 163 127 121 122 146 141  88 172  66][  0   0   0   0   0   0   0   0   0   1   1   1   0 200 232 232 233 229 223 223 215 213 164 127 123 196 229   0][  0   0   0   0   0   0   0   0   0   0   0   0   0 183 225 216 223 228 235 227 224 222 224 221 223 245 173   0][  0   0   0   0   0   0   0   0   0   0   0   0   0 193 228 218 213 198 180 212 210 211 213 223 220 243 202   0][  0   0   0   0   0   0   0   0   0   1   3   0  12 219 220 212 218 192 169 227 208 218 224 212 226 197 209  52][  0   0   0   0   0   0   0   0   0   0   6   0  99 244 222 220 218 203 198 221 215 213 222 220 245 119 167  56][  0   0   0   0   0   0   0   0   0   4   0   0  55 236 228 230 228 240 232 213 218 223 234 217 217 209  92   0][  0   0   1   4   6   7   2   0   0   0   0   0 237 226 217 223 222 219 222 221 216 223 229 215 218 255  77   0][  0   3   0   0   0   0   0   0   0  62 145 204 228 207 213 221 218 208 211 218 224 223 219 215 224 244 159   0][  0   0   0   0  18  44  82 107 189 228 220 222 217 226 200 205 211 230 224 234 176 188 250 248 233 238 215   0][  0  57 187 208 224 221 224 208 204 214 208 209 200 159 245 193 206 223 255 255 221 234 221 211 220 232 246   0][  3 202 228 224 221 211 211 214 205 205 205 220 240  80 150 255 229 221 188 154 191 210 204 209 222 228 225   0][ 98 233 198 210 222 229 229 234 249 220 194 215 217 241  65  73 106 117 168 219 221 215 217 223 223 224 229  29][ 75 204 212 204 193 205 211 225 216 185 197 206 198 213 240 195 227 245 239 223 218 212 209 222 220 221 230  67][ 48 203 183 194 213 197 185 190 194 192 202 214 219 221 220 236 225 216 199 206 186 181 177 172 181 205 206 115][  0 122 219 193 179 171 183 196 204 210 213 207 211 210 200 196 194 191 195 191 198 192 176 156 167 177 210  92][  0   0  74 189 212 191 175 172 175 181 185 188 189 188 193 198 204 209 210 210 211 188 188 194 192 216 170   0][  2   0   0   0  66 200 222 237 239 242 246 243 244 221 220 193 191 179 182 182 181 176 166 168  99  58   0   0][  0   0   0   0   0   0   0  40  61  44  72  41  35   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0][  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0][  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0]]
"""

输出以下这个照片

plt.imshow(train_images[0])

在这里插入图片描述

二、开始训练模型

model = keras.Sequential([keras.layers.Flatten(input_shape=(28,28)),#照片完全展平,一维数组形式keras.layers.Dense(128,activation=tf.nn.relu),#128个神经元keras.layers.Dense(10,activation=tf.nn.softmax)#输出层0-9,一共十个
])

查看模型的结构
第一层784个,flatten层将输入的2828图像进行展开,排列成一行,2828=784

第二层128个,128个神经元;100480个参数,第一层的784和第二层的128全排列,784*128=100352,每一个都有一个bias偏置项,100352+128=100480

第三层10个,也就是10分类,10个不同的类别,到时候输出10个概率值,哪个大就是哪一类;1290个参数,第二层128个神经元,分别于10进行全排列,128*10=1280,每一个都有一个bias偏置项,1280+10=1290

model.summary()
"""
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
flatten (Flatten)            (None, 784)               0         
_________________________________________________________________
dense (Dense)                (None, 128)               100480    
_________________________________________________________________
dense_1 (Dense)              (None, 10)                1290      
=================================================================
Total params: 101,770
Trainable params: 101,770
Non-trainable params: 0
_________________________________________________________________
"""

为了使得效果更好,将数据集中的图像像素值都归一化到0-1之间

train_images_y = train_images/255#对训练图像归一化

训练50次

model.compile(optimizer="adam",loss="sparse_categorical_crossentropy",metrics=['accuracy'])#指定优化方法和损失函数
model.fit(train_images_y,train_labels,epochs=50)#训练

因为模型训练的时候传入的时训练集归一化之后的图像
故,模型评估的时候也需要对测试集进行归一化图像

test_images_y = test_images/255#测试评估的时候需要对测试图像也要归一化
model.evaluate(test_images_y,test_labels)#evaluate评估效果
"""
[0.5110174604289234, 0.8845]
"""

从测试集中挑选几个进行测试,实际上会输出10个值,也就是可能性的概率值,最大的就是预测的类别

model.predict([[test_images[0]/255]])
"""
array([[2.2063166e-16, 1.1835037e-17, 7.4574429e-23, 2.0577940e-22,4.3680589e-17, 2.7080047e-08, 3.8249505e-15, 3.4797877e-06,1.4701404e-10, 9.9999654e-01]], dtype=float32)
"""

筛选模型预测出的值最大的那个

print(np.argmax(model.predict([[test_images[0]/255]])))
"""
9
"""

看下这个图片的实际标签

print(test_labels[0])
"""
9
"""

预测值和实际值一样,说明预测对了

展示下这个图片

plt.imshow(train_images[0])

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/377649.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

jquerymobile 切换页面时候闪烁问题

https://github.com/jquery/jquery-mobile/commit/acbec71e29b6acec6cd2087e84e8434fecc0053f 可以修改css好像是个bug -4,9 4,10 * Dual licensed under the MIT (MIT-LICENSE.txt) or GPL (GPL-LICENSE.txt) licenses.*/.spin {--webkit-animation-name: spin;--webkit-an…

二分法:两个有序数组长度为N,找到第N、N+1大的数

题目 两个有序数组长度为N,找到第N、N1大的数 思路1:双指针,O(N)复杂度 简述思路: 如果当前A指针指向的数组A的内容小于B指针指向的数组B的内容,那么A指针往右移动,然后nums(当前已经遍历过的数字个数)也…

Javascript -- In

http://www.caveofprogramming.com/articles/javascript-2/javascript-in-using-the-in-operator-to-iterate-through-arrays-and-objects/ http://msdn.microsoft.com/en-us/library/ie/9k25hbz2(vvs.94).aspx转载于:https://www.cnblogs.com/daishuguang/p/3392310.html

三、自动终止训练

有时候,当模型损失函数值预期的效果时,就可以结束训练了,一方面节约时间,另一方面防止过拟合 此时,设置损失函数值小于0.4,训练停止 from tensorflow import keras import tensorflow as tf import matplo…

矩阵形状| 使用Python的线性代数

Prerequisite: Linear Algebra | Defining a Matrix 先决条件: 线性代数| 定义矩阵 In the python code, we will add two Matrices. We can add two Matrices only and only if both the matrices have the same dimensions. Therefore, knowing the dimensions o…

[数据库]oracle客户端连服务器错误

昨天晚上和今天上午用11g客户端连同事10g服务器,报错: The Network Adapter could not establish the connection 检查尝试了好多次都没好。 用程序连,依旧是报这个错,所以一查就解决了! 参考:http://apps…

ASP.NET 抓取网页内容

(转)ASP.NET 抓取网页内容 ASP.NET 抓取网页内容-文字 ASP.NET 中抓取网页内容是非常方便的,而其中更是解决了 ASP 中困扰我们的编码问题。 需要三个类:WebRequest、WebResponse、StreamReader。 WebRequest、WebRespo…

leetcode 53. 最大子序和 动态规划解法、贪心法以及二分法

题目 给定一个整数数组 nums ,找到一个具有最大和的连续子数组(子数组最少包含一个元素),返回其最大和。 示例: 输入: [-2,1,-3,4,-1,2,1,-5,4] 输出: 6 解释: 连续子数组 [4,-1,2,1] 的和最大,为 6。 进阶: 如果你…

四、卷积神经网络(Convolution Neural Networks)

一、CNN(Convolution Neural Networks) 卷积神经网络基本思想:识别物体的特征,来进行判断物体 卷积Convolution:过滤器filter中的数值与图片像素值对应相乘再相加,6 * 6卷积一次(步数为1)变成4 * 4 Max Pooling:对卷积…

POJ3096Surprising Strings(map)

题意:输入很多字符串,以星号结束。判断每个字符串是不是“Surprising Strings”,判断方法是:以“ZGBG”为例,“0-pairs”是ZG,GB,BG,这三个子串不相同,所以是“0-unique”…

vs助手使用期过 编译CEGUI的问题:error C2061: 语法错误: 标识符“__RPC__out_xcount_part” VS2010...

第一个问题,下一个破解版的VX_A.dll,将其覆盖以前的dll即可, 但是目录有所要求,如下: XP系统:系统盘\Documents and Settings\用户名\Local Settings\Application win7或者vistaData\Microsoft\VisualStud…

五、项目实战---识别人和马

一、准备训练数据 下载数据集 validation验证集 train训练集 数据集结构如下: 将数据集解压到自己选择的目录下就行 最后的结构效果如下: 二、构建模型 ImageDataGenerator 真实数据中,往往图片尺寸大小不一,需要裁剪成一样…

leetcode 122. 买卖股票的最佳时机 II 思考分析

目录题目贪心法题目 给定一个数组,它的第 i 个元素是一支给定股票第 i 天的价格。 设计一个算法来计算你所能获取的最大利润。你可以尽可能地完成更多的交易(多次买卖一支股票)。 注意:你不能同时参与多笔交易(你必…

css设置a连接禁用样式_使用CSS禁用链接

css设置a连接禁用样式Question: 题: Links are one of the most essential aspects of any web page or website. They play a very important role in making our website or web page quite responsive or interactive. So the topic for discussion is quite pe…

服务器出现 HTTP 错误代码,及解决方法

HTTP 400 - 请求无效 HTTP 401.1 - 未授权:登录失败 HTTP 401.2 - 未授权:服务器配置问题导致登录失败 HTTP 401.3 - ACL 禁止访问资源 HTTP 401.4 - 未授权:授权被筛选器拒绝 HTTP 401.5 - 未授权:ISAPI 或 CGI 授权失败 HTTP 40…

leetcode 55. 跳跃游戏 思考分析

题目 给定一个非负整数数组,你最初位于数组的第一个位置。 数组中的每个元素代表你在该位置可以跳跃的最大长度。 判断你是否能够到达最后一个位置。 示例1: 输入: [2,3,1,1,4] 输出: true 解释: 我们可以先跳 1 步,从位置 0 到达 位置 1…

六、项目实战---识别猫和狗

一、准备数据集 kagglecatsanddogs网上一搜一大堆,这里我就不上传了,需要的话可以私信 导包 import os import zipfile import random import shutil import tensorflow as tf from tensorflow.keras.optimizers import RMSprop from tensorflow.kera…

修改shell终端提示信息

PS1:就是用户平时的提示符。PS2:第一行没输完,等待第二行输入的提示符。公共设置位置:/etc/profile echo $PS1可以看到当前提示符设置例如:显示绿色,并添加时间和shell版本export PS1"\[\e[32m\][\uyou are right…

java 字谜_计算字谜的出现次数

java 字谜Problem statement: 问题陈述: Given a string S and a word C, return the count of the occurrences of anagrams of the word in the text. Both string and word are in lowercase letter. 给定一个字符串S和一个单词C ,返回该单词在文本…

Origin绘制热重TG和微分热重DTG曲线

一、导入数据 二、传到Origin中 三、热重TG曲线 temp为横坐标、mass为纵坐标 绘制折线图 再稍微更改下格式 字体加粗,Times New Roman 曲线宽度设置为2 横纵坐标数值格式为Times New Roman 根据实际情况改下横纵坐标起始结束位置 四、微分热重DTG曲线 点击曲线…