Tensorflow2.0笔记 - 自定义Layer和Model实现CIFAR10数据集的训练

       本笔记记录使用自定义Layer和Model来做CIFAR10数据集的训练。

        CIFAR10数据集下载:

        https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz

        自定义的Layer和Model实现较为简单,参数量较少,并且没有卷积层和dropout等,最终准确率不高,仅做练习使用。

import tensorflow as tf
import numpy as np
from tensorflow import keras
from tensorflow.keras import datasets, layers, optimizers, Sequential, metricstf.__version__def preprocess(x, y):x = tf.cast(x, dtype=tf.float32) / 255y = tf.cast(y, dtype=tf.int32)return x,ybatchsize = 128
#CIFAR10数据集下载,可以直接使用网络下载
(x,y), (x_val, y_val) = datasets.cifar10.load_data()
#CIFAR10的标签(训练集)数据维度是[50000, 1],通过squeeze消除掉里面1的维度,变成[50000]
print("y.shape:", y.shape)
y = tf.squeeze(y)
print("squeezed y.shape:", y.shape)
y_val = tf.squeeze(y_val)
#进行onehot编码
y = tf.one_hot(y, depth=10)
y_val = tf.one_hot(y_val, depth=10)
print("Datasets: ", x.shape, " ", y.shape, " x.min():", x.min(), " x.max():", x.max())train_db = tf.data.Dataset.from_tensor_slices((x, y))
train_db = train_db.map(preprocess).shuffle(10000).batch(batchsize)
test_db = tf.data.Dataset.from_tensor_slices((x_val, y_val))
test_db = test_db.map(preprocess).batch(batchsize)sample = next(iter(train_db))
print("Batch:", sample[0].shape, sample[1].shape)#自定义Layer
class MyDense(layers.Layer):def __init__(self, input_dim, output_dim):super(MyDense, self).__init__()self.kernel = self.add_weight(name='w', shape=[input_dim, output_dim], initializer=tf.random_uniform_initializer(0, 1.0))self.bias = self.add_weight(name='b', shape=[output_dim], initializer=tf.random_uniform_initializer(0, 1.0))#self.kernel = self.add_weight(name='w', shape=[input_dim, output_dim])#self.bias = self.add_weight(name='b', shape=[output_dim])def call(self, inputs, training = None):x = inputs@self.kernel + self.biasreturn xclass MyNetwork(keras.Model):def __init__(self):super(MyNetwork, self).__init__()self.fc1 = MyDense(32 * 32 * 3, 512)self.fc2 = MyDense(512, 512)self.fc3 = MyDense(512, 256)self.fc4 = MyDense(256, 256)self.fc5 = MyDense(256, 10)def call(self, inputs, training = None):x = tf.reshape(inputs, [-1, 32 * 32 * 3])x = self.fc1(x)x = tf.nn.relu(x)x = self.fc2(x)x = tf.nn.relu(x)x = self.fc3(x)x = tf.nn.relu(x)x = self.fc4(x)x = tf.nn.relu(x)x = self.fc5(x)x = tf.nn.relu(x)#返回logitsreturn xtotal_epoches = 35
learn_rate = 0.001
network = MyNetwork()
network.compile(optimizer=optimizers.Adam(learning_rate=learn_rate),loss = tf.losses.CategoricalCrossentropy(from_logits=True),metrics=['Accuracy'])
network.fit(train_db, epochs=total_epoches, validation_data=test_db, validation_freq=1)

运行结果:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/790558.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

基于java+SpringBoot+Vue的图书个性化推荐系统的设计与实现

基于javaSpringBootVue的图书个性化推荐系统的设计与实现 开发语言: Java 数据库: MySQL技术: SpringBoot MyBatis Vue工具: IDEA/Eclipse、Navicat、Maven 系统展示 前台展示 首页:展示图书信息、好书推荐、留言反馈等。 图书信息:用户可以查看图…

easyExcel 模版导出 中间数据纵向延伸,并且对指定列进行合并

