【信号处理】基于变分自编码器(VAE)的图片典型增强方法实现

关于

深度学习中,经常面临图片数据量较小的问题,此时,对数据进行增强,显得比较重要。传统的图片增强方法包括剪切,增加噪声,改变对比度等等方法,但是,对于后端任务的性能提升有限。所以,变分自编码器被用来实现深度数据增强。

变分自编码器的主要缺点在于生成图像过于平滑和模糊,图像细节重建不足。

常见的图像增强方法:https://www.tensorflow.org/tutorials/images/data_augmentation

工具

数据集下载地址: CIFAR-10 and CIFAR-100 datasets

方法实现

加载数据和必要的库函数
import tensorflow.compat.v1.keras.backend as K
import tensorflow as tf
tf.compat.v1.disable_eager_execution()
import matplotlib.pyplot as plt
import numpy as np
from numpy import random
import tensorflow_datasets as tfds
import keras
from keras.models import Model
from keras.layers import Conv2D, Conv2DTranspose, Input, Flatten, Dense, Lambda, Reshapextrain , ytrain = tfds.as_numpy(tfds.load('cifar10',split='train',batch_size=-1,as_supervised=True,))
xtest , ytest = tfds.as_numpy(tfds.load('cifar10',split='test',batch_size=-1,as_supervised=True,))
xtrain = (xtrain.astype('float32'))/255
xtest = (xtest.astype('float32'))/255height=32
width=32
channels=3
print(f"Train Shape: {xtrain.shape},Test Shape: {xtest.shape}")
plt.imshow(xtrain[0])

编码器模型搭建
input_shape=(height,width,channels)
latent_dims=3072input_img= Input(shape=input_shape, name='encoder_input')
x=Conv2D(128, 4, padding='same', activation='relu',strides=2)(input_img)
x=Conv2D(256, 4, padding='same', activation='relu',strides=2)(x)
x=Conv2D(512, 4, padding='same', activation='relu',strides=2)(x)
x=Conv2D(1024, 4, padding='same', activation='relu',strides=2)(x)
conv_shape = K.int_shape(x)
x=Flatten()(x)
x=Dense(3072, activation='relu')(x)
z_mean=Dense(latent_dims, name='latent_mean')(x)
z_sigma=Dense(latent_dims, name='latent_sigma')(x)def sampler(args):z_mean, z_sigma = argseps = K.random_normal(shape=(K.shape(z_mean)[0], K.int_shape(z_mean)[1]))return z_mean + K.exp(z_sigma / 2) * epsz = Lambda(sampler, output_shape=(latent_dims, ), name='z')([z_mean, z_sigma])encoder = Model(input_img, [z_mean, z_sigma, z], name='encoder')
print(encoder.summary())

 解码器模型构建
decoder_input = Input(shape=(latent_dims, ), name='decoder_input')
x = Dense(conv_shape[1]*conv_shape[2]*conv_shape[3], activation='relu')(decoder_input)
x = Reshape((conv_shape[1], conv_shape[2], conv_shape[3]))(x)
x = Conv2DTranspose(256, 3, padding='same', activation='relu',strides=(2, 2))(x)
x = Conv2DTranspose(128, 3, padding='same', activation='relu',strides=(2, 2))(x)
x = Conv2DTranspose(64, 3, padding='same', activation='relu',strides=(2, 2))(x)
x = Conv2DTranspose(3, 3, padding='same', activation='relu',strides=(2, 2))(x)
x = Conv2DTranspose(channels, 3, padding='same', activation='sigmoid', name='decoder_output')(x)
decoder = Model(decoder_input, x, name='decoder')
decoder.summary()
z_decoded = decoder(z)class CustomLayer(keras.layers.Layer):def vae_loss(self, x, z_decoded):x = K.flatten(x)z_decoded = K.flatten(z_decoded)# Reconstruction loss (as we used sigmoid activation we can use binarycrossentropy)recon_loss = keras.metrics.binary_crossentropy(x, z_decoded)# KL divergencekl_loss = -5e-4 * K.mean(1 + z_sigma - K.square(z_mean) - K.exp(z_sigma), axis=-1)return K.mean(recon_loss + kl_loss)# add custom loss to the classdef call(self, inputs):x = inputs[0]z_decoded = inputs[1]loss = self.vae_loss(x, z_decoded)self.add_loss(loss, inputs=inputs)return x

 

整体模型构建
y = CustomLayer()([input_img, z_decoded])vae = Model(input_img, y, name='vae')
vae.compile(optimizer='adam', loss=None)
vae.summary()

 

模型训练

history=vae.fit(xtrain, verbose=2, epochs = 100, batch_size = 64, validation_split = 0.2)
 训练可视化
f = plt.figure(figsize=(10,7))
f.add_subplot()
#Adding Subplot
plt.plot(history.epoch, history.history['loss'], label = "loss") # Loss curve for training set
plt.plot(history.epoch, history.history['val_loss'], label = "val_loss") # Loss curve for validation setplt.title("Loss Curve",fontsize=18)
plt.xlabel("Epochs",fontsize=15)
plt.ylabel("Loss",fontsize=15)
plt.grid(alpha=0.3)
plt.legend()
plt.savefig("VAE_Loss_Trial5.png")
plt.show()

 中间编码特征可视化
