Tensorflow2.0笔记 - 自定义Layer和Model实现CIFAR10数据集的训练

       本笔记记录使用自定义Layer和Model来做CIFAR10数据集的训练。

        CIFAR10数据集下载:

        https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz

        自定义的Layer和Model实现较为简单,参数量较少,并且没有卷积层和dropout等,最终准确率不高,仅做练习使用。

import tensorflow as tf
import numpy as np
from tensorflow import keras
from tensorflow.keras import datasets, layers, optimizers, Sequential, metricstf.__version__def preprocess(x, y):x = tf.cast(x, dtype=tf.float32) / 255y = tf.cast(y, dtype=tf.int32)return x,ybatchsize = 128
#CIFAR10数据集下载,可以直接使用网络下载
(x,y), (x_val, y_val) = datasets.cifar10.load_data()
#CIFAR10的标签(训练集)数据维度是[50000, 1],通过squeeze消除掉里面1的维度,变成[50000]
print("y.shape:", y.shape)
y = tf.squeeze(y)
print("squeezed y.shape:", y.shape)
y_val = tf.squeeze(y_val)
#进行onehot编码
y = tf.one_hot(y, depth=10)
y_val = tf.one_hot(y_val, depth=10)
print("Datasets: ", x.shape, " ", y.shape, " x.min():", x.min(), " x.max():", x.max())train_db = tf.data.Dataset.from_tensor_slices((x, y))
train_db = train_db.map(preprocess).shuffle(10000).batch(batchsize)
test_db = tf.data.Dataset.from_tensor_slices((x_val, y_val))
test_db = test_db.map(preprocess).batch(batchsize)sample = next(iter(train_db))
print("Batch:", sample[0].shape, sample[1].shape)#自定义Layer
class MyDense(layers.Layer):def __init__(self, input_dim, output_dim):super(MyDense, self).__init__()self.kernel = self.add_weight(name='w', shape=[input_dim, output_dim], initializer=tf.random_uniform_initializer(0, 1.0))self.bias = self.add_weight(name='b', shape=[output_dim], initializer=tf.random_uniform_initializer(0, 1.0))#self.kernel = self.add_weight(name='w', shape=[input_dim, output_dim])#self.bias = self.add_weight(name='b', shape=[output_dim])def call(self, inputs, training = None):x = inputs@self.kernel + self.biasreturn xclass MyNetwork(keras.Model):def __init__(self):super(MyNetwork, self).__init__()self.fc1 = MyDense(32 * 32 * 3, 512)self.fc2 = MyDense(512, 512)self.fc3 = MyDense(512, 256)self.fc4 = MyDense(256, 256)self.fc5 = MyDense(256, 10)def call(self, inputs, training = None):x = tf.reshape(inputs, [-1, 32 * 32 * 3])x = self.fc1(x)x = tf.nn.relu(x)x = self.fc2(x)x = tf.nn.relu(x)x = self.fc3(x)x = tf.nn.relu(x)x = self.fc4(x)x = tf.nn.relu(x)x = self.fc5(x)x = tf.nn.relu(x)#返回logitsreturn xtotal_epoches = 35
learn_rate = 0.001
network = MyNetwork()
network.compile(optimizer=optimizers.Adam(learning_rate=learn_rate),loss = tf.losses.CategoricalCrossentropy(from_logits=True),metrics=['Accuracy'])
network.fit(train_db, epochs=total_epoches, validation_data=test_db, validation_freq=1)

运行结果:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/790558.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

基于java+SpringBoot+Vue的图书个性化推荐系统的设计与实现

基于javaSpringBootVue的图书个性化推荐系统的设计与实现 开发语言: Java 数据库: MySQL技术: SpringBoot MyBatis Vue工具: IDEA/Eclipse、Navicat、Maven 系统展示 前台展示 首页:展示图书信息、好书推荐、留言反馈等。 图书信息:用户可以查看图…

