sklearn 生成数据集 make_classification 参数详解:创建3类不平衡分类数据实战

📅 2026/7/5 3:57:25 👁️ 阅读次数 📝 编程学习
sklearn 生成数据集 make_classification 参数详解:创建3类不平衡分类数据实战

深入解析sklearn的make_classification:创建三类不平衡分类数据集实战

在机器学习实践中,高质量的数据集是算法验证和模型训练的基础。当我们需要测试分类算法在不平衡数据上的表现时,手动创建符合特定需求的数据集就显得尤为重要。scikit-learn提供的make_classification函数是一个强大的工具,它允许我们生成各种复杂度的分类数据集,包括样本数量不平衡的多类数据集。

1. make_classification函数核心参数解析

make_classification是scikit-learn中用于生成模拟分类数据集的函数,它提供了丰富的参数来控制数据的分布特性。让我们先了解其中最关键的一些参数:

from sklearn.datasets import make_classification # 基础参数示例 X, y = make_classification( n_samples=1250, # 总样本量 n_features=20, # 特征总数 n_informative=5, # 有效特征数 n_redundant=2, # 冗余特征数 n_classes=3, # 类别数 weights=[0.8, 0.15, 0.05], # 类别权重 flip_y=0.01, # 随机噪声比例 random_state=42 # 随机种子 )

1.1 控制数据结构的参数

  • n_informative:真正对分类有贡献的特征数量。这些特征将用于构建分类的超平面。
  • n_redundant:由信息特征线性组合生成的特征数量,增加了数据的真实感。
  • n_repeated:直接从信息特征和冗余特征中复制的特征数量。
# 特征类型对分类的影响比较 import pandas as pd params = { '仅信息特征': {'n_informative': 5, 'n_redundant': 0, 'n_repeated': 0}, '含冗余特征': {'n_informative': 3, 'n_redundant': 2, 'n_repeated': 0}, '含重复特征': {'n_informative': 3, 'n_redundant': 1, 'n_repeated': 1} } results = [] for desc, config in params.items(): X, y = make_classification(n_samples=1000, n_features=5, **config, random_state=42) # 这里可以添加分类性能评估代码 results.append({'配置': desc, '特征类型': str(config)}) pd.DataFrame(results)

1.2 控制类别不平衡的参数

  • weights:每个类别的相对比例。例如[0.8, 0.15, 0.05]将生成80%、15%和5%的三个类别。
  • n_clusters_per_class:每个类别由多少个簇组成,增加此值会使类别分布更复杂。

注意:当设置不平衡权重时,确保weights数组的和为1.0,否则sklearn会自动进行归一化处理。

2. 创建三类不平衡数据集实战

现在,让我们按照题目要求创建一个三类不平衡数据集,样本数分别为1000、200和50。

2.1 数据集生成代码

from sklearn.datasets import make_classification import matplotlib.pyplot as plt import numpy as np # 生成三类不平衡数据集 X, y = make_classification( n_samples=1250, # 1000+200+50=1250 n_features=2, # 二维特征方便可视化 n_informative=2, # 两个有效特征 n_redundant=0, # 无冗余特征 n_classes=3, # 三个类别 weights=[0.8, 0.16, 0.04], # 对应1000,200,50 class_sep=1.5, # 控制类别间距 flip_y=0.02, # 加入少量噪声 random_state=42 ) # 查看各类样本数量 unique, counts = np.unique(y, return_counts=True) print(dict(zip(unique, counts))) # 输出:{0: 1000, 1: 200, 2: 50}

2.2 数据集可视化

# 可视化三类数据分布 plt.figure(figsize=(10, 6)) colors = ['#1f77b4', '#ff7f0e', '#2ca02c'] markers = ['o', 's', '^'] for i, color, marker in zip([0, 1, 2], colors, markers): plt.scatter(X[y == i, 0], X[y == i, 1], color=color, marker=marker, label=f'Class {i} (n={counts[i]})', edgecolor='k', alpha=0.7) plt.title('Three-class Imbalanced Dataset (1000:200:50)') plt.xlabel('Feature 1') plt.ylabel('Feature 2') plt.legend() plt.grid(True) plt.show()

2.3 参数调整技巧

在实际应用中,我们可能需要调整参数以获得更符合需求的数据分布:

  1. class_sep:控制类别间的分离程度。值越大,类别越容易区分。
  2. flip_y:随机交换标签的比例,增加数据噪声。
  3. cluster_std:控制每个簇的标准差,影响簇的紧密度。
# 调整class_sep的效果比较 plt.figure(figsize=(15, 5)) for i, sep in enumerate([0.5, 1.0, 1.5], 1): X, y = make_classification( n_samples=1250, n_features=2, n_classes=3, weights=[0.8, 0.16, 0.04], class_sep=sep, random_state=42 ) plt.subplot(1, 3, i) for class_idx in [0, 1, 2]: plt.scatter(X[y == class_idx, 0], X[y == class_idx, 1], alpha=0.6, label=f'Class {class_idx}') plt.title(f'class_sep={sep}') plt.legend() plt.tight_layout() plt.show()

