git仓库:https://github.com/FoundationVision/LlamaGen
数据集准备
如果用ImageFolder读取,则最好和ImageNet一致。
data_path/class_1/image_001.jpgimage_002.jpg...class_2/image_003.jpgimage_004.jpg......class_n/image_005.jpgimage_006.jpg...
则
def build_imagenet(args, transform):return ImageFolder(args.data_path, transform=transform)
如果是train,val,test,最好整理成
data_path/train/class_1/image_001.jpgimage_002.jpg...class_2/image_003.jpgimage_004.jpg......val/class_1/image_005.jpgimage_006.jpg...class_2/image_007.jpgimage_008.jpg......test/class_1/image_009.jpgimage_010.jpg...class_2/image_011.jpgimage_012.jpg......
读取:
train_dataset = datasets.ImageFolder(root=args.data_path + '/train', transform=transform)# 加载验证集
val_dataset = datasets.ImageFolder(root=args.data_path + '/val', transform=transform)# 加载测试集
test_dataset = datasets.ImageFolder(root=args.data_path + '/test', transform=transform)
数据集预处理
NCCL_IB_DISABLE=1 NCCL_P2P_DISABLE=1 CUDA_VISIBLE_DEVICES=3 torchrun \
--nnodes=1 --nproc_per_node=1 --node_rank=0 \
--master_addr=localhost \
autoregressive/train/extract_codes_c2i.py \
--vq-ckpt ./pretrained_models/vq_ds16_c2i.pt \
--data-path 你的数据集 \
--code-path VQGAN处理的数据集放在哪 \--ten-crop \--crop-range 1.1 \--image-size 256
这里改成自己数据集的长度
ten-crop是作者定义的一种数据增强,每一个图片生成10个crop。最好修改一下这里的代码,训练的时候仅仅取一个。
注释掉这个self.flip
训练
NCCL_IB_DISABLE=1 NCCL_P2P_DISABLE=1 CUDA_VISIBLE_DEVICES=4,5 torchrun \
--nnodes=1 --nproc_per_node=2 --node_rank=0 \
--master_addr=localhost \
--master_port=8902 \
./autoregressive/train/train_c2i.py \
--cloud-save-path xxx \
--code-path 之前放VQGAN处理后数据集的地方 \
--image-size 256 \
--gpt-model GPT-B
生成
修改类别,权重
parser.add_argument("--num-classes", type=int, default=xxx)
label定义:
我的生成结果(数据集用了TinyImageNet的8个类)
300step
1500step