《扩散模型 从原理到实战》Hugging Face (三)

第四章 Diffusers 实战

安装Difffusers 库

pip install -qq -U diffusers datasets transformers accelerate ftfy pyarrow

扩散模型调度器

from diffusers import DDPMScheduler
noise_scheduler = DDPMScheduler(num_train_timesteps=1000)

定义扩散模型

from diffusers import  UNet2DModeldef model():model = UNet2DModel(sample_size = 240,in_channels = 4,out_channels = 4,layers_per_block = 2,block_out_channels = (64,128,128,256),down_block_types=("DownBlock2D","DownBlock2D","AttnDownBlock2D","AttnDownBlock2D",),up_block_types=("AttnUpBlock2D","AttnUpBlock2D","UpBlock2D","UpBlock2D",))return model

创建扩散模型训练循环

import torch.utils.data.dataset
import torchvision
from dataset import dataset_brats_2D
from torchvision import transforms
from diffusers import DDPMScheduler
import model
from torch.utils.data import DataLoader
import torch.nn.functional as F
import os
import timeif __name__ == "__main__":device = torch.device('cuda')dataset = #自定义datasettrain_dl = DataLoader(dataset, 128, False, num_workers=1)timesteps = torch.linspace(0, 1000, 2).long().to(device)model = model.model().to(device)model = torch.nn.DataParallel(model, device_ids=[0])noise_scheduler = DDPMScheduler(num_train_timesteps=1000, beta_schedule="squaredcos_cap_v2")optimizer = torch.optim.AdamW(model.parameters(), lr=4e-4)losses = []loss_flag = 10e+10for epoch in range(100):for step, batch in enumerate(train_dl):clean_images = batch.to(device)noise = torch.randn(clean_images.shape).to(device)batch_size = clean_images.shape[0]timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (batch_size,),device=device).long()noisy_images = noise_scheduler.add_noise(clean_images,noise,timesteps)noisy_pred = model(noisy_images,timesteps,return_dict=False)[0]loss = F.mse_loss(noisy_pred, noise)loss.backward(loss)losses.append(loss.item())optimizer.step()optimizer.zero_grad()if (epoch +1) % 5 == 0:loss_last_epoch = sum(losses[-len(train_dl) :]) / len(train_dl)print(f"Epoch:{epoch + 1}, loss:{loss_last_epoch}")state = {'net': model.state_dict(), 'optimizer': optimizer.state_dict(), 'epoch': epoch}if loss_flag < loss:torch.save(state,"best.pth")loss_flag = loss

图像的生成

import time
import torchvision.utils
from diffusers import DDPMPipeline,DDPMScheduler
import cv2
import torch
import torchvision
from PIL import Image
import model
import numpy as np
import time
def show_images(x):x = x * 0.5 + 0.5grid = torchvision.utils.make_grid(x)grid_im = grid.detach().cpu().permute(1,2,0).clip(0,1) *255grid_im = Image.fromarray(np.array(grid_im).astype(np.uint8))return grid_imif __name__ == "__main__":device = torch.device('cuda')sample = torch.randn(1, 4, 240, 240).to(device)model = model.model().to(device)ckpt = torch.load(r"")#自己的checkpointmodel.load_state_dict({k.replace('module.', ''): v for k, v in ckpt['net'].items() if k.startswith('module.')})noise_scheduler = DDPMScheduler(num_train_timesteps=1000, beta_schedule="squaredcos_cap_v2")flag = 0for k in range(10000):# start = time.time()sample = torch.randn(1, 4, 240, 240).to(device)for i,t in enumerate(noise_scheduler.timesteps):print(t)with torch.no_grad():residual = model(sample,t).samplesample = noise_scheduler.step(residual, t, sample).prev_sampletime_flag = time.time()print(sample.shape)image = show_images(sample[0][0])image.save(str(time_flag) +'_0'+'.png')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/29969.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

二维码展示与下载

1. 二维码展示(前端展示) 基于vue-qr组件实现 下载vue-qr组件 npm install vue-qr --save页面vueqr组件导入 import VueQr from vue-qr;export default {components: {VueQr} }页面展示 <el

VS安装FFmpeg库

