支持4K高分辨率,PixArt-Sigma最新文生图落地经验

PixArt-Sigma是由华为诺亚方舟实验室、大连理工大学和香港大学的研究人员共同开发的一个先进的文本到图像(Text-to-Image,T2I)生成模型。

PixArt-Sigma是在PixArt-alpha的基础上进一步改进的模型,旨在生成高质量的4K分辨率图像。

PixArt-Sigma通过整合高级元素和采用由弱到强式训练方法,这种策略有助于模型逐渐学习并优化图像细节,从而提高了生成图像的保真度和与文本提示的对齐程度。

PixArt-Sigma在美学质量上与当前顶级的文本到图像产品如DALL·E 3和Midjourney V6不相上下,并且在遵循文本提示方面表现出色。

PixArt-Sigma的生成能力支持高分辨率海报和壁纸的创作,有效支持电影和游戏等行业高质量视觉内容的制作。

github项目地址:https://github.com/PixArt-alpha/PixArt-sigma。

一、环境安装

1、python环境

建议安装python版本在3.10以上。

2、pip库安装

pip install torch==2.3.0+cu118 torchvision==0.18.0+cu118 torchaudio==2.3.0 --extra-index-url https://download.pytorch.org/whl/cu118

pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple

3、SDXL-VAE模型下载

git lfs install

git clone https://huggingface.co/PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers

3、PixArt-Sigma模型下载

python tools/download.py

、功能测试

1、命令行运行测试

(1)python代码调用测试
 

