AIGC笔记--Diffuser的训练pipeline

1--简单训练pipeline

import time
import numpy as np
import torch
from PIL import Image
import torchvision
import torch.nn.functional as F
from datasets import load_dataset
from torchvision import transforms
from matplotlib import pyplot as plt
from diffusers import DDPMScheduler, UNet2DModel, DDPMPipeline# 数据增广
def transform(examples):preprocess = transforms.Compose([transforms.Resize((32, 32)),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize([0.5], [0.5]),])images = [preprocess(image.convert("RGB")) for image in examples["image"]]return {"images": images}def process_dataset(batch_size):# 加载数据集# dataset = load_dataset("huggan/smithsonian_butterflies_subset", split = "train")dataset = load_dataset("/data-home/liujinfu/Diffuser/Data/smithsonian_butterflies_subset", split = "train")# 调用自定义的transform函数dataset.set_transform(transform)# 设置dataloadertrain_dataloader = torch.utils.data.DataLoader(dataset, batch_size = batch_size, shuffle = True)return train_dataloaderdef train_loop(train_dataloader, noise_scheduler, model, num_epoches, device):# 优化器optimizer = torch.optim.AdamW(model.parameters(), lr = 4e-4)losses = []start_time = time.time() for epoch in range(num_epoches):for _, batch in enumerate(train_dataloader): # 遍历clean_images = batch["images"].to(device) # B C H W# sample noise to add to the imagesnoise = torch.randn(clean_images.shape).to(clean_images.device) # B C H Wbs = clean_images.shape[0] # 64# sample a random timestep for each imagetimesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bs, ), device = clean_images.device).long() # B# Add noise to the clean images according to the noise magnitude at each timestepnoisy_images = noise_scheduler.add_noise(clean_images, noise, timesteps) # 加噪# Get model predictionnoise_pred = model(noisy_images, timesteps, return_dict=False)[0]# Calculate the lossloss = F.mse_loss(noise_pred, noise) # 计算预测噪音和真实噪音之间的损失loss.backward(loss)losses.append(loss.item())# Update the model parameters with the optimizeroptimizer.step()optimizer.zero_grad()if (epoch + 1) % 5 == 0:loss_last_epoch = sum(losses[-len(train_dataloader) :]) / len(train_dataloader)print(f"Epoch:{epoch+1}, loss: {loss_last_epoch}")end_time = time.time()elapsed_time = end_time - start_time # 记录训练时间print("time cost: ", elapsed_time)return lossesdef vis(losses):# 可视化 lossfig, axs = plt.subplots(1, 2, figsize=(12, 4))axs[0].plot(losses)axs[1].plot(np.log(losses))return figdef generate(model, noise_scheduler):# 1. create a pipelineimage_pipe = DDPMPipeline(unet = model, scheduler = noise_scheduler)pipeline_output = image_pipe()return pipeline_output.images[0]# 可视化生成图像
def show_images(x):"""Given a batch of images x, make a grid and convert to PIL"""x = x * 0.5 + 0.5  # Map from (-1, 1) back to (0, 1)grid = torchvision.utils.make_grid(x)grid_im = grid.detach().cpu().permute(1, 2, 0).clip(0, 1) * 255grid_im = Image.fromarray(np.array(grid_im).astype(np.uint8))return grid_imdef make_grid(images, size=64):"""Given a list of PIL images, stack them together into a line for easy viewing"""output_im = Image.new("RGB", (size * len(images), size))for i, im in enumerate(images):output_im.paste(im.resize((size, size)), (i * size, 0))return output_imdef main():# 获取训练集image_size = 32 batch_size = 64train_dataloader = process_dataset(batch_size = batch_size)# 设置Schedulernoise_scheduler = DDPMScheduler(num_train_timesteps = 1000, beta_schedule = "squaredcos_cap_v2") # 创建Unet modeldevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")model = UNet2DModel(sample_size = image_size, # target image resolutionin_channels = 3,out_channels = 3,layers_per_block = 2, # how many resnet layers to use per Unet blockblock_out_channels = (64, 128,128, 256),down_block_types = ("DownBlock2D","DownBlock2D","AttnDownBlock2D","AttnDownBlock2D",),up_block_types=("AttnUpBlock2D","AttnUpBlock2D", # a ResNet upsampling block with spatial self-attention"UpBlock2D","UpBlock2D", # a regular ResNet upsampling block),  ).to(device)# 开始训练losses = train_loop(train_dataloader = train_dataloader, noise_scheduler = noise_scheduler,model = model,num_epoches = 30, device = device)fig = vis(losses)fig.savefig("./loss.png")# 生成一张图片gen_img = generate(model, noise_scheduler)gen_img.save("./generate1.png")# 随机初始化噪音生成图片sample = torch.randn(8, 3, 32, 32).to(device)for i, t in enumerate(noise_scheduler.timesteps): # 反向去噪# Get model predwith torch.no_grad():residual = model(sample, t).sample# Update sample with stepsample = noise_scheduler.step(residual, t, sample).prev_sample# 可视化生成的图片grid_im = show_images(sample)grid_im.save("./genearate2.png")print("All Done!")if __name__ == "__main__":main()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/834473.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

