机器学习探索计划——数据集划分

文章目录

  • 导包
  • 手写数据划分函数
  • 使用sklearn内置的划分数据函数
    • stratify=y理解举例

导包

import numpy as np
from matplotlib import pyplot as plt
from sklearn.datasets import make_blobs

手写数据划分函数

x, y = make_blobs(
    n_samples = 300,
    n_features = 2,
    centers = 3,
    cluster_std = 1,
    center_box = (-10, 10),
    random_state = 666,
    return_centers = False
)

make_blobs:scikit-learn(sklearn)库中的一个函数,用于生成聚类任务中的合成数据集。它可以生成具有指定特征数和聚类中心数的随机数据集。

n_samples:生成的样本总数,本例中为 300。
n_features:生成的每个样本的特征数,本例中为 2。
centers:生成的簇的数量,本例中为 3。
cluster_std:每个簇中样本的标准差,本例中为 1。
center_box:每个簇中心的边界框(bounding box)范围,本例中为 (-10, 10)。
random_state:随机种子,用于控制数据的随机性,本例中为 666。
return_centers:是否返回生成的簇中心点,默认为 False,在本例中不返回。

plt.scatter(x[:, 0], x[:, 1], c = y, s = 15)
plt.show()

在这里插入图片描述

x[:, 0]:表示取 x 数据集中所有样本的第一个特征值。
x[:, 1]:表示取 x 数据集中所有样本的第二个特征值。
c=y:表示使用标签 y 对样本点进行颜色编码,即不同的标签值将使用不同的颜色进行展示。
s=15:表示散点的大小为 15,即每个样本点的显示大小。

index = np.arange(20)
np.random.shuffle(index)
index

output: array([12, 15, 7, 11, 14, 16, 6, 5, 0, 1, 2, 19, 13, 4, 18, 9, 8,
10, 3, 17])

np.random.permutation(20)

output: array([ 6, 4, 11, 13, 18, 1, 8, 3, 10, 9, 7, 0, 15, 17, 19, 16, 5,
2, 14, 12])

np.random.seed(666)
shuffle = np.random.permutation(len(x))
shuffle

output:
array([235, 169, 17, 92, 234, 15, 0, 152, 176, 243, 98, 260, 96,
123, 266, 220, 109, 286, 185, 177, 160, 11, 50, 246, 258, 254,
34, 229, 154, 66, 285, 214, 237, 95, 7, 205, 262, 281, 110,
64, 111, 87, 263, 38, 153, 129, 273, 255, 208, 56, 162, 106,
277, 224, 178, 265, 108, 104, 101, 158, 248, 29, 181, 62, 14,
75, 118, 201, 41, 150, 131, 183, 288, 291, 76, 293, 267, 1,
165, 12, 278, 53, 209, 114, 71, 135, 184, 206, 244, 61, 211,
213, 128, 3, 143, 296, 227, 242, 94, 251, 284, 253, 89, 49,
159, 35, 268, 249, 197, 55, 167, 146, 23, 283, 187, 173, 124,
68, 250, 189, 186, 5, 221, 65, 40, 119, 74, 22, 19, 59,
188, 231, 44, 137, 31, 256, 43, 85, 149, 134, 218, 120, 81,
67, 239, 195, 207, 240, 182, 179, 90, 216, 180, 47, 299, 30,
163, 193, 48, 245, 138, 28, 257, 125, 170, 157, 259, 290, 200,
203, 215, 238, 194, 121, 298, 73, 97, 8, 130, 105, 190, 6,
36, 27, 32, 144, 4, 117, 115, 171, 136, 84, 10, 113, 233,
247, 72, 292, 198, 252, 82, 228, 37, 39, 33, 280, 272, 79,
116, 172, 202, 226, 271, 145, 13, 78, 196, 274, 26, 297, 191,
232, 52, 20, 230, 18, 58, 294, 140, 132, 287, 217, 25, 133,
83, 99, 93, 21, 241, 168, 147, 275, 212, 127, 54, 199, 282,
107, 151, 289, 88, 100, 264, 45, 77, 295, 9, 166, 57, 80,
155, 279, 86, 219, 2, 269, 126, 102, 142, 192, 161, 103, 42,
261, 16, 175, 122, 174, 164, 112, 148, 24, 139, 276, 141, 204,
210, 69, 46, 63, 225, 270, 156, 223, 60, 51, 222, 91, 70,
236])

np.random.seed(666)使得随机数结果可复现

shuffle.shape

output: (300,)

train_size = 0.7
train_index = shuffle[:int(len(x) * train_size)]
test_index = shuffle[int(len(x) * train_size):]
train_index.shape, test_index.shape

output: ((210,), (90,))

x[train_index].shape, y[train_index].shape, x[test_index].shape, y[test_index].shape

output: ((210, 2), (210,), (90, 2), (90,))

def my_train_test_split(x, y, train_size = 0.7, random_state = None):
    if random_state:
        np.random.seed(random_state)
    shuffle = np.random.permutation(len(x))
    train_index = shuffle[:int(len(x) * train_size)]
    test_index = shuffle[int(len(x) * train_size):]
    return x[train_index], x[test_index], y[train_index], y[test_index]
