【代码学习】EAT复现+代码分析

论文:Efficient Emotional Adaptation for Audio-Driven Talking-Head Generation

代码:yuangan/EAT_code: Official code for ICCV 2023 paper: "Efficient Emotional Adaptation for Audio-Driven Talking-Head Generation". (github.com)

1. 训练

1.1 A2KP Training

training A2KP transformer with latent and pca loss:pretrain_a2kp.py

if __name__ == "__main__":if sys.version_info[0] < 3:raise Exception("You must use Python 3 or higher. Recommended version is Python 3.7")# 调用parser.parse_args()来解析命令行参数,并将结果存储在opt变量中。parser = ArgumentParser()parser.add_argument("--config", default="config/vox-transformer.yaml", help="path to config")parser.add_argument("--mode", default="train", choices=["train",])parser.add_argument("--gen", default="spade", choices=["original", "spade"])parser.add_argument("--log_dir", default='./output/', help="path to log into")parser.add_argument("--checkpoint", default='./00000189-checkpoint.pth.tar', help="path to checkpoint to restore")#parser.add_argument("--device_ids", default="0, 1, 2, 3, 4, 5, 6, 7", type=lambda x: list(map(int, x.split(','))),parser.add_argument("--device_ids", default="0, 1", type=lambda x: list(map(int, x.split(','))),help="Names of the devices comma separated.")parser.add_argument("--verbose", dest="verbose", action="store_true", help="Print model architecture")parser.set_defaults(verbose=False)opt = parser.parse_args()# 打开配置文件,并使用yaml库加载配置文件中的内容,并将结果存储在config变量中。with open(opt.config) as f:config = yaml.load(f, Loader=yaml.FullLoader)# log dir when checkpoint is set# if opt.checkpoint is not None:#     log_dir = os.path.join(*os.path.split(opt.checkpoint)[:-1])# else:# 根据配置文件的路径和当前时间生成一个日志目录。它使用了os.path模块来操作路径,并使用strftime函数生成日期和时间字符串。log_dir = os.path.join(opt.log_dir, os.path.basename(opt.config).split('.')[0])log_dir += ' ' + strftime("%d_%m_%y_%H.%M.%S", gmtime())# 根据选择的opt.gen参数创建不同类型的生成器模型对象。根据配置文件中的参数,调用相应的生成器类进行初始化。if opt.gen == 'original':generator = OcclusionAwareGenerator(**config['model_params']['generator_params'],**config['model_params']['common_params'])elif opt.gen == 'spade':generator = OcclusionAwareSPADEGenerator(**config['model_params']['generator_params'],**config['model_params']['common_params'])# 检查CUDA是否可用,并将生成器模型移动到指定的设备上(如果可用)。如果设置了verbose标志,则打印生成器模型的结构。if torch.cuda.is_available():print('cuda is available')generator.to(opt.device_ids[0])if opt.verbose:print(generator)discriminator = MultiScaleDiscriminator(**config['model_params']['discriminator_params'],**config['model_params']['common_params'])if torch.cuda.is_available():discriminator.to(opt.device_ids[0])if opt.verbose:print(discriminator)# 创建关键点检测器模型对象,并将其移动到指定的设备上(如果可用)。根据配置文件中的参数,调用相应的关键点检测器类进行初始化。kp_detector = KPDetector(**config['model_params']['kp_detector_params'],**config['model_params']['common_params'])if torch.cuda.is_available():kp_detector.to(opt.device_ids[0])if opt.verbose:print(kp_detector)# 创建音频到关键点转换器模型对象,并将其移动到指定的设备上(如果可用)。根据配置文件中的参数,调用相应的音频到关键点转换器类进行初始化。audio2kptransformer = Audio2kpTransformer(**config['model_params']['audio2kp_params'])if torch.cuda.is_available():audio2kptransformer.to(opt.device_ids[0])# 创建数据集对象。根据配置文件中的参数,调用相应的数据集类进行初始化。dataset = FramesWavsDatasetMEL25(is_train=(opt.mode == 'train'), **config['dataset_params'])if not os.path.exists(log_dir):os.makedirs(log_dir)if not os.path.exists(os.path.join(log_dir, os.path.basename(opt.config))):copy(opt.config, log_dir)# 根据opt.mode参数的值决定进行训练或其他操作。如果设置为"train",则调用train函数进行模型训练,传递所需的参数。if opt.mode == 'train':print("Training...")train(config, generator, discriminator, kp_detector, audio2kptransformer, opt.checkpoint, log_dir, dataset, opt.device_ids)

