详解Megatron中的数据混合算法(BlendableDataset)

🧑‍💻 本文主要讲解Megatron早期版本中的数据混合算法。

目录

  • 1. 数据混合
  • 2. 源码解析
  • 3. 证明部分&讨论
  • 4. 进一步优化

1. 数据混合

在谈源码之前,我们有必要先了解一下Megatron中的数据混合思想。

给定 n n n 个数据集 D 1 , D 2 , ⋯   , D n \mathcal{D}_1,\mathcal{D}_2,\cdots,\mathcal{D}_n D1,D2,,Dn 和对应的 n n n 个权重 w 1 , w 2 , ⋯   , w n w_1,w_2,\cdots,w_n w1,w2,,wn,我们要按照这些权重去混合 n n n 个数据集,设混合后的数据集为 D \mathcal{D} D

Megatron假定:

  • ∣ D ∣ = ∑ i = 1 n ∣ D i ∣ |\mathcal{D}|=\sum_{i=1}^n|\mathcal{D}_i| D=i=1nDi。即混合后的数据集大小等于混合前的各数据集大小之和。
  • D \mathcal{D} D 中有 ∣ D ∣ ⋅ w i |\mathcal{D}|\cdot w_i Dwi 个样本来自 D i \mathcal{D}_i Di

那如何确定 D \mathcal{D} D 中到底有多少个样本是来自 D i \mathcal{D}_i Di 的呢?一种最直观的做法是,计算 ∣ D ∣ ⋅ w i |\mathcal{D}|\cdot w_i Dwi,然后进行取整,但这种操作无法保证所有取整后的 ∣ D ∣ ⋅ w i |\mathcal{D}|\cdot w_i Dwi 相加起来恰好是 ∣ D ∣ |\mathcal{D}| D 如果总和大于 ∣ D ∣ |\mathcal{D}| D,说明某些数据集被过采样了,应当减少相应数据集的采样数;如果总和小于 ∣ D ∣ |\mathcal{D}| D,说明某些数据集被欠采样了,应当增加相应数据集的采样数。可问题是,如何确定这些被过采样/欠采样的数据集呢?显然我们需要一个更加公平的算法。

我们可以把获取数据集 D \mathcal{D} D 看作是一个采样过程:一开始有 n n n 个数据源 { D i } i = 1 n \{\mathcal{D}_i\}_{i=1}^n {Di}i=1n,每一轮迭代,我们需要先从这 n n n 个数据源中选出一个数据源 D i \mathcal{D}_i Di,然后再从这个数据源中选出一个样本 S \mathcal{S} S 由于每一轮迭代只会选出一个样本,因此 ∣ D ∣ |\mathcal{D}| D 轮迭代结束后,我们便得到了 ∣ D ∣ |\mathcal{D}| D 个样本,这些样本构成了混合后的数据集 D \mathcal{D} D

每一轮迭代都会产生两个信息:要选取的数据源 D i \mathcal{D}_i Di,要从 D i \mathcal{D}_i Di 中选取的样本。我们可以考虑构造两个整数序列 P , S \mathcal{P},\mathcal{S} P,S,它们的长度均为 ∣ D ∣ |\mathcal{D}| D,含义如下:

  • P j \mathcal{P}_j Pj 代表的是第 j j j 轮迭代时,选取的数据源的下标。例如 P 10 = 3 \mathcal{P}_{10}=3 P10=3 意味着第 10 10 10 轮迭代选取的数据源是 D 3 \mathcal{D}_3 D3
  • S j \mathcal{S}_j Sj 代表的是第 j j j 轮迭代时,从数据源 D P j \mathcal{D}_{\mathcal{P}_j} DPj 选取的样本的下标。

由以上定义知, ∀ j \forall j j,都有 1 ≤ P j ≤ n 1\leq \mathcal{P}_j\leq n 1Pjn 1 ≤ S j ≤ ∣ D P j  ⁣ ∣ 1\leq \mathcal{S}_j\leq|\mathcal{D}_{\mathcal{P}_j}\!| 1SjDPj(下标均从 1 1 1 开始)。

接下来的问题是,如何确定每一轮的 P j \mathcal{P}_j Pj S j \mathcal{S}_j Sj 呢?

先谈 P j \mathcal{P}_j Pj。因为是一个从 1 1 1 ∣ D ∣ |\mathcal{D}| D 的一个逐步采样过程,在第 j j j 轮迭代时,我们已经抽取了 j − 1 j-1 j1个样本,接下来要确定第 j j j 个样本。根据Megatron的假定,在确定下来第 j j j 个样本后,这 j j j 个样本中应当有约 j ⋅ w i j\cdot w_i jwi 个样本是来自 D i \mathcal{D}_i Di 的。

