bert 的MLM框架任务-梯度累积

参考:BEHRT/task/MLM.ipynb at ca0163faf5ec09e5b31b064b20085f6608c2b6d1 · deepmedicine/BEHRT · GitHub

class BertConfig(Bert.modeling.BertConfig):def __init__(self, config):super(BertConfig, self).__init__(vocab_size_or_config_json_file=config.get('vocab_size'),hidden_size=config['hidden_size'],num_hidden_layers=config.get('num_hidden_layers'),num_attention_heads=config.get('num_attention_heads'),intermediate_size=config.get('intermediate_size'),hidden_act=config.get('hidden_act'),hidden_dropout_prob=config.get('hidden_dropout_prob'),attention_probs_dropout_prob=config.get('attention_probs_dropout_prob'),max_position_embeddings = config.get('max_position_embedding'),initializer_range=config.get('initializer_range'),)self.seg_vocab_size = config.get('seg_vocab_size')self.age_vocab_size = config.get('age_vocab_size')class TrainConfig(object):def __init__(self, config):self.batch_size = config.get('batch_size')self.use_cuda = config.get('use_cuda')self.max_len_seq = config.get('max_len_seq')self.train_loader_workers = config.get('train_loader_workers')self.test_loader_workers = config.get('test_loader_workers')self.device = config.get('device')self.output_dir = config.get('output_dir')self.output_name = config.get('output_name')self.best_name = config.get('best_name')file_config = {'vocab':'',  # vocabulary idx2token, token2idx'data': '',  # formated data 'model_path': '', # where to save model'model_name': '', # model name'file_name': '',  # log path
}
create_folder(file_config['model_path'])global_params = {'max_seq_len': 64,'max_age': 110,'month': 1,'age_symbol': None,'min_visit': 5,'gradient_accumulation_steps': 1
}optim_param = {'lr': 3e-5,'warmup_proportion': 0.1,'weight_decay': 0.01
}train_params = {'batch_size': 256,'use_cuda': True,'max_len_seq': global_params['max_seq_len'],'device': 'cuda:0'
}

模型:

BertVocab = load_obj(file_config['vocab'])
ageVocab, _ = age_vocab(max_age=global_params['max_age'], mon=global_params['month'], symbol=global_params['age_symbol'])data = pd.read_parquet(file_config['data'])
# remove patients with visits less than min visit
data['length'] = data['caliber_id'].apply(lambda x: len([i for i in range(len(x)) if x[i] == 'SEP']))
data = data[data['length'] >= global_params['min_visit']]
data = data.reset_index(drop=True)Dset = MLMLoader(data, BertVocab['token2idx'], ageVocab, max_len=train_params['max_len_seq'], code='caliber_id')
trainload = DataLoader(dataset=Dset, batch_size=train_params['batch_size'], shuffle=True, num_workers=3)model_config = {'vocab_size': len(BertVocab['token2idx'].keys()), # number of disease + symbols for word embedding'hidden_size': 288, # word embedding and seg embedding hidden size'seg_vocab_size': 2, # number of vocab for seg embedding'age_vocab_size': len(ageVocab.keys()), # number of vocab for age embedding'max_position_embedding': train_params['max_len_seq'], # maximum number of tokens'hidden_dropout_prob': 0.1, # dropout rate'num_hidden_layers': 6, # number of multi-head attention layers required'num_attention_heads': 12, # number of attention heads'attention_probs_dropout_prob': 0.1, # multi-head attention dropout rate'intermediate_size': 512, # the size of the "intermediate" layer in the transformer encoder'hidden_act': 'gelu', # The non-linear activation function in the encoder and the pooler "gelu", 'relu', 'swish' are supported'initializer_range': 0.02, # parameter weight initializer range
}conf = BertConfig(model_config)
model = BertForMaskedLM(conf)model = model.to(train_params['device'])
optim = adam(params=list(model.named_parameters()), config=optim_param)

计算准确率:

def cal_acc(label, pred):logs = nn.LogSoftmax()label=label.cpu().numpy()ind = np.where(label!=-1)[0]truepred = pred.detach().cpu().numpy()truepred = truepred[ind]truelabel = label[ind]truepred = logs(torch.tensor(truepred))outs = [np.argmax(pred_x) for pred_x in truepred.numpy()]precision = skm.precision_score(truelabel, outs, average='micro')return precision

开始训练:

def train(e, loader):tr_loss = 0temp_loss = 0nb_tr_examples, nb_tr_steps = 0, 0cnt= 0start = time.time()for step, batch in enumerate(loader):cnt +=1batch = tuple(t.to(train_params['device']) for t in batch)age_ids, input_ids, posi_ids, segment_ids, attMask, masked_label = batchloss, pred, label = model(input_ids, age_ids, segment_ids, posi_ids,attention_mask=attMask, masked_lm_labels=masked_label)if global_params['gradient_accumulation_steps'] >1:loss = loss/global_params['gradient_accumulation_steps']loss.backward()temp_loss += loss.item()tr_loss += loss.item()nb_tr_examples += input_ids.size(0)nb_tr_steps += 1if step % 200==0:print("epoch: {}\t| cnt: {}\t|Loss: {}\t| precision: {:.4f}\t| time: {:.2f}".format(e, cnt, temp_loss/2000, cal_acc(label, pred), time.time()-start))temp_loss = 0start = time.time()if (step + 1) % global_params['gradient_accumulation_steps'] == 0:optim.step()optim.zero_grad()print("** ** * Saving fine - tuned model ** ** * ")model_to_save = model.module if hasattr(model, 'module') else model  # Only save the model it-selfcreate_folder(file_config['model_path'])output_model_file = os.path.join(file_config['model_path'], file_config['model_name'])torch.save(model_to_save.state_dict(), output_model_file)cost = time.time() - startreturn tr_loss, cost

保存结果:

f = open(os.path.join(file_config['model_path'], file_config['file_name']), "w")
f.write('{}\t{}\t{}\n'.format('epoch', 'loss', 'time'))
for e in range(50):loss, time_cost = train(e, trainload)loss = loss/data_lenf.write('{}\t{}\t{}\n'.format(e, loss, time_cost))
f.close()    

某大模型对上面代码梯度累积的解释:

gradient_accumulation_steps 是一个用于在训练过程中累积梯度的参数。它的作用是在每次更新模型参数之前,累积多个 batch 的梯度,然后进行一次参数更新。这个技巧有助于在显存受限的情况下仍能使用较大的 batch size,并且可以稳定模型的训练过程。

对于BERT的MLM任务来说,通常会使用梯度累积,因为BERT是一个比较大的模型,使用较大的batch size可能会导致显存不足的问题,特别是在GPU资源受限的情况下。梯度累积可以帮助解决这个问题。

在你提供的代码中,gradient_accumulation_steps 的值被用来判断是否要对 loss 进行除法操作。如果 gradient_accumulation_steps 大于1,那么会将 loss 除以 gradient_accumulation_steps,以确保在多个batch的梯度累积后对参数进行更新。这是一个常见的做法。

因此,可以说对于BERT的MLM任务,使用梯度累积是一种常见的做法,特别是在资源受限的情况下。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/11098.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

java入门-面向对象的三大特性

