一文读懂最强中文NLP预训练模型ERNIE

基于飞桨开源的持续学习的语义理解框架ERNIE 2.0,及基于此框架的ERNIE 2.0预训练模型,在共计16个中英文任务上超越了BERT和XLNet, 取得了SOTA效果。本文带你进一步深入了解ERNIE的技术细节。

一:ERNIE 简介

1.1 简介

Google 最近提出的 BERT 模型,通过随机屏蔽15%的字或者word,利用 Transformer 的多层 self-attention 双向建模能力,在各项nlp 下游任务中(如 sentence pair classification task,singe sentence classification task,question answering task) 都取得了很好的成绩。但是,BERT 模型主要是聚焦在针对字或者英文word粒度的完形填空学习上面,没有充分利用训练数据当中词法结构,语法结构,以及语义信息去学习建模。比如 “我要买苹果手机”,BERT 模型 将 “我”,“要”, “买”,“苹”, “果”,“手”, “机” 每个字都统一对待,随机mask,丢失了“苹果手机” 是一个很火的名词这一信息,这个是词法信息的缺失。同时 我 + 买 + 名词 是一个非常明显的购物意图的句式,BERT 没有对此类语法结构进行专门的建模,如果预训练的语料中只有“我要买苹果手机”,“我要买华为手机”,哪一天出现了一个新的手机牌子比如栗子手机,而这个手机牌子在预训练的语料当中并不存在,没有基于词法结构以及句法结构的建模,对于这种新出来的词是很难给出一个很好的向量表示的,而ERNIE 通过对训练数据中的词法结构,语法结构,语义信息进行统一建模,极大地增强了通用语义表示能力,在多项任务中均取得了大幅度超越BERT的效果!!

 

1.2 下载地址(这么好用的模型赶紧下载起来吧!)

ERNIE 的Fine-tuning代码和英文预训练模型已通过飞桨开源

 

Github 地址:

https://github.com/PaddlePaddle/ERNIE

 

二:ERNIE 详解

2.1 ERNIE 结构

2.1.1 ERNIE 初探

640?wx_fmt=png

 

 

2.1.1 ERNIE 结构详解

 

640?wx_fmt=png

Figure 2:ERNIE 的encoder 结构详解 

 

相比transformer,ERNIE 基本上是 transformer 的encoder 部分,并且encoder 在结构上是全部一样的,但是并不共享权重,具体区别如下:

  • Transformer: 6 encoder layers, 512 hidden units, 8 attention heads

  • ERNIE Base: 12 encoder layers, 768 hidden units, 12 attention heads

  • ERNIE Large: 24 encoder layers,1024 hidden units, 16 attention heads

 

从输入上来看第一个输入是一个特殊的CLS, CLS 表示分类任务就像 transformer 的一般的encoder, ERINE 将一序列的words 输入到encoder 中。 每层使用self-attention, feed-word network, 然后把结果传入到下一个encoder。

 

2.1.2 ERNIE encoder 说明

encoder

encoder 由两层构成, 首先流入self-attention layer,self-attention layer 输出流入 feed-forward 神经网络。至于self-attention的结构,我们在这里不再展开,有兴趣的同学可以进入以下链接仔细阅读http://jalammar.github.io/illustrated-transformer/,来进一步了解self-attention的结构!!

 

640?wx_fmt=png

Figure 3: encoder 结构详解

 

embedding

最下层的encoder的输入是embedding的向量, 其他的encoder的输入,便是更下层的encoder的输出, 一般设置输入的vectors 的维度为512,同学们也可以自己设置。

 

640?wx_fmt=png

Figure 4: encoder 结构详解

 

2.2 : ERNIE 1.0 介绍

相比于BERT, ERNIE 1.0 改进了两种 masking 策略,一种是基于phrase (在这里是短语 比如 a series of, written等)的masking策略,另外一种是基于 entity(在这里是人名、位置、组织、产品等名词,比如Apple, J.K. Rowling)的masking 策略。在ERNIE 当中,将由多个字组成的phrase 或者entity 当成一个统一单元,相比于bert 基于字的mask,这个单元当中的所有字在训练的时候,统一被mask。对比直接将知识类的query 映射成向量然后直接加起来,ERNIE 通过统一mask的方式可以潜在地学习到知识的依赖以及更长的语义依赖来让模型更具泛化性。

 

640?wx_fmt=png

Figure 5: ERNIE 1.0 不同的mask 策略说明

 

 

2.3: ERNIE 2.0 介绍

传统的pre-training 模型主要基于文本中words  和 sentences 之间的共现进行学习。 事实上,训练文本数据中的词法结构、语法结构、语义信息也同样是很重要的。在命名实体识别中人名、机构名、组织名等名词包含概念信息对应了词法结构,句子之间的顺序对应了语法结构,文章中的语义相关性对应了语义信息。为了去发现训练数据中这些有价值的信息,在ERNIE 2.0 中,提出了一个预训练框架,可以在大型数据集合中进行增量训练。

 