x_train, x_test, y_train, y_test = my_train_test_split(x, y, train_size=0.7, random_state=233)
x_train.shape, x_test.shape, y_train.shape, y_test.shape

output: ((210, 2), (90, 2), (210,), (90,))

plt.scatter(x_train[:, 0], x_train[:, 1], c=y_train, s=15)  # y_train一样的,颜色相同
plt.show()

在这里插入图片描述

plt.scatter(x_test[:, 0], x_test[:, 1], c=y_test, s=15)
plt.show()

在这里插入图片描述

使用sklearn内置的划分数据函数

from sklearn.model_selection import train_test_split
x_train, x_test, y_train, y_test = train_test_split(x, y, train_size=0.7, random_state=233)
x_train.shape, x_test.shape, y_train.shape, y_test.shape

output: ((210, 2), (90, 2), (210,), (90,))

from collections import Counter
Counter(y_test)

output: Counter({2: 34, 0: 29, 1: 27})

x_train, x_test, y_train, y_test = train_test_split(x, y, train_size=0.7, random_state=666, stratify=y)

stratify=y: 使用标签 y 进行分层采样,确保训练集和测试集中的类别分布相对一致。
这样做的好处是,在训练过程中,模型可以接触到各个类别的样本,从而更好地学习每个类别的特征和模式,提高模型的泛化能力。

Counter(y_test)

output: Counter({1: 30, 0: 30, 2: 30})

stratify=y理解举例

x = np.random.randn(1000, 2)  # 1000个样本,2个特征
y = np.concatenate([np.zeros(800), np.ones(200)])  # 800个负样本,200个正样本

# 使用 stratify 进行分层采样
x_train, x_test, y_train, y_test = train_test_split(x, y, train_size=0.7, random_state=42, stratify=y)

# 打印训练集中正负样本的比例。通过使用 np.mean,我们可以方便地计算出比例或平均值,以了解数据集的分布情况或对模型性能进行评估。
print("训练集中正样本比例:", np.mean(y_train == 1))
print("训练集中负样本比例:", np.mean(y_train == 0))

# 打印测试集中正负样本的比例
print("测试集中正样本比例:", np.mean(y_test == 1))
print("测试集中负样本比例:", np.mean(y_test == 0))

output:
训练集中正样本比例: 0.2
训练集中负样本比例: 0.8
测试集中正样本比例: 0.2
测试集中负样本比例: 0.8

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/184322.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【UE5】资源(Asset)

了解UE游戏的基本构成 资源(Asset): 在UE中,资源(Asset)是指游戏中使用到的各种素材,例如模型、纹理、材质、声音、动画、蓝图、数据表格、关卡等(通常以uasset结尾),他…

freeswitch设置多个execute_on_media

概述 freeswitch是一款简单好用的VOIP开源软交换平台。 fs中有非常多的接口和通道变量,使用方式多变。 官方文档有时候也仅仅是介绍了最基本的使用方法和格式。 环境 centos:CentOS release 7.0 (Final)或以上版本 freeswitch:v1.6 G…

办公技巧:Word中插入图片、形状、文本框排版技巧

目录 一、插入图片排版技巧 二、添加形状排版技巧 三、插入“文本框”排版技巧 我们平常在制作word时候经常会遇到插入选项卡下的图片、形状和文本框这三种情况下,那么如何使得Word文档当中添加这三个元素的同时,又能保证样式美观呢,今天小…

Leetcode200. 岛屿数量

Every day a Leetcode 题目来源:200. 岛屿数量 解法1:深度优先搜索 设目前指针指向一个岛屿中的某一点 (i, j),寻找包括此点的岛屿边界。 从 (i, j) 向此点的上下左右 (i1,j),(i-1,j),(i,j1),(i,j-1) …

静态链表的结构设计与主要操作功能的实现(初始化,头插,尾插,判空,删除,输出,清空,销毁)