面向对象三大特性 封装 什么是封装 封装 是将代码及其处理的数据绑定在一起的一种编程机制,该机制保证了程序和数据都不受外部干扰且不被误用。 封装的作用 访问控制符 方法传参-值传递 传参类型是基本类型 程序案例: public static void main(St…

C++笔记(体系结构与内核分析)

1.OOP面向对象编程 vs. GP泛型编程 OOP将data和method放在一起,目的是通过封装、继承、多态提高软件的可维护性和可扩展性GP将data和method分开,可以将任何容器与任何算法结合使用,只要容器满足塞饭所需的迭代器类型 2.算法与仿函数的区别 …

Flutter 中的 Column 小部件:全面指南

Flutter 中的 Column 小部件:全面指南 在 Flutter 中,Column 是一个垂直布局的小部件,用于将子控件沿着垂直轴排列。Column 与 Row 相对,Row 是水平布局,而 Column 则是垂直布局。它非常适合用来创建列式布局&#xf…

AJAX前端与后端交互技术知识点以及案例

Promise promise对象用于表示一个异步操作的最终完成(或失败)及其结果值 好处: 逻辑更清晰了解axios函数内部运作机制成功和失败状态,可以关联对应处理程序能解决回调函数地狱问题 /*** 目标:使用Promise管理异步任…

Linux-02

Linux常用命令: ls: 列出目录touch: 创建文件 touch test.txt echo:往文件写内容echo "i love linux" >>test.txtcd:切换目录pwd:显示目前的目录mkdir:创建一个新的目录 mkdir dai:创建目录dai mkdir -p test1/t…

基于springboot的物业服务平台的设计与实现

基于springboot的物业服务平台的设计与实现 摘 要:本文针对社区物业服务管理现状,采用B/S系统架构并选择MySQL数据库作为系统的数据存储系统,设计并实现一个以Spring Boot为后端框架、Vue为前端框架的社区物业服务管理平台。与传统的物业服务管理方式相比,该系统取代了传统…

印象笔记使用技巧

印象笔记(Evernote)是一款广泛使用的笔记应用,它帮助用户整理个人信息、文件和备忘录。以下是一些提高在印象笔记中效率的使用技巧: ### 1. 使用标签和笔记本组织笔记 - **建立笔记本**:为不同的项目或类别创建笔记本…

2024-Docker常用命令大全

Docker是一个流行的开源容器化平台,它允许开发人员将应用程序及其依赖项打包到可移植的容器中,并可以轻松地发布到任何Linux机器上。以下是Docker的一些常用命令总结: 一、帮助与启动类命令 启动Docker:sudo systemctl start do…

如何查看打包后的jar包启动方法main方法

背景 有时候我们在引用一个jar包的时候,想查看一个jar包的结构,这时候查看启动类就比较重要,因为一些关键配置是在启动类上的,这里教大家如何查看这个启动类(springboot项目) 步骤 首先打开jar包预览结构,可以使用解压缩工具直接双击打开或者预览结构 打开路径 META-INF/MA…

springfox.documentation.spi.DocumentationType没有OAS_30(从swagger2转到swagger3出现的问题)

直接开讲: 查看源码根本没有OAS_30的类型选择 右键package的springfox找到maven下载的包,打开到资源管理器 可以看到项目优先使用2版本的jar包,但是OAS_30只在3版本中才有,意思就是让项目优先使用以下图片中的3.0.0jar包 解决办法…

[AutoSar]BSW_Diagnostic_004 ReadDataByIdentifier(0x22)的配置和实现

目录 关键词平台说明背景一、配置DcmDspDataInfos二、配置DcmDspDatas三、创建DcmDspDidInfos四、创建DcmDspDids五、总览六、创建一个ASWC七、mapping DCM port八、打开davinci developer,创建runnabl九、生成代码 关键词 嵌入式、C语言、autosar、OS、BSW、UDS、…

系统和功能测试:确保软件的功能和易用性

目录 概述 功能测试 LOSED 模型 用例的设计 等价类划分 边界值分析 循环结构测试的综合方法 因果图 决策表 功能图 正交实验设计 易用性测试 内部易用性测试 外部易用性测试 功能性测试 正向功能性测试 负向功能性测试 功能性测试工具 结语 概述 在软件开发…

课堂练习——路由策略

需求:将1.1.1.0/24网段重发布到网络中,不允许出现次优路径,实现全网可达。 在R1上重发布1.1.1.0/24网段,但是需要过滤192.168.12.0/24和192.168.13.0/24在R2和R3上执行双向重发布 因为R1引入的域外路由信息的优先级为150&#xff…

8.微服务项目结合SpringSecurity项目结构

项目结构 acl_parent:创建父工程用来管理依赖版本 common service_base&#xff1a;工具类 spring_security: Spring Security相关配置 infrastructure api_gateway: 网关 service service_acl: 实现权限管理功能代码 acl_parent的pom.xml <?xml version"1.0" …

Go 之 interface接口理解

go语言并没有面向对象的相关概念&#xff0c;go语言提到的接口和java、c等语言提到的接口不同&#xff0c;它不会显示的说明实现了接口&#xff0c;没有继承、子类、implements关键词。go语言通过隐性的方式实现了接口功能&#xff0c;相对比较灵活。 interface是go语言的一大…

系统和非功能性测试:全面提升软件的性能和可靠性

目录 非功能性测试的需求 负载测试、压力测试和性能测试的概念 负载测试 压力测试 性能测试 负载测试 性能测试 压力测试 性能测试工具 兼容性测试 安全性测试 容错性测试 可靠性测试 结语 在软件开发中&#xff0c;系统测试和非功能性测试是确保软件质量的关键环…

vue组件循环依赖

1、两个组件相互引用导致的报错&#xff0c;代码如下&#xff1a; <!-- Component A --> <template><div>{{ message }}<B /></div> </template><script> import B from ./B.vue;export default {components: { B },data() {return…

STM32 | STC-USB驱动安装Windows 10(64 位)

Windows 10&#xff08;64 位&#xff09;安装方法 由于 Windows10 64 位操作系统在默认状态下&#xff0c;对于没有数字签名的驱动程序是不能安装成功的。所以在安装 STC-USB 驱动前&#xff0c;需要按照如下步骤&#xff0c;暂时跳过数字签名&#xff0c;即可顺利安装成功。…

Transformers中加载预训练模型的过程剖析(二)

接着 Transformers中加载预训练模型的过程剖析(一) 来讲模型初始和载入预训练权重。 初始化模型 经过以上两个步骤,对于下载的 shibing624/text2vec-base-chinese-paraphrase 预训练模型(底层的模型是ernie),我们从代码 model = AutoModel.from_pretrained(model_path) 成功…

镜像制作过程

镜像制作过程 Centos镜像制作 虚拟机系统安装将网卡转换为eth0在install安装时按tab健加入一下配置 net.ifnames=0 biosdevname=0