training A2KP transformer with all loss :pretrain_a2kp_img.py

大致同上

1.2 Emotional Adaptation Training

prompt_st_dp_eam3d.py:

大致同上

2. 数据处理

视频预处理:preprocess_video.py

from glob import glob
import os# 使用glob模块获取指定目录下所有的.mp4文件路径
allmp4s = glob('./video/*.mp4')# 设置目标文件夹路径,并确保该文件夹存在
path_fps25='./video_fps25'
os.makedirs(path_fps25, exist_ok=True)# 遍历每个.mp4文件
for mp4 in allmp4s:# 获取文件名(不带路径)name = os.path.basename(mp4)# 使用ffmpeg命令将视频转换为25帧每秒的视频,并设置音频参数os.system(f'ffmpeg -y -i {mp4} -filter:v fps=25 -ac 1 -ar 16000 -crf 10 {path_fps25}/{name}')# 使用ffmpeg命令将上一步生成的视频转换为.wav格式的音频文件os.system(f'ffmpeg -y -i {path_fps25}/{name} {path_fps25}/{name[:-4]}.wav')#============== extract lmk for crop =================
# 提取关键点信息用于裁剪
print('============== extract lmk for crop =================')
os.system(f'python extract_lmks_eat.py {path_fps25}')#======= extract speech in deepspeech_features =======
# 提取语音特征
print('======= extract speech in deepspeech_features =======')
os.chdir('./deepspeech_features/')
os.system(f'python extract_ds_features.py --input=../{path_fps25}')
os.chdir('../')
os.system('python deepfeature32.py')#=================== crop videos =====================
# 裁剪视频
print('=================== crop videos =====================')
os.chdir('./vid2vid/')
os.system('python data_preprocess.py --dataset_mode preprocess_eat')
os.chdir('../')#========== extract latent from cropped videos =======
#从裁剪的视频中提取潜在特征
print('========== extract latent from cropped videos =======')
os.system('python videos2img.py')
os.system('python latent_extractor.py')#=========== extract poseimg from latent =============
# 从潜在特征中提取姿势图像
print('=========== extract poseimg from latent =============')
os.system('python generate_poseimg.py')

之后,Extract the bbox for training:preprocess/extract_bbox.py

fa = face_detection.FaceAlignment(face_detection.LandmarksType._2D, flip_input=False, device='cuda')
#初始化了一个人脸对齐(Face Alignment)模型对象,CUDA加速。def detect_bbox(img_names):bboxs = []for img_name in img_names:img = img_as_float32(io.imread(img_name)).transpose((2, 0, 1))img = np.transpose(img[np.newaxis], (0,2,3,1))[...,::-1]bbox = fa.get_detections_for_batch(img*255)if bbox is not None:bboxs.append(bbox[0])else:bboxs.append(None)assert(len(bboxs)==len(img_names))return bboxs

这个函数用于检测一组图像中的人脸边界框。它接受一个图像文件名列表作为输入,并返回相应的人脸边界框列表。该函数首先加载图像文件,然后将其转换为指定的格式,并调用人脸对齐模型的get_detections_for_batch方法来获取人脸边界框。如果检测到了人脸边界框,则将其添加到bboxs列表中;否则,将None添加到列表中。

def main(args):file_images = glob('/data2/gy/lrw/lrw_images/*')file_images.sort()p = args.partt = len(file_images)for fi in tqdm(file_images[t*p:t*(p+1)]):out = basename(fi)outpath =f'/data2/gy/lrw/lrw_bbox/{out}.npy'if exists(outpath):continueimages = glob(fi+'/*.jpg')images.sort()bboxs = detect_bbox(images)np.save(outpath, bboxs)if __name__ == "__main__":parser = ArgumentParser()parser.add_argument("--files", default="*", help="filenames")parser.add_argument("--part", default="0", type=int, help="part")args = parser.parse_args()main(args)