想要达到的效果 引入maven引用 <dependency><groupId>com.alibaba</groupId><artifactId>easyexcel</artifactId><version>3.2.1</version></dependency> 按照要求创建模版 备注 : 模板注意 用{} 来表示你要用的变量 如果本…

商务电子邮件: 在WorkPlace中高效且安全

高效和安全的沟通是任何组织成功的核心。在我们关于电子邮件类型的系列文章的第二期中&#xff0c;我们将重点关注商业电子邮件在促进无缝交互中的关键作用。当你身处重要的工作场环境时&#xff0c;本系列的每篇文章都提供了电子邮件的不同维度的视角。 “2024年&#xff0c;全…

计算机视觉之三维重建(6)---多视图几何(上)

文章目录 一、运动恢复结构问题&#xff08;SfM&#xff09;二、欧式结构恢复2.1 概述2.2 求解2.3 欧式结构恢复歧义 三、仿射结构恢复3.1 概述3.2 因式分解法3.3 总结3.4 仿射结构恢复歧义 一、运动恢复结构问题&#xff08;SfM&#xff09; 1. 运动恢复结构问题&#xff1a;通…

enqueue:oracle锁机制

实现锁的方式就是排队咯&#xff0c;那么排队就是有enqueue这么个结构来管理 管理锁的结构叫队列&#xff0c;即enqueue 所有和enqueue相关的函数都叫KSQ-- kernal service enqueue lock是从应用层面看到的锁&#xff0c;enqueue是oracle内部管理锁的一个结构。 可以用v$lock_…

基于单片机的超声波测距仪设计_kaic

摘 要 如今社会持续深化转型&#xff0c;在人工智能领域&#xff0c;传感器采集外部数据&#xff0c;经过处理器对数 据运算和处理&#xff0c;从而实现相应的功能。比如自动驾驶技术中&#xff0c;超声波传感器应用广泛&#xff0c; 超声波是一种频率在 20khz 以上的声波&…

OpenHarmony实战:小型系统移植概述

驱动主要包含两部分&#xff0c;平台驱动和器件驱动。平台驱动主要包括通常在SOC内的GPIO、I2C、SPI等&#xff1b;器件驱动则主要包含通常在SOC外的器件&#xff0c;如 LCD、TP、WLAN等 图1 OpenHarmony 驱动分类 HDF驱动被设计为可以跨OS使用的驱动程序&#xff0c;HDF驱动框…

【WebKit架构讲解】

&#x1f308;个人主页:程序员不想敲代码啊 &#x1f3c6;CSDN优质创作者&#xff0c;CSDN实力新星&#xff0c;CSDN博客专家 &#x1f44d;点赞⭐评论⭐收藏 &#x1f91d;希望本文对您有所裨益&#xff0c;如有不足之处&#xff0c;欢迎在评论区提出指正&#xff0c;让我们共…

Nginx从安装到高可用实用教程!

一、Nginx安装 1、去官网http://nginx.org/下载对应的nginx包&#xff0c;推荐使用稳定版本 2、上传nginx到linux系统 3、安装依赖环境 (1)安装gcc环境 yum install gcc-c(2)安装PCRE库&#xff0c;用于解析正则表达式 yum install -y pcre pcre-devel(3)zlib压缩和解压缩…

解决el-table设置固定高度后,展示不同列时表格高度变小bug

解决el-table设置固定高度后&#xff0c;展示不同列时表格高度变小bug 1、需求分析2、解决方案 1、需求分析 在el-table使用过程中&#xff0c;选择多个参数展示更多列时会出现高度变小问题究其原因可知是el-table列动态发生变化后&#xff0c;el-table__body-wrapper的高度变…

CNAS软件测试公司有什么好处?如何选择靠谱的软件测试公司?

CNAS认可是中国合格评定国家认可委员会的英文缩写&#xff0c;由国家认证认可监督管理委员会批准设立并授权的国家认可机构&#xff0c;统一负责对认证机构、实验室和检验机构等相关机构的认可工作。 在软件测试行业&#xff0c;CNAS认可具有重要意义。它标志着一个软件测试公…

