详解Megatron中的数据混合算法(BlendableDataset)

🧑‍💻 本文主要讲解Megatron早期版本中的数据混合算法。

目录

  • 1. 数据混合
  • 2. 源码解析
  • 3. 证明部分&讨论
  • 4. 进一步优化

1. 数据混合

在谈源码之前,我们有必要先了解一下Megatron中的数据混合思想。

给定 n n n 个数据集 D 1 , D 2 , ⋯ , D n \mathcal{D}_1,\mathcal{D}_2,\cdots,\mathcal{D}_n D1,D2,,Dn 和对应的 n n n 个权重 w 1 , w 2 , ⋯ , w n w_1,w_2,\cdots,w_n w1,w2,,wn,我们要按照这些权重去混合 n n n 个数据集,设混合后的数据集为 D \mathcal{D} D

Megatron假定:

  • ∣ D ∣ = ∑ i = 1 n ∣ D i ∣ |\mathcal{D}|=\sum_{i=1}^n|\mathcal{D}_i| D=i=1nDi。即混合后的数据集大小等于混合前的各数据集大小之和。
  • D \mathcal{D} D 中有 ∣ D ∣ ⋅ w i |\mathcal{D}|\cdot w_i Dwi 个样本来自 D i \mathcal{D}_i Di

那如何确定 D \mathcal{D} D 中到底有多少个样本是来自 D i \mathcal{D}_i Di 的呢?一种最直观的做法是,计算 ∣ D ∣ ⋅ w i |\mathcal{D}|\cdot w_i Dwi,然后进行取整,但这种操作无法保证所有取整后的 ∣ D ∣ ⋅ w i |\mathcal{D}|\cdot w_i Dwi 相加起来恰好是 ∣ D ∣ |\mathcal{D}| D 如果总和大于 ∣ D ∣ |\mathcal{D}| D,说明某些数据集被过采样了,应当减少相应数据集的采样数;如果总和小于 ∣ D ∣ |\mathcal{D}| D,说明某些数据集被欠采样了,应当增加相应数据集的采样数。可问题是,如何确定这些被过采样/欠采样的数据集呢?显然我们需要一个更加公平的算法。

我们可以把获取数据集 D \mathcal{D} D 看作是一个采样过程:一开始有 n n n 个数据源 { D i } i = 1 n \{\mathcal{D}_i\}_{i=1}^n {Di}i=1n,每一轮迭代,我们需要先从这 n n n 个数据源中选出一个数据源 D i \mathcal{D}_i Di,然后再从这个数据源中选出一个样本 S \mathcal{S} S 由于每一轮迭代只会选出一个样本,因此 ∣ D ∣ |\mathcal{D}| D 轮迭代结束后,我们便得到了 ∣ D ∣ |\mathcal{D}| D 个样本,这些样本构成了混合后的数据集 D \mathcal{D} D

每一轮迭代都会产生两个信息:要选取的数据源 D i \mathcal{D}_i Di,要从 D i \mathcal{D}_i Di 中选取的样本。我们可以考虑构造两个整数序列 P , S \mathcal{P},\mathcal{S} P,S,它们的长度均为 ∣ D ∣ |\mathcal{D}| D,含义如下:

  • P j \mathcal{P}_j Pj 代表的是第 j j j 轮迭代时,选取的数据源的下标。例如 P 10 = 3 \mathcal{P}_{10}=3 P10=3 意味着第 10 10 10 轮迭代选取的数据源是 D 3 \mathcal{D}_3 D3
  • S j \mathcal{S}_j Sj 代表的是第 j j j 轮迭代时,从数据源 D P j \mathcal{D}_{\mathcal{P}_j} DPj 选取的样本的下标。

由以上定义知, ∀ j \forall j j,都有 1 ≤ P j ≤ n 1\leq \mathcal{P}_j\leq n 1Pjn 1 ≤ S j ≤ ∣ D P j ⁣ ∣ 1\leq \mathcal{S}_j\leq|\mathcal{D}_{\mathcal{P}_j}\!| 1SjDPj(下标均从 1 1 1 开始)。

接下来的问题是,如何确定每一轮的 P j \mathcal{P}_j Pj S j \mathcal{S}_j Sj 呢?

先谈 P j \mathcal{P}_j Pj。因为是一个从 1 1 1 ∣ D ∣ |\mathcal{D}| D 的一个逐步采样过程,在第 j j j 轮迭代时,我们已经抽取了 j − 1 j-1 j1 个样本,接下来要确定第 j j j 个样本。根据Megatron的假定,在确定下来第 j j j 个样本后,这 j j j 个样本中应当有约 j ⋅ w i j\cdot w_i jwi 个样本是来自 D i \mathcal{D}_i Di 的。