考虑构造一个长度为 n n n 的序列 C \mathcal{C} C,该序列随着迭代不断更新。 C i \mathcal{C}_i Ci 代表当前已经从 D i \mathcal{D}_i Di 抽取了多少个样本。显然可知,第一轮迭代开始时,有 C i = 0 ,   i = 1 , 2 , ⋯   , n \mathcal{C}_i=0,\,i=1,2,\cdots,n Ci=0,i=1,2,,n。最后一轮迭代结束后,有 ∑ i = 1 n C i = ∣ D ∣ \sum_{i=1}^n\mathcal{C}_i=|\mathcal{D}| i=1nCi=D,并且

C i = { ∑ t = 1 j − 1 I ( P t = i ) , P j 确定前 ∑ t = 1 j I ( P t = i ) , P j 确定后 , ∀ i \mathcal{C}_i=\begin{cases} \sum_{t=1}^{j-1} I(\mathcal{P}_t=i),&\text{$\mathcal{P}_j$确定前} \\ \sum_{t=1}^{j} I(\mathcal{P}_t=i),&\text{$\mathcal{P}_j$确定后} \\ \end{cases},\quad \forall i Ci={t=1j1I(Pt=i),t=1jI(Pt=i),Pj确定前Pj确定后,i

回到对 P j \mathcal{P}_j Pj 的讨论中。假设在确定第 j j j 个样本前已经从 D i \mathcal{D}_i Di 中抽取了 C i \mathcal{C}_i Ci 个样本,在确定第 j j j 个样本后,诸 C i \mathcal{C}_i Ci有且仅有一个的值会增加 1 1 1,不妨记为 C k \mathcal{C}_k Ck,这个过程可以形容为

[ C 1 , ⋯   , C k , ⋯   , C n ] ⏟ 第 j 轮迭代开始时 → [ C 1 , ⋯   , C k + 1 , ⋯   , C n ] ⏟ 第 j 轮迭代结束时 [ j ⋅ w 1 , j ⋅ w 2 , ⋯   , j ⋅ w n ] ⏟ 理论值 \underbrace{[\mathcal{C}_1,\cdots,\mathcal{C}_k,\cdots,\mathcal{C}_n]}_{第j轮迭代开始时}\to\underbrace{[\mathcal{C}_1,\cdots,\mathcal{C}_{k}+1,\cdots,\mathcal{C}_n]}_{第j轮迭代结束时}\qquad \underbrace{[j\cdot w_1,j\cdot w_2,\cdots,j\cdot w_n]}_{理论值} j轮迭代开始时 [C1,,Ck,,Cn]j轮迭代结束时 [C1,,Ck+1,,Cn]理论值 [jw1,jw2,,jwn]

我们期望第 j j j 轮迭代结束时,诸 C i \mathcal{C}_i Ci 应当尽可能地接近理论值(在MSE下)。由于只能让其中一个 C k \mathcal{C}_k Ck 自增 1 1 1,显然有 k = arg max ⁡ i ( j ⋅ w i − C i ) k=\argmax_i(j\cdot w_i-\mathcal{C}_i) k=argmaxi(jwiCi)

再谈 S j \mathcal{S}_j Sj。在确定了数据源是 D k \mathcal{D}_k Dk 后,为了避免重复,我们应当做到不放回、随机地从中采样。如何做到这两点呢?我们可以在一开始就对 n n n 个数据源进行打乱,然后在采样的时候只需要从前往后进行,就可以做到以上两点。注意到 C i \mathcal{C}_i Ci 的值是从 0 0 0 开始,以步长为 1 1 1 依次递增,所以我们可以用每次更新完的 C i \mathcal{C}_i Ci 赋值给相应的 S j \mathcal{S}_j Sj,即 S j = 第 j 轮迭代结束时的 C i \mathcal{S}_j=第j轮迭代结束时的\mathcal{C}_i Sj=j轮迭代结束时的Ci

由此我们可以得到整个算法的伪代码:

2. 源码解析

Python部分:

class BlendableDataset(torch.utils.data.Dataset):
    def __init__(self, datasets, weights):
        self.datasets = datasets
        num_datasets = len(datasets)
        assert num_datasets == len(weights), "The number of datasets and weights must match."

        self.size = sum(len(dataset) for dataset in self.datasets)

        # Normalize weights.
        weights = np.array(weights, dtype=np.float64)
        sum_weights = np.sum(weights)
        assert sum_weights > 0.0, "Sum of weights must be positive."
        weights /= sum_weights

        # Build indices.
        start_time = time.time()
        assert num_datasets < 255, "Number of datasets must be less than 255."
        self.dataset_index = np.zeros(self.size, dtype=np.uint8)
        self.dataset_sample_index = np.zeros(self.size, dtype=np.int64)

        helpers.build_blending_indices(
            self.dataset_index,
            self.dataset_sample_index,
            weights,
            num_datasets,
            self.size,
            torch.distributed.get_rank() == 0,
        )
        print_rank_0(f'> elapsed time for building blendable dataset indices: '
                     f'{time.time() - start_time:.2f} sec')

    def __len__(self):
        return self.size

    def __getitem__(self, idx):
        dataset_idx = self.dataset_index[idx]
        sample_idx = self.dataset_sample_index[idx]
        return {
            "dataset_idx": dataset_idx,
            **self.datasets[dataset_idx][sample_idx],
        }

C++部分:

void build_blending_indices(
    py::array_t<uint8_t> &dataset_index,
    py::array_t<int64_t> &dataset_sample_index,
    const py::array_t<double> &weights,
    const int32_t num_datasets,
    const int64_t size,
    const bool verbose
) {
    /* Given multiple datasets and a weighting array, build samples
       such that it follows those weights. */

    if (verbose) {
        std::cout << "> building indices for blendable datasets ..." << std::endl;
    }

    // Get the pointer access without the checks.
    auto dataset_index_ptr = dataset_index.mutable_unchecked<1>();
    auto dataset_sample_index_ptr = dataset_sample_index.mutable_unchecked<1>();
    auto weights_ptr = weights.unchecked<1>();

    // Initialize buffer for number of samples used for each dataset.
    int64_t current_samples[num_datasets];
    for (int64_t i = 0; i < num_datasets; ++i) {
        current_samples[i] = 0;
    }

    // For each sample:
    for (int64_t sample_idx = 0; sample_idx < size; ++sample_idx) {

        // Determine where the max error in sampling is happening.
        auto sample_idx_double = std::max(static_cast<double>(sample_idx), 1.0);
        int64_t max_error_index = 0;
        double max_error = weights_ptr[0] * sample_idx_double - static_cast<double>(current_samples[0]);
        for (int64_t dataset_idx = 1; dataset_idx < num_datasets; ++dataset_idx) {
            double error = weights_ptr[dataset_idx] * sample_idx_double - static_cast<double>(current_samples[dataset_idx]);
            if (error > max_error) {
                max_error = error;
                max_error_index = dataset_idx;
            }
        }

        // Populate the indices.
        dataset_index_ptr[sample_idx] = static_cast<uint8_t>(max_error_index);
        dataset_sample_index_ptr[sample_idx] = current_samples[max_error_index];

        // Update the total samples.
        current_samples[max_error_index] += 1;
    }

    // Print info
    if (verbose) {
        std::cout << " > sample ratios:" << std::endl;
        for (int64_t dataset_idx = 0; dataset_idx < num_datasets; ++dataset_idx) {
            auto ratio = static_cast<double>(current_samples[dataset_idx]) / static_cast<double>(size);
            std::cout << "   dataset " << dataset_idx << ", input: " << weights_ptr[dataset_idx] << ", achieved: " << ratio << std::endl;
        }
    }
}

具体的算法实现是在C++的函数中,我们先来看Python部分。

self.size 实际上就是 ∣ D ∣ |\mathcal{D}| D,即混合后的数据集大小(从后面的 __len__ 也能看出)。在构造函数中,首先会对 weights 进行归一化,然后声明 P , S \mathcal{P},\mathcal{S} P,S 两个数组。注意 self.dataset_index 实际上就是 P \mathcal{P} Pself.dataset_sample_index 实际上就是 S \mathcal{S} S。由于 P \mathcal{P} P 的数据类型是 uint8,这表明其中元素的范围是 [ 0 , 2 8 − 1 = 255 ] [0,2^8-1=255] [0,281=255],故 P \mathcal{P} P 最多能表示 256 256 256 个数据集,而源码中规定了参与混合的数据集个数必须严格少于 255 255 255博主不是很懂这一点,看懂的小伙伴可以在评论区留言)。

再来看C++部分。前五个形参分别是 P , S , { w i } i , n , ∣ D ∣ \mathcal{P},\mathcal{S},\{w_i\}_i,n,|\mathcal{D}| P,S,{wi}i,n,D

C \mathcal{C} C 数组会在该函数中进行声明并初始化。随后的两个嵌套 for 循环则是整个算法的核心流程,注意到这里的实现中,sample_idx(即 j j j)是从 0 0 0 开始的,而算法伪代码中的 j j j 是从 1 1 1 开始的,所以一开始要执行 j = max ⁡ ( j , 1 ) j=\max(j,1) j=max(j,1) 以确保 j j j 至少是 1 1 1但这样做有一个弊端就是前两轮的循环里, j j j 的值是相同的,和我们期望的每一轮里 j j j 值不同相违背,这是源码中的一个缺陷,实际上应该计算 ( j + 1 ) ⋅ w i − C i (j+1)\cdot w_i-\mathcal{C}_i (j+1)wiCi)。内层循环中的 error 实际上就是 j ⋅ w i − C i j\cdot w_i-\mathcal{C}_i jwiCi。此外,由于 j j j 是从 0 0 0 开始的,所以 C P j \mathcal{C}_{\mathcal{P}_j} CPj 的更新要放到最后执行。