640?wx_fmt=png

Figure 6: ERNIE 2.0 框架

 

2.3.1 ERNIE 2.0 结构

ERNIE 2.0 中有一个很重要的概念便是连续学习(Continual Learning),连续学习的目的是在一个模型中顺序训练多个不同的任务,以便在学习下个任务当中可以记住前一个学习任务学习到的结果。通过使用连续学习,可以不断积累新的知识,模型在新任务当中可以用历史任务学习到参数进行初始化,一般来说比直接开始新任务的学习会获得更好的效果。

 

a: 预训练连续学习

ERNIE 的预训练连续学习分为两步,首先,连续用大量的数据与先验知识连续构建不同的预训练任务。其次,不断的用预训练任务更新ERNIE 模型。

 

对于第一步,ERNIE 2.0 分别构建了词法级别,语法级别,语义级别的预训练任务。所有的这些任务,都是基于无标注或者弱标注的数据。需要注意的是,在连续训练之前,首先用一个简单的任务来初始化模型,在后面更新模型的时候,用前一个任务训练好的参数来作为下一个任务模型初始化的参数。这样不管什么时候,一个新的任务加进来的时候,都用上一个模型的参数初始化保证了模型不会忘记之前学习到的知识。通过这种方式,在连续学习的过程中,ERNIE 2.0 框架可以不断更新并记住以前学习到的知识可以使得模型在新任务上获得更好的表现。我们在下面的e, f, g 中会具体介绍ERNIE 2.0 构建哪些预训练任务,并且这些预训练任务起了什么作用。

 

在图7中,介绍了ERNIE2.0连续学习的架构。这个架构包含了一系列共享文本encoding layers 来 encode 上下文信息。这些encoder layers 的参数可以被所有的预训练任务更新。有两种类型的 loss function,一种是sequence level 的loss, 一种是word level的loss。在ERNIE 2.0 预训练中,一个或多个sentence level的loss function可以和多个token level的loss functions 结合来共同更新模型。

 

640?wx_fmt=png

Figure 7: ERINE 2.0 连续学习流程

 

b: encoder

ERNIE 2.0 用了我们前文提到的transformer 结构encoder,结构基本一致,但是权重并不共享。

 

c: task embedding.

ERNIE 2.0 用了不同的task id 来标示预训练任务,task id 从1 到N 对应下面的e, f ,g中提到的预训练任务。对应的token segment position 以及task embedding 被用来作为模型的输入。

 

640?wx_fmt=png

Figure 8: ERNIE 2.0 连续学习详解

 

 

e: 构建词法级别的预训练任务,来获取训练数据中的词法信息

1: knowledge masking task,即 ERNIE 1.0 中的entity mask 以及 phrase entity mask 来获取phrase 以及entity的先验知识,相较于 sub-word masking, 该策略可以更好的捕捉输入样本局部和全局的语义信息。

 

2: Capitalization Prediction Task,大写的词比如Apple相比于其他词通常在句子当中有特定的含义,所以在ERNIE 2.0 加入一个任务来判断一个词是否大写。

 

3: Token-Document Relation Prediction Task,类似于tf-idf,预测一个词在文中的A 段落出现,是否会在文中的B 段落出现。如果一个词在文章当中的许多部分出现一般就说明这个词经常被用到或者和这个文章的主题相关。通过识别这个文中关键的的词, 这个任务可以增强模型去获取文章的关键词语的能力。

 

f: 构建语法级别的预训练任务,来获取训练数据中的语法信息

1:  Sentence Reordering Task,在训练当中,将paragraph 随机分成1 到m 段,将所有的组合随机shuffle。我们让pre-trained 的模型来识别所有的这些segments正确的顺序。这便是一个k 分类任务

 

640?wx_fmt=png

 

通常来说,这些sentence 重排序任务能够让pre-trained 模型学习到document 中不同sentence 的关系。

 

2: Sentence Distance Task, 构建一个三分类任务来判别句子的距离,0表示两个句子是同一个文章中相邻的句子,1表示两个句子是在同一个文章,但是不相邻,2表示两个句子是不同的文章。通过构建这样一个三分类任务去判断句对 (sentence pairs) 位置关系 (包含邻近句子、文档内非邻近句子、非同文档内句子 3 种类别),更好的建模语义相关性。

 

g:构建语义级别的预训练任务,来获取训练数据中的语义任务

1: Discourse Relation Task,除了上面的distance task,ERNIE通过判断句对 (sentence pairs) 间的修辞关系 (semantic & rhetorical relation),更好的学习句间语义。

 

2: IR Relevance Task,在这里主要是利用baidu 的日志来获取这个关系,将query 作为第一个sentence,title 作为第二个 sentence。0 表示强关系, 1 表示弱关系,2表示无关系,通过类似google-distance 的关系来衡量 两个query之间的语义相关性,更好的建模句对相关性。

 

三: 代码梳理