考虑构造一个长度为 n n n 的序列 C \mathcal{C} C,该序列随着迭代不断更新。 C i \mathcal{C}_i Ci 代表当前已经从 D i \mathcal{D}_i Di 抽取了多少个样本。显然可知,第一轮迭代开始时,有 C i = 0 , i = 1 , 2 , ⋯ , n \mathcal{C}_i=0,\,i=1,2,\cdots,n Ci=0,i=1,2,,n。最后一轮迭代结束后,有 ∑ i = 1 n C i = ∣ D ∣ \sum_{i=1}^n\mathcal{C}_i=|\mathcal{D}| i=1nCi=D,并且

C i = { ∑ t = 1 j − 1 I ( P t = i ) , P j 确定前 ∑ t = 1 j I ( P t = i ) , P j 确定后 , ∀ i \mathcal{C}_i=\begin{cases} \sum_{t=1}^{j-1} I(\mathcal{P}_t=i),&\text{$\mathcal{P}_j$确定前} \\ \sum_{t=1}^{j} I(\mathcal{P}_t=i),&\text{$\mathcal{P}_j$确定后} \\ \end{cases},\quad \forall i Ci={t=1j1I(Pt=i),t=1jI(Pt=i),Pj确定前Pj确定后,i

回到对 P j \mathcal{P}_j Pj 的讨论中。假设在确定第 j j j 个样本前已经从 D i \mathcal{D}_i Di 中抽取了 C i \mathcal{C}_i Ci 个样本,在确定第 j j j 个样本后,诸 C i \mathcal{C}_i Ci有且仅有一个的值会增加 1 1 1,不妨记为 C k \mathcal{C}_k Ck,这个过程可以形容为

[ C 1 , ⋯ , C k , ⋯ , C n ] ⏟ 第 j 轮迭代开始时 → [ C 1 , ⋯ , C k + 1 , ⋯ , C n ] ⏟ 第 j 轮迭代结束时 [ j ⋅ w 1 , j ⋅ w 2 , ⋯ , j ⋅ w n ] ⏟ 理论值 \underbrace{[\mathcal{C}_1,\cdots,\mathcal{C}_k,\cdots,\mathcal{C}_n]}_{第j轮迭代开始时}\to\underbrace{[\mathcal{C}_1,\cdots,\mathcal{C}_{k}+1,\cdots,\mathcal{C}_n]}_{第j轮迭代结束时}\qquad \underbrace{[j\cdot w_1,j\cdot w_2,\cdots,j\cdot w_n]}_{理论值} j轮迭代开始时 [C1,,Ck,,Cn]j轮迭代结束时 [C1,,Ck+1,,Cn]理论值 [jw1,jw2,,jwn]

我们期望第 j j j 轮迭代结束时,诸 C i \mathcal{C}_i Ci 应当尽可能地接近理论值(在MSE下)。由于只能让其中一个 C k \mathcal{C}_k Ck 自增 1 1 1,显然有 k = arg max ⁡ i ( j ⋅ w i − C i ) k=\argmax_i(j\cdot w_i-\mathcal{C}_i) k=argmaxi(jwiCi)

再谈 S j \mathcal{S}_j Sj。在确定了数据源是 D k \mathcal{D}_k Dk 后,为了避免重复,我们应当做到不放回、随机地从中采样。如何做到这两点呢?我们可以在一开始就对 n n n 个数据源进行打乱,然后在采样的时候只需要从前往后进行,就可以做到以上两点。注意到 C i \mathcal{C}_i Ci 的值是从 0 0 0 开始,以步长为 1 1 1 依次递增,所以我们可以用每次更新完的 C i \mathcal{C}_i Ci 赋值给相应的 S j \mathcal{S}_j Sj,即 S j = 第 j 轮迭代结束时的 C i \mathcal{S}_j=第j轮迭代结束时的\mathcal{C}_i Sj=j轮迭代结束时的Ci

由此我们可以得到整个算法的伪代码:

2. 源码解析

Python部分:

class BlendableDataset(torch.utils.data.Dataset):def __init__(self, datasets, weights):self.datasets = datasetsnum_datasets = len(datasets)assert num_datasets == len(weights), "The number of datasets and weights must match."self.size = sum(len(dataset) for dataset in self.datasets)# Normalize weights.weights = np.array(weights, dtype=np.float64)sum_weights = np.sum(weights)assert sum_weights > 0.0, "Sum of weights must be positive."weights /= sum_weights# Build indices.start_time = time.time()assert num_datasets < 255, "Number of datasets must be less than 255."self.dataset_index = np.zeros(self.size, dtype=np.uint8)self.dataset_sample_index = np.zeros(self.size, dtype=np.int64)helpers.build_blending_indices(self.dataset_index,self.dataset_sample_index,weights,num_datasets,self.size,torch.distributed.get_rank() == 0,)print_rank_0(f'> elapsed time for building blendable dataset indices: 'f'{time.time() - start_time:.2f} sec')def __len__(self):return self.sizedef __getitem__(self, idx):dataset_idx = self.dataset_index[idx]sample_idx = self.dataset_sample_index[idx]return {"dataset_idx": dataset_idx,**self.datasets[dataset_idx][sample_idx],}

C++部分:

void build_blending_indices(py::array_t<uint8_t> &dataset_index,py::array_t<int64_t> &dataset_sample_index,const py::array_t<double> &weights,const int32_t num_datasets,const int64_t size,const bool verbose
) {/* Given multiple datasets and a weighting array, build samplessuch that it follows those weights. */if (verbose) {std::cout << "> building indices for blendable datasets ..." << std::endl;}// Get the pointer access without the checks.auto dataset_index_ptr = dataset_index.mutable_unchecked<1>();auto dataset_sample_index_ptr = dataset_sample_index.mutable_unchecked<1>();auto weights_ptr = weights.unchecked<1>();// Initialize buffer for number of samples used for each dataset.int64_t current_samples[num_datasets];for (int64_t i = 0; i < num_datasets; ++i) {current_samples[i] = 0;}// For each sample:for (int64_t sample_idx = 0; sample_idx < size; ++sample_idx) {// Determine where the max error in sampling is happening.auto sample_idx_double = std::max(static_cast<double>(sample_idx), 1.0);int64_t max_error_index = 0;double max_error = weights_ptr[0] * sample_idx_double - static_cast<double>(current_samples[0]);for (int64_t dataset_idx = 1; dataset_idx < num_datasets; ++dataset_idx) {double error = weights_ptr[dataset_idx] * sample_idx_double - static_cast<double>(current_samples[dataset_idx]);if (error > max_error) {max_error = error;max_error_index = dataset_idx;}}// Populate the indices.dataset_index_ptr[sample_idx] = static_cast<uint8_t>(max_error_index);dataset_sample_index_ptr[sample_idx] = current_samples[max_error_index];// Update the total samples.current_samples[max_error_index] += 1;}// Print infoif (verbose) {std::cout << " > sample ratios:" << std::endl;for (int64_t dataset_idx = 0; dataset_idx < num_datasets; ++dataset_idx) {auto ratio = static_cast<double>(current_samples[dataset_idx]) / static_cast<double>(size);std::cout << "   dataset " << dataset_idx << ", input: " << weights_ptr[dataset_idx] << ", achieved: " << ratio << std::endl;}}
}

具体的算法实现是在C++的函数中,我们先来看Python部分。

self.size 实际上就是 ∣ D ∣ |\mathcal{D}| D,即混合后的数据集大小(从后面的 __len__ 也能看出)。在构造函数中,首先会对 weights 进行归一化,然后声明 P , S \mathcal{P},\mathcal{S} P,S 两个数组。注意 self.dataset_index 实际上就是 P \mathcal{P} Pself.dataset_sample_index 实际上就是 S \mathcal{S} S。由于 P \mathcal{P} P 的数据类型是 uint8,这表明其中元素的范围是 [ 0 , 2 8 − 1 = 255 ] [0,2^8-1=255] [0,281=255],故 P \mathcal{P} P 最多能表示 256 256 256 个数据集,而源码中规定了参与混合的数据集个数必须严格少于 255 255 255博主不是很懂这一点,看懂的小伙伴可以在评论区留言)。

再来看C++部分。前五个形参分别是 P , S , { w i } i , n , ∣ D ∣ \mathcal{P},\mathcal{S},\{w_i\}_i,n,|\mathcal{D}| P,S,{wi}i,n,D

C \mathcal{C} C 数组会在该函数中进行声明并初始化。随后的两个嵌套 for 循环则是整个算法的核心流程,注意到这里的实现中,sample_idx(即 j j j)是从 0 0 0 开始的,而算法伪代码中的 j j j 是从 1 1 1 开始的,所以一开始要执行 j = max ⁡ ( j , 1 ) j=\max(j,1) j=max(j,1) 以确保 j j j 至少是 1 1 1但这样做有一个弊端就是前两轮的循环里, j j j 的值是相同的,和我们期望的每一轮里 j j j 值不同相违背,这是源码中的一个缺陷,实际上应该计算 ( j + 1 ) ⋅ w i − C i (j+1)\cdot w_i-\mathcal{C}_i (j+1)wiCi)。内层循环中的 error 实际上就是 j ⋅ w i − C i j\cdot w_i-\mathcal{C}_i jwiCi。此外,由于 j j j 是从 0 0 0 开始的,所以 C P j \mathcal{C}_{\mathcal{P}_j} CPj 的更新要放到最后执行。