这部分代码定义了一个名为main的函数,并在脚本最后调用该函数。main函数从指定路径中获取图像文件列表,然后遍历每个图像文件夹。对于每个图像文件夹,它首先构建输出路径和文件名,检查是否已经存在保存的结果(如果存在,则跳过),然后获取图像文件列表并排序。接下来,它调用detect_bbox函数来进行人脸边界框检测,并将结果保存到指定的输出路径中。

整体而言,这段代码利用人脸对齐模型进行人脸边界框检测,并将结果保存到.npy文件中。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/726855.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

开源爬虫技术在金融行业市场分析中的应用与实战解析

一、项目介绍 在当今信息技术飞速发展的时代&#xff0c;数据已成为企业最宝贵的资产之一。特别是在${industry}领域&#xff0c;海量数据的获取和分析对于企业洞察市场趋势、优化产品和服务至关重要。在这样的背景下&#xff0c;爬虫技术应运而生&#xff0c;它能够高效地从互…

企业级数字人形象自定义解决方案

在品牌传播、线上营销等领域&#xff0c;一个独特且符合企业形象的数字人形象&#xff0c;无疑能为企业带来更强的品牌识别度和市场竞争力。美摄科技&#xff0c;作为业界领先的数字人形象解决方案提供商&#xff0c;凭借多年的技术积累和深厚的行业经验&#xff0c;推出了一套…

UnityAPI的学习——Matrix4x4类

在脚本中通常用Vector3、Quaternion、Transform等类的属性和方法来对物体进行交换&#xff0c;Matrix4x4类通常用在一些比较特殊的地方&#xff0c;如对摄像机的非标准投影变换。 Matrix4x4类实例方法 在Matrix4x4类中&#xff0c;涉及的实例方法有MultiplyPoint方法、Multip…

单机Kubenetes集群——KinD安装

文章目录 前言一、Linux安装二、安装docker三、创建单节点集群四、kubectl安装总结 前言 KinD&#xff1a;单机测试K8s集群 源码&#xff1a;https://github.com/kubernetes-sigs/kind 官方文档&#xff1a;https://kind.sigs.k8s.io/docs/user/quick-start/ 一、Linux安装 (…

Linux第68步_旧字符设备驱动的一般模板

file_operations结构体中的函数就是我们要实现的具体操作函数。 注意&#xff1a; register_chrdev()和 unregister_chrdev()这两个函数是老版本驱动使用的。现在新字符设备驱动已经不再使用这两个函数&#xff0c;而是使用Linux内核推荐的新字符设备驱动API函数。 1、创建C…

公众号公司主体变更如何操作?

公众号迁移有什么用&#xff1f;只能改主体吗&#xff1f;好多朋友都想做公众号迁移&#xff0c;但是又不太清楚具体有啥用&#xff0c;今天我就来详细说说。首先&#xff0c;公众号迁移最重要的作用就是可以修改主体。比如你的公众号原来是 A 公司的&#xff0c;现在 A 公司不…

Linux 系统上卸载 Docker

停止 Docker 服务&#xff1a; sudo systemctl stop docker卸载 Docker 程序包&#xff1a; 不同的 Linux 发行版有不同的包管理工具&#xff0c;以下是一些常见的发行版的卸载命令&#xff1a; 对于使用 apt 的系统&#xff08;如 Ubuntu、Debian&#xff09;&#xff1a;sudo…

SpringCloud(20)之Skywalking Agent原理剖析

一、Agent原理剖析 使用Skywalking的时候&#xff0c;并没有修改程序中任何一行 Java 代码&#xff0c;这里便使用到了 Java Agent 技术&#xff0c;我 们接下来展开对Java Agent 技术的学习。 1.1 Java Agent Java Agent 是从 JDK1.5 开始引入的&#xff0c;算是一个比较老的…

Ruby CanCanCan 动态定义方法

