论文:Efficient Emotional Adaptation for Audio-Driven Talking-Head Generation
代码:yuangan/EAT_code: Official code for ICCV 2023 paper: "Efficient Emotional Adaptation for Audio-Driven Talking-Head Generation". (github.com)
1. 训练
1.1 A2KP Training
training A2KP transformer with latent and pca loss:pretrain_a2kp.py
if __name__ == "__main__":if sys.version_info[0] < 3:raise Exception("You must use Python 3 or higher. Recommended version is Python 3.7")# 调用parser.parse_args()来解析命令行参数,并将结果存储在opt变量中。parser = ArgumentParser()parser.add_argument("--config", default="config/vox-transformer.yaml", help="path to config")parser.add_argument("--mode", default="train", choices=["train",])parser.add_argument("--gen", default="spade", choices=["original", "spade"])parser.add_argument("--log_dir", default='./output/', help="path to log into")parser.add_argument("--checkpoint", default='./00000189-checkpoint.pth.tar', help="path to checkpoint to restore")#parser.add_argument("--device_ids", default="0, 1, 2, 3, 4, 5, 6, 7", type=lambda x: list(map(int, x.split(','))),parser.add_argument("--device_ids", default="0, 1", type=lambda x: list(map(int, x.split(','))),help="Names of the devices comma separated.")parser.add_argument("--verbose", dest="verbose", action="store_true", help="Print model architecture")parser.set_defaults(verbose=False)opt = parser.parse_args()# 打开配置文件,并使用yaml库加载配置文件中的内容,并将结果存储在config变量中。with open(opt.config) as f:config = yaml.load(f, Loader=yaml.FullLoader)# log dir when checkpoint is set# if opt.checkpoint is not None:# log_dir = os.path.join(*os.path.split(opt.checkpoint)[:-1])# else:# 根据配置文件的路径和当前时间生成一个日志目录。它使用了os.path模块来操作路径,并使用strftime函数生成日期和时间字符串。log_dir = os.path.join(opt.log_dir, os.path.basename(opt.config).split('.')[0])log_dir += ' ' + strftime("%d_%m_%y_%H.%M.%S", gmtime())# 根据选择的opt.gen参数创建不同类型的生成器模型对象。根据配置文件中的参数,调用相应的生成器类进行初始化。if opt.gen == 'original':generator = OcclusionAwareGenerator(**config['model_params']['generator_params'],**config['model_params']['common_params'])elif opt.gen == 'spade':generator = OcclusionAwareSPADEGenerator(**config['model_params']['generator_params'],**config['model_params']['common_params'])# 检查CUDA是否可用,并将生成器模型移动到指定的设备上(如果可用)。如果设置了verbose标志,则打印生成器模型的结构。if torch.cuda.is_available():print('cuda is available')generator.to(opt.device_ids[0])if opt.verbose:print(generator)discriminator = MultiScaleDiscriminator(**config['model_params']['discriminator_params'],**config['model_params']['common_params'])if torch.cuda.is_available():discriminator.to(opt.device_ids[0])if opt.verbose:print(discriminator)# 创建关键点检测器模型对象,并将其移动到指定的设备上(如果可用)。根据配置文件中的参数,调用相应的关键点检测器类进行初始化。kp_detector = KPDetector(**config['model_params']['kp_detector_params'],**config['model_params']['common_params'])if torch.cuda.is_available():kp_detector.to(opt.device_ids[0])if opt.verbose:print(kp_detector)# 创建音频到关键点转换器模型对象,并将其移动到指定的设备上(如果可用)。根据配置文件中的参数,调用相应的音频到关键点转换器类进行初始化。audio2kptransformer = Audio2kpTransformer(**config['model_params']['audio2kp_params'])if torch.cuda.is_available():audio2kptransformer.to(opt.device_ids[0])# 创建数据集对象。根据配置文件中的参数,调用相应的数据集类进行初始化。dataset = FramesWavsDatasetMEL25(is_train=(opt.mode == 'train'), **config['dataset_params'])if not os.path.exists(log_dir):os.makedirs(log_dir)if not os.path.exists(os.path.join(log_dir, os.path.basename(opt.config))):copy(opt.config, log_dir)# 根据opt.mode参数的值决定进行训练或其他操作。如果设置为"train",则调用train函数进行模型训练,传递所需的参数。if opt.mode == 'train':print("Training...")train(config, generator, discriminator, kp_detector, audio2kptransformer, opt.checkpoint, log_dir, dataset, opt.device_ids)
training A2KP transformer with all loss :pretrain_a2kp_img.py
大致同上
1.2 Emotional Adaptation Training
prompt_st_dp_eam3d.py:
大致同上
2. 数据处理
视频预处理:preprocess_video.py
from glob import glob
import os# 使用glob模块获取指定目录下所有的.mp4文件路径
allmp4s = glob('./video/*.mp4')# 设置目标文件夹路径,并确保该文件夹存在
path_fps25='./video_fps25'
os.makedirs(path_fps25, exist_ok=True)# 遍历每个.mp4文件
for mp4 in allmp4s:# 获取文件名(不带路径)name = os.path.basename(mp4)# 使用ffmpeg命令将视频转换为25帧每秒的视频,并设置音频参数os.system(f'ffmpeg -y -i {mp4} -filter:v fps=25 -ac 1 -ar 16000 -crf 10 {path_fps25}/{name}')# 使用ffmpeg命令将上一步生成的视频转换为.wav格式的音频文件os.system(f'ffmpeg -y -i {path_fps25}/{name} {path_fps25}/{name[:-4]}.wav')#============== extract lmk for crop =================
# 提取关键点信息用于裁剪
print('============== extract lmk for crop =================')
os.system(f'python extract_lmks_eat.py {path_fps25}')#======= extract speech in deepspeech_features =======
# 提取语音特征
print('======= extract speech in deepspeech_features =======')
os.chdir('./deepspeech_features/')
os.system(f'python extract_ds_features.py --input=../{path_fps25}')
os.chdir('../')
os.system('python deepfeature32.py')#=================== crop videos =====================
# 裁剪视频
print('=================== crop videos =====================')
os.chdir('./vid2vid/')
os.system('python data_preprocess.py --dataset_mode preprocess_eat')
os.chdir('../')#========== extract latent from cropped videos =======
#从裁剪的视频中提取潜在特征
print('========== extract latent from cropped videos =======')
os.system('python videos2img.py')
os.system('python latent_extractor.py')#=========== extract poseimg from latent =============
# 从潜在特征中提取姿势图像
print('=========== extract poseimg from latent =============')
os.system('python generate_poseimg.py')
之后,Extract the bbox for training:preprocess/extract_bbox.py
fa = face_detection.FaceAlignment(face_detection.LandmarksType._2D, flip_input=False, device='cuda')
#初始化了一个人脸对齐(Face Alignment)模型对象,CUDA加速。def detect_bbox(img_names):bboxs = []for img_name in img_names:img = img_as_float32(io.imread(img_name)).transpose((2, 0, 1))img = np.transpose(img[np.newaxis], (0,2,3,1))[...,::-1]bbox = fa.get_detections_for_batch(img*255)if bbox is not None:bboxs.append(bbox[0])else:bboxs.append(None)assert(len(bboxs)==len(img_names))return bboxs
这个函数用于检测一组图像中的人脸边界框。它接受一个图像文件名列表作为输入,并返回相应的人脸边界框列表。该函数首先加载图像文件,然后将其转换为指定的格式,并调用人脸对齐模型的get_detections_for_batch
方法来获取人脸边界框。如果检测到了人脸边界框,则将其添加到bboxs列表中;否则,将None添加到列表中。
def main(args):file_images = glob('/data2/gy/lrw/lrw_images/*')file_images.sort()p = args.partt = len(file_images)for fi in tqdm(file_images[t*p:t*(p+1)]):out = basename(fi)outpath =f'/data2/gy/lrw/lrw_bbox/{out}.npy'if exists(outpath):continueimages = glob(fi+'/*.jpg')images.sort()bboxs = detect_bbox(images)np.save(outpath, bboxs)if __name__ == "__main__":parser = ArgumentParser()parser.add_argument("--files", default="*", help="filenames")parser.add_argument("--part", default="0", type=int, help="part")args = parser.parse_args()main(args)
这部分代码定义了一个名为main
的函数,并在脚本最后调用该函数。main
函数从指定路径中获取图像文件列表,然后遍历每个图像文件夹。对于每个图像文件夹,它首先构建输出路径和文件名,检查是否已经存在保存的结果(如果存在,则跳过),然后获取图像文件列表并排序。接下来,它调用detect_bbox
函数来进行人脸边界框检测,并将结果保存到指定的输出路径中。
整体而言,这段代码利用人脸对齐模型进行人脸边界框检测,并将结果保存到.npy文件中。