在Visual Studio中安装FFmpeg库通常涉及以下步骤: 下载FFmpeg:访问FFmpeg官网(https://ffmpeg.org/download.html)下载对应于您的操作系统的预编译二进制文件。 解压FFmpeg:将下载的压缩包解压到您选择的目录。 配置系统环境变量:将FFmpeg的bin目录添加到系统的PATH环境变…

JimuReport 积木报表 v1.7.6 版本发布,免费的低代码报表

项目介绍 一款免费的数据可视化报表工具&#xff0c;含报表和大屏设计&#xff0c;像搭建积木一样在线设计报表&#xff01;功能涵盖&#xff0c;数据报表、打印设计、图表报表、大屏设计等&#xff01; Web 版报表设计器&#xff0c;类似于excel操作风格&#xff0c;通过拖拽完…

Python构造TCP三次握手、传输数据、四次挥手pcap数据包并打乱顺序

Python构造数据包&#xff0c;包含&#xff1a; TCP三次握手、 传输数据、 四次挥手 实现 随机乱序TCP数据包 from scapy.all import * from scapy.all import Ether, IP, TCP, UDP, wrpcap from abc import ABC, abstractmethod import random import dpkt from scapy.all…

6月18日(周二)美股行情总结:纳指七日连创新高,英伟达市值全球第一,苹果微软回落,油价七周最高

美国5月零售销售意外走软&#xff0c;尽管一众美联储官员均鹰派发声支持多等待通胀数据再做决策&#xff0c;市场仍抬升对年内降息两次的押注。标普500指数在七天里第六天上涨并再创新高&#xff0c;标普科技板块连续七天创新高、期间累涨8.6%&#xff0c;道指一周高位&#xf…

MySQL----慢查询日志

慢日志 MySQL可以设置慢查询日志&#xff0c;当SQL执行的时间超过我们设定的时间&#xff0c;那么这些SQL就会被记录在慢查询日志当中&#xff0c;然后我们通过查看日志&#xff0c;用explain分析这些SQL的执行计划&#xff0c;来判定为什么效率低下。 查看相关信息 show va…

Oracle Database 23ai 创建新用户

Oracle Database 23ai 创建新用户 1. 创建新用户2. 配置 cohere 认证 1. 创建新用户 sqlplus syslocalhost:1521/orclpdb1 as sysdbaCreate bigfile tablespace tbs100 Datafile bigtbs_f100.dbf SIZE 1G AUTOEXTEND ON next 32m maxsize unlimited extent management local …

iOS 18 终于更新了 iOS 隐藏 App 功能,这次是真的隐藏

如何锁定或隐藏 App 我们一起来看看 iOS 如何隐藏软件&#xff0c;下面是具体的操作步骤&#xff1a; iOS 隐藏 App 的第一步肯定是找到你想隐藏或锁定的应用程序&#xff0c;然后长按它的图标&#xff0c;在长按之后出现的选项中我们选择“需要 Face ID”。 然后在新弹出的选…

web版的数字孪生,选择three.js、unity3D、还是UE4

数字孪生分为客户端版和web端版&#xff0c;开发引擎多种多用&#xff0c;本文重点分析web端版采用哪种引擎最合适&#xff0c; 贝格前端工场结合实际经验和网上主流说法&#xff0c;为您讲解。 一、数字孪生的web版和桌面版 数字孪生的Web版和桌面版是数字孪生技术在不同平台…

Mamba: Linear-Time Sequence Modeling with Selective State Spaces论文笔记

文章目录 Mamba: Linear-Time Sequence Modeling with Selective State Spaces摘要引言 相关工作(SSMs)离散化计算线性时间不变性(LTI)结构和尺寸一般状态空间模型SSMs架构S4(补充)离散数据的连续化: 基于零阶保持技术做连续化并采样循环结构表示: 方便快速推理卷积结构表示: 方…

对SpringBoot入门案例的关键点

我们SpringBoot的入门案例中&#xff0c;即做了两个重要工作&#xff1a; 配置pom.xml文件写启动类 1.pom.xml依赖配置文件 ①帮助我们进行版本控制的父模块 <parent><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter…

Inpaint_2024软件最新版下载-inpaint下载安装2024-inpaint下载最新版本

众多使用者向我们证明了高效去除背景无关游客&#xff0c;只需要花费几秒钟在照片上选择不必要的对象或人员&#xff0c;剩下的交给Inpaint。准确来讲快速去水印&#xff0c;用Inpaint,选中水印&#xff0c;一键清除&#xff0c;还你一个干净整洁的图形。我们都知道快速去水印&…

【2024】kafka streams的详细使用与案例练习(2)

目录 前言使用1、整体结构1.1、序列化 2、 Kafka Streams 常用的 API2.1、 StreamsBuilder2.2、 KStream 和 KTable2.3、 filter和 filterNot2.4、 map 和 mapValues2.5、 flatMap 和 flatMapValues2.6、 groupByKey 和 groupBy2.7、 count、reduce 和 aggregate2.8、 join 和 …

基于EasyAnimate模型的视频生成最佳实践

EasyAnimate是阿里云PAI平台自主研发的DiT的视频生成框架&#xff0c;它提供了完整的高清长视频生成解决方案&#xff0c;包括视频数据预处理、VAE训练、DiT训练、模型推理和模型评测等。本文为您介绍如何在PAI平台集成EasyAnimate并一键完成模型推理、微调及部署的实践流程。 …

shader的优化,specialization constants

volkan specialization_constants 与Uniform buffer objects (UBOs) 和 Push constants不同的是 specialization constants 可以在shader编译前设置控制量&#xff0c;从而能够删除无用代码和静态展开循环( remove unused code blocks and statically unroll)。不但缩减shader…

【Python特征工程系列】基于方差分析的特征重要性分析(案例+源码)

这是我的第304篇原创文章。 一、引言 方差分析&#xff08;Analysis of Variance&#xff0c;简称ANOVA&#xff09;是一种统计方法&#xff0c;用于比较两个或多个组之间的平均值是否存在显著差异。 方法简介&#xff1a; ANOVA 通过分解总方差为组间方差和组内方差&#x…

MySQL入门学习.子查询.IN

IN 子查询是 MySQL 中一种常见的子查询类型&#xff0c;用于在查询中确定一个值是否在另一个查询的结果集中。IN 子查询的特点是简洁明了&#xff0c;它可以在一个查询中方便地检查一个值是否在一组值中&#xff0c;非常适用于需要进行条件验证或关联查询的情况。 在 MySQL 中&…

怪物猎人物语什么时候上线?游戏售价多少?

怪物猎人物语是一款全新的RPG游戏&#xff0c;玩家在游戏中将化身为骑士&#xff0c;不断与怪物建立羁绊、不断成长&#xff0c;踏上前往外面世界的旅程&#xff0c;且最终目的地是以狩猎怪物为生的猎人世界。因为最近有不少玩家在关注这款游戏&#xff0c;所以下面就给大家分享…