一言以蔽之, j j j 1 1 1 开始,更新顺序为 P → C → S \mathcal{P}\to\mathcal{C}\to\mathcal{S} PCS j j j 0 0 0 开始,更新顺序为 P → S → C \mathcal{P}\to\mathcal{S}\to\mathcal{C} PSC

得到了 P , S \mathcal{P},\mathcal{S} P,S 数组后,我们便可得到混合后的数据集 D \mathcal{D} D

D j = D P j [ S j ] , j = 1 , 2 , ⋯   , ∣ D ∣ \mathcal{D}_j=\mathcal{D}_{\mathcal{P}_j}[\mathcal{S}_j],\quad j=1,2,\cdots,|\mathcal{D}| Dj=DPj[Sj],j=1,2,,D

其中 D i [ j ] \mathcal{D}_i[j] Di[j] 代表数据集 D i \mathcal{D}_i Di 中的第 j j j 个样本。

回到Python部分,__getitem__ 中传入的 idx 实际上就是 j j jself.datasets[dataset_idx][sample_idx] 实际上就是上述的 D P j [ S j ] \mathcal{D}_{\mathcal{P}_j}[\mathcal{S}_j] DPj[Sj]

3. 证明部分&讨论


Prop    1.   \text{Prop} \;1.\, Prop1. 每一轮循环开始时所有误差加和为 1 1 1,即 ∑ i = 1 n e i = 1 \sum_{i=1}^n e_i=1 i=1nei=1,其中 e i ≜ j ⋅ w i − C i e_i\triangleq j\cdot w_i-\mathcal{C}_i eijwiCi