easyExcel 模版导出 中间数据纵向延伸,并且对指定列进行合并

想要达到的效果 引入maven引用 <dependency><groupId>com.alibaba</groupId><artifactId>easyexcel</artifactId><version>3.2.1</version></dependency> 按照要求创建模版 备注 : 模板注意 用{} 来表示你要用的变量 如果本…

商务电子邮件: 在WorkPlace中高效且安全

高效和安全的沟通是任何组织成功的核心。在我们关于电子邮件类型的系列文章的第二期中&#xff0c;我们将重点关注商业电子邮件在促进无缝交互中的关键作用。当你身处重要的工作场环境时&#xff0c;本系列的每篇文章都提供了电子邮件的不同维度的视角。 “2024年&#xff0c;全…

01 使用ArcGIS生成节点路径

目录 1 测试数据准备 1.1 创建空的GDB文件及数据集 1.2 创建道路图层 1.3 绘制路网

计算机视觉之三维重建(6)---多视图几何(上)

文章目录 一、运动恢复结构问题&#xff08;SfM&#xff09;二、欧式结构恢复2.1 概述2.2 求解2.3 欧式结构恢复歧义 三、仿射结构恢复3.1 概述3.2 因式分解法3.3 总结3.4 仿射结构恢复歧义 一、运动恢复结构问题&#xff08;SfM&#xff09; 1. 运动恢复结构问题&#xff1a;通…

enqueue:oracle锁机制

实现锁的方式就是排队咯&#xff0c;那么排队就是有enqueue这么个结构来管理 管理锁的结构叫队列&#xff0c;即enqueue 所有和enqueue相关的函数都叫KSQ-- kernal service enqueue lock是从应用层面看到的锁&#xff0c;enqueue是oracle内部管理锁的一个结构。 可以用v$lock_…

python将visio转换为 PDF 文件

参考链接&#xff1a;在 Python 中將 Visio 轉換為 PDF | Python Visio 到 PDF 庫 (aspose.com) 下载软件包&#xff1a; pip install aspose-diagram-python 读取文件&#xff0c;保存为PDF # 此代碼示例演示如何使用 PDF 保存選項將 Visio 轉換為 PDF import aspose.dia…

基于单片机的超声波测距仪设计_kaic

摘 要 如今社会持续深化转型&#xff0c;在人工智能领域&#xff0c;传感器采集外部数据&#xff0c;经过处理器对数 据运算和处理&#xff0c;从而实现相应的功能。比如自动驾驶技术中&#xff0c;超声波传感器应用广泛&#xff0c; 超声波是一种频率在 20khz 以上的声波&…

HTML优化SEO的实用技巧

在网站开发中&#xff0c;除了关注设计和用户体验&#xff0c;SEO&#xff08;搜索引擎优化&#xff09;也是提升网站流量和可见度的关键。合理的HTML结构和元素运用能够帮助搜索引擎更好地理解页面内容&#xff0c;从而提高搜索排名。以下是一些基于HTML的SEO优化技巧&#xf…

OpenHarmony实战:小型系统移植概述

驱动主要包含两部分&#xff0c;平台驱动和器件驱动。平台驱动主要包括通常在SOC内的GPIO、I2C、SPI等&#xff1b;器件驱动则主要包含通常在SOC外的器件&#xff0c;如 LCD、TP、WLAN等 图1 OpenHarmony 驱动分类 HDF驱动被设计为可以跨OS使用的驱动程序&#xff0c;HDF驱动框…

【WebKit架构讲解】

&#x1f308;个人主页:程序员不想敲代码啊 &#x1f3c6;CSDN优质创作者&#xff0c;CSDN实力新星&#xff0c;CSDN博客专家 &#x1f44d;点赞⭐评论⭐收藏 &#x1f91d;希望本文对您有所裨益&#xff0c;如有不足之处&#xff0c;欢迎在评论区提出指正&#xff0c;让我们共…