目录 一.静态链表的结构设计 二.静态链表的结构设计示意图 三.静态链表的实现 四.静态链表的总结 一.静态链表的结构设计 typedef struct SNode {int data;//数据int next;//后继指针(下标) }SNode,SLinkList[MAXSIZE]; 二.静态链表的结构设计示意图 0:有效数据链的头节点;…

ATA-3080功率放大器在海底管道悬跨振动激振器检测中的应用

海底管道悬跨振动检测是指对海底管道在悬跨(即管道跨越两个支撑点之间的区域)段发生的振动进行监测和分析的过程。为了实现海底管道悬跨振动检测,通常使用以下几种方法: 1.加速度传感器:通过在管道表面安装加速度传感器…

现在可以手动获取真随机数吗?

获取真正的随机数并不像获取伪随机数那样简单,因为真随机数的产生依赖于物理过程或者其他难以预测的现象。在计算机科学中,通常使用的是伪随机数,它们是通过算法生成的,看起来像是随机的,但实际上是可以重现的。 如果…

新生儿散光:原因、科普和注意事项

引言: 散光是一种常见的眼睛问题,虽然在新生儿时期相对较少见,但了解其原因、科普相关知识,并提供一些建议的注意事项,对于婴儿的视力健康至关重要。本文将深入探讨新生儿散光的原因、相关科普知识,并为父…

新的centos7.9安装jenkins—(一)

更多ruoyi-nbcio功能请看演示系统 gitee源代码地址 前后端代码: https://gitee.com/nbacheng/ruoyi-nbcio 演示地址:RuoYi-Nbcio后台管理系统 因为是用java8,所以还是要最后java8版本的jenkins,版本号是2.346.3,后…

​ 一文带你了解多文件混淆加密

目录 🔒 一文带你了解 JavaScript 多文件混淆加密 ipaguard加密前 ipaguard加密后 ​ 🔒 一文带你了解 JavaScript 多文件混淆加密 JavaScript 代码多文件混淆加密可以有效保护源代码不被他人轻易盗取。虽然前端的 JS 无法做到纯粹的加密&#xff0c…

Echarts 大屏注册自定义地图解析文件流报错问题解决

效果图: 1、首先通过后台接口获取到SVG图片的文件流,postman能够正确解析出文件流,前端调用api时需要设置返回的响应格式为image/svg+xml格式,否则解析失败 拿到文件流后是这样的 <?xml version="1.0" encoding="utf-8"?> <!-- Generator: …

6.3.WebRTC中的SDP类的结构

在上节课中呢&#xff0c;我向你介绍了sdp协议&#xff0c; 那这节课呢&#xff0c;我们再来看看web rtc中。是如何存储sdp的&#xff1f;也就是sdp的类结构&#xff0c;那在此之前呢&#xff1f;我们先对sdp的内容啊&#xff0c;做一下分类。因为在上节课中呢&#xff0c;虽然…

软件设计不是CRUD(6):低耦合模块设计实战——组织机构模块(上)

组织机构功能是应用系统中常见的业务功能之一&#xff0c;但是不同性质、不同行业背景、不同使用场景的应用系统对组织机构功能的要求可能完全不一样。所以使用这样的功能对低耦合模块设计进行示例性的讲解是比较具有代表性的。在后续的几篇文章中&#xff0c;我们会首先进行示…

linux磁盘清理

目录 排查过程1、查看磁盘占用情况2. 按照占用大小进行倒排-当前目录及其子目录3.当前目录磁盘占用情况 清理命令 排查过程 1、查看磁盘占用情况 df -hdf -h 命令用于显示磁盘空间的使用情况&#xff0c;以人类可读的方式呈现&#xff0c;其中&#xff1a;df 是 “disk free”…

ROS2编译Python节点来发布和订阅的实践《2》

通过熟悉&#xff1a;ROS2对比ROS1的一些变化与优势&#xff08;全新安装ROS2以及编译错误处理&#xff09;《1》 我们大概了解到了ROS2的重新设计带来的巨大优势&#xff0c;最核心的就是去掉了roscore&#xff0c;这样就避免了因为节点管理器崩溃而使整个系统都崩溃的场景出现…

机器学习/sklearn 笔记:K-means,kmeans++,MiniBatchKMeans,二分Kmeans

1 K-means介绍 1.0 方法介绍 KMeans算法通过尝试将样本分成n个方差相等的组来聚类&#xff0c;该算法要求指定群集的数量。它适用于大量样本&#xff0c;并已在许多不同领域的广泛应用领域中使用。KMeans算法将一组样本分成不相交的簇&#xff0c;每个簇由簇中样本的平均值描…

【ChatGLM2-6B】Docker下部署及微调

【ChatGLM2-6B】小白入门及Docker下部署 一、简介1、ChatGLM2是什么2、组成部分3、相关地址 二、基于Docker安装部署1、前提2、CentOS7安装NVIDIA显卡驱动1&#xff09;查看服务器版本及显卡信息2&#xff09;相关依赖安装3&#xff09;显卡驱动安装 2、 CentOS7安装NVIDIA-Doc…

idea 问题合集

调试按钮失效&#xff1a; 依次点击&#xff1a;Modules-web-src-Sources&#xff0c;重启IDEA即可&#xff08;网上看到的方法&#xff0c;原因呢未明&#xff09;

Modbus故障码速查手册(故障码含义、分析原因、详细解读)

Modbus故障码速查手册 文章目录 Modbus故障码速查手册引言故障码表故障详解0x01 IllegalFunction0x02 IllegalDataAddress0x03 IllegalDataValue0x04 SlaveDeviceFailure0x05 Acknowledge0x06 SlaveDeviceBusy0x08 MemoryParityError0x0A GatewayPathUnavailable0x0B GatewayTa…

java spring-boot 修改打包的jar包名称

修改pom文件 <finalName>lzwd</finalName><build><finalName>lzwd</finalName><plugins><plugin><groupId>org.springframework.boot</groupId><artifactId>spring-boot-maven-plugin</artifactId></plu…