FastSpeech2中文语音合成就步解析:TTS数据训练实战篇

  1. 参考github网址:

GitHub - roedoejet/FastSpeech2: An implementation of Microsoft’s “FastSpeech 2: Fast and High-Quality End-to-End Text to Speech”

  1. 数据训练所用python 命令:

python3 train.py -p config/AISHELL3/preprocess.yaml -m config/AISHELL3/model.yaml -t config/AISHELL3/train.yaml

  1. 数据训练代码解析

3.1 代码架构overview:

通过 if __name__ == "__main__"运行整个py文件:

调用 “train.txt"和dataset.py加载数据,

调用utils文件夹下的model.py加载模型,声码器,

调用model文件夹下的loss.py中的FastSpeech2Loss class 设置损失函数,

用前面加载的模型和损失函数开始训练模型,导出结果并记录日志。

3.2 按训练步骤分解代码:

Step 0 : 定义可控训练参数, 调动main函数

if __name__ == "__main__":#Define Argsparser = argparse.ArgumentParser()parser.add_argument("--restore_step", type=int, default=0)parser.add_argument("-p","--preprocess_config",type=str,required=True,help="path to preprocess.yaml",)parser.add_argument("-m", "--model_config", type=str, required=True, help="path to model.yaml")parser.add_argument("-t", "--train_config", type=str, required=True, help="path to train.yaml")args = parser.parse_args() #args为可控训练参数# Read Configpreprocess_config = yaml.load(open(args.preprocess_config, "r"), Loader=yaml.FullLoader)model_config = yaml.load(open(args.model_config, "r"), Loader=yaml.FullLoader)train_config = yaml.load(open(args.train_config, "r"), Loader=yaml.FullLoader)configs = (preprocess_config, model_config, train_config)#Run _main_ functionmain(args, configs)

Step 1 : 启动main函数,加载可控训练参数

def main(args, configs): print("Prepare training ...")#加载可控训练参数preprocess_config, model_config, train_config = configs

Step 2 : 从train.txt加载数据,并经由dataset.py和torch里的Dataloader处理

def main(args, configs):# Get datasetdataset = Dataset("train.txt", preprocess_config, train_config, sort=True, drop_last=True) #从 train.txt 中获取datasetbatch_size = train_config["optimizer"]["batch_size"]group_size = 4  # Set this larger than 1 to enable sorting in Dataset,初始值为4assert batch_size * group_size < len(dataset)loader = DataLoader(dataset,batch_size=batch_size * group_size,shuffle=True,collate_fn=dataset.collate_fn,)

Step 3 : 定义模型,声码器,损失函数

def main(args, configs):# Prepare modelmodel, optimizer = get_model(args, configs, device, train=True) #设置优化器# 将模型并行训练并移入计算设备中model = nn.DataParallel(model) # Model Has Been Defined# 计算模型参数量num_param = get_param_num(model) # Number of TTS Parameters: num_paramprint("Number of FastSpeech2 Parameters:", num_param)# 设置损失函数Loss = FastSpeech2Loss(preprocess_config, model_config).to(device)# 加载声码器vocoder = get_vocoder(model_config, device)

Step 4 : 加载日志,在"./output/log/AISHELL3"目录建立train, val两个文件夹来记录日志

def main(args, configs):# Init loggerfor p in train_config["path"].values():os.makedirs(p, exist_ok=True)train_log_path = os.path.join(train_config["path"]["log_path"], "train")val_log_path = os.path.join(train_config["path"]["log_path"], "val")os.makedirs(train_log_path, exist_ok=True)os.makedirs(val_log_path, exist_ok=True)train_logger = SummaryWriter(train_log_path)val_logger = SummaryWriter(val_log_path)

Step 5 : 准备训练,加载可控训练参数