3.1 : 预训练脚本

 

  1. set -eux
  2. export FLAGS_eager_delete_tensor_gb=0
  3. export FLAGS_sync_nccl_allreduce=1
  4. export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
  5. python ./pretrain_launch.py  \   
  6. --nproc_per_node 8 \    
  7. --selected_gpus 0,1,2,3,4,5,6,7 \    
  8. --node_ips $(hostname -i) \    
  9. --node_id 0 \.
  10. /train.py  --use_cuda True \                
  11. --is_distributed False\                
  12. --use_fast_executor True \                
  13. --weight_sharing True \                
  14. --in_tokens true \                
  15. --batch_size 8192 \                
  16. --vocab_path ./config/vocab.txt \               
  17. --train_filelist ./data/train_filelist \               
  18. --valid_filelist ./data/valid_filelist \                
  19. --validation_steps 100 \               
  20. --num_train_steps 1000000 \                
  21. --checkpoints ./checkpoints \               
  22. --save_steps 10000 \                
  23. --ernie_config_path ./config/ernie_config.json \               
  24. --learning_rate 1e-4 \              
  25. --use_fp16 false \                
  26. --weight_decay 0.01 \               
  27. --max_seq_len 512 \                
  28. --skip_steps 10

脚本初始化代码 pretrain_launch.py

  1. from __future__ import absolute_import
  2. from __future__ import division
  3. from __future__ import print_function
  4. from __future__ import unicode_literals
  5. from __future__ import absolute_import
  6. from __future__ import division
  7. import sys
  8. import subprocess
  9. import os
  10. import six
  11. import copy
  12. import argparse
  13. import time
  14. import logging
  15. from utils.args import ArgumentGroup, print_arguments,     prepare_logger
  16. from pretrain_args import parser as worker_parser
  17. # yapf: disable
  18. parser = argparse.ArgumentParser(__doc__)
  19. multip_g = ArgumentGroup(parser, "multiprocessing",
  20. "start paddle training using multi-processing mode.")
  21. multip_g.add_arg("node_ips", str, None,
  22. "paddle trainer ips")
  23. multip_g.add_arg("node_id", int, 0,
  24. "the trainer id of the node for multi-node distributed training.")
  25. multip_g.add_arg("print_config", bool, True,
  26. "print the config of multi-processing mode.")
  27. multip_g.add_arg("current_node_ip", str, None,
  28. "the ip of current node.")
  29. multip_g.add_arg("split_log_path", str, "./log",
  30. "log path for each trainer.")
  31. multip_g.add_arg("log_prefix", str, "",
  32. "the prefix name of job log.")
  33. multip_g.add_arg("nproc_per_node", int, 8,
  34. "the number of process to use on each node.")
  35. multip_g.add_arg("selected_gpus", str, "0,1,2,3,4,5,6,7",
  36. "the gpus selected to use.")
  37. multip_g.add_arg("training_script", str, None, "the program/script to be lauched "
  38. "in parallel followed by all the arguments",     positional_arg=True)
  39. multip_g.add_arg("training_script_args", str, None,
  40. "training script args", positional_arg=True, nargs=argparse.REMAINDER)
  41. # yapf: enable
  42. log = logging.getLogger()
  43. def start_procs(args):
  44. procs = []
  45. log_fns = []
  46. default_env = os.environ.copy()
  47. node_id = args.node_id
  48. node_ips = [x.strip() for x in args.node_ips.split(',')]
  49. current_ip = args.current_node_ip
  50. if args.current_node_ip is None:
  51. assert len(node_ips) == 1
  52. current_ip = node_ips[0]
  53. log.info(current_ip)
  54. num_nodes = len(node_ips)
  55. selected_gpus = [x.strip() for x in args.selected_gpus.split(',')]
  56. selected_gpu_num = len(selected_gpus)
  57. all_trainer_endpoints = ""
  58. for ip in node_ips:
  59. for i in range(args.nproc_per_node):
  60. if all_trainer_endpoints != "":
  61. all_trainer_endpoints += ","
  62. all_trainer_endpoints += "%s:617%d" % (ip, i)
  63. nranks = num_nodes * args.nproc_per_node
  64. gpus_per_proc = args.nproc_per_node % selected_gpu_num
  65. if gpus_per_proc == 0:
  66. gpus_per_proc =  selected_gpu_num // args.nproc_per_node
  67. else:
  68. gpus_per_proc =  selected_gpu_num // args.nproc_per_node + 1
  69. log.info(gpus_per_proc)
  70. selected_gpus_per_proc = [selected_gpus[i:i + gpus_per_proc] for i in range(0, len(selected_gpus), gpus_per_proc)]
  71. if args.print_config:
  72. log.info("all_trainer_endpoints: %s"
  73. ", node_id: %s"
  74. ", current_ip: %s"
  75. ", num_nodes: %s"
  76. ", node_ips: %s"
  77. ", gpus_per_proc: %s"
  78. ", selected_gpus_per_proc: %s"
  79. ", nranks: %s" % (
  80. all_trainer_endpoints,
  81. node_id,
  82. current_ip,
  83. num_nodes,
  84. node_ips,
  85. gpus_per_proc,
  86. selected_gpus_per_proc,
  87. nranks))
  88. current_env = copy.copy(default_env)
  89. procs = []
  90. cmds = []
  91. log_fns = []
  92. for i in range(0, args.nproc_per_node):
  93. trainer_id = node_id * args.nproc_per_node + i
  94. current_env.update({
  95. "FLAGS_selected_gpus""%s" % ",".join([str(s) for s in selected_gpus_per_proc[i]]),
  96. "PADDLE_TRAINER_ID" : "%d" % trainer_id,
  97. "PADDLE_CURRENT_ENDPOINT""%s:617%d" % (current_ip, i),
  98. "PADDLE_TRAINERS_NUM""%d" % nranks,
  99. "PADDLE_TRAINER_ENDPOINTS": all_trainer_endpoints,
  100. "PADDLE_NODES_NUM""%d" % num_nodes
  101. })
  102. try:
  103. idx = args.training_script_args.index('--is_distributed')
  104. args.training_script_args[idx + 1] = 'true'
  105. except ValueError:
  106. args.training_script_args += ['--is_distributed''true']
  107. cmd = [sys.executable, "-u",
  108. args.training_script] + args.training_script_args
  109. cmds.append(cmd)
  110. if args.split_log_path:
  111. fn = open("%s/%sjob.log.%d" % (args.split_log_path, args.log_prefix, trainer_id), "a")
  112. log_fns.append(fn)
  113. process = subprocess.Popen(cmd, env=current_env, stdout=fn, stderr=fn)
  114. else:
  115. process = subprocess.Popen(cmd, env=current_env)
  116. log.info('subprocess launched')
  117. procs.append(process)
  118. try:
  119. for i in range(len(procs)):
  120. proc = procs[i]
  121. proc.wait()
  122. if len(log_fns) > 0:
  123. log_fns[i].close()
  124. if proc.returncode != 0:    
  125. raise subprocess.CalledProcessError(returncode=procs[i].returncode,
  126. cmd=cmds[i])
  127. else:
  128. log.info("proc %d finsh" % i)
  129. except KeyboardInterrupt as e:
  130. for p in procs:
  131. log.info('killing %s' % p)
  132. p.terminate()
  133. def main(args):
  134. if args.print_config:
  135. print_arguments(args)
  136. start_procs(args)
  137. if __name__ == "__main__":
  138. prepare_logger(log)
  139. lanch_args = parser.parse_args()
  140. pretraining_args = worker_parser.parse_args(
  141. lanch_args.training_script_args)
  142. init_path = pretraining_args.init_checkpoint
  143. if init_path and not pretraining_args.use_fp16:
  144. os.system('rename .master "" ' + init_path + '/*.master')
  145. main(lanch_args)