Nginx从安装到高可用实用教程!

一、Nginx安装 1、去官网http://nginx.org/下载对应的nginx包&#xff0c;推荐使用稳定版本 2、上传nginx到linux系统 3、安装依赖环境 (1)安装gcc环境 yum install gcc-c(2)安装PCRE库&#xff0c;用于解析正则表达式 yum install -y pcre pcre-devel(3)zlib压缩和解压缩…

java面试题(3)|解释 null 和 “null“ 之间的区别,并举例说明它们在编程中的使用场景

null 和 "null" 之间的区别主要在于语义和数据类型上&#xff1a; null 是一个特殊的值&#xff0c;通常用于表示缺少有效值或未定义的变量。在许多编程语言中&#xff0c;null是一个关键字&#xff0c;表示空值。例如&#xff0c;在Java中&#xff0c;当一个对象尚…

解决el-table设置固定高度后,展示不同列时表格高度变小bug

解决el-table设置固定高度后&#xff0c;展示不同列时表格高度变小bug 1、需求分析2、解决方案 1、需求分析 在el-table使用过程中&#xff0c;选择多个参数展示更多列时会出现高度变小问题究其原因可知是el-table列动态发生变化后&#xff0c;el-table__body-wrapper的高度变…

sqlite在非主键创建一个自增字段

sqlite 自增比较奇葩&#xff0c;自增字段必须建在主键上&#xff0c;但主键很重要。不是每种情况都是给自增去做。比如要实现replace into 时&#xff0c; 要主键作为更新标识。用自增很难实现。 开工&#xff1a; 1、建立一个主表&#xff0c;主表的ID是自增ID&#xff0c;…

CNAS软件测试公司有什么好处?如何选择靠谱的软件测试公司?

CNAS认可是中国合格评定国家认可委员会的英文缩写&#xff0c;由国家认证认可监督管理委员会批准设立并授权的国家认可机构&#xff0c;统一负责对认证机构、实验室和检验机构等相关机构的认可工作。 在软件测试行业&#xff0c;CNAS认可具有重要意义。它标志着一个软件测试公…

站群服务器如何提高搜索引擎排名

站群服务器是一种专门为多个相关联的网站提供支持的服务器&#xff0c;旨在通过网站集合的形式提高搜索引擎排名和曝光度。那么站群服务器如何提高搜索引擎排名呢?Rak部落小编为您整理发布。 站群服务器提高搜索引擎排名的原理主要在于以下几个方面&#xff1a; - **提高网站…

websocket 对于手游的意义

WebSocket作为一个HTTP的升级协议&#xff0c;其实对HTTP协议用的不多&#xff0c;主要是消息头相关部分&#xff0c;WebScoket协议最初的动机应该是给网页应用增加一个更贴近实时环境的通讯方式&#xff0c;让某些网页应用得到更佳的通讯质量&#xff08;双工&#xff0c;低延…

2024阿里云老用户服务器优惠价格99元和199元

阿里云服务器租用价格表2024年最新&#xff0c;云服务器ECS经济型e实例2核2G、3M固定带宽99元一年&#xff0c;轻量应用服务器2核2G3M带宽轻量服务器一年61元&#xff0c;ECS u1服务器2核4G5M固定带宽199元一年&#xff0c;2核4G4M带宽轻量服务器一年165元12个月&#xff0c;2核…

基于Unet的BraTS 3d 脑肿瘤医学图像分割,从nii.gz文件中切分出2D图片数据

1、前言 3D图像分割一直是医疗领域的难题&#xff0c;在这方面nnunet已经成为了标杆&#xff0c;不过nnunet教程较少&#xff0c;本人之前跑了好久&#xff0c;一直目录报错、格式报错&#xff0c;反正哪里都是报错等等。并且&#xff0c;nnunet对于硬件的要求很高&#xff0c…