def main(args, configs):# Trainingstep = args.restore_step + 1epoch = 1grad_acc_step = train_config["optimizer"]["grad_acc_step"]grad_clip_thresh = train_config["optimizer"]["grad_clip_thresh"]total_step = train_config["step"]["total_step"]log_step = train_config["step"]["log_step"]save_step = train_config["step"]["save_step"]synth_step = train_config["step"]["synth_step"]val_step = train_config["step"]["val_step"]outer_bar = tqdm(total=total_step, desc="Training", position=0)outer_bar.n = args.restore_stepouter_bar.update()

Step 6 : 准备训练,加载进度条,调动utils文件夹下tools.py中的to_device function来提取数据

    while True:inner_bar = tqdm(total=len(loader), desc="Epoch {}".format(epoch), position=1)for batchs in loader:for batch in batchs:batch = to_device(batch, device)

Step 7 :开始训练,前向传播,计算损失,反向传播,梯度剪枝,更新模型权重参数

    #Load Datafor batch in batchs:batch = to_device(batch, device)# Forwardoutput = model(*(batch[2:]))# Cal Losslosses = Loss(batch, output)total_loss = losses[0]# Backwardtotal_loss = total_loss / grad_acc_steptotal_loss.backward()if step % grad_acc_step == 0:# Clipping gradients to avoid gradient explosionnn.utils.clip_grad_norm_(model.parameters(), grad_clip_thresh)# Update weightsoptimizer.step_and_update_lr()optimizer.zero_grad()

Step 8 : 当训练步数到达预先设定的log_step时,调动utils文件夹下tool.py里的log function,记录loss和step

                if step % log_step == 0:losses = [l.item() for l in losses]message1 = "Step {}/{}, ".format(step, total_step)message2 = "Total Loss: {:.4f}, Mel Loss: {:.4f}, Mel PostNet Loss: {:.4f}, Pitch Loss: {:.4f}, Energy Loss: {:.4f}, Duration Loss: {:.4f}".format(*losses)with open(os.path.join(train_log_path, "log.txt"), "a") as f:f.write(message1 + message2 + "\n")outer_bar.write(message1 + message2)log(train_logger, step, losses=losses)

Step 9 : 当训练步数到达预先设定的synth_step时,调动utils文件夹下tool.py里的log function 和 synth_one_sample function(具体用来干什么没看懂)

                if step % synth_step == 0:fig, wav_reconstruction, wav_prediction, tag = synth_one_sample(batch,output,vocoder,model_config,preprocess_config,)log(train_logger,fig=fig,tag="Training/step_{}_{}".format(step, tag),)sampling_rate = preprocess_config["preprocessing"]["audio"]["sampling_rate"]log(train_logger,audio=wav_reconstruction,sampling_rate=sampling_rate,tag="Training/step_{}_{}_reconstructed".format(step, tag),)log(train_logger,audio=wav_prediction,sampling_rate=sampling_rate,tag="Training/step_{}_{}_synthesized".format(step, tag),)

Step 10 : 当训练步数到达预先设定的val_step时,调动evaluate.py里的evaluate function来进行evaluation,并记录在log/AISHELL3/val/log.txt

                if step % val_step == 0:model.eval()message = evaluate(model, step, configs, val_logger, vocoder)with open(os.path.join(val_log_path, "log.txt"), "a") as f:f.write(message + "\n")outer_bar.write(message)model.train()

Step 11 : 当训练步数到达预先设定的save_step时,保存训练模型

                if step % save_step == 0:torch.save({"model": model.module.state_dict(),"optimizer": optimizer._optimizer.state_dict(),},os.path.join(train_config["path"]["ckpt_path"],"{}.pth.tar".format(step),),)

Step 12 : 当训练步数到达预先设定的total_step时,退出训练

                if step == total_step:quit()step += 1outer_bar.update(1)inner_bar.update(1)epoch += 1
  1. 数据训练代码的输出

在train_log_path和val_log_path输出日志

在ckpt_path输出训练过程中按照save_step存储的模型

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/41122.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

ida动态调试-cnblog

