🧑💻 本文主要讲解Megatron早期版本中的数据混合算法。
目录
- 1. 数据混合
- 2. 源码解析
- 3. 证明部分&讨论
- 4. 进一步优化
1. 数据混合
在谈源码之前,我们有必要先了解一下Megatron中的数据混合思想。
给定 n n n 个数据集 D 1 , D 2 , ⋯ , D n \mathcal{D}_1,\mathcal{D}_2,\cdots,\mathcal{D}_n D1,D2,⋯,Dn 和对应的 n n n 个权重 w 1 , w 2 , ⋯ , w n w_1,w_2,\cdots,w_n w1,w2,⋯,wn,我们要按照这些权重去混合 n n n 个数据集,设混合后的数据集为 D \mathcal{D} D。
Megatron假定:
- ∣ D ∣ = ∑ i = 1 n ∣ D i ∣ |\mathcal{D}|=\sum_{i=1}^n|\mathcal{D}_i| ∣D∣=∑i=1n∣Di∣。即混合后的数据集大小等于混合前的各数据集大小之和。
- D \mathcal{D} D 中有约 ∣ D ∣ ⋅ w i |\mathcal{D}|\cdot w_i ∣D∣⋅wi 个样本来自 D i \mathcal{D}_i Di。
那如何确定 D \mathcal{D} D 中到底有多少个样本是来自 D i \mathcal{D}_i Di 的呢?一种最直观的做法是,计算 ∣ D ∣ ⋅ w i |\mathcal{D}|\cdot w_i ∣D∣⋅wi,然后进行取整,但这种操作无法保证所有取整后的 ∣ D ∣ ⋅ w i |\mathcal{D}|\cdot w_i ∣D∣⋅wi 相加起来恰好是 ∣ D ∣ |\mathcal{D}| ∣D∣。 如果总和大于 ∣ D ∣ |\mathcal{D}| ∣D∣,说明某些数据集被过采样了,应当减少相应数据集的采样数;如果总和小于 ∣ D ∣ |\mathcal{D}| ∣D∣,说明某些数据集被欠采样了,应当增加相应数据集的采样数。可问题是,如何确定这些被过采样/欠采样的数据集呢?显然我们需要一个更加公平的算法。
我们可以把获取数据集 D \mathcal{D} D 看作是一个采样过程:一开始有 n n n 个数据源 { D i } i = 1 n \{\mathcal{D}_i\}_{i=1}^n {Di}i=1n,每一轮迭代,我们需要先从这 n n n 个数据源中选出一个数据源 D i \mathcal{D}_i Di,然后再从这个数据源中选出一个样本 S \mathcal{S} S。 由于每一轮迭代只会选出一个样本,因此 ∣ D ∣ |\mathcal{D}| ∣D∣ 轮迭代结束后,我们便得到了 ∣ D ∣ |\mathcal{D}| ∣D∣ 个样本,这些样本构成了混合后的数据集 D \mathcal{D} D。
每一轮迭代都会产生两个信息:要选取的数据源 D i \mathcal{D}_i Di,要从 D i \mathcal{D}_i Di 中选取的样本。我们可以考虑构造两个整数序列 P , S \mathcal{P},\mathcal{S} P,S,它们的长度均为 ∣ D ∣ |\mathcal{D}| ∣D∣,含义如下:
- P j \mathcal{P}_j Pj 代表的是第 j j j 轮迭代时,选取的数据源的下标。例如 P 10 = 3 \mathcal{P}_{10}=3 P10=3 意味着第 10 10 10 轮迭代选取的数据源是 D 3 \mathcal{D}_3 D3。
- S j \mathcal{S}_j Sj 代表的是第 j j j 轮迭代时,从数据源 D P j \mathcal{D}_{\mathcal{P}_j} DPj 选取的样本的下标。
由以上定义知, ∀ j \forall j ∀j,都有 1 ≤ P j ≤ n 1\leq \mathcal{P}_j\leq n 1≤Pj≤n, 1 ≤ S j ≤ ∣ D P j ∣ 1\leq \mathcal{S}_j\leq|\mathcal{D}_{\mathcal{P}_j}\!| 1≤Sj≤∣DPj∣(下标均从 1 1 1 开始)。
接下来的问题是,如何确定每一轮的 P j \mathcal{P}_j Pj 和 S j \mathcal{S}_j Sj 呢?
先谈 P j \mathcal{P}_j Pj。因为是一个从 1 1 1 到 ∣ D ∣ |\mathcal{D}| ∣D∣ 的一个逐步采样过程,在第 j j j 轮迭代时,我们已经抽取了 j − 1 j-1 j−1 个样本,接下来要确定第 j j j 个样本。根据Megatron的假定,在确定下来第 j j j 个样本后,这 j j j 个样本中应当有约 j ⋅ w i j\cdot w_i j⋅wi 个样本是来自 D i \mathcal{D}_i Di 的。
考虑构造一个长度为 n n n 的序列 C \mathcal{C} C,该序列随着迭代不断更新。 C i \mathcal{C}_i Ci 代表当前已经从 D i \mathcal{D}_i Di 抽取了多少个样本。显然可知,第一轮迭代开始时,有 C i = 0 , i = 1 , 2 , ⋯ , n \mathcal{C}_i=0,\,i=1,2,\cdots,n Ci=0,i=1,2,⋯,n。最后一轮迭代结束后,有 ∑ i = 1 n C i = ∣ D ∣ \sum_{i=1}^n\mathcal{C}_i=|\mathcal{D}| ∑i=1nCi=∣D∣,并且
C i = { ∑ t = 1 j − 1 I ( P t = i ) , P j 确定前 ∑ t = 1 j I ( P t = i ) , P j 确定后 , ∀ i \mathcal{C}_i=\begin{cases} \sum_{t=1}^{j-1} I(\mathcal{P}_t=i),&\text{$\mathcal{P}_j$确定前} \\ \sum_{t=1}^{j} I(\mathcal{P}_t=i),&\text{$\mathcal{P}_j$确定后} \\ \end{cases},\quad \forall i Ci={∑t=1j−1I(Pt=i),∑t=1jI(Pt=i),Pj确定前Pj确定后,∀i
回到对 P j \mathcal{P}_j Pj 的讨论中。假设在确定第 j j j 个样本前已经从 D i \mathcal{D}_i Di 中抽取了 C i \mathcal{C}_i Ci 个样本,在确定第 j j j 个样本后,诸 C i \mathcal{C}_i Ci 中有且仅有一个的值会增加 1 1 1,不妨记为 C k \mathcal{C}_k Ck,这个过程可以形容为
[ C 1 , ⋯ , C k , ⋯ , C n ] ⏟ 第 j 轮迭代开始时 → [ C 1 , ⋯ , C k + 1 , ⋯ , C n ] ⏟ 第 j 轮迭代结束时 [ j ⋅ w 1 , j ⋅ w 2 , ⋯ , j ⋅ w n ] ⏟ 理论值 \underbrace{[\mathcal{C}_1,\cdots,\mathcal{C}_k,\cdots,\mathcal{C}_n]}_{第j轮迭代开始时}\to\underbrace{[\mathcal{C}_1,\cdots,\mathcal{C}_{k}+1,\cdots,\mathcal{C}_n]}_{第j轮迭代结束时}\qquad \underbrace{[j\cdot w_1,j\cdot w_2,\cdots,j\cdot w_n]}_{理论值} 第j轮迭代开始时 [C1,⋯,Ck,⋯,Cn]→第j轮迭代结束时 [C1,⋯,Ck+1,⋯,Cn]理论值 [j⋅w1,j⋅w2,⋯,j⋅wn]
我们期望第 j j j 轮迭代结束时,诸 C i \mathcal{C}_i Ci 应当尽可能地接近理论值(在MSE下)。由于只能让其中一个 C k \mathcal{C}_k Ck 自增 1 1 1,显然有 k = arg max i ( j ⋅ w i − C i ) k=\argmax_i(j\cdot w_i-\mathcal{C}_i) k=argmaxi(j⋅wi−Ci)。
再谈 S j \mathcal{S}_j Sj。在确定了数据源是 D k \mathcal{D}_k Dk 后,为了避免重复,我们应当做到不放回、随机地从中采样。如何做到这两点呢?我们可以在一开始就对 n n n 个数据源进行打乱,然后在采样的时候只需要从前往后进行,就可以做到以上两点。注意到 C i \mathcal{C}_i Ci 的值是从 0 0 0 开始,以步长为 1 1 1 依次递增,所以我们可以用每次更新完的 C i \mathcal{C}_i Ci 赋值给相应的 S j \mathcal{S}_j Sj,即 S j = 第 j 轮迭代结束时的 C i \mathcal{S}_j=第j轮迭代结束时的\mathcal{C}_i Sj=第j轮迭代结束时的Ci。
由此我们可以得到整个算法的伪代码:
2. 源码解析
Python部分:
class BlendableDataset(torch.utils.data.Dataset):def __init__(self, datasets, weights):self.datasets = datasetsnum_datasets = len(datasets)assert num_datasets == len(weights), "The number of datasets and weights must match."self.size = sum(len(dataset) for dataset in self.datasets)# Normalize weights.weights = np.array(weights, dtype=np.float64)sum_weights = np.sum(weights)assert sum_weights > 0.0, "Sum of weights must be positive."weights /= sum_weights# Build indices.start_time = time.time()assert num_datasets < 255, "Number of datasets must be less than 255."self.dataset_index = np.zeros(self.size, dtype=np.uint8)self.dataset_sample_index = np.zeros(self.size, dtype=np.int64)helpers.build_blending_indices(self.dataset_index,self.dataset_sample_index,weights,num_datasets,self.size,torch.distributed.get_rank() == 0,)print_rank_0(f'> elapsed time for building blendable dataset indices: 'f'{time.time() - start_time:.2f} sec')def __len__(self):return self.sizedef __getitem__(self, idx):dataset_idx = self.dataset_index[idx]sample_idx = self.dataset_sample_index[idx]return {"dataset_idx": dataset_idx,**self.datasets[dataset_idx][sample_idx],}
C++部分:
void build_blending_indices(py::array_t<uint8_t> &dataset_index,py::array_t<int64_t> &dataset_sample_index,const py::array_t<double> &weights,const int32_t num_datasets,const int64_t size,const bool verbose
) {/* Given multiple datasets and a weighting array, build samplessuch that it follows those weights. */if (verbose) {std::cout << "> building indices for blendable datasets ..." << std::endl;}// Get the pointer access without the checks.auto dataset_index_ptr = dataset_index.mutable_unchecked<1>();auto dataset_sample_index_ptr = dataset_sample_index.mutable_unchecked<1>();auto weights_ptr = weights.unchecked<1>();// Initialize buffer for number of samples used for each dataset.int64_t current_samples[num_datasets];for (int64_t i = 0; i < num_datasets; ++i) {current_samples[i] = 0;}// For each sample:for (int64_t sample_idx = 0; sample_idx < size; ++sample_idx) {// Determine where the max error in sampling is happening.auto sample_idx_double = std::max(static_cast<double>(sample_idx), 1.0);int64_t max_error_index = 0;double max_error = weights_ptr[0] * sample_idx_double - static_cast<double>(current_samples[0]);for (int64_t dataset_idx = 1; dataset_idx < num_datasets; ++dataset_idx) {double error = weights_ptr[dataset_idx] * sample_idx_double - static_cast<double>(current_samples[dataset_idx]);if (error > max_error) {max_error = error;max_error_index = dataset_idx;}}// Populate the indices.dataset_index_ptr[sample_idx] = static_cast<uint8_t>(max_error_index);dataset_sample_index_ptr[sample_idx] = current_samples[max_error_index];// Update the total samples.current_samples[max_error_index] += 1;}// Print infoif (verbose) {std::cout << " > sample ratios:" << std::endl;for (int64_t dataset_idx = 0; dataset_idx < num_datasets; ++dataset_idx) {auto ratio = static_cast<double>(current_samples[dataset_idx]) / static_cast<double>(size);std::cout << " dataset " << dataset_idx << ", input: " << weights_ptr[dataset_idx] << ", achieved: " << ratio << std::endl;}}
}
具体的算法实现是在C++的函数中,我们先来看Python部分。
self.size
实际上就是 ∣ D ∣ |\mathcal{D}| ∣D∣,即混合后的数据集大小(从后面的 __len__
也能看出)。在构造函数中,首先会对 weights
进行归一化,然后声明 P , S \mathcal{P},\mathcal{S} P,S 两个数组。注意 self.dataset_index
实际上就是 P \mathcal{P} P,self.dataset_sample_index
实际上就是 S \mathcal{S} S。由于 P \mathcal{P} P 的数据类型是 uint8
,这表明其中元素的范围是 [ 0 , 2 8 − 1 = 255 ] [0,2^8-1=255] [0,28−1=255],故 P \mathcal{P} P 最多能表示 256 256 256 个数据集,而源码中规定了参与混合的数据集个数必须严格少于 255 255 255(博主不是很懂这一点,看懂的小伙伴可以在评论区留言)。
再来看C++部分。前五个形参分别是 P , S , { w i } i , n , ∣ D ∣ \mathcal{P},\mathcal{S},\{w_i\}_i,n,|\mathcal{D}| P,S,{wi}i,n,∣D∣。
C \mathcal{C} C 数组会在该函数中进行声明并初始化。随后的两个嵌套 for
循环则是整个算法的核心流程,注意到这里的实现中,sample_idx
(即 j j j)是从 0 0 0 开始的,而算法伪代码中的 j j j 是从 1 1 1 开始的,所以一开始要执行 j = max ( j , 1 ) j=\max(j,1) j=max(j,1) 以确保 j j j 至少是 1 1 1(但这样做有一个弊端就是前两轮的循环里, j j j 的值是相同的,和我们期望的每一轮里 j j j 值不同相违背,这是源码中的一个缺陷,实际上应该计算 ( j + 1 ) ⋅ w i − C i (j+1)\cdot w_i-\mathcal{C}_i (j+1)⋅wi−Ci)。内层循环中的 error
实际上就是 j ⋅ w i − C i j\cdot w_i-\mathcal{C}_i j⋅wi−Ci。此外,由于 j j j 是从 0 0 0 开始的,所以 C P j \mathcal{C}_{\mathcal{P}_j} CPj 的更新要放到最后执行。
一言以蔽之, j j j 从 1 1 1 开始,更新顺序为 P → C → S \mathcal{P}\to\mathcal{C}\to\mathcal{S} P→C→S; j j j 从 0 0 0 开始,更新顺序为 P → S → C \mathcal{P}\to\mathcal{S}\to\mathcal{C} P→S→C。
得到了 P , S \mathcal{P},\mathcal{S} P,S 数组后,我们便可得到混合后的数据集 D \mathcal{D} D:
D j = D P j [ S j ] , j = 1 , 2 , ⋯ , ∣ D ∣ \mathcal{D}_j=\mathcal{D}_{\mathcal{P}_j}[\mathcal{S}_j],\quad j=1,2,\cdots,|\mathcal{D}| Dj=DPj[Sj],j=1,2,⋯,∣D∣
其中 D i [ j ] \mathcal{D}_i[j] Di[j] 代表数据集 D i \mathcal{D}_i Di 中的第 j j j 个样本。
回到Python部分,__getitem__
中传入的 idx
实际上就是 j j j,self.datasets[dataset_idx][sample_idx]
实际上就是上述的 D P j [ S j ] \mathcal{D}_{\mathcal{P}_j}[\mathcal{S}_j] DPj[Sj]。
3. 证明部分&讨论
Prop 1. \text{Prop} \;1.\, Prop1. 每一轮循环开始时所有误差加和为 1 1 1,即 ∑ i = 1 n e i = 1 \sum_{i=1}^n e_i=1 ∑i=1nei=1,其中 e i ≜ j ⋅ w i − C i e_i\triangleq j\cdot w_i-\mathcal{C}_i ei≜j⋅wi−Ci。
P r o o f . Proof.\; Proof. 注意到第 j j j 轮循环开始时,此时一共只采样了 j − 1 j-1 j−1 个样本,所以 ∑ i = 1 n C i = j − 1 \sum_{i=1}^n\mathcal{C}_i=j-1 ∑i=1nCi=j−1,从而
∑ i = 1 n e i = ∑ i = 1 n ( j ⋅ w i − C i ) = j ⋅ ∑ i = 1 n w i − ∑ i = 1 n C i = j − ∑ i = 1 n C i = j − ( j − 1 ) = 1 \sum_{i=1}^n e_i=\sum_{i=1}^n (j\cdot w_i-\mathcal{C}_i)=j\cdot\sum_{i=1}^n w_i-\sum_{i=1}^n\mathcal{C}_i=j-\sum_{i=1}^n\mathcal{C}_i=j-(j-1)=1 i=1∑nei=i=1∑n(j⋅wi−Ci)=j⋅i=1∑nwi−i=1∑nCi=j−i=1∑nCi=j−(j−1)=1
进一步可知,每一轮循环结束时所有误差加和为 0 0 0。
Prop 2. \text{Prop} \;2.\, Prop2. 假定下标从 1 1 1 开始,且 n = 2 n=2 n=2(即只有两个数据源)。若 e 1 ≥ 0.5 e_1\geq 0.5 e1≥0.5,则 P j = 1 \mathcal{P}_j=1 Pj=1,否则 P j = 2 \mathcal{P}_j=2 Pj=2。
P r o o f . Proof.\; Proof. e 1 > 0.5 e_1>0.5 e1>0.5 的情况显然。当 e 1 = e 2 = 0.5 e_1=e_2=0.5 e1=e2=0.5 时, arg max \argmax argmax 会优先挑选下标最小的,故此时 P j \mathcal{P}_j Pj 仍是 1 1 1。
Prop 3. \text{Prop} \;3.\, Prop3. 假定下标从 1 1 1 开始。可能存在一组 { D i } i \{\mathcal{D}_i\}_i {Di}i 和 { w i } i \{w_i\}_i {wi}i,使得经由上述算法得到的 P , S \mathcal{P},\mathcal{S} P,S 数组, ∃ j , s.t. S j > ∣ D P j ∣ \exists \,j,\,\text{s.t.}\;\,\mathcal{S}_j>|\mathcal{D}_{\mathcal{P}_j}| ∃j,s.t.Sj>∣DPj∣,意味着 __getitem__
会出现下标越界的错误。
P r o o f . Proof.\; Proof. 构造特殊情形即可。令 n = 2 n=2 n=2, ∣ D 1 ∣ = ∣ D 2 ∣ = 2 |\mathcal{D}_1|=|\mathcal{D}_2|=2 ∣D1∣=∣D2∣=2, w 1 = 0.1 , w 2 = 0.9 w_1=0.1,\,w_2=0.9 w1=0.1,w2=0.9。
由 ∣ D 1 ∣ + ∣ D 2 ∣ = 4 |\mathcal{D}_1|+|\mathcal{D}_2|=4 ∣D1∣+∣D2∣=4 可知,总共会有 4 4 4 轮循环。且理应有 1 ≤ P j , S j ≤ 2 , j = 1 , 2 , 3 , 4 1\leq \mathcal{P}_j,\mathcal{S}_j\leq 2,\,j=1,2,3,4 1≤Pj,Sj≤2,j=1,2,3,4。
利用 Prop 2 \text{Prop} \;2 Prop2 快速计算:
-
第一轮循环,计算误差 e 1 = 1 ⋅ w 1 − 0 = 0.1 < 0.5 e_1=1\cdot w_1-0=0.1<0.5 e1=1⋅w1−0=0.1<0.5,故 P 1 = 2 \mathcal{P}_1=2 P1=2, C = { 0 , 1 } \mathcal{C}=\{0,1\} C={0,1}, S 1 = C 2 = 1 \mathcal{S}_1=\mathcal{C}_2=1 S1=C2=1。
-
第二轮循环,计算误差 e 1 = 2 ⋅ w 1 − 0 = 0.2 < 0.5 e_1=2\cdot w_1-0=0.2<0.5 e1=2⋅w1−0=0.2<0.5,故 P 2 = 2 \mathcal{P}_2=2 P2=2, C = { 0 , 2 } \mathcal{C}=\{0,2\} C={0,2}, S 2 = C 2 = 2 \mathcal{S}_2=\mathcal{C}_2=2 S2=C2=2。
-
第三轮循环,计算误差 e 1 = 3 ⋅ w 1 − 0 = 0.3 < 0.5 e_1=3\cdot w_1-0=0.3<0.5 e1=3⋅w1−0=0.3<0.5,故 P 3 = 2 \mathcal{P}_3=2 P3=2, C = { 0 , 3 } \mathcal{C}=\{0,3\} C={0,3}, S 3 = C 2 = 3 \mathcal{S}_3=\mathcal{C}_2=3 S3=C2=3。
-
第四轮循环,计算误差 e 1 = 4 ⋅ w 1 − 0 = 0.4 < 0.5 e_1=4\cdot w_1-0=0.4<0.5 e1=4⋅w1−0=0.4<0.5,故 P 4 = 2 \mathcal{P}_4=2 P4=2, C = { 0 , 4 } \mathcal{C}=\{0,4\} C={0,4}, S 4 = C 2 = 4 \mathcal{S}_4=\mathcal{C}_2=4 S4=C2=4。
由上可知 j = 3 , 4 j=3,4 j=3,4 满足要求。
Prop 4. \text{Prop} \;4.\, Prop4. 在MSE下,要使 [ C 1 , ⋯ , C k + 1 , ⋯ , C n ] [\mathcal{C}_1,\cdots,\mathcal{C}_{k}+1,\cdots,\mathcal{C}_n] [C1,⋯,Ck+1,⋯,Cn] 尽可能接近 [ j ⋅ w 1 , j ⋅ w 2 , ⋯ , j ⋅ w n ] [j\cdot w_1,j\cdot w_2,\cdots,j\cdot w_n] [j⋅w1,j⋅w2,⋯,j⋅wn],应当有 k = arg max i ( j ⋅ w i − C i ) k=\argmax_i(j\cdot w_i-\mathcal{C}_i) k=argmaxi(j⋅wi−Ci)。
P r o o f . Proof.\; Proof. 注意到
Δ MSE = MSE a f t e r − MSE b e f o r e = 1 n [ ( j ⋅ w k − C k − 1 ) 2 − ( j ⋅ w k − C k ) 2 ] = 1 n [ 1 − 2 ( j ⋅ w k − C k ) ] \begin{aligned} \Delta \text{MSE}=\text{MSE}_{after}-\text{MSE}_{before}&=\frac1n[(j\cdot w_k-\mathcal{C}_k-1)^2-(j\cdot w_k-\mathcal{C}_k)^2] \\ &=\frac1n[1-2(j\cdot w_k-\mathcal{C}_k)] \end{aligned} ΔMSE=MSEafter−MSEbefore=n1[(j⋅wk−Ck−1)2−(j⋅wk−Ck)2]=n1[1−2(j⋅wk−Ck)]
由上式可知,要使 Δ MSE \Delta \text{MSE} ΔMSE 越小,应使 j ⋅ w k − C k j\cdot w_k-\mathcal{C}_k j⋅wk−Ck 越大,故 k = arg max i ( j ⋅ w i − C i ) k=\argmax_i(j\cdot w_i-\mathcal{C}_i) k=argmaxi(j⋅wi−Ci)。
Prop 5. \text{Prop} \;5.\, Prop5. 假定下标从 1 1 1 开始。若 w 1 = w 2 = ⋯ = w n = 1 / n w_1=w_2=\cdots=w_n=1/n w1=w2=⋯=wn=1/n,令 ∣ D ∣ = q ⋅ n + r |\mathcal{D}|=q\cdot n+r ∣D∣=q⋅n+r,其中 q q q 是商, r r r 是余数,则有
P = [ 1 , 2 , ⋯ , n ] ∗ q + [ 1 , 2 , ⋯ , r ] S = [ 1 , 1 , ⋯ , 1 ] + [ 2 , 2 , ⋯ , 2 ] + ⋯ + [ q , q , ⋯ , q ] ⏟ 每个列表的长度均为 n + [ q + 1 , q + 1 , ⋯ , q + 1 ] ⏟ 长度为 r C = [ q + 1 , q + 1 , ⋯ , q + 1 ⏟ r 个 , q , q , ⋯ , q ⏟ n − r 个 ] \begin{aligned} \mathcal{P}&=[1,2,\cdots,n] * q + [1,2,\cdots,r] \\ \mathcal{S}&=\underbrace{[1,1,\cdots,1] + [2,2,\cdots,2] + \cdots+[q,q,\cdots,q]}_{每个列表的长度均为 n}+\underbrace{[q+1,q+1,\cdots,q+1]}_{长度为r} \\ \mathcal{C}&=[\underbrace{q+1,q+1,\cdots,q+1}_{r个},\underbrace{q,q,\cdots,q}_{n-r个}] \end{aligned} PSC=[1,2,⋯,n]∗q+[1,2,⋯,r]=每个列表的长度均为n [1,1,⋯,1]+[2,2,⋯,2]+⋯+[q,q,⋯,q]+长度为r [q+1,q+1,⋯,q+1]=[r个 q+1,q+1,⋯,q+1,n−r个 q,q,⋯,q]
上述的 ∗ * ∗ 和 + + + 均是列表运算符。
P r o o f . Proof.\; Proof. 证明留给读者。
讨论:
Prop 3. \text{Prop} \;3.\, Prop3. 中提到了可能会出现下标越界的错误,为了避免这个错误,我们可以在得到 P , S \mathcal{P},\mathcal{S} P,S 数组后,对 S \mathcal{S} S 进行更新(假定下标从 1 1 1 开始):
S j = S j mod ( ∣ D P j ∣ + 1 ) , j = 1 , 2 , ⋯ , ∣ D ∣ \mathcal{S}_j=\mathcal{S}_j\;\text{mod}\; (|\mathcal{D}_{\mathcal{P}_j}|+1),\quad j=1,2,\cdots,|\mathcal{D}| Sj=Sjmod(∣DPj∣+1),j=1,2,⋯,∣D∣
例如某个数据集是 [ 1 , 2 , 3 , 4 , 5 ] [1,2,3,4,5] [1,2,3,4,5],如果要从这个数据集采样 8 8 8 个样本,则原先的算法会在采样第 6 6 6 个样本时抛出下标越界错误,修正后的算法的采样结果为 [ 1 , 2 , 3 , 4 , 5 , 1 , 2 , 3 ] [1,2,3,4,5,1,2,3] [1,2,3,4,5,1,2,3]。
为什么Megatron源码里没有规避这个错误但在使用的过程中却好像并没有遇到bug呢?注意到 self.datasets[dataset_idx]
实际上指向的是 megatron/data/gpt_dataset.py
中的 GPTDataset
类,在混合数据集场景下,Megatron会预先根据权重计算每个数据集所需要的样本数,然后根据这个样本数构建 GPTDataset
,而非根据document数去构建。所以,即使对于两个完全相同的数据集,当赋予它们的权重不同时,所得到的 GPTDataset
的长度也不同,这一点可以通过向 BlendableDataset
源码中加入以下代码来验证:
for i, dataset in enumerate(self.datasets):print(f"dataset {i}: {len(dataset)}")
因为 GPTDataset
的长度已经根据权重做出了相应的调整,所以绝大部分时候是不会出现bug的,但我们依然可以构造极端情形来触发bug。
考虑在训练脚本中提供两个完全相同的路径,但却赋予它们不同的权重,如下:
--train-data-path 0.001 /path/to/your/data_text_document 0.999 /path/to/your/data_text_document
然后在 BlendableDataset
源码中的 __getitem__
方法中固定索引,即:
def __getitem__(self, idx):idx = self.size - 1 # 意味着我们总是取BlendableDataset的最后一个样本dataset_idx = self.dataset_index[idx]sample_idx = self.dataset_sample_index[idx]return {"dataset_idx": dataset_idx,**self.datasets[dataset_idx][sample_idx],}
这样就可以稳定的触发下标越界的bug。
📝 注意到从
GPTDataset
中取出来的是sample,所以Megatron的混合算法实际上是以sample为单位的,而非以document为单位。
4. 进一步优化
根据 Prop 3. \text{Prop} \;3.\, Prop3. 和 Prop 5. \text{Prop} \;5.\, Prop5. 以及其他细节,我们有以下几个优化方向:
- 修复可能会出现的下标越界错误(可通过取余来实现)。
- 在等权重情形下加速混合(利用
numpy
)。 - 支持更多数据集进行混合(修改
uint8
为其他类型)。
假设相应的接口名为 make_blendable_dataset
,它接收两个形参:datasets
和 weights
。前者是一个二维列表,包含了要进行混合的数据集(每个数据集是一个一维列表),后者是一个一维列表,包含了每个数据集的权重。
使用Python进行实现:
from typing import List, Any, Union
import numpy as np
import random
from tqdm import tqdmdef make_blendable_dataset(datasets: List[List[Any]], weights: List[Union[float, int]]) -> List[Any]:num_datasets = len(datasets)assert num_datasets == len(weights), "The number of datasets must match the number of weights."# Shufflesize = 0for dataset in datasets:size += len(dataset)random.shuffle(dataset)# Normalize weightsweights = np.array(weights, dtype=np.float64)assert np.all(weights > 0), "All weights must be positive."weights /= weights.sum()# Determine if all weights are equalif np.ptp(weights) < 1e-5:q, r = divmod(size, num_datasets)dataset_index = np.concatenate([np.tile(np.arange(num_datasets, dtype=np.int16), q),np.arange(r, dtype=np.int16)])dataset_sample_index = np.concatenate([np.repeat(np.arange(q, dtype=np.int64), num_datasets),np.full(r, q, dtype=np.int64)])current_samples = np.full(num_datasets, q, dtype=np.int64)current_samples[:r] += 1else:dataset_index = np.zeros(size, dtype=np.int16)dataset_sample_index = np.zeros(size, dtype=np.int64)current_samples = np.zeros(num_datasets, dtype=np.int64)for sample_idx in tqdm(range(size), desc="Calculating error"):errors = weights * (sample_idx + 1) - current_samplesmax_error_index = np.argmax(errors)dataset_index[sample_idx] = max_error_indexdataset_sample_index[sample_idx] = current_samples[max_error_index]current_samples[max_error_index] += 1print(f"Ratios:")for i in range(num_datasets):print(f"input: {weights[i]}, achieved: {current_samples[i] / size}")# Blendres = []for i in tqdm(range(size), desc="Blending"):dataset_idx = dataset_index[i]sample_idx = dataset_sample_index[i] % len(datasets[dataset_idx])res.append(datasets[dataset_idx][sample_idx])return res