训练代码 train.py

  1. #   Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. #     http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """ERNIE pretraining."""
  15. from __future__ import absolute_import
  16. from __future__ import division
  17. from __future__ import print_function
  18. from __future__ import unicode_literals
  19. from __future__ import absolute_import
  20. import os
  21. import time
  22. import multiprocessing
  23. import logging
  24. import numpy as np
  25. import paddle.fluid as fluid
  26. from reader.pretraining import ErnieDataReader
  27. from model.ernie_v1 import ErnieModel, ErnieConfig
  28. from optimization import optimization
  29. from utils.args import print_arguments, check_cuda, prepare_logger
  30. from utils.init import init_checkpoint, init_pretraining_params
  31. from pretrain_args import parser
  32. log = logging.getLogger()
  33. args = parser.parse_args()
  34. # yapf: enable.
  35. def create_model(pyreader_name, ernie_config):
  36. pyreader = fluid.layers.py_reader(
  37. capacity=70,
  38. shapes=[[-1, args.max_seq_len, 1], [-1, args.max_seq_len, 1],
  39. [-1, args.max_seq_len, 1], [-1, args.max_seq_len, 1], [-11],
  40. [-11], [-11]],
  41. dtypes=[
  42. 'int64''int64''int64''float32''int64''int64''int64'
  43. ],
  44. lod_levels=[0000000],
  45. name=pyreader_name,
  46. use_double_buffer=True)
  47. (src_ids, pos_ids, sent_ids, input_mask, mask_label, mask_pos,
  48. labels) = fluid.layers.read_file(pyreader)
  49. ernie = ErnieModel(
  50. src_ids=src_ids,
  51. position_ids=pos_ids,
  52. sentence_ids=sent_ids,
  53. input_mask=input_mask,
  54. config=ernie_config,
  55. weight_sharing=args.weight_sharing,
  56. use_fp16=args.use_fp16)
  57. next_sent_acc, mask_lm_loss, total_loss = ernie.get_pretraining_output(
  58. mask_label, mask_pos, labels)
  59. return pyreader, next_sent_acc, mask_lm_loss, total_loss
  60. def predict_wrapper(args,
  61. exe,
  62. ernie_config,
  63. test_prog=None,
  64. pyreader=None,
  65. fetch_list=None):
  66. # Context to do validation.
  67. filelist = args.test_filelist if args.do_test else args.valid_filelist
  68. data_reader = ErnieDataReader(
  69. filelist,
  70. vocab_path=args.vocab_path,
  71. batch_size=args.batch_size,
  72. voc_size=ernie_config['vocab_size'],
  73. shuffle_files=False,
  74. epoch=1,
  75. max_seq_len=args.max_seq_len,
  76. is_test=True)
  77. if args.do_test:
  78. assert args.init_checkpoint is not None, "[FATAL] Please use --init_checkpoint '/path/to/checkpoints' \
  79.                                                to specify you pretrained model checkpoints"
  80. init_pretraining_params(exe, args.init_checkpoint, test_prog)
  81. def predict(exe=exe, pyreader=pyreader):
  82. pyreader.decorate_tensor_provider(data_reader.data_generator())
  83. pyreader.start()
  84. cost = 0
  85. lm_cost = 0
  86. acc = 0
  87. steps = 0
  88. time_begin = time.time()
  89. while True:
  90. try:
  91. each_next_acc, each_mask_lm_cost, each_total_cost = exe.run(
  92. fetch_list=fetch_list, program=test_prog)
  93. acc += each_next_acc
  94. lm_cost += each_mask_lm_cost
  95. cost += each_total_cost
  96. steps += 1
  97. if args.do_test and steps % args.skip_steps == 0:
  98. log.info("[test_set] steps: %d" % steps)
  99. except fluid.core.EOFException:
  100. pyreader.reset()
  101. break
  102. used_time = time.time() - time_begin
  103. return cost, lm_cost, acc, steps, (args.skip_steps / used_time)
  104. return predict
  105. def test(args):
  106. ernie_config = ErnieConfig(args.ernie_config_path)
  107. ernie_config.print_config()
  108. test_prog = fluid.Program()
  109. test_startup = fluid.Program()
  110. with fluid.program_guard(test_prog, test_startup):
  111. with fluid.unique_name.guard():
  112. test_pyreader, next_sent_acc, mask_lm_loss, total_loss = create_model(
  113. pyreader_name='test_reader', ernie_config=ernie_config)
  114. test_prog = test_prog.clone(for_test=True)
  115. place = fluid.CUDAPlace(0if args.use_cuda == True else fluid.CPUPlace()
  116. exe = fluid.Executor(place)
  117. exe.run(test_startup)
  118. predict = predict_wrapper(
  119. args,
  120. exe,
  121. ernie_config,
  122. test_prog=test_prog,
  123. pyreader=test_pyreader,
  124. fetch_list=[next_sent_acc.name, mask_lm_loss.name, total_loss.name])
  125. log.info("test begin")
  126. loss, lm_loss, acc, steps, speed = predict()
  127. log.info(
  128. "[test_set] loss: %f, global ppl: %f, next_sent_acc: %f, speed: %f steps/s"
  129. % (np.mean(np.array(loss) / steps),
  130. np.exp(np.mean(np.array(lm_loss) / steps)),
  131. np.mean(np.array(acc) / steps), speed))
  132. def train(args):
  133. log.info("pretraining start")
  134. ernie_config = ErnieConfig(args.ernie_config_path)
  135. ernie_config.print_config()
  136. train_program = fluid.Program()
  137. startup_prog = fluid.Program()
  138. with fluid.program_guard(train_program, startup_prog):
  139. with fluid.unique_name.guard():
  140. train_pyreader, next_sent_acc, mask_lm_loss, total_loss = create_model(
  141. pyreader_name='train_reader', ernie_config=ernie_config)
  142. scheduled_lr, _ = optimization(
  143. loss=total_loss,
  144. warmup_steps=args.warmup_steps,
  145. num_train_steps=args.num_train_steps,
  146. learning_rate=args.learning_rate,
  147. train_program=train_program,
  148. startup_prog=startup_prog,
  149. weight_decay=args.weight_decay,
  150. scheduler=args.lr_scheduler,
  151. use_fp16=args.use_fp16,
  152. use_dynamic_loss_scaling=args.use_dynamic_loss_scaling,
  153. init_loss_scaling=args.init_loss_scaling,
  154. incr_every_n_steps=args.incr_every_n_steps,
  155. decr_every_n_nan_or_inf=args.decr_every_n_nan_or_inf,
  156. incr_ratio=args.incr_ratio,
  157. decr_ratio=args.decr_ratio)
  158. test_prog = fluid.Program()
  159. with fluid.program_guard(test_prog, startup_prog):
  160. with fluid.unique_name.guard():
  161. test_pyreader, next_sent_acc, mask_lm_loss, total_loss = create_model(
  162. pyreader_name='test_reader', ernie_config=ernie_config)
  163. test_prog = test_prog.clone(for_test=True)
  164. if len(fluid.cuda_places()) == 0:
  165. raise RuntimeError('not cuda device cound, check ur env setting')
  166. if args.use_cuda:
  167. place = fluid.cuda_places()[0]
  168. dev_count = fluid.core.get_cuda_device_count()
  169. else:
  170. place = fluid.CPUPlace()
  171. dev_count = int(os.environ.get('CPU_NUM', multiprocessing.cpu_count()))
  172. log.info("Device count %d" % dev_count)
  173. log.info("theoretical memory usage: ")
  174. log.info(fluid.contrib.memory_usage(
  175. program=train_program, batch_size=args.batch_size // args.max_seq_len))
  176. nccl2_num_trainers = 1
  177. nccl2_trainer_id = 0
  178. log.info("args.is_distributed: %s" % args.is_distributed)
  179. if args.is_distributed:
  180. worker_endpoints_env = os.getenv("PADDLE_TRAINER_ENDPOINTS")
  181. worker_endpoints = worker_endpoints_env.split(",")
  182. trainers_num = len(worker_endpoints)
  183. current_endpoint = os.getenv("PADDLE_CURRENT_ENDPOINT")
  184. trainer_id = worker_endpoints.index(current_endpoint)
  185. if trainer_id == 0:
  186. log.info("train_id == 0, sleep 60s")
  187. time.sleep(60)
  188. log.info("worker_endpoints:{} trainers_num:{} current_endpoint:{} \
  189.            trainer_id:{}".format(worker_endpoints, trainers_num,
  190. current_endpoint, trainer_id))
  191. # prepare nccl2 env.
  192. config = fluid.DistributeTranspilerConfig()
  193. config.mode = "nccl2"
  194. t = fluid.DistributeTranspiler(config=config)
  195. t.transpile(
  196. trainer_id,
  197. trainers=worker_endpoints_env,
  198. current_endpoint=current_endpoint,
  199. program=train_program,
  200. startup_program=startup_prog)
  201. nccl2_num_trainers = trainers_num
  202. nccl2_trainer_id = trainer_id
  203. exe = fluid.Executor(place)
  204. exe.run(startup_prog)
  205. if args.init_checkpoint and args.init_checkpoint != "":
  206. init_checkpoint(exe, args.init_checkpoint, train_program, args.use_fp16)
  207. data_reader = ErnieDataReader(
  208. filelist=args.train_filelist,
  209. batch_size=args.batch_size,
  210. vocab_path=args.vocab_path,
  211. voc_size=ernie_config['vocab_size'],
  212. epoch=args.epoch,
  213. max_seq_len=args.max_seq_len,
  214. generate_neg_sample=args.generate_neg_sample)
  215. exec_strategy = fluid.ExecutionStrategy()
  216. if args.use_fast_executor:
  217. exec_strategy.use_experimental_executor = True
  218. exec_strategy.num_threads = dev_count
  219. exec_strategy.num_iteration_per_drop_scope = min(10, args.skip_steps)
  220. build_strategy = fluid.BuildStrategy()
  221. build_strategy.remove_unnecessary_lock = False
  222. train_exe = fluid.ParallelExecutor(
  223. use_cuda=args.use_cuda,
  224. loss_name=total_loss.name,
  225. build_strategy=build_strategy,
  226. exec_strategy=exec_strategy,
  227. main_program=train_program,
  228. num_trainers=nccl2_num_trainers,
  229. trainer_id=nccl2_trainer_id)
  230. if args.valid_filelist and args.valid_filelist != "":
  231. predict = predict_wrapper(
  232. args,
  233. exe,
  234. ernie_config,
  235. test_prog=test_prog,
  236. pyreader=test_pyreader,
  237. fetch_list=[
  238. next_sent_acc.name, mask_lm_loss.name, total_loss.name
  239. ])
  240. train_pyreader.decorate_tensor_provider(data_reader.data_generator())
  241. train_pyreader.start()
  242. steps = 0
  243. cost = []
  244. lm_cost = []
  245. acc = []
  246. time_begin = time.time()
  247. while steps < args.num_train_steps:
  248. try:
  249. steps += nccl2_num_trainers
  250. skip_steps = args.skip_steps * nccl2_num_trainers
  251. if nccl2_trainer_id != 0:
  252. train_exe.run(fetch_list=[])
  253. continue
  254. if steps % skip_steps != 0:
  255. train_exe.run(fetch_list=[])
  256. else:
  257. each_next_acc, each_mask_lm_cost, each_total_cost, np_lr = train_exe.run(
  258. fetch_list=[
  259. next_sent_acc.name, mask_lm_loss.name, total_loss.name,
  260. scheduled_lr.name
  261. ])
  262. acc.extend(each_next_acc)
  263. lm_cost.extend(each_mask_lm_cost)
  264. cost.extend(each_total_cost)
  265. log.info("feed_queue size %d" % train_pyreader.queue.size())
  266. time_end = time.time()
  267. used_time = time_end - time_begin
  268. epoch, current_file_index, total_file, current_file, mask_type = data_reader.get_progress(
  269. )
  270. log.info("current learning_rate:%f" % np_lr[0])
  271. log.info(
  272. "epoch: %d, progress: %d/%d, step: %d, loss: %f, "
  273. "ppl: %f, next_sent_acc: %f, speed: %f steps/s, file: %s, mask_type: %s"
  274. % (epoch, current_file_index, total_file, steps,
  275. np.mean(np.array(cost)),
  276. np.mean(np.exp(np.array(lm_cost))),
  277. np.mean(np.array(acc)), skip_steps / used_time,
  278. current_file, mask_type))
  279. cost = []
  280. lm_cost = []
  281. acc = []
  282. time_begin = time.time()
  283. if steps % args.save_steps == 0:
  284. save_path = os.path.join(args.checkpoints, "step_" + str(steps))
  285. fluid.io.save_persistables(exe, save_path, train_program)
  286. if args.valid_filelist and steps % args.validation_steps == 0:
  287. vali_cost, vali_lm_cost, vali_acc, vali_steps, vali_speed = predict(
  288. )
  289. log.info("[validation_set] epoch: %d, step: %d, "
  290. "loss: %f, global ppl: %f, batch-averged ppl: %f, "
  291. "next_sent_acc: %f, speed: %f steps/s" %
  292. (epoch, steps, np.mean(np.array(vali_cost) / vali_steps),
  293. np.exp(np.mean(np.array(vali_lm_cost) / vali_steps)),
  294. np.mean(np.exp(np.array(vali_lm_cost) / vali_steps)),
  295. np.mean(np.array(vali_acc) / vali_steps), vali_speed))
  296. except fluid.core.EOFException:
  297. train_pyreader.reset()
  298. break
  299. if __name__ == '__main__':
  300. prepare_logger(log)
  301. print_arguments(args)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/479799.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

BERT原理、代码、相关模型、精调技巧,看这个就够了

星标/置顶小屋&#xff0c;带你解锁最萌最前沿的NLP、搜索与推荐技术2018 年 10 月&#xff0c;由 Google 推出的 BERT 模型一鸣惊人&#xff0c;刷爆了各路榜单&#xff0c;甚至超越了人类基线分数&#xff0c;实现了 NLP 领域里程碑式的突破。 如今&#xff0c;对于 NLP 算法…

论文浅尝 | 利用 KG Embedding 进行问题回答

论文笔记整理&#xff1a;吴杨&#xff0c;浙江大学计算机学院&#xff0c;知识图谱、NLP方向。http://research.baidu.com/Public/uploads/5c1c9a58317b3.pdf动机本文主要针对基于知识库的问题回答中的简单问题&#xff0c;也就是问题的答案只涉及KG中的一跳&#xff0c;此类问…

想成为阿里160万年薪的P8架构师?你必须掌握如下6大技能体系!

程序设计和开发 数据结构和算法&#xff1a;常用数据结构&#xff0c;排序&#xff0c;检索等 面向对象编程、设计模式&#xff0c;掌握建模语言和建模工具&#xff1a;UML、MVC编程思想 高质量编码能力&#xff1a;重用性&#xff0c;低耦合&#xff0c;可扩展性&#xff0c…

技术动态 | 知识图谱的策展

作者&#xff1a;Jiaoyan Chen, Senior Researcher, Department of Computer Science, University of Oxford, Research interests: Knowledge Base, Knowledge-based Learning, Machine Learning Explanation.知识图谱在众多的领域中发挥了重要作用&#xff0c;比如聊天机器人…

21届校招薪资曝光:严重倒挂老员工!

源 | 量子位一开始&#xff0c;还以为是科技互联网公司招聘的新把式。因为就在最近&#xff0c;一张美团应届生薪资的截图&#xff0c;在各大社区和校招群里火了。仅仅算法岗、开发岗的薪资白菜价&#xff0c;就有27k15.5&#xff0c;算下来&#xff0c;年薪就有41万。虽然这两…

从Java程序员进阶到架构师,6大核心技能要领详解

“ java架构师技能将分为如下6大环节&#xff1a;数据结构和算法&#xff0c;Java高级特性&#xff0c;Java web核心&#xff0c;数据库&#xff0c;Java框架与必备工具&#xff0c;系统架构设计。 希望能真正帮助到从程序员进阶到架构师之路的朋友。 数据结构和算法 算法分…

领域应用 | ​英文抗生素药物医学知识图谱 IASO1.0 版发布 线上试用正式启动

本文转载自公众号&#xff1a;PKU自然语言处理前沿。近日&#xff0c;由北京大学互联网信息工程研发中心&#xff08;CIRE&#xff09;开发的英语医学知识图谱英文抗生素药物医学知识图谱IASO1.0发布&#xff0c;面向公众正式开放试用。IASO是利用自然语言处理与文本挖掘技术&a…

谷歌大改Transformer注意力,速度大涨,显存大降!

源 | 机器之心导读考虑到 Transformer 对于机器学习最近一段时间的影响&#xff0c;这样一个研究就显得异常引人注目了。Transformer 有着巨大的内存和算力需求&#xff0c;因为它构造了一个注意力矩阵&#xff0c;需求与输入呈平方关系。谷歌大脑 Krzysztof Choromanski 等人最…

阿里P7架构师要求:Web核心+开源框架+大型网站架构!含面试题目!

阿里P7技能&#xff08;一&#xff09;&#xff1a;数据结构和算法&#xff1a; 常用数据结构&#xff1a;链表、堆与栈、哈希表等&#xff0c;常用的排序等。 掌握&#xff1a;精通 阿里P7技能&#xff08;二&#xff09;&#xff1a;java高级 java相关的高级特性&#xff1…

LeetCode 986. 区间列表的交集

文章目录1. 题目信息2. 解题1. 题目信息 给定两个由一些闭区间组成的列表&#xff0c;每个区间列表都是成对不相交的&#xff0c;并且已经排序。 返回这两个区间列表的交集。 &#xff08;形式上&#xff0c;闭区间 [a, b]&#xff08;其中 a < b&#xff09;表示实数 x …

论文浅尝 | 学习开发知识图谱中的长期关系依赖 - ICML 2019 ​

本文转载自公众号&#xff1a;南大Websoft。 论文&#xff1a;https://arxiv.org/abs/1905.04914代码&#xff1a;https://github.com/nju-websoft/RSN背景知识图谱结构化地存储着大量现实世界中的事实。其中&#xff0c;每个事实都以三元组 (s, r, o) 的方式进行描述&#xf…

一张图看懂小米千亿美金生态链产品

小米上市近在眼前&#xff0c;最快5月初提交IPO申请&#xff0c;再到小米IPO股指不断攀升&#xff0c;估值直奔1000亿美金以上&#xff0c;小米用了7年时间&#xff0c;这在整个互联网的发展史上&#xff0c;也算是火箭般的发展速度。 今天我们一起复盘看看小米的千亿美金生态…

Pycharm使用远程服务器运行代码

pycharm下载专业版&#xff0c;然后用学生邮箱申请个激活码&#xff08;我这里申请了个账号&#xff0c;更方便&#xff09;。 连上厦大VPN&#xff0c;再用pycharm高级版可以直接连到学校的GPU服务器&#xff0c;这样平时不在学校也能调试服务器了。 厦大VPN设置 pycharm下载…

吐槽贴:用ELECTRA、ALBERT之前,你真的了解它们吗?

文 | 苏剑林单位 | 追一科技编 | 兔子酱在预训练语言模型中&#xff0c;ALBERT和ELECTRA算是继BERT之后的两个“后起之秀”。它们从不同的角度入手对BERT进行了改进&#xff0c;最终提升了效果&#xff08;至少在不少公开评测数据集上是这样&#xff09;&#xff0c;因此也赢得…

LeetCode 56. 合并区间(优先队列)

文章目录1. 题目信息2. 解题2.1 报错的答案2.2 优先队列解题1. 题目信息 给出一个区间的集合&#xff0c;请合并所有重叠的区间。 示例 1:输入: [[1,3],[2,6],[8,10],[15,18]] 输出: [[1,6],[8,10],[15,18]] 解释: 区间 [1,3] 和 [2,6] 重叠, 将它们合并为 [1,6]. 示例 2:输入…

论文浅尝 | 基于复杂查询图编码的知识库问答

论文笔记整理&#xff1a;谭亦鸣&#xff0c;东南大学博士生&#xff0c;研究方向为知识库问答。来源&#xff1a;EMNLP 2018链接&#xff1a;https://www.aclweb.org/anthology/D18-1242文章表示&#xff0c;复杂问答所面对的问题往往包含多种实体和关系&#xff08;来自知识库…

阿里Java架构师精通资料:性能优化+亿级并发架构汇总+架构选型

分布式并发架构 微服务、Docker容器的基本原理、架构设计&#xff0c;以及应用场景。 缓存&#xff1a;Redis、Memcached、CDN、本地缓存 搜索引擎的选型&#xff1a;Lucene、Solr等选型与比较 应用服务器雪崩&#xff1a;长事务、SQL超时、同步接口引起的雪崩场景&#xff…

Google Cloud TPUs支持Pytorch框架啦!

文 | Sherry在2019年PyTorch开发者大会上&#xff0c;Facebook&#xff0c;Google和Salesforce Research联合宣布启动PyTorch-TPU项目。项目的目标是在保持PyTorch的灵活性的同时让社区尽可能容易地利用云TPU提供的高性能计算。团队创建了PyTorch/XLA这个repo&#xff0c;它可以…

LeetCode 231. 2的幂 LeetCode 338. 比特位计数(2进制1的个数)

文章目录1. 题目信息2. 解题拓展&#xff1a;求一个数n的2进制有多少个1&#xff1f;LeetCode 3381. 题目信息 给定一个整数&#xff0c;编写一个函数来判断它是否是 2 的幂次方。 示例 1:输入: 1 输出: true 解释: 20 1 示例 2:输入: 16 输出: true 解释: 24 16 示例 3:输…