拓展神经网络八股(入门级)

自制数据集

minst等数据集是别人打包好的,如果是本领域的数据集。自制数据集。

替换

把图片路径和标签文件输入到函数里,并返回输入特征和标签

要生成.npy格式的数据集,在进行读入训练集。

只需要把图片灰度值数据拼接到特征列表,标签添加到标签列表,提取操作函数如下:

def generateds(path, txt):f = open(txt, 'r')contents = f.readlines() #读取所有行f.close()x, y_ = [], []for content in contents:value = content.split()img_path = path + value[0]#找到图片索引路径img = Image.open(img_path) #图片打开img = np.array(img.convert('L')) # 图片变为8位灰度的npy格式的数据集                    img = img / 255.x.append(img)y_.append(value[1])print('loading:' + content) # 打印状态提示x = np.array(x)y_ = np.array(y_)y_ = y_astype(np.int64)return x, y_

 完整代码

import tensorflow as tf
from PIL import Image
import numpy as np
import ostrain_path = './fashion_image_label/fashion_train_jpg_60000/'
train_txt = './fashion_image_label/fashion_train_jpg_60000.txt'
x_train_savepath = './fashion_image_label/fashion_x_train.npy'
y_train_savepath = './fashion_image_label/fahion_y_train.npy'test_path = './fashion_image_label/fashion_test_jpg_10000/'
test_txt = './fashion_image_label/fashion_test_jpg_10000.txt'
x_test_savepath = './fashion_image_label/fashion_x_test.npy'
y_test_savepath = './fashion_image_label/fashion_y_test.npy'def generateds(path, txt):f = open(txt, 'r')contents = f.readlines()  # 按行读取f.close()x, y_ = [], []for content in contents:value = content.split()  # 以空格分开,存入数组img_path = path + value[0]img = Image.open(img_path)img = np.array(img.convert('L'))img = img / 255.x.append(img)y_.append(value[1])print('loading : ' + content)x = np.array(x)y_ = np.array(y_)y_ = y_.astype(np.int64)return x, y_if os.path.exists(x_train_savepath) and os.path.exists(y_train_savepath) and os.path.exists(x_test_savepath) and os.path.exists(y_test_savepath):print('-------------Load Datasets-----------------')x_train_save = np.load(x_train_savepath)y_train = np.load(y_train_savepath)x_test_save = np.load(x_test_savepath)y_test = np.load(y_test_savepath)x_train = np.reshape(x_train_save, (len(x_train_save), 28, 28))x_test = np.reshape(x_test_save, (len(x_test_save), 28, 28))
else:print('-------------Generate Datasets-----------------')x_train, y_train = generateds(train_path, train_txt)x_test, y_test = generateds(test_path, test_txt)print('-------------Save Datasets-----------------')x_train_save = np.reshape(x_train, (len(x_train), -1))x_test_save = np.reshape(x_test, (len(x_test), -1))np.save(x_train_savepath, x_train_save)np.save(y_train_savepath, y_train)np.save(x_test_savepath, x_test_save)np.save(y_test_savepath, y_test)model = tf.keras.models.Sequential([tf.keras.layers.Flatten(),tf.keras.layers.Dense(128, activation='relu'),tf.keras.layers.Dense(10, activation='softmax')
])model.compile(optimizer='adam',loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),metrics=['sparse_categorical_accuracy'])model.fit(x_train, y_train, batch_size=32, epochs=5, validation_data=(x_test, y_test), validation_freq=1)
model.summary()

数据增强

如果数据量过少,模型见识不足。增加数据,提高泛化力。

用来应对因为拍照角度不同引起的图片变形

image_gen_train=tf,keras.preprocessing,image.ImageDataGenneratorP(...)

image_gen)train,fit(x_train)

import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGeneratorfashion = tf.keras.datasets.fashion_mnist
(x_train, y_train), (x_test, y_test) = fashion.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0x_train = x_train.reshape(x_train.shape[0], 28, 28, 1)  # 给数据增加一个维度,使数据和网络结构匹配image_gen_train = ImageDataGenerator(rescale=1. / 1.,  # 如为图像,分母为255时,可归至0~1rotation_range=45,  # 随机45度旋转width_shift_range=.15,  # 宽度偏移height_shift_range=.15,  # 高度偏移horizontal_flip=True,  # 水平翻转zoom_range=0.5  # 将图像随机缩放阈量50%
)
image_gen_train.fit(x_train)model = tf.keras.models.Sequential([tf.keras.layers.Flatten(),tf.keras.layers.Dense(128, activation='relu'),tf.keras.layers.Dense(10, activation='softmax')
])model.compile(optimizer='adam',loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),metrics=['sparse_categorical_accuracy'])model.fit(image_gen_train.flow(x_train, y_train, batch_size=32), epochs=5, validation_data=(x_test, y_test),validation_freq=1)
model.summary()

 因为是标准MINST数据集,因此在准确度上看不出来,需要在具体应用中才能体现

断点续训

实时保存最优模型

 保存模型参数可以使用tensorflow提供的ModelCheckpoint(filepath=checkpoint_save,

                              save_weight_only,sabe_best_only)