一言以蔽之, j j j 1 1 1 开始,更新顺序为 P → C → S \mathcal{P}\to\mathcal{C}\to\mathcal{S} PCS j j j 0 0 0 开始,更新顺序为 P → S → C \mathcal{P}\to\mathcal{S}\to\mathcal{C} PSC

得到了 P , S \mathcal{P},\mathcal{S} P,S 数组后,我们便可得到混合后的数据集 D \mathcal{D} D

D j = D P j [ S j ] , j = 1 , 2 , ⋯ , ∣ D ∣ \mathcal{D}_j=\mathcal{D}_{\mathcal{P}_j}[\mathcal{S}_j],\quad j=1,2,\cdots,|\mathcal{D}| Dj=DPj[Sj],j=1,2,,D

其中 D i [ j ] \mathcal{D}_i[j] Di[j] 代表数据集 D i \mathcal{D}_i Di 中的第 j j j 个样本。

回到Python部分,__getitem__ 中传入的 idx 实际上就是 j j jself.datasets[dataset_idx][sample_idx] 实际上就是上述的 D P j [ S j ] \mathcal{D}_{\mathcal{P}_j}[\mathcal{S}_j] DPj[Sj]

3. 证明部分&讨论


Prop 1. \text{Prop} \;1.\, Prop1. 每一轮循环开始时所有误差加和为 1 1 1,即 ∑ i = 1 n e i = 1 \sum_{i=1}^n e_i=1 i=1nei=1,其中 e i ≜ j ⋅ w i − C i e_i\triangleq j\cdot w_i-\mathcal{C}_i eijwiCi