mu, _, _ = encoder.predict(xtest)
#Plot dim1 and dim2 for mu
plt.figure(figsize=(10, 10))
plt.scatter(mu[:, 0], mu[:, 1], c=ytest, cmap='brg')
plt.xlabel('dim 1')
plt.ylabel('dim 2')
plt.colorbar()
plt.show()
plt.savefig("VAE_Colourbar_Trial5.png")

 

数据增强生成
#RANDOM GENERATION
def generate():n=20figure = np.zeros((width *2 , height * 10, channels))#Create a Grid of latent variables, to be provided as inputs to decoder.predict
#Creating vectors within range -5 to 5 as that seems to be the range in latent spacefor k in range(2):for l in range(10):z_sample =random.rand(3072)z_out=np.array([z_sample])x_decoded = decoder.predict(z_out)digit = x_decoded[0].reshape(width, height, channels)figure[k * width: (k + 1) * width,l * height: (l + 1) * height] = digitplt.figure(figsize=(10, 10))
#Reshape for visualizationfig_shape = np.shape(figure)figure = figure.reshape((fig_shape[0], fig_shape[1],3))plt.imshow(figure, cmap='gnuplot2')plt.show()  plt.savefig("VAE_imagesgen_Trial5.png")

解码器图像重建
#IMAGE RECONSTRUCT USING TEST SET IMGS
def reconstruct():num_imgs = 6rand = np.random.randint(1, xtest.shape[0]-6) xtestsample = xtest[rand:rand+num_imgs]x_encoded = np.array(encoder.predict(xtestsample))latent_xtest=x_encoded[2]x_decoded = decoder.predict(latent_xtest)rows = 2 # defining no. of rows in figurecols = 3 # defining no. of colums in figurecell_size = 1.5f = plt.figure(figsize=(cell_size*cols,cell_size*rows*2)) # defining a figure f.tight_layout()for i in range(rows):for j in range(cols): f.add_subplot(rows*2,cols, (2*i*cols)+(j+1)) # adding sub plot to figure on each iterationplt.imshow(xtestsample[i*cols + j]) plt.axis("off")for j in range(cols): f.add_subplot(rows*2,cols,((2*i+1)*cols)+(j+1)) # adding sub plot to figure on each iterationplt.imshow(x_decoded[i*cols + j]) plt.axis("off")f.suptitle("Autoencoder Results - Cifar10",fontsize=18)plt.savefig("VAE_imagesrecons_Trial5.png")plt.show()

 

代码获取

已经附在文章底部,自行拿取。

项目开发,相关问题咨询,欢迎交流沟通。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/790873.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

ObjectiveC-08-OOP面向对象程序设计-类的分离与组合

本节用一简短的文章来说下是ObjectiveC中的类。类其实是OOP中的一个概念,概念上简单来讲类是它是一组关系密切属性的集合,所谓的关系就是对现实事物的抽象。 上面提到的关系包括很多种,比如has a, is a,has some等&…

小程序滑动删除组件+全选批量删除组件+附源码

小程序滑动删除组件全选批量删除组件附源码 说明 使用 uni-app、uview 组件开发,全端(微信小程序、QQ小程序、抖音小程序等等) 支持滑动删除组件、支持左滑删除、长按进入批量删除、全选删除、长按弹窗删除、 组件式开发,文章…

【VUE】ruoyi框架自带页面可正常缓存,新页面缓存无效

ruoyi框架自带页面可正常缓存,新页面缓存无效 背景: 用若依框架进行开发时,发现ruoyi自带的页面缓存正常,而新开发的页面即使设置了缓存,当重新进入页面时依旧刷新了接口。 原因:页面name与 getRouters …

外围极简便携式T12电烙铁(CH32X035)-第二篇

文章目录 系列文章目录前言一、pandas是什么?二、使用步骤 1.引入库2.读入数据总结 一、工程简介 原理图: PCB: 外壳: BOM: 二、功能模块介绍 1、 |----系统初始化 0:填写系统初值 …

OpenHarmony实战:Makefile方式组织编译的库移植

以yxml库为例,其移植过程如下文所示。 源码获取 从仓库获取yxml源码,其目录结构如下表: 表1 源码目录结构 名称描述yxml/bench/benchmark相关代码yxml/test/测试输入输出文件,及测试脚本yxml/Makefile编译组织文件yxml/.gitat…

水果销售(源码+文档)

水果销售管理系统(小程序、ios、安卓都可部署) 文件包含内容程序简要说明含有功能项目截图客户端添加地址首页商品详细意见反馈待发货商品分类我的代付款我的地址搜索防骗指南资料修改登录注册 后端管理分类管理反馈管理订单管理商品管理用户管理 文件包…

搜索与图论——拓扑排序