import os
import re
import sys
import argparse
from datetime import datetime
from pathlib import Pathimport torch
from torch import nn
from torchvision.utils import save_image
from tqdm import tqdm
from diffusers.models import AutoencoderKL
from transformers import T5EncoderModel, T5Tokenizerfrom diffusion.model.utils import prepare_prompt_ar
from diffusion import IDDPM, DPMS, SASolverSampler
from diffusion.model.nets import PixArtMS_XL_2, PixArt_XL_2
from diffusion.data.datasets import get_chunks
import diffusion.data.datasets.utils as ds_utils
from tools.download import find_modelclass ImageGenerator:def __init__(self, args):self.args = argsself.device = "cuda" if torch.cuda.is_available() else "cpu"self.seed = args.seedself._set_env()self._load_model_components()def _set_env(self):torch.manual_seed(self.seed)torch.set_grad_enabled(False)for _ in range(30):torch.randn(1, 4, self.args.image_size, self.args.image_size)def _load_model_components(self):self.latent_size = self.args.image_size // 8self.max_sequence_length = {"alpha": 120, "sigma": 300}[self.args.version]self.pe_interpolation = self.args.image_size / 512self.micro_condition = self.args.version == 'alpha' and self.args.image_size == 1024self.sample_steps_dict = {'iddpm': 100, 'dpm-solver': 20, 'sa-solver': 25}self.sample_steps = self.args.step if self.args.step != -1 else self.sample_steps_dict[self.args.sampling_algo]self.weight_dtype = torch.float16self._load_main_model()self._load_vae()self._load_text_components()def _load_main_model(self):if self.args.image_size in [512, 1024, 2048] or self.args.version == 'sigma':self.model = PixArtMS_XL_2(input_size=self.latent_size,pe_interpolation=self.pe_interpolation,micro_condition=self.micro_condition,model_max_length=self.max_sequence_length,).to(self.device)else:self.model = PixArt_XL_2(input_size=self.latent_size,pe_interpolation=self.pe_interpolation,model_max_length=self.max_sequence_length,).to(self.device)print("Generating sample from ckpt: %s" % self.args.model_path)state_dict = find_model(self.args.model_path)state_dict['state_dict'].pop('pos_embed', None)missing, unexpected = self.model.load_state_dict(state_dict['state_dict'], strict=False)print('Missing keys: ', missing)print('Unexpected keys', unexpected)self.model.eval()self.model.to(self.weight_dtype)self.base_ratios = getattr(ds_utils, f'ASPECT_RATIO_{self.args.image_size}', ds_utils.ASPECT_RATIO_1024)def _load_vae(self):vae_path = "output/pretrained_models/sd-vae-ft-ema" if self.args.sdvae else f"{self.args.pipeline_load_from}/vae"self.vae = AutoencoderKL.from_pretrained(vae_path).to(self.device).to(self.weight_dtype)def _load_text_components(self):self.tokenizer = T5Tokenizer.from_pretrained(self.args.pipeline_load_from, subfolder="tokenizer")self.text_encoder = T5EncoderModel.from_pretrained(self.args.pipeline_load_from, subfolder="text_encoder").to(self.device)null_caption_token = self.tokenizer("", max_length=self.max_sequence_length, padding="max_length", truncation=True, return_tensors="pt").to(self.device)self.null_caption_embs = self.text_encoder(null_caption_token.input_ids, attention_mask=null_caption_token.attention_mask)[0]def generate_images(self, items: list):save_root = self._prepare_save_directory()self._visualize(items, save_root)def _prepare_save_directory(self):work_dir = 'output'try:epoch_name = re.search(r'.*epoch_(\d+).*', self.args.model_path).group(1)step_name = re.search(r'.*step_(\d+).*', self.args.model_path).group(1)except:epoch_name = 'unknown'step_name = 'unknown'img_save_dir = os.path.join(work_dir, 'vis')os.umask(0o000)  # file permission: 666; dir permission: 777os.makedirs(img_save_dir, exist_ok=True)save_root = os.path.join(img_save_dir, f"{datetime.now().date()}_{self.args.dataset}_epoch{epoch_name}_step{step_name}_scale{self.args.cfg_scale}_step{self.sample_steps}_size{self.args.image_size}_bs{self.args.bs}_samp{self.args.sampling_algo}_seed{self.seed}")print("save_root: ", save_root)os.makedirs(save_root, exist_ok=True)return save_root@torch.inference_mode()def _visualize(self, items, save_root):for chunk in tqdm(list(get_chunks(items, self.args.bs)), unit='batch'):prompts, hw, ar = self._prepare_prompts_and_configs(chunk)caption_embs, emb_masks, null_y = self._get_text_embeddings(prompts)with torch.no_grad():samples = self._run_sampling(hw, ar, caption_embs, emb_masks, null_y)self._save_images(samples, save_root)def _prepare_prompts_and_configs(self, chunk):prompts = []if self.args.bs == 1:timestamp = datetime.now().strftime("%Y%m%d%H%M%S")save_path = os.path.join(save_root, f"{timestamp}.jpg")if os.path.exists(save_path):returnprompt_clean, _, hw, ar, custom_hw = prepare_prompt_ar(chunk[0], self.base_ratios, device=self.device, show=False)latent_size_h, latent_size_w = int(hw[0, 0] // 8), int(hw[0, 1] // 8)prompts.append(prompt_clean.strip())else:hw = torch.tensor([[self.args.image_size, self.args.image_size]], dtype=torch.float, device=self.device).repeat(self.args.bs, 1)ar = torch.tensor([[1.]], device=self.device).repeat(self.args.bs, 1)for prompt in chunk:prompts.append(prepare_prompt_ar(prompt, self.base_ratios, device=self.device, show=False)[0].strip())latent_size_h, latent_size_w = self.latent_size, self.latent_sizereturn prompts, hw, ardef _get_text_embeddings(self, prompts):caption_token = self.tokenizer(prompts, max_length=self.max_sequence_length, padding="max_length", truncation=True, return_tensors="pt").to(self.device)caption_embs = self.text_encoder(caption_token.input_ids, attention_mask=caption_token.attention_mask)[0]emb_masks = caption_token.attention_maskcaption_embs = caption_embs[:, None]null_y = self.null_caption_embs.repeat(len(prompts), 1, 1)[:, None]print(f'finish embedding')return caption_embs, emb_masks, null_ydef _run_sampling(self, hw, ar, caption_embs, emb_masks, null_y):model_kwargs = dict(data_info={'img_hw': hw, 'aspect_ratio': ar}, mask=emb_masks)if self.args.sampling_algo == 'iddpm':z = torch.randn(len(prompts), 4, latent_size_h, latent_size_w, device=self.device).repeat(2, 1, 1, 1)model_kwargs['y'] = torch.cat([caption_embs, null_y])model_kwargs['cfg_scale'] = self.args.cfg_scalediffusion = IDDPM(str(self.sample_steps))samples = diffusion.p_sample_loop(self.model.forward_with_cfg, z.shape, z, clip_denoised=False, model_kwargs=model_kwargs, progress=True,device=self.device)samples, _ = samples.chunk(2, dim=0)elif self.args.sampling_algo == 'dpm-solver':z = torch.randn(len(prompts), 4, latent_size_h, latent_size_w, device=self.device)dpm_solver = DPMS(self.model.forward_with_dpmsolver,condition=caption_embs,uncondition=null_y,cfg_scale=self.args.cfg_scale,model_kwargs=model_kwargs)samples = dpm_solver.sample(z,steps=self.sample_steps,order=2,skip_type="time_uniform",method="multistep",)elif self.args.sampling_algo == 'sa-solver':sa_solver = SASolverSampler(self.model.forward_with_dpmsolver, device=self.device)samples = sa_solver.sample(S=25,batch_size=len(prompts),shape=(4, latent_size_h, latent_size_w),eta=1,conditioning=caption_embs,unconditional_conditioning=null_y,unconditional_guidance_scale=self.args.cfg_scale,model_kwargs=model_kwargs,)[0]samples = samples.to(self.weight_dtype)samples = self.vae.decode(samples / self.vae.config.scaling_factor).sampletorch.cuda.empty_cache()return samplesdef _save_images(self, samples, save_root):os.umask(0o000)  # file permission: 666; dir permission: 777for i, sample in enumerate(samples):timestamp = datetime.now().strftime("%Y%m%d%H%M%S")save_path = os.path.join(save_root, f"{timestamp}.jpg")print("Saving path: ", save_path)save_image(sample, save_path, nrow=1, normalize=True, value_range=(-1, 1))def get_args():parser = argparse.ArgumentParser()parser.add_argument('--image_size', default=1024, type=int)parser.add_argument('--version', default='sigma', type=str)parser.add_argument("--pipeline_load_from", default='PixArt-sigma-model/pixart_sigma_sdxlvae_T5_diffusers',type=str, help="Download for loading text_encoder, ""tokenizer and vae from https://huggingface.co/PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers")parser.add_argument('--txt_file', default='asset/test.txt', type=str)parser.add_argument('--model_path', default='PixArt-sigma-model/PixArt-Sigma-XL-2-1024-MS.pth', type=str)parser.add_argument('--sdvae', action='store_true', help='sd vae')parser.add_argument('--bs', default=1, type=int)parser.add_argument('--cfg_scale', default=4.5, type=float)parser.add_argument('--sampling_algo', default='dpm-solver', type=str, choices=['iddpm', 'dpm-solver', 'sa-solver'])parser.add_argument('--seed', default=0, type=int)parser.add_argument('--dataset', default='custom', type=str)parser.add_argument('--step', default=-1, type=int)parser.add_argument('--save_name', default='test_sample', type=str)return parser.parse_args()if __name__ == '__main__':args = get_args()generator = ImageGenerator(args)with open(args.txt_file, 'r') as f:items = [item.strip() for item in f.readlines()]generator.generate_images(items)