P r o o f . Proof.\; Proof. 注意到第 j j j 轮循环开始时,此时一共只采样了 j − 1 j-1 j1 个样本,所以 ∑ i = 1 n C i = j − 1 \sum_{i=1}^n\mathcal{C}_i=j-1 i=1nCi=j1,从而

∑ i = 1 n e i = ∑ i = 1 n ( j ⋅ w i − C i ) = j ⋅ ∑ i = 1 n w i − ∑ i = 1 n C i = j − ∑ i = 1 n C i = j − ( j − 1 ) = 1 \sum_{i=1}^n e_i=\sum_{i=1}^n (j\cdot w_i-\mathcal{C}_i)=j\cdot\sum_{i=1}^n w_i-\sum_{i=1}^n\mathcal{C}_i=j-\sum_{i=1}^n\mathcal{C}_i=j-(j-1)=1 i=1nei=i=1n(jwiCi)=ji=1nwii=1nCi=ji=1nCi=j(j1)=1

进一步可知,每一轮循环结束时所有误差加和为 0 0 0


Prop 2. \text{Prop} \;2.\, Prop2. 假定下标从 1 1 1 开始,且 n = 2 n=2 n=2(即只有两个数据源)。若 e 1 ≥ 0.5 e_1\geq 0.5 e10.5,则 P j = 1 \mathcal{P}_j=1 Pj=1,否则 P j = 2 \mathcal{P}_j=2 Pj=2

P r o o f . Proof.\; Proof. e 1 > 0.5 e_1>0.5 e1>0.5 的情况显然。当 e 1 = e 2 = 0.5 e_1=e_2=0.5 e1=e2=0.5 时, arg max ⁡ \argmax argmax 会优先挑选下标最小的,故此时 P j \mathcal{P}_j Pj 仍是 1 1 1


Prop 3. \text{Prop} \;3.\, Prop3. 假定下标从 1 1 1 开始。可能存在一组 { D i } i \{\mathcal{D}_i\}_i {Di}i { w i } i \{w_i\}_i {wi}i,使得经由上述算法得到的 P , S \mathcal{P},\mathcal{S} P,S 数组, ∃ j , s.t. S j > ∣ D P j ∣ \exists \,j,\,\text{s.t.}\;\,\mathcal{S}_j>|\mathcal{D}_{\mathcal{P}_j}| j,s.t.Sj>DPj,意味着 __getitem__ 会出现下标越界的错误。

P r o o f . Proof.\; Proof. 构造特殊情形即可。令 n = 2 n=2 n=2 ∣ D 1 ∣ = ∣ D 2 ∣ = 2 |\mathcal{D}_1|=|\mathcal{D}_2|=2 D1=D2=2 w 1 = 0.1 , w 2 = 0.9 w_1=0.1,\,w_2=0.9 w1=0.1,w2=0.9

∣ D 1 ∣ + ∣ D 2 ∣ = 4 |\mathcal{D}_1|+|\mathcal{D}_2|=4 D1+D2=4 可知,总共会有 4 4 4 轮循环。且理应有 1 ≤ P j , S j ≤ 2 , j = 1 , 2 , 3 , 4 1\leq \mathcal{P}_j,\mathcal{S}_j\leq 2,\,j=1,2,3,4 1Pj,Sj2,j=1,2,3,4

