一、搭建dataset
基于datasets这个库创建的dataloader,底层代码还待探索
二、修改模型结构(非必要)
尽量可以利用已有的预训练权重去训练模型,但是权重并不一定能够完全是适配,所以还需要自己来视情况做修改,未能加载预训练权重的那一部分参数必须要重新开始训练,不存在finetune一说
三、无条件样本生成
先搭建环境
train_unconditional.py这个代码待细看
train_unconditional.py中创建的unet
model = UNet2DModel(sample_size=args.resolution,in_channels=3,out_channels=3,layers_per_block=2,block_out_channels=(128, 128, 256, 256, 512, 512),down_block_types=("DownBlock2D","DownBlock2D","DownBlock2D","DownBlock2D","AttnDownBlock2D","DownBlock2D",),up_block_types=("UpBlock2D","AttnUpBlock2D","UpBlock2D","UpBlock2D","UpBlock2D","UpBlock2D",),
)
# Initialize the scheduler
accepts_prediction_type = "prediction_type" in set(inspect.signature(DDPMScheduler.__init__).parameters.keys())
if accepts_prediction_type:noise_scheduler = DDPMScheduler(num_train_timesteps=args.ddpm_num_steps,beta_schedule=args.ddpm_beta_schedule,prediction_type=args.prediction_type,)
else:noise_scheduler = DDPMScheduler(num_train_timesteps=args.ddpm_num_steps, beta_schedule=args.ddpm_beta_schedule)# Initialize the optimizer
optimizer = torch.optim.AdamW(model.parameters(),lr=args.learning_rate,betas=(args.adam_beta1, args.adam_beta2),weight_decay=args.adam_weight_decay,eps=args.adam_epsilon,
)
默认的优化器和采样器
dataset = load_dataset("imagefolder", data_dir=args.train_data_dir, cache_dir=args.cache_dir, split="train")augmentations = transforms.Compose([transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x),transforms.ToTensor(),transforms.Normalize([0.5], [0.5]),]
)
读取数据模块
accelerate launch train_unconditional.py \--dataset_name="huggan/flowers-102-categories" \--output_dir="ddpm-ema-flowers-64" \--mixed_precision="fp16" \--push_to_hub
# 单卡训accelerate launch --multi_gpu train_unconditional.py \--dataset_name="huggan/flowers-102-categories" \--output_dir="ddpm-ema-flowers-64" \--mixed_precision="fp16" \--push_to_hub
多卡训
预测代码(小疑问,这个路径咋确定的呢?)