(2)web端测试

未完......

更多详细的内容欢迎关注:杰哥新技术
 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/49060.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

2024牛客暑期多校第四场

A-LCT 带权并查集&#xff0c;维护一下每个点在当前树的深度和以它为根能找到的最深的深度。‘ #include<bits/stdc.h>using namespace std; typedef long long ll; const int N 1e6 100;int fa[N],ans[N],val[N];int find(int x){if(fa[x]x)return x;int tfa[x];fa[x…

C++初学(3)

面向对象编程&#xff08;OOP&#xff09;的本质是设计并拓展自己的数据类型&#xff0c;设计自己的数据类型就是让类型与数据匹配。内置的C类型分为两组&#xff1a;基本类型和复合类型。这里我们将介绍基本类型的整数和浮点数 3.1、简单变量 3.1.1、变量名 C必须遵循几种简…

场外期权如何报价?名义本金是什么?

今天带你了解场外期权如何报价&#xff1f;名义本金是什么&#xff1f;投资者首先需要挑选自己想要进行期权交易的沪深上市公司股票。选出股票后&#xff0c;需要将股票信息、预期的操作时间&#xff08;如期限&#xff09;、看涨或看跌的选择以及预计的交易金额等信息报给场外…

计算机网络(四)数字签名和CA认证