ida动态调试 传递启动ida服务 android_server在ida\dbgsrv目录中 adb push android_server /data/local/tmp/chmod 755 /data/local/tmp/android_server /data/local/tmp/android_serveradb forward tcp:23946 tcp:23946ida报错:大多是手机端口被占用 报错提示&#xff1a; …

java面试-java基础(下)

文章目录 一、和equals区别&#xff1f;二、hashcode方法作用&#xff1f;两个对象的hashCode方法相同&#xff0c;则equals方法也一定为true吗&#xff1f;三、为什么重写equals方法就一定要重写hashCode方法&#xff1f;四、Java中的参数传递时传值呢还是传引用&#xff1f;五…

期末上分站——计组(3)

复习题21-42 21、指令周期是指__C_。 A. CPU从主存取出一条指令的时间 B. CPU执行一条指令的时间 C. CPU从主存取出一条指令的时间加上执行这条指令的时间。 D. 时钟周期时间 22、微型机系统中外设通过适配器与主板的系统总线相连接&#xff0c;其功能是__D_。 A. 数据缓冲和…

数据库可视化管理工具dbeaver试用及问题处理。

本文记录了在内网离线安装数据库可视化管理工具dbeaver的过程和相关问题处理方法。 一、下载dbeaver https://dbeaver.io/download/ 笔者测试时Windows平台最新版本为&#xff1a;dbeaver-ce-24.1.1-x86_64-setup.exe 二、安装方法 一路“下一步”即可 三、问题处理 1、问…

【深度学习】vscode 命令行下的debug

其实我一直知道vscode可以再命令行下进行debug。 比如 python aaa.py --bb1 --cc2 以前的做法是 去aaa.py 写死bb和cc 然后直接debug。 直到今天我遇到这个&#xff1a; hydra hydra.main(version_baseNone, config_name/home/justin/Desktop/code/python_project/WASB-SBDT-m…

Truffle学习笔记

Truffle学习笔记 安装truffle, 注意: 虽然目前truffle最新版是 5.0.0, 但是经过我实践之后, 返现和v4有很多不同(比如: web3.eth.accounts; 都获取不到账户), 还是那句话: “nodejs模块的版本问题会搞死人的 !” 目前4.1.15之前的版本都不能用了, 只能安装v4.1.15 npm instal…

新手学Cocos报错 [Assets] Failed to open

两个都在偏好设置里面调&#xff08;文件下面的偏好设置&#xff09;&#xff1a; 1.设置中文&#xff1f; 2.报错 [Assets] Failed to open&#xff1f; 这样在点击打开ts文件的时候就不会报错&#xff0c;并且用vscode编辑器打开了&#xff0c; 同样也可以改成你们自己喜欢…

LabVIEW在图像处理中的应用

abVIEW作为一种图形化编程环境&#xff0c;不仅在数据采集和仪器控制领域表现出色&#xff0c;还在图像处理方面具有强大的功能。借助其Vision Development Module&#xff0c;LabVIEW提供了丰富的图像处理工具&#xff0c;广泛应用于工业检测、医学影像、自动化控制等多个领域…

Apache Seata应用侧启动过程剖析——RM TM如何与TC建立连接

本文来自 Apache Seata官方文档&#xff0c;欢迎访问官网&#xff0c;查看更多深度文章。 本文来自 Apache Seata官方文档&#xff0c;欢迎访问官网&#xff0c;查看更多深度文章。 Apache Seata应用侧启动过程剖析——RM & TM如何与TC建立连接 前言 看过官网 README 的第…

Android最近任务显示的图片

Android最近任务显示的图片 1、TaskSnapshot截图1.1 snapshotTask1.2 drawAppThemeSnapshot 2、导航栏显示问题3、Recentan按键进入最近任务 1、TaskSnapshot截图 frameworks/base/services/core/java/com/android/server/wm/TaskSnapshotController.java frameworks/base/cor…