P r o o f .    Proof.\; Proof. 注意到第 j j j 轮循环开始时,此时一共只采样了 j − 1 j-1 j1 个样本,所以 ∑ i = 1 n C i = j − 1 \sum_{i=1}^n\mathcal{C}_i=j-1 i=1nCi=j1,从而

∑ i = 1 n e i = ∑ i = 1 n ( j ⋅ w i − C i ) = j ⋅ ∑ i = 1 n w i − ∑ i = 1 n C i = j − ∑ i = 1 n C i = j − ( j − 1 ) = 1 \sum_{i=1}^n e_i=\sum_{i=1}^n (j\cdot w_i-\mathcal{C}_i)=j\cdot\sum_{i=1}^n w_i-\sum_{i=1}^n\mathcal{C}_i=j-\sum_{i=1}^n\mathcal{C}_i=j-(j-1)=1 i=1nei=i=1n(jwiCi)=ji=1nwii=1nCi=ji=1nCi=j(j1)=1

进一步可知,每一轮循环结束时所有误差加和为 0 0 0


Prop    2.   \text{Prop} \;2.\, Prop2. 假定下标从 1 1 1 开始,且 n = 2 n=2 n=2(即只有两个数据源)。若 e 1 ≥ 0.5 e_1\geq 0.5 e10.5,则 P j = 1 \mathcal{P}_j=1 Pj=1,否则 P j = 2 \mathcal{P}_j=2 Pj=2

P r o o f .    Proof.\; Proof. e 1 > 0.5 e_1>0.5 e1>0.5 的情况显然。当 e 1 = e 2 = 0.5 e_1=e_2=0.5 e1=e2=0.5 时, arg max ⁡ \argmax argmax 会优先挑选下标最小的,故此时 P j \mathcal{P}_j Pj 仍是 1 1 1


Prop    3.   \text{Prop} \;3.\, Prop3. 假定下标从 1 1 1 开始。可能存在一组 { D i } i \{\mathcal{D}_i\}_i {Di}i { w i } i \{w_i\}_i {wi}i,使得经由上述算法得到的 P , S \mathcal{P},\mathcal{S} P,S 数组, ∃   j ,   s.t.     S j > ∣ D P j ∣ \exists \,j,\,\text{s.t.}\;\,\mathcal{S}_j>|\mathcal{D}_{\mathcal{P}_j}| j,s.t.Sj>DPj,意味着 __getitem__ 会出现下标越界的错误。

P r o o f .    Proof.\; Proof. 构造特殊情形即可。令 n = 2 n=2 n=2 ∣ D 1 ∣ = ∣ D 2 ∣ = 2 |\mathcal{D}_1|=|\mathcal{D}_2|=2 D1=D2=2 w 1 = 0.1 ,   w 2 = 0.9 w_1=0.1,\,w_2=0.9 w1=0.1,w2=0.9

∣ D 1 ∣ + ∣ D 2 ∣ = 4 |\mathcal{D}_1|+|\mathcal{D}_2|=4 D1+D2=4 可知,总共会有 4 4 4 轮循环。且理应有 1 ≤ P j , S j ≤ 2 ,   j = 1 , 2 , 3 , 4 1\leq \mathcal{P}_j,\mathcal{S}_j\leq 2,\,j=1,2,3,4 1Pj,Sj2,j=1,2,3,4