什么是数字签名和CA认证&#xff1f; 数字签名 数字签名的过程通常涉及以下几个步骤&#xff1a; 信息哈希&#xff1a;首先&#xff0c;发送方使用一个哈希函数&#xff08;如SHA-256&#xff09;对要发送的信息&#xff08;如电子邮件、文件等&#xff09;生成一个固定长度…

全链路追踪 性能监控,GO 应用可观测全面升级

作者&#xff1a;古琦 01 介绍 随着 Kubernetes 和容器化技术的普及&#xff0c;Go 语言不仅在云原生基础组件领域广泛应用&#xff0c;也在各类业务场景中占据了重要地位。如今&#xff0c;越来越多的新兴业务选择 Golang 作为首选编程语言。得益于丰富的 RPC 框架&#xff…

Golang实现Word模板内容填充导出

这里我们使用一个广泛使用且免费处理 .docx 文件的库&#xff0c;github.com/nguyenthenguyen/docx. 安装 github.com/nguyenthenguyen/docx 库 首先&#xff0c;确保你已经安装了 docx 库&#xff1a; go get github.com/nguyenthenguyen/docx使用 docx 库处理 Word 模板 …

ubuntu实践

目录 扩容 本机上ping不通新建立的虚拟机 ssh连接 装sshd ssh客户端版本较低&#xff0c;会报key exchange算法不匹配问题 ubuntun上装docker 将centos7下的安装包改造成适配 ubuntu的包 参考文章 扩容 Hyper-V 管理器安装的ubutun扩容磁盘空间说明_hype-v磁盘扩容-…

复现open-mmlab的mmsegmentation详细细节

复现open-mmlab的mmsegmentation详细细节 1.配置环境2.数据处理3.训练 1.配置环境 stage1&#xff1a;创建python环境 conda create --name openmmlab python3.8 -y conda activate openmmlabstage2&#xff1a;安装pytorch&#xff08;这里我是以torch1.10.0为例&#xff09…

VINS-Fusion 回环检测pose_graph_node

VINS-Fusion回环检测,在节点pose_graph_node中启动。 pose_graph_node总体流程如下: 重点看process线程。 process线程中,将订阅的图像、点云、位姿时间戳对齐,对齐后分别存入image_msg、point_msg、pose_msg。pose_msg为VIO后端优化发布的位姿。 一、创建关键帧keyFram…

mac|安装PostgreSQL

1、官网下载&#xff1a;EDB: Open-Source, Enterprise Postgres Database Management 选择需要的版本&#xff1a; 双击得到的.dmg文件 双击&#xff0c;弹窗选择打开&#xff0c;一路next&#xff0c;然后输入你要设置的密码&#xff0c;默认账号名字为&#xff1a;postgres…

项目一缓存商品

文章目录 概要整体架构流程技术细节小结 概要 因为商品是经常被浏览的,所以数据库的访问量就问大大增加,造成负载过大影响性能,所以我们需要把商品缓存到redis当中,因为redis是存在内存中的,所以效率会比MySQL的快. 整体架构流程 技术细节 我们在缓存时需要保持数据的一致性所…

