论文地址:https://arxiv.org/pdf/2309.03241v2.pdf
项目地址:https://github.com/THUDM/MathGLM#arithmetic-tasks
数据集格式:
读取数据集代码:
def make_loaders(args, create_dataset_function):"""makes training/val/testArgs:args.train_data, args.valid_data, args.test_data: str. Paths to the dataset.args.split: str. format: "8,1,1". how to split train_data.args.dataset_type: use to create the right datasets. """make_dataset = partial(make_dataset_full, create_dataset_function=create_dataset_function)world_size = torch.distributed.get_world_size(group=mpu.get_data_parallel_group())batch_size = args.batch_size * world_sizeeval_batch_size = batch_sizeif args.eval_batch_size is not None:eval_batch_size = args.eval_batch_size * world_sizesplit = get_split(args)data_set_args = {'path': args.train_data,'split': split,}eval_set_args = copy.copy(data_set_args)eval_set_args['split'] = [1.]# make datasets splits and tokenizertrain = Nonevalid = Nonetest = Noneif args.train_data is not None:train = make_dataset(**data_set_args, args=args, dataset_weights=args.train_data_weights, is_train_data=True)if should_split(split):train, valid, test = train# make training and val dataset if necessaryif valid is None and args.valid_data is not None:eval_set_args['path'] = args.valid_datavalid = make_dataset(**eval_set_args, args=args, random_mapping=not args.strict_eval)if test is None and args.test_data is not None:eval_set_args['path'] = args.test_datatest = make_dataset(**eval_set_args, args=args, random_mapping=not args.strict_eval)# wrap datasets with data loaderif train is not None and args.batch_size > 0:train = make_data_loader(train, batch_size, args, split='train')args.do_train = Trueelse:args.do_train = Falseeval_batch_size = eval_batch_size if eval_batch_size != 0 else batch_sizeif valid is not None:valid = make_data_loader(valid, eval_batch_size, args, split='val')args.do_valid = Trueelse:args.do_valid = Falseif test is not None:test = make_data_loader(test, eval_batch_size, args, split='test')args.do_test = Trueelse:args.do_test = Falsereturn train, valid, test
数据读取后:
/home/user/zjb/SAT/mathglm/continue_train_mathglm.py
get_batch(data_iterator, args, timers):# 传入参数data_iterator
在continue_train_mathglm.py
文件中,get_batch
函数从data_iterator
获取数据,然后将其转换为词向量。这是通过调用get_batch
函数中的mpu.broadcast_data
函数实现的。
然后,在forward_step
函数中,get_batch
函数的返回值被传递给模型。模型接收的输入是词向量,而不是原始的字符串。
这种转换是因为模型不能直接处理原始的文本数据。模型需要的是一种数值表示,通常是词向量,这样才能进行数学运算。因此,原始的字符串数据需要被转换为词向量。
在create_dataset_function
函数中,你可以看到这个转换过程。process_fn
函数接收一个字符串row
,然后使用tokenizer._encode(value)
将其转换为词向量。这个词向量然后被添加到ids
列表中,最后返回一个包含词向量的字典。
所以,data_iterator
中的数据是字符串,因为这是原始的输入数据。然后,这些数据被转换为词向量,以便可以被模型处理。
其中data_iterator是dataloaderlter如图所示
data如图所示
data_b
mpu.broadcast_data(keys, data, datatype) 是一个函数调用,它来自于 SwissArmyTransformer 库中的 mpu 模块。这个函数的作用是在分布式环境中广播数据。
create_dataset_function
函数在continue_train_mathglm.py
文件中。这个函数用于创建一个数据集,它接收一个路径和参数,然后返回一个MathDataset
对象。在这个函数中,它定义了一个process_fn
函数,这个函数用于处理每一行数据,将其转换为词向量。