3. 不平衡数据集的挑战与处理

三类不平衡数据集(1000:200:50)带来了几个独特的挑战:

3.1 评估指标的选择

对于不平衡数据集,准确率(accuracy)往往不是合适的评估指标。考虑使用:

  • 混淆矩阵(Confusion Matrix)
  • 精确率(Precision)、召回率(Recall)
  • F1-score(特别是加权F1)
  • ROC AUC(对于多类问题使用macro或weighted)
from sklearn.metrics import classification_report from sklearn.model_selection import train_test_split from sklearn.linear_model import LogisticRegression # 划分训练测试集 X_train, X_test, y_train, y_test = train_test_split( X, y, test_size=0.3, random_state=42, stratify=y ) # 训练简单逻辑回归模型 model = LogisticRegression(multi_class='auto', solver='lbfgs') model.fit(X_train, y_train) # 评估 y_pred = model.predict(X_test) print(classification_report(y_test, y_pred))

3.2 处理不平衡的常用方法

  1. 重采样技术

    • 过采样少数类(如SMOTE)
    • 欠采样多数类
  2. 类别权重

    • 在模型中使用class_weight参数
# 使用类别权重的逻辑回归 weighted_model = LogisticRegression( multi_class='auto', solver='lbfgs', class_weight='balanced' # 自动调整权重 ) weighted_model.fit(X_train, y_train) # 比较结果 print("不加权模型:") print(classification_report(y_test, model.predict(X_test))) print("\n加权模型:") print(classification_report(y_test, weighted_model.predict(X_test)))

3.3 实际应用中的注意事项

  • 保持测试集的分层抽样(stratify)以确保分布一致
  • 考虑使用交叉验证时也保持分层
  • 对于极度不平衡的类别(如50样本),可能需要收集更多数据或考虑异常检测方法

4. 高级应用:自定义数据分布

make_classification还允许我们通过n_clusters_per_class参数创建更复杂的数据分布。

4.1 多簇分布示例

# 每个类别由多个簇组成 X, y = make_classification( n_samples=1250, n_features=2, n_informative=2, n_classes=3, weights=[0.8, 0.16, 0.04], n_clusters_per_class=2, # 每个类有2个簇 random_state=42 ) # 可视化 plt.figure(figsize=(8, 6)) for i in [0, 1, 2]: plt.scatter(X[y == i, 0], X[y == i, 1], alpha=0.6, label=f'Class {i}') plt.title('Multi-cluster Distribution') plt.legend() plt.show()

4.2 添加非线性可分性

通过组合make_classification和其他生成函数,可以创建更复杂的数据分布:

from sklearn.datasets import make_moons, make_circles # 生成一个月牙形和一个圆形数据集 X1, y1 = make_moons(n_samples=1000, noise=0.1, random_state=42) X2, y2 = make_circles(n_samples=250, noise=0.1, factor=0.5, random_state=42) # 合并并创建三类不平衡数据集 X = np.vstack([X1, X2]) y = np.hstack([np.zeros(1000), np.ones(200), np.full(50, 2)]) # 1000:200:50 # 可视化 plt.figure(figsize=(8, 6)) plt.scatter(X[y==0, 0], X[y==0, 1], alpha=0.5, label='Class 0 (1000)') plt.scatter(X[y==1, 0], X[y==1, 1], alpha=0.5, label='Class 1 (200)') plt.scatter(X[y==2, 0], X[y==2, 1], alpha=0.5, label='Class 2 (50)') plt.legend() plt.title('Complex Non-linear Imbalanced Dataset') plt.show()

4.3 高维数据生成与可视化

虽然我们通常使用二维数据进行可视化演示,但make_classification可以生成任意维度的数据:

# 生成20维特征的三类不平衡数据 X_highdim, y_highdim = make_classification( n_samples=1250, n_features=20, n_informative=8, n_redundant=4, n_classes=3, weights=[0.8, 0.16, 0.04], random_state=42 ) # 使用PCA降维可视化 from sklearn.decomposition import PCA pca = PCA(n_components=2) X_pca = pca.fit_transform(X_highdim) plt.figure(figsize=(8, 6)) for i in [0, 1, 2]: plt.scatter(X_pca[y_highdim == i, 0], X_pca[y_highdim == i, 1], alpha=0.6, label=f'Class {i}') plt.title('20D Data Visualized with PCA') plt.legend() plt.show()

提示:在高维数据生成时,适当增加n_informative参数的值,以确保在高维空间中有足够的判别特征。同时,class_sep可能需要调整得更大一些,因为高维空间中点之间的距离行为与低维空间不同。