利用 Prop 2 \text{Prop} \;2 Prop2 快速计算:

  • 第一轮循环,计算误差 e 1 = 1 ⋅ w 1 − 0 = 0.1 < 0.5 e_1=1\cdot w_1-0=0.1<0.5 e1=1w10=0.1<0.5,故 P 1 = 2 \mathcal{P}_1=2 P1=2 C = { 0 , 1 } \mathcal{C}=\{0,1\} C={0,1} S 1 = C 2 = 1 \mathcal{S}_1=\mathcal{C}_2=1 S1=C2=1

  • 第二轮循环,计算误差 e 1 = 2 ⋅ w 1 − 0 = 0.2 < 0.5 e_1=2\cdot w_1-0=0.2<0.5 e1=2w10=0.2<0.5,故 P 2 = 2 \mathcal{P}_2=2 P2=2 C = { 0 , 2 } \mathcal{C}=\{0,2\} C={0,2} S 2 = C 2 = 2 \mathcal{S}_2=\mathcal{C}_2=2 S2=C2=2

  • 第三轮循环,计算误差 e 1 = 3 ⋅ w 1 − 0 = 0.3 < 0.5 e_1=3\cdot w_1-0=0.3<0.5 e1=3w10=0.3<0.5,故 P 3 = 2 \mathcal{P}_3=2 P3=2 C = { 0 , 3 } \mathcal{C}=\{0,3\} C={0,3} S 3 = C 2 = 3 \mathcal{S}_3=\mathcal{C}_2=3 S3=C2=3

  • 第四轮循环,计算误差 e 1 = 4 ⋅ w 1 − 0 = 0.4 < 0.5 e_1=4\cdot w_1-0=0.4<0.5 e1=4w10=0.4<0.5,故 P 4 = 2 \mathcal{P}_4=2 P4=2 C = { 0 , 4 } \mathcal{C}=\{0,4\} C={0,4} S 4 = C 2 = 4 \mathcal{S}_4=\mathcal{C}_2=4 S4=C2=4

由上可知 j = 3 , 4 j=3,4 j=3,4 满足要求。


Prop 4. \text{Prop} \;4.\, Prop4. 在MSE下,要使 [ C 1 , ⋯ , C k + 1 , ⋯ , C n ] [\mathcal{C}_1,\cdots,\mathcal{C}_{k}+1,\cdots,\mathcal{C}_n] [C1,,Ck+1,,Cn] 尽可能接近 [ j ⋅ w 1 , j ⋅ w 2 , ⋯ , j ⋅ w n ] [j\cdot w_1,j\cdot w_2,\cdots,j\cdot w_n] [jw1,jw2,,jwn],应当有 k = arg max ⁡ i ( j ⋅ w i − C i ) k=\argmax_i(j\cdot w_i-\mathcal{C}_i) k=argmaxi(jwiCi)

P r o o f . Proof.\; Proof. 注意到

Δ MSE = MSE a f t e r − MSE b e f o r e = 1 n [ ( j ⋅ w k − C k − 1 ) 2 − ( j ⋅ w k − C k ) 2 ] = 1 n [ 1 − 2 ( j ⋅ w k − C k ) ] \begin{aligned} \Delta \text{MSE}=\text{MSE}_{after}-\text{MSE}_{before}&=\frac1n[(j\cdot w_k-\mathcal{C}_k-1)^2-(j\cdot w_k-\mathcal{C}_k)^2] \\ &=\frac1n[1-2(j\cdot w_k-\mathcal{C}_k)] \end{aligned} ΔMSE=MSEafterMSEbefore=n1[(jwkCk1)2(jwkCk)2]=n1[12(jwkCk)]

由上式可知,要使 Δ MSE \Delta \text{MSE} ΔMSE 越小,应使 j ⋅ w k − C k j\cdot w_k-\mathcal{C}_k jwkCk 越大,故 k = arg max ⁡ i ( j ⋅ w i − C i ) k=\argmax_i(j\cdot w_i-\mathcal{C}_i) k=argmaxi(jwiCi)


Prop 5. \text{Prop} \;5.\, Prop5. 假定下标从 1 1 1 开始。若 w 1 = w 2 = ⋯ = w n = 1 / n w_1=w_2=\cdots=w_n=1/n w1=w2==wn=1/n,令 ∣ D ∣ = q ⋅ n + r |\mathcal{D}|=q\cdot n+r D=qn+r,其中 q q q 是商, r r r 是余数,则有