面试场景题系列--(2)短 URL 生成器设计:百亿短 URL 怎样做到无冲突?--xunznux

文章目录 面试场景题&#xff1a;短 URL 生成器设计&#xff1a;百亿短 URL 怎样做到无冲突&#xff1f;1. 需求分析2. 短链接生成算法2.1 自增法2.2 散列函数法2.3 预生成法 3. 部署模型3.1 其他部署方案 4. 设计4.1 重定向响应码4.2 短 URL 预生成文件及预加载4.3 用户自定义…

个人百度百科怎么创建?

百度百科词条分为企业词条、品牌词条、人物词条等&#xff0c;个人百度百科创建的需求量很大&#xff0c;各式各样的人物需求都有。现在凡是要推广个人的人&#xff0c;创建百度百科都是其中一个必要的步骤。 作为一个有知名度的人物&#xff0c;拥有一个百度百科从侧面也证明了…

基于微信小程序+SpringBoot+Vue的自习室选座与门禁系统(带1w+文档)

基于微信小程序SpringBootVue的自习室选座与门禁系统(带1w文档) 基于微信小程序SpringBootVue的自习室选座与门禁系统(带1w文档) 本课题研究的研学自习室选座与门禁系统让用户在小程序端查看座位&#xff0c;预定座位&#xff0c;支付座位价格&#xff0c;该系统让用户预定座位…

CentOS搭建Apache服务器

安装对应的软件包 [roothds ~]# yum install httpd mod_ssl -y 查看防火墙的状态和selinux [roothds ~]# systemctl status firewalld [roothds ~]# cat /etc/selinux/config 若未关闭&#xff0c;则关闭防火墙和selinux [roothds ~]# systemctl stop firewalld [roothds ~]# …

ARM功耗管理之autosleep和睡眠锁实验

安全之安全(security)博客目录导读 ARM功耗管理精讲与实战汇总参见&#xff1a;Arm功耗管理精讲与实战 思考&#xff1a;睡眠唤醒实验&#xff1f;压力测试&#xff1f;Suspend-to-Idle/RAM/Disk演示&#xff1f; 1、实验环境准备 2、软件代码准备 3、唤醒源 4、Suspend-…

18.jdk源码阅读之CopyOnWriteArrayList

1. 写在前面 CopyOnWriteArrayList 是 Java 中的一种线程安全的 List 实现&#xff0c;基于“写时复制”&#xff08;Copy-On-Write&#xff09;机制。下面几个问题大家可以先思考下&#xff0c;在阅读源码的过程中都会解答&#xff1a; CopyOnWriteArrayList 适用于哪些场景…

Profinet转ModbusTCP网关模块的配置与应用详解

Profinet转ModbusTCP网关模块&#xff08;XD-ETHPN20&#xff09;是一种常见的工业通信设备&#xff0c;广泛应用于现代工业自动化系统中。通过使用Profinet转Modbus TCP网关模块&#xff08;XD-ETHPN20&#xff09;将Profinet协议转换成Modbus TCP协议&#xff0c;实现了不同网…

抓包工具Charles

1、抓包的目的 遇到问题需要进行分析 发现bug需要定位 检查数据传输的安全性 接口测试时&#xff0c;开发给的需求文档不详细 在弱网环境下APP的测试 2、Charles是java语言编写的程序&#xff0c;本质是一个代理服务器&#xff0c;通过拦截服务端和客户端的http请求&#xff0…

【SpringCloud】企业认证、分布式事务,分布式锁方案落地-2

目录 高并发缓存三问 - 穿透 缓存穿透 概念 现象举例 解决方案 缓存穿透 - 预热架构 缓存穿透 - 布隆过滤器 布隆过滤器 布隆过滤器基本思想​编辑 了解 高并发缓存三问 - 击穿 缓存击穿 高并发缓存三问 - 雪崩 缓存雪崩 解决方案 总结 为什么要使用数据字典&…