参数提取

获取各层网络最优参数,可以在各个平台实现应用

model.trainable_variables 返回模型中可训练参数

acc/loss可视化

查看训练效果

history=model.fit()

import tensorflow as tf
import os
import numpy as np
from matplotlib import pyplot as pltnp.set_printoptions(threshold=np.inf)fashion = tf.keras.datasets.fashion_mnist
(x_train, y_train), (x_test, y_test) = fashion.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0model = tf.keras.models.Sequential([tf.keras.layers.Flatten(),tf.keras.layers.Dense(128, activation='relu'),tf.keras.layers.Dense(10, activation='softmax')
])model.compile(optimizer='adam',loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),metrics=['sparse_categorical_accuracy'])checkpoint_save_path = "./checkpoint/fashion.ckpt"
if os.path.exists(checkpoint_save_path + '.index'):print('-------------load the model-----------------')model.load_weights(checkpoint_save_path)cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path,save_weights_only=True,save_best_only=True)history = model.fit(x_train, y_train, batch_size=32, epochs=5, validation_data=(x_test, y_test), validation_freq=1,callbacks=[cp_callback])
model.summary()print(model.trainable_variables)
file = open('./weights.txt', 'w')
for v in model.trainable_variables:file.write(str(v.name) + '\n')file.write(str(v.shape) + '\n')file.write(str(v.numpy()) + '\n')
file.close()###############################################    show   ################################################ 显示训练集和验证集的acc和loss曲线
acc = history.history['sparse_categorical_accuracy']
val_acc = history.history['val_sparse_categorical_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']plt.subplot(1, 2, 1) 画出第一列
plt.plot(acc, label='Training Accuracy')
plt.plot(val_acc, label='Validation Accuracy')
plt.title('Training and Validation Accuracy')
plt.legend()plt.subplot(1, 2, 2) #画出第二列
plt.plot(loss, label='Training Loss')
plt.plot(val_loss, label='Validation Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.show()

应用程序

给图识物

给出一张图片,输出预测结果

1.复现模型 Sequential加载模型

2.加载参数 load_weights(model_save_path)

3.预测结果

我们需要对颜色取反,我们的训练图片是黑底白字

减少了背景噪声的影响

from PIL import Image
import numpy as np
import tensorflow as tftype = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']model_save_path = './checkpoint/fashion.ckpt'
model = tf.keras.models.Sequential([tf.keras.layers.Flatten(),tf.keras.layers.Dense(128, activation='relu'),tf.keras.layers.Dense(10, activation='softmax')
])model.load_weights(model_save_path)preNum = int(input("input the number of test pictures:"))
for i in range(preNum):image_path = input("the path of test picture:")img = Image.open(image_path)img=img.resize((28,28),Image.ANTIALIAS)img_arr = np.array(img.convert('L'))img_arr = 255 - img_arr  #每个像素点= 255 - 各自点当前灰度值img_arr = img_arr/255.0x_predict = img_arr[tf.newaxis,...]result = model.predict(x_predict)pred=tf.argmax(result, axis=1)print('\n')print(type[int(pred)])

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/46573.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

进程间通信(下)

system V共享内存 共享内存区是最快的IPC形式。一旦这样的内存映射到共享它的进程的地址空间,这些进程间数据传递不再涉及到内核,换句话说是进程不再通过执行进入内核的系统调用来传递彼此的数据 共享内存示意图 通过上面的图,我们不难想到…

linux的学习(三):用户权限,查找,压缩命令

简介 关于用户权限,查找和压缩解压缩命令的简单使用 用户管理命令 useradd useradd:添加新用户,要root权限才能使用 useradd -g 组名 用户名:可以添加到组 创建成功会在 /home下有用户的主目录 passwd passwd 用户名&#x…

【WEB前端2024】3D智体编程:乔布斯3D纪念馆-第60集-agent训练资讯APP重点推荐AI资讯内容(含视频)

【WEB前端2024】3D智体编程:乔布斯3D纪念馆-第60集-agent训练资讯APP重点推荐AI资讯内容(含视频) 使用dtns.network德塔世界(开源的智体世界引擎),策划和设计《乔布斯超大型的开源3D纪念馆》的系列教程。d…

php反序列化--2--PHP反序列化漏洞基础知识

一、什么是反序列化? 反序列化是将序列化的字符串还原为PHP的值的过程。 二、如何反序列化 使用unserialize()函数来执行反序列化操作 代码1: $serializedStr O:8:"stdClass":1:{s:4:"data";s:6:"sample";}; $origina…

Android Service的解析

人不走空 🌈个人主页:人不走空 💖系列专栏:算法专题 ⏰诗词歌赋:斯是陋室,惟吾德馨 Android服务,即Service,是Android四大组件之一,是一种程序后台运行的方案&am…

新增支持GIS地图、数据模型引擎升级、增强数据分析处理能力

为了帮助企业提升数据分析处理能力,Smartbi重点围绕产品易用性、用户体验、操作便捷性进行了更新迭代,同时重磅更新了体验中心。用更加匹配项目及业务需求的Smartbi,帮助企业真正发挥数据的价值,赋能决策经营与管理。 Smartbi用户…

js中使用原型链增加方法后,遍历对象的key-value时会遍历出方法

原因:js使用原型链实现方法时,这个方法默认是可迭代的,所以在遍历时就会被遍历出来, 例: Array.prototype.remove function(n){return this.slice(0,n).concat(this.slice(n1,this.length));}var cc ["cccaaaa…

wifi信号处理的CRC8、CRC32

🧑🏻个人简介:具有3年工作经验,擅长通信算法的MATLAB仿真和FPGA实现。代码事宜,私信博主,程序定制、设计指导。 🚀wifi信号处理的CRC8、CRC32 目录 🚀1.CRC概述 🚀1.C…

定时器的计数模式 定时器中断时钟配置

目录 一,定时器的计数模式 二,定时器中断时钟的配置 三,输入和输出原理 四,PWM波的小简介 一,定时器的计数模式 1.1 定时器的计数模式分别有三种 1.2 定时器溢出的时间(中断,事件产生的时间…

QT多线程下,信号槽分别在什么线程中执行,如何控制?

可以通过connect的第五个参数进行控制信号槽执行时所在的线程 connect有几种连接方式,直接连接、队列连接和 自动连接 直接连接(Qt::DirectConnection):信号槽在信号发出者所在的线程中执行 队列连接(Qt::QueuedConn…

python初学者知识点笔记更新

文章目录 1.main函数入口2.__init__.py 文件作用3.from .applications import server解释4.变量没有修饰,直接创建使用1. 内置数据类型和函数2. 类和对象3.总结 5.mod app.__module__6.集合对比区分集合类型:混合集合类型 7.安装包失败 1.main函数入口 …

vitest 单元测试应用与配置

vitest 应用与配置 一、简介 Vitest 旨在将自己定位为 Vite 项目的首选测试框架,即使对于不使用 Vite 的项目也是一个可靠的替代方案。它本身也兼容一些Jest的API用法。 二、安装vitest // npm npm install -D vitest // yarn yarn add -D vitest // pnpm pnpm …

Linux 06-01:简易shell编写

考虑一下这个与shell典型的互动:ls、ps 用下图的时间轴来表示事件的发生次序。其中时间从左向右。shell由标识为sh的方块代表,它随着时间的流逝从左向右移动。shell从用户读入字符串"ls"。shell建立一个新的进程,然后在那个进程中运…

vs code 启动react项目,执行npm start报错原因分析

1.执行 npm start错误信息:npm : 无法将“npm”项识别为 cmdlet、函数、脚本文件或可运行程序的名称。请检查名称的拼写,如果包括路径,请确保路径正确,然后再试一次。 所在位置 行:1 字符: 1 npm start ~~~ CategoryInfo …

2024年5000元投影仪推荐:五千元最值得买的三款家用激光投影推荐

五千元是很多家庭购买投影仪会选择的价位,这个价位的投影一般属于中高端产品,如果懂配置,知道怎么选的朋友可以选到一款性价比颇高的投影,但是如果不会选不懂配置可能会花冤枉钱。所以五千元价位的投影该如何选择?市面…

企业知识库用不起来?试一下用HelpLook同步钉钉组织架构

提升企业管理和协同效率已成为增强竞争力的关键。企业通过知识管理,搭建内部知识库,将分散的经验和知识转化为系统化流程,减少重复解释,促进业务高效运作。这为企业提供了坚实的基础。 企业知识库面临的挑战 尽管传统知识库内容丰…

Jeecgboot vue3的选择部门组件JSelectDept如何实现只查询本级以及子级的部门

jeecgboot vue3的文档:地址 JSelectDept组件实现了弹窗然后选择部门返回的功能,但部门是所有数据,不符合需求,所以在原有代码上稍微改动了一下 组件属性值如下: 当serverTreeDatafalse的时候,从后端查询…

2024年7月9日~2024年7月15日周报

目录 一、前言 二、完成情况 2.1 特征图保存方法 2.1.1 定义网络模型 2.1.2 定义保存特征图的钩子函数 2.1.3 为模型层注册钩子 2.1.4 运行模型并检查特征图 2.2 实验情况 三、下周计划 一、前言 本周的7月11日~7月14日参加了机器培训的学习讨论会,对很多概…

通过MATLAB控制TI毫米波雷达的工作状态之TLV数据解析及绘制

前言 前一章博主介绍了如何基于设计视图中的这些组件结合MATLAB代码来实现TI毫米波雷达数据的实时采集。这一章将在此基础上实现TI毫米波雷达的TLV数据解析。过程中部分算法会涉及到一些简单的毫米波雷达相关算法,需要各位有一定的毫米波雷达基础。 TLV数据之协议解析 紧着…

数字孪生技术如何助力低空经济飞跃式发展?

一、什么是低空经济? 低空经济,是一个以通用航空产业为主导的经济形态,它涵盖了低空飞行、航空旅游、航空物流、应急救援等多个领域。它以垂直起降型飞机和无人驾驶航空器为载体,通过载人、载货及其他作业等多场景低空飞行活动&a…