P = [ 1 , 2 , ⋯ , n ] ∗ q + [ 1 , 2 , ⋯ , r ] S = [ 1 , 1 , ⋯ , 1 ] + [ 2 , 2 , ⋯ , 2 ] + ⋯ + [ q , q , ⋯ , q ] ⏟ 每个列表的长度均为 n + [ q + 1 , q + 1 , ⋯ , q + 1 ] ⏟ 长度为 r C = [ q + 1 , q + 1 , ⋯ , q + 1 ⏟ r 个 , q , q , ⋯ , q ⏟ n − r 个 ] \begin{aligned} \mathcal{P}&=[1,2,\cdots,n] * q + [1,2,\cdots,r] \\ \mathcal{S}&=\underbrace{[1,1,\cdots,1] + [2,2,\cdots,2] + \cdots+[q,q,\cdots,q]}_{每个列表的长度均为 n}+\underbrace{[q+1,q+1,\cdots,q+1]}_{长度为r} \\ \mathcal{C}&=[\underbrace{q+1,q+1,\cdots,q+1}_{r个},\underbrace{q,q,\cdots,q}_{n-r个}] \end{aligned} PSC=[1,2,,n]q+[1,2,,r]=每个列表的长度均为n [1,1,,1]+[2,2,,2]++[q,q,,q]+长度为r [q+1,q+1,,q+1]=[r q+1,q+1,,q+1,nr q,q,,q]

上述的 ∗ * + + + 均是列表运算符

P r o o f . Proof.\; Proof. 证明留给读者。


讨论:

Prop 3. \text{Prop} \;3.\, Prop3. 中提到了可能会出现下标越界的错误,为了避免这个错误,我们可以在得到 P , S \mathcal{P},\mathcal{S} P,S 数组后,对 S \mathcal{S} S 进行更新(假定下标从 1 1 1 开始):

S j = S j mod ( ∣ D P j ∣ + 1 ) , j = 1 , 2 , ⋯ , ∣ D ∣ \mathcal{S}_j=\mathcal{S}_j\;\text{mod}\; (|\mathcal{D}_{\mathcal{P}_j}|+1),\quad j=1,2,\cdots,|\mathcal{D}| Sj=Sjmod(DPj+1),j=1,2,,D

例如某个数据集是 [ 1 , 2 , 3 , 4 , 5 ] [1,2,3,4,5] [1,2,3,4,5],如果要从这个数据集采样 8 8 8 个样本,则原先的算法会在采样第 6 6 6 个样本时抛出下标越界错误,修正后的算法的采样结果为 [ 1 , 2 , 3 , 4 , 5 , 1 , 2 , 3 ] [1,2,3,4,5,1,2,3] [1,2,3,4,5,1,2,3]

为什么Megatron源码里没有规避这个错误但在使用的过程中却好像并没有遇到bug呢?注意到 self.datasets[dataset_idx] 实际上指向的是 megatron/data/gpt_dataset.py 中的 GPTDataset 类,在混合数据集场景下,Megatron会预先根据权重计算每个数据集所需要的样本数,然后根据这个样本数构建 GPTDataset,而非根据document数去构建。所以,即使对于两个完全相同的数据集,当赋予它们的权重不同时,所得到的 GPTDataset 的长度也不同,这一点可以通过向 BlendableDataset 源码中加入以下代码来验证:

for i, dataset in enumerate(self.datasets):print(f"dataset {i}: {len(dataset)}")

因为 GPTDataset 的长度已经根据权重做出了相应的调整,所以绝大部分时候是不会出现bug的,但我们依然可以构造极端情形来触发bug。

考虑在训练脚本中提供两个完全相同的路径,但却赋予它们不同的权重,如下:

--train-data-path 0.001 /path/to/your/data_text_document 0.999 /path/to/your/data_text_document

然后在 BlendableDataset 源码中的 __getitem__ 方法中固定索引,即:

def __getitem__(self, idx):idx = self.size - 1  # 意味着我们总是取BlendableDataset的最后一个样本dataset_idx = self.dataset_index[idx]sample_idx = self.dataset_sample_index[idx]return {"dataset_idx": dataset_idx,**self.datasets[dataset_idx][sample_idx],}

这样就可以稳定的触发下标越界的bug。

📝 注意到从 GPTDataset 中取出来的是sample,所以Megatron的混合算法实际上是以sample为单位的,而非以document为单位。

4. 进一步优化

根据 Prop 3. \text{Prop} \;3.\, Prop3. Prop 5. \text{Prop} \;5.\, Prop5. 以及其他细节,我们有以下几个优化方向:

  • 修复可能会出现的下标越界错误(可通过取余来实现)。
  • 在等权重情形下加速混合(利用 numpy)。
  • 支持更多数据集进行混合(修改 uint8 为其他类型)。

假设相应的接口名为 make_blendable_dataset,它接收两个形参:datasetsweights。前者是一个二维列表,包含了要进行混合的数据集(每个数据集是一个一维列表),后者是一个一维列表,包含了每个数据集的权重。

使用Python进行实现:

from typing import List, Any, Union
import numpy as np
import random
from tqdm import tqdmdef make_blendable_dataset(datasets: List[List[Any]], weights: List[Union[float, int]]) -> List[Any]:num_datasets = len(datasets)assert num_datasets == len(weights), "The number of datasets must match the number of weights."# Shufflesize = 0for dataset in datasets:size += len(dataset)random.shuffle(dataset)# Normalize weightsweights = np.array(weights, dtype=np.float64)assert np.all(weights > 0), "All weights must be positive."weights /= weights.sum()# Determine if all weights are equalif np.ptp(weights) < 1e-5:q, r = divmod(size, num_datasets)dataset_index = np.concatenate([np.tile(np.arange(num_datasets, dtype=np.int16), q),np.arange(r, dtype=np.int16)])dataset_sample_index = np.concatenate([np.repeat(np.arange(q, dtype=np.int64), num_datasets),np.full(r, q, dtype=np.int64)])current_samples = np.full(num_datasets, q, dtype=np.int64)current_samples[:r] += 1else:dataset_index = np.zeros(size, dtype=np.int16)dataset_sample_index = np.zeros(size, dtype=np.int64)current_samples = np.zeros(num_datasets, dtype=np.int64)for sample_idx in tqdm(range(size), desc="Calculating error"):errors = weights * (sample_idx + 1) - current_samplesmax_error_index = np.argmax(errors)dataset_index[sample_idx] = max_error_indexdataset_sample_index[sample_idx] = current_samples[max_error_index]current_samples[max_error_index] += 1print(f"Ratios:")for i in range(num_datasets):print(f"input: {weights[i]}, achieved: {current_samples[i] / size}")# Blendres = []for i in tqdm(range(size), desc="Blending"):dataset_idx = dataset_index[i]sample_idx = dataset_sample_index[i] % len(datasets[dataset_idx])res.append(datasets[dataset_idx][sample_idx])return res

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/701133.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

给定一个边与边可能相交的多边形,求它的轮廓线

大家好&#xff0c;我是前端西瓜哥。 最近遇到一个需求&#xff0c;给定一个多边形&#xff08;边与边可能相交&#xff09;&#xff0c;求这个多边形的轮廓线。 需要注意的是&#xff0c;轮廓线多边形内不能有空洞&#xff0c;使用的不是常见的非零绕数规则&#xff08;nonze…

Java+SpringBoot,打造极致申报体验

✍✍计算机编程指导师 ⭐⭐个人介绍&#xff1a;自己非常喜欢研究技术问题&#xff01;专业做Java、Python、微信小程序、安卓、大数据、爬虫、Golang、大屏等实战项目。 ⛽⛽实战项目&#xff1a;有源码或者技术上的问题欢迎在评论区一起讨论交流&#xff01; ⚡⚡ Java实战 |…

2024全国水科技大会暨流域水环境治理与水生态修复论坛(六)

论坛召集人 冯慧娟 中国环境科学研究院流域中心研究员 刘 春 河北科技大学环境与工程学院院长、教授 一、会议背景 为深入贯彻“山水林田湖是一个生命共同体”的重要指示精神&#xff0c;大力实施生态优先绿色发展战略&#xff0c;积极践行人、水、自然和谐共生理念&…

opencascade在vs和qt下改变视图方向和设置线框模式

一.改变视图方向&#xff08;以顶部视图为例&#xff09; 1.在qt的界面代码中设置好 2.在view.h中设置好槽函数 3.在lzzcad.cpp中设置槽与信号的连接&#xff0c;并在工具栏上显示 4.在view.cpp中给出函数实现 5.给出快捷键实现方式 二.设置线框模式 同上&#xff0c;加入函数…

[深度学习]yolov9+deepsort+pyqt5实现目标追踪

【YOLOv9DeepSORTPyQt5追踪介绍】 随着人工智能技术的飞速发展&#xff0c;目标追踪在视频监控、自动驾驶等领域的应用日益广泛。其中&#xff0c;YOLOv9作为先进的目标检测算法&#xff0c;结合DeepSORT多目标追踪算法和PyQt5图形界面库&#xff0c;能够为用户提供高效、直观…

python-可视化篇-简单-条形图输出主要省份GDP排名情况

条形图输出主要省份GDP排名情况 代码 gdp广东:97277.77:107671.07 江苏:92595.40:99631.52 山东:76469.70:71067.5 浙江:56197.00:62353 河南:48055.90:54259.2 四川:40678.10:46615.82 湖北:39366.60:45828.31 湖南:36425.78:39752.12 河北:36010.30:35104.5 福建:35804.04:…

windows安装 RabbitMQ

首先打开 RabbitMQ 官网&#xff0c;点击 Get Started(开始) 点击 Download Installation(下载安装)。 这里提供了两种方式进行安装&#xff0c;我们使用第二种方法。 使用 chocolatey以管理用户身份使用官方安装程序 往下滑&#xff0c;第二种方法需要 Erlang 的依赖&#x…

mongoose httpserver浅析

文章目录 前言一、结构体及其功能二、函数MG_LOGmg_http_listenmg_mgr_poll question参考链接 前言 mongoose是一款基于C/C的网络库&#xff0c;可以实现TCP, UDP, HTTP, WebSocket, MQTT通讯。mongoose是的嵌入式网络程序更快、健壮&#xff0c;易于实现。 mongoose只有mong…