利用 Prop    2 \text{Prop} \;2 Prop2 快速计算:

  • 第一轮循环,计算误差 e 1 = 1 ⋅ w 1 − 0 = 0.1 < 0.5 e_1=1\cdot w_1-0=0.1<0.5 e1=1w10=0.1<0.5,故 P 1 = 2 \mathcal{P}_1=2 P1=2 C = { 0 , 1 } \mathcal{C}=\{0,1\} C={0,1} S 1 = C 2 = 1 \mathcal{S}_1=\mathcal{C}_2=1 S1=C2=1

  • 第二轮循环,计算误差 e 1 = 2 ⋅ w 1 − 0 = 0.2 < 0.5 e_1=2\cdot w_1-0=0.2<0.5 e1=2w10=0.2<0.5,故 P 2 = 2 \mathcal{P}_2=2 P2=2 C = { 0 , 2 } \mathcal{C}=\{0,2\} C={0,2} S 2 = C 2 = 2 \mathcal{S}_2=\mathcal{C}_2=2 S2=C2=2

  • 第三轮循环,计算误差 e 1 = 3 ⋅ w 1 − 0 = 0.3 < 0.5 e_1=3\cdot w_1-0=0.3<0.5 e1=3w10=0.3<0.5,故 P 3 = 2 \mathcal{P}_3=2 P3=2 C = { 0 , 3 } \mathcal{C}=\{0,3\} C={0,3} S 3 = C 2 = 3 \mathcal{S}_3=\mathcal{C}_2=3 S3=C2=3

  • 第三轮循环,计算误差 e 1 = 4 ⋅ w 1 − 0 = 0.4 < 0.5 e_1=4\cdot w_1-0=0.4<0.5 e1=4w10=0.4<0.5,故 P 4 = 2 \mathcal{P}_4=2 P4=2 C = { 0 , 4 } \mathcal{C}=\{0,4\} C={0,4} S 4 = C 2 = 4 \mathcal{S}_4=\mathcal{C}_2=4 S4=C2=4

由上可知 j = 3 , 4 j=3,4 j=3,4 满足要求。


Prop    4.   \text{Prop} \;4.\, Prop4. 在MSE下,要使 [ C 1 , ⋯   , C k + 1 , ⋯   , C n ] [\mathcal{C}_1,\cdots,\mathcal{C}_{k}+1,\cdots,\mathcal{C}_n] [C1,,Ck+1,,Cn] 尽可能接近 [ j ⋅ w 1 , j ⋅ w 2 , ⋯   , j ⋅ w n ] [j\cdot w_1,j\cdot w_2,\cdots,j\cdot w_n] [jw1,jw2,,jwn],应当有 k = arg max ⁡ i ( j ⋅ w i − C i ) k=\argmax_i(j\cdot w_i-\mathcal{C}_i) k=argmaxi(jwiCi)

P r o o f .    Proof.\; Proof. 注意到

Δ MSE = MSE a f t e r − MSE b e f o r e = 1 n [ ( j ⋅ w k − C k − 1 ) 2 − ( j ⋅ w k − C k ) 2 ] = 1 n [ 1 − 2 ( j ⋅ w k − C k ) ] \begin{aligned} \Delta \text{MSE}=\text{MSE}_{after}-\text{MSE}_{before}&=\frac1n[(j\cdot w_k-\mathcal{C}_k-1)^2-(j\cdot w_k-\mathcal{C}_k)^2] \\ &=\frac1n[1-2(j\cdot w_k-\mathcal{C}_k)] \end{aligned} ΔMSE=MSEafterMSEbefore=n1[(jwkCk1)2(jwkCk)2]=n1[12(jwkCk)]

由上式可知,要使 Δ MSE \Delta \text{MSE} ΔMSE 越小,应使 j ⋅ w k − C k j\cdot w_k-\mathcal{C}_k jwkCk 越大,故 k = arg max ⁡ i ( j ⋅ w i − C i ) k=\argmax_i(j\cdot w_i-\mathcal{C}_i) k=argmaxi(jwiCi)


Prop    5.   \text{Prop} \;5.\, Prop5. 假定下标从 1 1 1 开始。若 w 1 = w 2 = ⋯ = w n = 1 / n w_1=w_2=\cdots=w_n=1/n w1=w2==wn=1/n,令 ∣ D ∣ = q ⋅ n + r |\mathcal{D}|=q\cdot n+r D=qn+r,其中 q q q 是商, r r r 是余数,则有

P = [ 1 , 2 , ⋯   , n ] ∗ q + [ 1 , 2 , ⋯   , r ] S = [ 1 , 1 , ⋯   , 1 ] + [ 2 , 2 , ⋯   , 2 ] + ⋯ + [ q , q , ⋯   , q ] ⏟ 每个列表的长度均为 n + [ q + 1 , q + 1 , ⋯   , q + 1 ] ⏟ 长度为 r C = [ q + 1 , q + 1 , ⋯   , q + 1 ⏟ r 个 , q , q , ⋯   , q ⏟ n − r 个 ] \begin{aligned} \mathcal{P}&=[1,2,\cdots,n] * q + [1,2,\cdots,r] \\ \mathcal{S}&=\underbrace{[1,1,\cdots,1] + [2,2,\cdots,2] + \cdots+[q,q,\cdots,q]}_{每个列表的长度均为 n}+\underbrace{[q+1,q+1,\cdots,q+1]}_{长度为r} \\ \mathcal{C}&=[\underbrace{q+1,q+1,\cdots,q+1}_{r个},\underbrace{q,q,\cdots,q}_{n-r个}] \end{aligned} PSC=[1,2,,n]q+[1,2,,r]=每个列表的长度均为n [1,1,,1]+[2,2,,2]++[q,q,,q]+长度为r [q+1,q+1,,q+1]=[r q+1,q+1,,q+1,nr q,q,,q]