灵感来自这里https://github.com/kristianmandrup/cantango/wiki/CanCan-vs-CanTango 如果权限不多,我们可以通过这种方式来定义 class CanCan::Abilitydef initialize user, options = {}if !usercan :read, :allendif useradmin_rules if user.roles.include? :adminedit…

STL中push_back和emplace_back效率的对比

文章目录 过程对比1.通过构造参数向vector中插入对象&#xff08;emplace_back更高效&#xff09;2.通过插入实例对象&#xff08;调用copy函数&#xff09;3.通过插入临时对象&#xff08;调用move函数&#xff09; 效率对比emplace_back 的缺点 我们以STL中的vector容器为例。…

解决 Pandas 导出文件出现 dtype: object 字样

文章目录 1. 问题2. 解决方法 1. 问题 python 用 pandas 输出 excel 文件时&#xff0c;发现有些列的单元格出现 “dtype: object” 的字样&#xff0c;如下图&#xff1a; 这是 pandas 没有处理好导致的 2. 解决方法 结果用 .values 进行输出&#xff0c;这样就转成字符串…

Vue的Diff详解

在 Vue 中&#xff0c;当我们更新数据时&#xff0c;Vue 会自动更新视图&#xff0c;这个过程就是虚拟 DOM 的 diff 算法。虚拟 DOM 是一种以 JavaScript 对象的形式表示 DOM 节点的方式&#xff0c;它可以更快地计算出需要更新的节点&#xff0c;从而提高渲染效率。 接下来&a…

聊天室项目

服务器 #include <myhead.h> #define SER_IP "192.168.122.39" #define SER_PORT 8888 typedef struct Node //链表存储客户端的所有信息 {struct sockaddr_in cin; //存储客户端的网络地址信息struct Node *next; }*List; typedef struct Message//消息结构…

洛谷 P1731 [NOI1999] 生日蛋糕

题目 题目链接 自己没看题解写的&#xff0c;摸石头过河&#xff0c;解释一下 首先&#xff0c;输入输出都是正整数。先搞定输入&#xff0c;再判断条件&#xff0c;如果无解&#xff0c;输出0&#xff0c;否则输出蛋糕外表面面积Q&#xff08;这里用全局变量&#xff0c;开l…

数据库的分类和特点介绍

#基础概念# #入门 数据库的主要分类 关系型数据库&#xff08;RDBMS&#xff09; 数据以表格形式存储&#xff0c;通过预定义的关系模型建立数据间的连接&#xff0c;使用SQL作为查询语言。常见的例子包括MySQL、Oracle、SQL Server、PostgreSQL、IBM DB2等。 非关系型数据库…

SEDEX验厂审核重点

SEDEX验厂简介 在全球化的今天&#xff0c;供应链的透明性和可持续性越来越受到人们的关注。为了确保供应链的合规性和可持续性&#xff0c;许多企业开始采用SEDEX验厂这一方法。SEDEX验厂是一种基于国际劳工组织&#xff08;ILO&#xff09;核心劳工标准的供应链审核体系&…

数据库:2024/3/6

作业1&#xff1a;使用C语言完成数据库的增删改 代码&#xff1a; #include <myhead.h>//定义添加员工信息函数 int Add_worker(sqlite3 *ppDb) {//准备sql语句printf("请输入要添加的员工信息:\n");//从终端获取员工信息char rbuf[128]"";fgets(r…

算法-买卖股票的最佳时机

1、题目来源 121. 买卖股票的最佳时机 - 力扣&#xff08;LeetCode&#xff09; 2、题目描述 给定一个数组 prices &#xff0c;它的第 i 个元素 prices[i] 表示一支给定股票第 i 天的价格。 你只能选择 某一天 买入这只股票&#xff0c;并选择在 未来的某一个不同的日子 卖…

React学习

&#x1f4d1;前言 本文主要是【React】——React基础的文章&#xff0c;如果有什么需要改进的地方还请大佬指出⛺️ &#x1f3ac;作者简介&#xff1a;大家好&#xff0c;我是听风与他&#x1f947; ☁️博客首页&#xff1a;CSDN主页听风与他 &#x1f304;每日一句&#x…