qt波位图

1&#xff0c;QPainter 绘制&#xff0c;先绘制这一堆蓝色的东西, 2&#xff0c;在用定时器&#xff1a;QTimer&#xff0c;配合绘制棕色的圆。用到取余&#xff0c;取整 #pragma once#include <QWidget> #include <QPaintEvent>#include <QTimer>QT_BEGIN_…

LangChain Agent v0.2.0简明教程 (上)

快速入门指南 – LangChain中文网 langchain源码剖析系列课程 九天玩转Langchain! 1. LangChain是什么2. LangChain Expression Language (LCEL)Runnable 接口3. Model I/O3.1 Prompt Templates3.2 Language Model3.3 Output ParsersUse case(Q&A with RAG)1. LangChain…

【踩坑】PyTorch中指定GPU不生效和GPU编号不一致问题

转载请注明出处&#xff1a;小锋学长生活大爆炸[xfxuezhang.cn] 指定GPU不生效问题 解释&#xff1a;就是使用os.environ["CUDA_VISIBLE_DEVICES"] "1"后&#xff0c;后面使用起来仍然是cuda0. 解决&#xff1a;在最开头就使用 import os os.environ[&…

sentinel整合nacos在gateway中实现限流

sentinel整合nacos在gateway中实现限流 一、应用层面完成网关整合nacos和sentinel实现限流 前沿 启动nacos与sentinel的jar的启动&#xff0c;这里不细讲 sentinel官网 https://github.com/alibaba/Sentinel/wiki/%E4%B8%BB%E9%A1%B5 sentinel 下载地址 https://github.com/…

车载电子电器架构 —— 电气架构开发计划

车载电子电器架构 —— 电气架构开发计划 我是穿拖鞋的汉子,魔都中坚持长期主义的汽车电子工程师。 老规矩,分享一段喜欢的文字,避免自己成为高知识低文化的工程师: 屏蔽力是信息过载时代一个人的特殊竞争力,任何消耗你的人和事,多看一眼都是你的不对。非必要不费力证明…

实现KingSCADA系统按钮弹窗出现位置随点击位置变化。

哈喽&#xff0c;你好啊&#xff0c;我是雷工&#xff01; 在用KingSCADA做项目时&#xff0c;当我们点击不同的控制按钮&#xff0c;都可以弹出对应的控制弹窗。 在常规不做设置的情况下弹窗都是出现在固定的位置&#xff0c;要么一直出现在左上角&#xff0c;要么一直出现在…

【Java】常用实用类及java集合框架(实验六)

目录 一、实验目的 二、实验内容 三、实验小结 3.1 常用实用类 3.2 Java集合框架 一、实验目的 1、掌握java常用类的方法 2、掌握String类与数值类型数据的相互转化 3、掌握正则表达式的应用 4、掌握常用集合的创建和操作方法 二、实验内容 1、菜单的内容如下&#x…

南邮概率统计与随机过程练习册答案

**南京邮电大学** **概率统计与随机过程练习册答案简介** 本文档是一份精心整理的南京邮电大学概率统计与随机过程课程的练习册答案集。它旨在为学习该课程的学生提供一个详尽的解题参考,帮助他们更好地理解和掌握概率论与统计学的基本概念和方法。 **内容概览:** - **章节…

抖音视频评论数据提取软件|抖音数据抓取工具

一、开发背景&#xff1a; 在业务需求中&#xff0c;我们经常需要下载抖音视频。然而&#xff0c;在网上找到的视频通常只能通过逐个复制链接的方式进行抓取和下载&#xff0c;这种操作非常耗时。我们希望能够通过关键词自动批量抓取并选择性地下载抖音视频。因此&#xff0c;为…

git 拉取远程分支到本地

背景&#xff1a; 我的 github 上的远程仓库上除了 main 分支外还提交了好几个别的分支&#xff0c;现在我换机器了&#xff0c;git clone 原仓库后只剩 main 分支&#xff0c;我要把其他分支拉下来到本地。 1. 查看所有远程remote分支 git branch -r 比如我这里&#xff1…

深入浅出:探究过完备字典矩阵

在数学和信号处理的世界里&#xff0c;我们总是在寻找表达数据的最佳方式。在这篇博文中&#xff0c;我们将探讨一种特殊的矩阵——过完备字典矩阵&#xff0c;这是线性代数和信号处理中一个非常有趣且实用的概念。 什么是过完备字典矩阵&#xff1f; 首先&#xff0c;我们先…

认识K8S

K8S K8S 的全称为 Kubernetes (K12345678S) 是一个跨主机容器编排工具 作用 用于自动部署、扩展和管理“容器化&#xff08;containerized&#xff09;应用程序”的开源系统。 可以理解成 K8S 是负责自动化运维管理多个容器化程序&#xff08;比如 Docker&#xff09;的集群…