上述的 ∗ * + + + 均是列表运算符

P r o o f .    Proof.\; Proof. 证明留给读者。


讨论:

Prop    3.   \text{Prop} \;3.\, Prop3. 中提到了可能会出现下标越界的错误,为了避免这个错误,我们可以在得到 P , S \mathcal{P},\mathcal{S} P,S 数组后,对 S \mathcal{S} S 进行更新(假定下标从 1 1 1 开始):

S j = S j    mod    ( ∣ D P j ∣ + 1 ) , j = 1 , 2 , ⋯   , ∣ D ∣ \mathcal{S}_j=\mathcal{S}_j\;\text{mod}\; (|\mathcal{D}_{\mathcal{P}_j}|+1),\quad j=1,2,\cdots,|\mathcal{D}| Sj=Sjmod(DPj+1),j=1,2,,D

例如某个数据集是 [ 1 , 2 , 3 , 4 , 5 ] [1,2,3,4,5] [1,2,3,4,5],如果要从这个数据集采样 8 8 8 个样本,则原先的算法会在采样第 6 6 6 个样本时抛出下标越界错误,修正后的算法的采样结果为 [ 1 , 2 , 3 , 4 , 5 , 1 , 2 , 3 ] [1,2,3,4,5,1,2,3] [1,2,3,4,5,1,2,3]

为什么Megatron源码里没有规避这个错误但在使用的过程中却好像并没有遇到bug呢?注意到 self.datasets[dataset_idx] 实际上指向的是 megatron/data/gpt_dataset.py 中的 GPTDataset 类,在混合数据集场景下,Megatron会预先根据权重计算每个数据集所需要的样本数,然后根据这个样本数构建 GPTDataset,而非根据document数去构建。所以,即使对于两个完全相同的数据集,当赋予它们的权重不同时,所得到的 GPTDataset 的长度也不同,这一点可以通过向 BlendableDataset 源码中加入以下代码来验证:

for i, dataset in enumerate(self.datasets):
    print(f"dataset {i}: {len(dataset)}")

因为 GPTDataset 的长度已经根据权重做出了相应的调整,所以绝大部分时候是不会出现bug的,但我们依然可以构造极端情形来触发bug。

考虑在训练脚本中提供两个完全相同的路径,但却赋予它们不同的权重,如下:

--train-data-path 0.001 /path/to/your/data_text_document 0.999 /path/to/your/data_text_document

然后在 BlendableDataset 源码中的 __getitem__ 方法中固定索引,即:

def __getitem__(self, idx):
	idx = self.size - 1  # 意味着我们总是取BlendableDataset的最后一个样本
    dataset_idx = self.dataset_index[idx]
    sample_idx = self.dataset_sample_index[idx]
    return {
        "dataset_idx": dataset_idx,
        **self.datasets[dataset_idx][sample_idx],
    }

这样就可以稳定的触发下标越界的bug。

📝 注意到从 GPTDataset 中取出来的是sample,所以Megatron的混合算法实际上是以sample为单位的,而非以document为单位。

4. 进一步优化

根据 Prop    3.   \text{Prop} \;3.\, Prop3. Prop    5.   \text{Prop} \;5.\, Prop5. 以及其他细节,我们有以下几个优化方向:

  • 修复可能会出现的下标越界错误(可通过取余来实现)。
  • 在等权重情形下加速混合(利用 numpy)。
  • 支持更多数据集进行混合(修改 uint8 为其他类型)。

假设相应的接口名为 make_blendable_dataset,它接收两个形参:datasetsweights。前者是一个二维列表,包含了要进行混合的数据集(每个数据集是一个一维列表),后者是一个一维列表,包含了每个数据集的权重。

使用Python进行实现:

from typing import List, Any, Union
import numpy as np
import random
from tqdm import tqdm