2024阿里云老用户服务器优惠价格99元和199元

阿里云服务器租用价格表2024年最新&#xff0c;云服务器ECS经济型e实例2核2G、3M固定带宽99元一年&#xff0c;轻量应用服务器2核2G3M带宽轻量服务器一年61元&#xff0c;ECS u1服务器2核4G5M固定带宽199元一年&#xff0c;2核4G4M带宽轻量服务器一年165元12个月&#xff0c;2核…

基于Unet的BraTS 3d 脑肿瘤医学图像分割,从nii.gz文件中切分出2D图片数据

1、前言 3D图像分割一直是医疗领域的难题&#xff0c;在这方面nnunet已经成为了标杆&#xff0c;不过nnunet教程较少&#xff0c;本人之前跑了好久&#xff0c;一直目录报错、格式报错&#xff0c;反正哪里都是报错等等。并且&#xff0c;nnunet对于硬件的要求很高&#xff0c…

mac、windows 电脑安装使用多个版本的node

我们为啥要安装多个不同版本的node&#xff1f; 开发旧项目时&#xff0c;使用低版本Nodejs。开发新项目时&#xff0c;需使用高版本Node.js。可使用n同时安装多个版本Node.js&#xff0c;并切换到指定版本Node.js。 mac电脑安装 一、全局安装 npm install -g n 二、mac电脑…

Elasticsearch 压测实践总结

背景 搜索、ES运维场景离不开压力测试。 1.宿主机层面变更&#xff1a;参数调优 & 配置调整 & 硬件升级2.集群层面变更&#xff1a;参数调优3.索引层面变更&#xff1a;mapping调整 当然还有使用层面变更&#xff0c;使用API调优&#xff08;不属于该文章的讨论范围…

四川古力未来科技抖音小店:安全便捷,购物新体验

在数字化浪潮席卷全球的今天&#xff0c;电商平台的安全性与便捷性成为了消费者最为关心的问题。四川古力未来科技有限公司&#xff0c;凭借其强大的技术实力和深厚的行业经验&#xff0c;为广大消费者带来了一个安全可靠的购物新选择——古力未来科技抖音小店。 古力未来科技抖…

Twitter Api查询用户粉丝列表

如果大家为了获取实现方式代码的话可能要让大家失望了&#xff0c;这边文章主要是为了节省大家开发时间&#xff0c;少点坑。https://api.twitter.com/2/users/:id/followers &#xff0c;这个接口很熟悉吧&#xff0c;他是推特提供的获取用户关注者&#xff08;粉丝&#xff0…

基于AI智能识别技术的智慧展览馆视频监管方案设计

一、建设背景 随着科技的不断进步和社会安全需求的日益增长&#xff0c;展览馆作为展示文化、艺术和科技成果的重要场所&#xff0c;其安全监控系统的智能化升级已成为当务之急。为此&#xff0c;旭帆科技&#xff08;TSINGSEE青犀&#xff09;基于视频智能分析技术推出了展览…

再拓信创生态圈|宁盾身份域管与深信服桌面云完成兼容互认证

近日&#xff0c;宁盾国产化身份域管&#xff08;即身份目录服务软件&#xff09;与深信服桌面云系统aDesk完成产品兼容性互认证。经过共同严格测试&#xff0c;宁盾国产化身份域管能够与深信服桌面云系统兼容对接运行&#xff0c;双方相互兼容&#xff0c;共同为企事业单位提供…

H5面临的网络安全威胁和防范措施

H5&#xff0c;是基于HTML5技术的网页文件。HTML&#xff0c;全称Hyper Text Markup Language&#xff0c;即超文本标记语言&#xff0c;由Web的发明者Tim Berners-Lee与同事Daniel W. Connolly共同创立。作为SGML的一种应用&#xff0c;HTML编写的超文本文档能够独立于各种操作…