视频监控平台:交通运输标准JTT808设备SDK接入源代码函数分享

目录 一、JT/T 808标准简介 (一)概述 (二)协议特点 1、通信方式 2、鉴权机制 3、消息分类 (三)协议主要内容 1、位置信息 2、报警信息 3、车辆控制 4、数据转发 二、代码和解释 (一…

《ESP8266通信指南》13-Lua 简单入门(打印数据)

往期 《ESP8266通信指南》12-Lua 固件烧录-CSDN博客 《ESP8266通信指南》11-Lua开发环境配置-CSDN博客 《ESP8266通信指南》10-MQTT通信(Arduino开发)-CSDN博客 《ESP8266通信指南》9-TCP通信(Arudino开发)-CSDN博客 《ESP82…

水平滑动与垂直滑动菜单

水平滑动菜单 <!DOCTYPE html> <html><head><meta charset"utf-8"><title></title><style>*{margin: 0;padding: 0;}ul{background-color: #000;}ul li{text-shadow: none;display: inline-block;height: 40px;}ul li a{…

AJAX知识点(前后端交互技术)

原生AJAX AJAX全称为Asynchronous JavaScript And XML,就是异步的JS和XML&#xff0c;通过AJAX可以在浏览器中向服务器发送异步请求&#xff0c;最大的优势&#xff1a;无需刷新就可获取数据。 AJAX不是新的编程语言&#xff0c;而是一种将现有的标准组合在一起使用的新方式 …

二维数组的动态分配

C语言 malloc函数 一维数组 int* buf ( int * )malloc( 10*sizeof( int )); 二维数组 // 定义二维数组的行数和列数 int row 3; int col 4;// 动态分配一个能存放row个指针的数组&#xff0c;这些指针将分别指向每行的首地址 int** buf (int**)malloc(row * sizeof(int*…

C语言【文件操作 2】

文章目录 前言顺序读写函数的介绍fputc && fgetcfputcfgetc fputs && fgetsfputsfgets fprintf && fscanffprintffscanf fwrite && freadfwritefread 文件的随机读写fseek函数偏移量ftell函数rewind函数 文件的结束判断被错误使用的feof 结语 …

04_kafka_java-api

文章目录 使用api 实现 topic 增删改查pom.xmllog4j.properties创建、查询 Topic生产者、消费者 api自定义 生产者分区发送策略自定义序列化器自定义 生产者拦截器offset 提交控制确认-acks 与 重试-retries幂等消息生产者事务生产者消费者事务 04_kafka_java-api 使用api 实现…

Linux与windows网络管理

文章目录 一、TCP/IP1.1、TCP/IP概念TCP/IP是什么TCP/IP的作用TCP/IP的特点TCP/IP的工作原理 1.2、TCP/IP网络发展史1.3、OSI网络模型1.4、TCP/IP网络模型1.5、linux中配置网络网络配置文件位置DNS配置文件主机名配置文件常用网络查看命令 1.6、windows中配置网络CMD中网络常用…

认识卷积神经网络

我们现在开始了解卷积神经网络&#xff0c;卷积神经网络是深度学习在计算机视觉领域的突破性成果&#xff0c;在计算机视觉领域&#xff0c;往往我们输入的图像都很大&#xff0c;使用全连接网络的话&#xff0c;计算的代价较高&#xff0c;图像也很难保留原有的特征&#xff0…

python 和 MATLAB 都能绘制的母亲节花束!!

hey 母亲节快到了&#xff0c;教大家用python和MATLAB两种语言绘制花束~这段代码是我七夕节发的&#xff0c;我对代码进行了简化&#xff0c;同时自己整了个python版本 MATLAB 版本代码 function roseBouquet_M() % author : slandarer% 生成花朵数据 [xr,tr]meshgrid((0:24).…

tile 跟slice 是什么关系?一个tile可以包含多个slice吗?TILE在图形渲染中是什么概念?有什么作用

在H.264&#xff08;也称为AVC&#xff09;中&#xff0c;slice 和 tile 是两个与编码和解码过程相关的概念&#xff0c;但它们有着不同的用途和定义。 Slice&#xff1a; 一个slice是编码图像&#xff08;如帧或场&#xff09;的一部分。在H.264中&#xff0c;一幅图像可以被分…

鲁棒性能优化问题

鲁棒性能优化问题是一类基于不确定性和干扰的优化问题。鲁棒优化是一种在内部结构和外部环境不确定环境下进行优化的新方法。其核心思想是在考虑参数的不确定性或外部环境的扰动时&#xff0c;寻找一个最优解&#xff0c;该解在所有可能的情况下都能保持较好的性能。 鲁棒优化…

前端WebSocket

WebSocket定义 WebSocket是HTML5提供的一种浏览器与服务器进行全双工通讯的网络技术。是一种在Web浏览器和服务器之间建立持久连接的通信协议。它允许服务器主动向浏览器发送数据&#xff0c;而不需要浏览器发起请求。相比起传统的HTTP请求-响应模式&#xff0c;WebSocket能够…

力扣---三数之和(Java、模拟)

题目描述&#xff1a; 给定一个包含 n 个整数的数组 nums&#xff0c;判断 nums 中是否存在三个元素 a &#xff0c;b &#xff0c;c &#xff0c;使得 a b c 0 &#xff1f;请找出所有和为 0 且 不重复 的三元组。 示例 1&#xff1a; 输入&#xff1a;nums [-1,0,1,2,-…

jQuery-1.语法、选择器、节点操作

jQuery jQueryJavaScriptQuery&#xff0c;是一个JavaScript函数库&#xff0c;为编写JavaScript提供了更高效便捷的接口。 jQuery安装 去官网下载jQuery&#xff0c;1.x版本练习就够用 jQuery引用 <script src"lib/jquery-1.11.2.min.js"></script>…

我的Transformer专栏来啦

五一节前吹的牛&#xff0c;五一期间没完成&#xff0c;今天忙里偷闲&#xff0c;给完成了。 那就是初步拟定了一个《Transformer最后一公里》的写作大纲。 之前一直想写一系列Transformer架构的算法解析文章&#xff0c;但因为一直在忙&#xff08;虽然不知道在忙啥&#xf…

倍思|西圣开放式耳机哪个好用?热门机型深度测评!

在数字化生活的浪潮中&#xff0c;耳机已成为我们不可或缺的伴侣。然而&#xff0c;长时间佩戴传统的耳机容易导致的耳道疼痛等问题&#xff0c;严重的话将影响听力。许多人开始寻找更为舒适的佩戴体验。开放式耳机因为不需要需直接插入耳道的设计&#xff0c;逐渐受到大众的青…

Apipost使用心得,让接口文档变得更清晰,更快捷

Idea和Apipost结合使用 Idea 安装插件Apipost-Helper-2.0 在【file】–>【settings】–>【Plugins】搜索 “Apipost-Helper-2.0”–>【install】&#xff0c;重启Idea 编写controller接口 在idea中编写业务功能及接口之后&#xff0c;在controller中鼠标【右键】单…

Linux下的SPI通信

SPI通信 一. 1.SPI简介: SPI 是一种高速,全双工,同步串行总线。 SPI 有主从俩种模式通常由一个主设备和一个或者多个从设备组从。SPI不支持多主机。 SPI通信至少需要四根线,分别是 MISO(主设备数据输入,从设备输出),MOSI (主设数据输出从设备输入),SCLK(时钟信号),CS/SS…