def make_blendable_dataset(datasets: List[List[Any]], weights: List[Union[float, int]]) -> List[Any]:
    num_datasets = len(datasets)
    assert num_datasets == len(weights), "The number of datasets must match the number of weights."

    # Shuffle
    size = 0
    for dataset in datasets:
        size += len(dataset)
        random.shuffle(dataset)

    # Normalize weights
    weights = np.array(weights, dtype=np.float64)
    assert np.all(weights > 0), "All weights must be positive."
    weights /= weights.sum()

    # Determine if all weights are equal
    if np.ptp(weights) < 1e-5:
        q, r = divmod(size, num_datasets)
        dataset_index = np.concatenate([
            np.tile(np.arange(num_datasets, dtype=np.int16), q),
            np.arange(r, dtype=np.int16)
        ])
        dataset_sample_index = np.concatenate([
            np.repeat(np.arange(q, dtype=np.int64), num_datasets),
            np.full(r, q, dtype=np.int64)
        ])
        current_samples = np.full(num_datasets, q, dtype=np.int64)
        current_samples[:r] += 1
    else:
        dataset_index = np.zeros(size, dtype=np.int16)
        dataset_sample_index = np.zeros(size, dtype=np.int64)
        current_samples = np.zeros(num_datasets, dtype=np.int64)

        for sample_idx in tqdm(range(size), desc="Calculating error"):
            errors = weights * (sample_idx + 1) - current_samples
            max_error_index = np.argmax(errors)
            dataset_index[sample_idx] = max_error_index
            dataset_sample_index[sample_idx] = current_samples[max_error_index]
            current_samples[max_error_index] += 1

    print(f"Ratios:")
    for i in range(num_datasets):
        print(f"input: {weights[i]}, achieved: {current_samples[i] / size}")

    # Blend
    res = []
    for i in tqdm(range(size), desc="Blending"):
        dataset_idx = dataset_index[i]
        sample_idx = dataset_sample_index[i] % len(datasets[dataset_idx])
        res.append(datasets[dataset_idx][sample_idx])

    return res

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/406329.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Go基本数据类型

bool类型 布尔型的值只可以是常量 true 或者 false。⼀个简单的例⼦&#xff1a;var b bool true 数值型 1. 整数型 可以简单讲解⼀下⼆进制和位数的关系&#xff0c;以及int和uint的关系 int8 有符号 8 位整型 (-128 到 127) ⻓度&#xff1a;8bitint16 有符号 16 位整型…

探索海外短剧系统开发:引领潮流的创新与实践

随着全球化的深入发展&#xff0c;海外短剧市场正逐渐成为文化交流与娱乐产业的新热点。本文将探讨海外短剧系统开发的重要性、机遇与挑战&#xff0c;以及如何应对这些挑战&#xff0c;以实现海外短剧市场的可持续发展。 一、海外短剧系统开发的重要性 海外短剧系统开发不仅…

BUGKU-WEB 备份是个好习惯

题目描述 题目截图如下&#xff1a; 进入场景看看&#xff1a; 解题思路 看源码看提示&#xff1a;备份是个好习惯扫描目录md5弱比较 相关工具 御剑md5解密&#xff1a;https://www.somd5.com/ 解题步骤 看到的这串字符&#xff0c;有点像md5&#xff1f; d41d8cd98…

Web3之光:揭秘数字创新的未来

随着数字化时代的深入发展&#xff0c;Web3正以其独特的技术和理念&#xff0c;为我们打开数字创新的崭新视角。作为数字化时代的新兴力量&#xff0c;Web3将深刻影响着我们的生活、工作和社会。本文将揭秘Web3的奥秘&#xff0c;探讨其在数字创新领域的前景和潜力。 1. 重新定…

【论文精读】Segment Anything

Segment Anything 前言Abstract1. Introduction2. Segment Anything Task3. Segment Anything Model4. Segment Anything Data Engine5. Segment Anything Dataset6. Segment Anything RAI Analysis7. Zero-Shot Transfer Experiments7.1. Zero-Shot Single Point Valid Mask E…

深度学习基础(二)卷积神经网络(CNN)

之前的章节我们初步介绍了深度学习相关基础知识和训练神经网络&#xff1a; 深度学习基础&#xff08;一&#xff09;神经网络基本原理-CSDN博客文章浏览阅读924次&#xff0c;点赞13次&#xff0c;收藏19次。在如今的科技浪潮中&#xff0c;神经网络作为人工智能的核心技术之…

【微服务】国内微服务生态标准-SpringCloud Alibaba

现在已经是21世纪的二十年代&#xff0c;在未来的很长时间&#xff0c;以互联网、IOT物联网为代表的分布式应用必将越来越多&#xff0c;大量的软件企业对掌握微服务与高可用、高性能、高并发的架构人才也必定趋之若鹜。我们可以看看现阶段针对软件架构师的招聘需求和薪资&…

【元宵佳节】砖一祝您节日快乐!

元宵节的由来 相传&#xff0c;汉文帝(前179-前157年)为庆祝周勃于正月十五勘平诸吕之乱&#xff0c;每逢此夜&#xff0c;必出言游玩&#xff0c;与民同乐&#xff0c;在古代&#xff0c;夜同宵&#xff0c;正月又称元月&#xff0c;汉文帝就将正月十五定为元宵节&#xff0c…

算法打卡day1|数组篇|Leetcode 704.二分查找、27.移除元素