IPython 性能评估工具的较量:%%timeit 与 %timeit 的差异解析

IPython 性能评估工具的较量&#xff1a;%%timeit 与 %timeit 的差异解析 在 IPython 的世界中&#xff0c;性能评估是一项至关重要的任务。%%timeit 和 %timeit 是两个用于测量代码执行时间的魔术命令&#xff0c;但它们之间存在一些关键的差异。本文将深入探讨这两个命令的不…

2786. 访问数组中的位置使分数最大

2786. 访问数组中的位置使分数最大 题目链接&#xff1a;2786. 访问数组中的位置使分数最大 代码如下&#xff1a; //参考链接:https://leetcode.cn/problems/visit-array-positions-to-maximize-score/solutions/2810335/dp-by-kkkk-16-tn9f class Solution { public:long …

vue-router 4汇总

一、vue和vue-router版本&#xff1a; "vue": "^3.4.29", "vue-router": "^4.4.0" 二、路由传参&#xff1a; 方式一&#xff1a; 路由配置&#xff1a;/src/router/index.ts import {createRouter,createWebHistory } from &quo…

探索 WebKit 的缓存迷宫:深入理解其高效缓存机制

探索 WebKit 的缓存迷宫&#xff1a;深入理解其高效缓存机制 在当今快速变化的网络世界中&#xff0c;WebKit 作为领先的浏览器引擎之一&#xff0c;其缓存机制对于提升网页加载速度、减少服务器负载以及改善用户体验起着至关重要的作用。本文将深入探讨 WebKit 的缓存机制&am…

代码随想录leetcode200题之额外题目

目录 1 介绍2 训练3 参考 1 介绍 本博客用来记录代码随想录leetcode200题之额外题目相关题目。 2 训练 题目1&#xff1a;1365. 有多少小于当前数字的数字 解题思路&#xff1a;二分查找。 C代码如下&#xff0c; class Solution { public:vector<int> smallerNumb…

卷积神经网络(CNN)和循环神经网络(RNN) 的区别与联系

卷积神经网络&#xff08;CNN&#xff09;和循环神经网络&#xff08;RNN&#xff09;是两种广泛应用于深度学习的神经网络架构&#xff0c;它们在设计理念和应用领域上有显著区别&#xff0c;但也存在一些联系。 ### 卷积神经网络&#xff08;CNN&#xff09; #### 主要特点…

解决C++编译时的产生的skipping incompatible xxx 错误

问题 我在编译项目时&#xff0c;产生了一个 /usr/bin/ld: skipping incompatible ../../xxx/ when searching for -lxxx 的编译错误&#xff0c;如下图所示&#xff1a; 解决方法 由图中的错误可知&#xff0c;在编译时&#xff0c;是能够在我们指定目录下的 *.so 动态库的…

python函数和c的区别有哪些

Python有很多内置函数&#xff08;build in function&#xff09;&#xff0c;不需要写头文件&#xff0c;Python还有很多强大的模块&#xff0c;需要时导入便可。C语言在这一点上远不及Python&#xff0c;大多时候都需要自己手动实现。 C语言中的函数&#xff0c;有着严格的顺…

Java基础(六)——继承

个人简介 &#x1f440;个人主页&#xff1a; 前端杂货铺 ⚡开源项目&#xff1a; rich-vue3 &#xff08;基于 Vue3 TS Pinia Element Plus Spring全家桶 MySQL&#xff09; &#x1f64b;‍♂️学习方向&#xff1a; 主攻前端方向&#xff0c;正逐渐往全干发展 &#x1…

【Web】

1、配仓库 [rootlocalhost yum.repos.d]# vi rpm.repo ##本地仓库标准写法 [baseos] namemiaoshubaseos baseurl/mnt/BaseOS gpgcheck0 [appstream] namemiaoshuappstream baseurlfile:///mnt/AppStream gpgcheck0 2、挂载 [rootlocalhost ~]mount /dev/sr0 /mnt mount: /m…