有向图的拓扑排序就是图的宽度优先遍历的一个应用 有向无环图一定存在拓扑序列(有向无环图又被称为拓扑图),有向有环图一定不存在拓扑序列。无向图没有拓扑序列。 拓扑序列:将一个图排成拓扑序后,所有的边都是从前指…

C 字符串

在 C 语言中,字符串实际上是使用空字符 \0 结尾的一维字符数组。因此,\0 是用于标记字符串的结束。 空字符(Null character)又称结束符,缩写 NUL,是一个数值为 0 的控制字符,\0 是转义字符&…

CAD Plant3D 2023 下载地址及安装教程

CAD Plant3D是一款专业的三维工厂设计软件,用于在工业设备和管道设计领域进行建模和绘图。它是Autodesk公司旗下的AutoCAD系列产品之一,专门针对工艺、石油、化工、电力等行业的设计和工程项目。 CAD Plant3D提供了一套丰富的工具和功能,帮助…

计算机网络-HTTP相关知识-HTTP的发展

HTTP/1.1 特点: 简单:HTTP/1.1的报文格式包括头部和主体,头部信息是键值对的形式,使得其易于理解和使用。灵活和易于扩展:HTTP/1.1的请求方法、URL、状态码、头字段等都可以自定义和扩展,使得其具有很高的…

docker------docker入门

🎈个人主页:靓仔很忙i 💻B 站主页:👉B站👈 🎉欢迎 👍点赞✍评论⭐收藏 🤗收录专栏:Linux 🤝希望本文对您有所裨益,如有不足之处&#…

好用的Android Studio插件管理器

1.使用阿里云的通义灵码方便快速开发 1.1下载插件File->plugin->marketplace 搜索 Tongyilingma然后安装重启登录阿里云,确认 1.2 使用方法 输入信息描述 比如 //写一段冒泡排序然后换行,输入public/private/protected方法会自动生成联想代码…

机器学习——几个线性模型的简介

目录 形式 假设 一元回归例子理解最小二乘法 多元回归 广义线性回归 对数线性回归 逻辑回归 线性判别分析 形式 线性说白了就是初中的一次函数的一种应用,根据不同的(x,y)拟合出一条直线以预测,从而解决各种分类或回归问题,假设有 n …

03原理图:接口、无线、电机、STM32主控、整体模块化设计总结

接口部分 一、TTL 转 USB 驱动电路设计 方案很多,本设计采用的芯片是 CH340E 。 该芯片内部已经集成了振荡器,不需要外部增加晶振。如果其他型号的芯片内部没有振荡器,则外面需要加一个晶振。 再看这篇笔记的时候,你可能有点懵…

蓝桥杯第十三届电子类单片机组决赛程序设计

前言 一、决赛题目 1.比赛题目 2.题目解读 二、功能实现 1.关于定时器资源 1)超声波和NE555需要的定时器资源 2)定时器2 2.单位切换 3.数据长度不足时,高位熄灭 4.AD/DA多通道的处理 5.PWM输出 6.长按功能的实现 三、完整代码演…

Qt C++ | Qt 元对象系统、信号和槽及事件(第一集)

01 元对象系统 一、元对象系统基本概念 1、Qt 的元对象系统提供的功能有:对象间通信的信号和槽机制、运行时类型信息和动态属性系统等。 2、元对象系统是 Qt 对原有的 C++进行的一些扩展,主要是为实现信号和槽机制而引入的, 信号和槽机制是 Qt 的核心特征。 3、要使用元…

三星加强Bixby智能:迈向生成式AI,抗衡谷歌Gemini

每周跟踪AI热点新闻动向和震撼发展 想要探索生成式人工智能的前沿进展吗?订阅我们的简报,深入解析最新的技术突破、实际应用案例和未来的趋势。与全球数同行一同,从行业内部的深度分析和实用指南中受益。不要错过这个机会,成为AI领…

AI2.0时代如何快速落地AI智能应用开发,抓住时代机会

写在前面的话 当我们提到人工智能时也就是AI的时候呢,我们大多数人首先想到的可能就是像chatGPT这样的聊天机器人,这些聊天机器人通过理解,还有生成自然语言可以给我们提供一些信息,这个是AI最终的形态吗或者AI最终的形式吗&…

【STM32嵌入式系统设计与开发】——16InputCapture(输入捕获应用)

这里写目录标题 STM32资料包: 百度网盘下载链接:链接:https://pan.baidu.com/s/1mWx9Asaipk-2z9HY17wYXQ?pwd8888 提取码:8888 一、任务描述二、任务实施1、工程文件夹创建2、函数编辑(1)主函数编辑&#…

代码随想录阅读笔记-二叉树【合并二叉树】

题目 给定两个二叉树,想象当你将它们中的一个覆盖到另一个上时,两个二叉树的一些节点便会重叠。 你需要将他们合并为一个新的二叉树。合并的规则是如果两个节点重叠,那么将他们的值相加作为节点合并后的新值,否则不为 NULL 的节…