数组理论基础 数组是存放在连续内存空间上的相同类型数据的集合&#xff0c;可以方便的通过下标索引的方式获取到下标下对应的数据。 1.数组下标都是从0开始的。 2.数组内存空间的地址是连续的。 正是因为数组的在内存空间的地址是连续的&#xff0c;所以我们在删除或者增添…

LiveQing视频点播流媒体RTMP推流服务功能-支持配置开启 HTTPS 服务什么时候需要开启HTTPS服务

LiveQing视频点播流媒体RTMP推流服务功能支持配置开启 HTTPS 服务什么时候需要开启HTTPS服务 1、配置开启HTTPS1.1、准备https证书1.1.1、选择Nginx类型证书下载 1.2、配置 开启 HTTPS1.2.1 web页面配置1.2.2 配置文件配置 2、验证HTTPS服务3、为什么要开启HTTPS3.1、安全性要求…

免费的数据恢复软件哪个好?这10个数据恢复软件可以试试

遇到电脑、硬盘或U盘等设备中数据丢失&#xff0c;不用着急&#xff0c;数据恢复软件来帮你。 在遇到数据丢失的问题时&#xff0c;很多朋友都会很着急也不知道该怎么办。作为数据恢复小白&#xff0c;我们可以选择使用数据恢复软件进行扫描恢复。现在市面上的数据恢复软件很多…

AI技术那些事儿:揭开潜伏在你生活中的高科技小能手

你有没有发现&#xff0c;现在的生活里有些“看不见”的聪明家伙&#xff0c;它们时时刻刻在帮咱们忙活呢&#xff1f;从早上用语音命令打开窗帘、播报新闻&#xff0c;到晚上喊一声关灯睡觉&#xff0c;这些都离不开人工智能&#xff08;简称AI&#xff09;的助攻。今天咱就掰…

C++笔记:二叉搜索树(Binary Search Tree)

文章目录 二叉搜索树的概念二叉搜索树操作1. 框架搭建2. 遍历3. 查找迭代实现递归实现 4. 插入迭代实现递归实现 5. 删除迭代实现递归实现 6. 析构与销毁7. 拷贝构造与赋值重载 二叉搜索树的应用二叉搜索树的性能分析二叉搜索树模拟实现源码 二叉搜索树的概念 二叉搜索树又称二…

泛微e-office系统敏感信息泄露漏洞

声明 本文仅用于技术交流&#xff0c;请勿用于非法用途 由于传播、利用此文所提供的信息而造成的任何直接或者间接的后果及损失&#xff0c;均由使用者本人负责&#xff0c;文章作者不为此承担任何责任 1、系统简介 泛微e-office系统是标准、易用、快速部署上线的专业协同OA软…

gitlab升级

查看当前版本 cat /opt/gitlab/embedded/service/gitlab-rails/VERSION12.9.3-ee 备份数据 gitlab-rake gitlab:backup:create备份默认位置在 /var/opt/gitlab/backups/ 备份配置数据&#xff08;git配置目录&#xff09; tar -zcvf gitlab12.9.3-ee.tar.gz /etc/gitlab备…

01VScode开发stm32环境搭建

title: VScode开发stm32环境搭建 tags: STM32vscode 1.准备工作 1.下载并安装VSCODE 在百度上搜索vscode记住一定要是官方的 不然你自己就是在给自己下毒2345全来了 打红圈一定要有不然就是在垃圾网站上下的 VSCode下载链接 选一个适合你的      安装正常流程走就行不再…

达尔克仪器设备邀您观摩2024生物发酵产品与技术装备展

参展企业介绍 达尔克本着“诚信、求精、协作、创新”的企业精神&#xff0c;以专业的技术、优良的品质和完善的服务&#xff0c;与广大客户共创辉煌未来。 我们的产品——包括水质分析、压力、温度、流量、物位等工业在线过程控制仪表、其他类型传感器以及自动化控制方案——全…

docker运行onlyoffice,并配置https访问【参考仅用】

官方说明&#xff1a; Installing ONLYOFFICE Docs for Docker on a local server - ONLYOFFICEhttps://helpcenter.onlyoffice.com/installation/docs-developer-install-docker.aspx 一、容器端口、目录卷映射 sudo docker run --name容器名称 --restartalways -i -t -d -p…

论文阅读——SimpleClick

SimpleClick: Interactive Image Segmentation with Simple Vision Transformers 模型直接在VIT上增加交互是分割 用VIT MAE方法训练的预训练权重 用交互式分割方法微调&#xff0c;微调流程&#xff1a; 1、在当前分割自动模拟点击&#xff0c;没有人为提供的点击 受到RITM启发…

python[6]

类和对象 面向对象编程–说白就是让对象干活 创建类&#xff1a;class 类名&#xff1a; 创建类对象 对象名 类名&#xff08;&#xff09; 构造方法 1、构造方法的名称是__init__ 2、构造方法的作用&#xff1f; 构建类对象的时候会自动运行 构建类对象的传参会传递给构造…