K—近邻算法实际应用案例

K—近邻算法实际应用案例

  • 1. 案例1:鸢尾花种类预测
    • 1.1 数据集获取和属性介绍
      • 1.1.1 scikit-learn中的数据集介绍
      • 1.1.2 sklearn数据集返回值介绍
    • 1.2 数据可视化介绍(查看数据分布)
    • 1.3 数据集的划分
    • 1.4 特征工程
      • 1.4.1 归一化
      • 1.4.2 标准化
    • 1.5 鸢尾花种类预测
    • 1.6 交叉验证和网格搜索
      • 1.6.1 交叉验证
      • 1.6.2 网格搜索
      • 6.1.3 鸢尾花实例

1. 案例1:鸢尾花种类预测


Iris数据集是常用的分类实验数据集,由Fisher,1936年收集整理。Iris也称鸢尾花卉数据集,是一类多重变量分析的数据集。

关于数据集的具体介绍:

  • 特征值-4个:花瓣、花萼的长度、宽度
  • 目标值-3个:setosa、vericolor、virginica

该虹膜数据集包含150行数据,包括来自美国的三个相关鸢尾花种类50个样本:山鸢尾,虹膜锦葵,变色鸢尾。

在这里插入图片描述

注意:

在机器学习的相关细分领域,数据的收集工作一般是由行业专家来完成的。因为只有这些专家才知道哪些特征重要,哪些特征不重要。


1.1 数据集获取和属性介绍


1.1.1 scikit-learn中的数据集介绍


1. scikit-learn数据集API介绍

scikit-learn中获取数据的操作是在一个大类sklearn.datasets下的。

1)sklearn小数据集

  • 使用sklearn.datasets.load_*()获取小数据集,数据包含在datasets里,从本地获取,也就是活sklearn中有一小部分已经下好的数据集;
  • 例:使用sklearn.datasets.load_iris()获取鸢尾花数据集。sklearn.datasets.load_iris()返回值是鸢尾花数据集。
from sklearn.datasets import load_iris

# 1. 数据集获取
# 1.1 获取小数据集用datasets.load_*()
iris = load_iris()
print(iris) # 会显示一堆的内容,但是阅读起来不方便

2)sklearn大数据集

  • 使用datasets.fetch_*(data_home=None)获取大规模数据集,需要从网上下载,函数的第一个参数data_home,表示数据集下载的目录,默认是~/scikit_learn_data/
  • 例:使用sklearn.datasets.fetch_20newsgroups(data_home=None, subset='train')获取20类新闻数据。参数subject可传’train’、‘test’、或’all’,是可选的,表示要加载的数据集类型(测试集、训练集、全集)。
from sklearn.datasets import fetch_20newsgroups

# 1.2 获取大数据集用datasets.fetch_*(data_home=None)
news = fetch_20newsgroups()
print(news)

1.1.2 sklearn数据集返回值介绍


loadfetch返回的数据类型是datasets.base.Bunch(字典格式):

  • data:特征数据数组,是二维numpy.ndarray数组;
  • target:标签数据,是一位numpy.ndarray数组;
  • DESCR:数据描述信息;
  • feature_names:特证名;
  • target_names:标签名。
from sklearn.datasets import load_iris

# 1. 数据集获取
iris = load_iris()

# 2. 数据集属性描述
print('数据集的特征值:\n', iris.data)
print('数据集的目标值:\n', iris['target'])
print('数据集的特征值名字:\n', iris.feature_names)
print('数据集的目标值名字:\n', iris.target_names)
print('数据集的描述:\n', iris.DESCR)

输出:

数据集的特征值:
 [[5.1 3.5 1.4 0.2]
 [4.9 3.  1.4 0.2]
 [4.7 3.2 1.3 0.2]
 [4.6 3.1 1.5 0.2]
 [5.  3.6 1.4 0.2]
 [5.4 3.9 1.7 0.4]
 [4.6 3.4 1.4 0.3]
 [5.  3.4 1.5 0.2]
 [4.4 2.9 1.4 0.2]
 [4.9 3.1 1.5 0.1]
 [5.4 3.7 1.5 0.2]
 [4.8 3.4 1.6 0.2]
 [4.8 3.  1.4 0.1]
 [4.3 3.  1.1 0.1]
 [5.8 4.  1.2 0.2]
 [5.7 4.4 1.5 0.4]
 [5.4 3.9 1.3 0.4]
 [5.1 3.5 1.4 0.3]
 [5.7 3.8 1.7 0.3]
 [5.1 3.8 1.5 0.3]
 [5.4 3.4 1.7 0.2]
 [5.1 3.7 1.5 0.4]
 [4.6 3.6 1.  0.2]
 [5.1 3.3 1.7 0.5]
 [4.8 3.4 1.9 0.2]
 [5.  3.  1.6 0.2]
 [5.  3.4 1.6 0.4]
 [5.2 3.5 1.5 0.2]
 [5.2 3.4 1.4 0.2]
 [4.7 3.2 1.6 0.2]
 [4.8 3.1 1.6 0.2]
 [5.4 3.4 1.5 0.4]
 [5.2 4.1 1.5 0.1]
 [5.5 4.2 1.4 0.2]
 [4.9 3.1 1.5 0.2]
 [5.  3.2 1.2 0.2]
 [5.5 3.5 1.3 0.2]
 [4.9 3.6 1.4 0.1]
 [4.4 3.  1.3 0.2]
 [5.1 3.4 1.5 0.2]
 [5.  3.5 1.3 0.3]
 [4.5 2.3 1.3 0.3]
 [4.4 3.2 1.3 0.2]
 [5.  3.5 1.6 0.6]
 [5.1 3.8 1.9 0.4]
 [4.8 3.  1.4 0.3]
 [5.1 3.8 1.6 0.2]
 [4.6 3.2 1.4 0.2]
 [5.3 3.7 1.5 0.2]
 [5.  3.3 1.4 0.2]
 [7.  3.2 4.7 1.4]
 [6.4 3.2 4.5 1.5]
 [6.9 3.1 4.9 1.5]
 [5.5 2.3 4.  1.3]
 [6.5 2.8 4.6 1.5]
 [5.7 2.8 4.5 1.3]
 [6.3 3.3 4.7 1.6]
 [4.9 2.4 3.3 1. ]
 [6.6 2.9 4.6 1.3]
 [5.2 2.7 3.9 1.4]
 [5.  2.  3.5 1. ]
 [5.9 3.  4.2 1.5]
 [6.  2.2 4.  1. ]
 [6.1 2.9 4.7 1.4]
 [5.6 2.9 3.6 1.3]
 [6.7 3.1 4.4 1.4]
 [5.6 3.  4.5 1.5]
 [5.8 2.7 4.1 1. ]
 [6.2 2.2 4.5 1.5]
 [5.6 2.5 3.9 1.1]
 [5.9 3.2 4.8 1.8]
 [6.1 2.8 4.  1.3]
 [6.3 2.5 4.9 1.5]
 [6.1 2.8 4.7 1.2]
 [6.4 2.9 4.3 1.3]
 [6.6 3.  4.4 1.4]
 [6.8 2.8 4.8 1.4]
 [6.7 3.  5.  1.7]
 [6.  2.9 4.5 1.5]
 [5.7 2.6 3.5 1. ]
 [5.5 2.4 3.8 1.1]
 [5.5 2.4 3.7 1. ]
 [5.8 2.7 3.9 1.2]
 [6.  2.7 5.1 1.6]
 [5.4 3.  4.5 1.5]
 [6.  3.4 4.5 1.6]
 [6.7 3.1 4.7 1.5]
 [6.3 2.3 4.4 1.3]
 [5.6 3.  4.1 1.3]
 [5.5 2.5 4.  1.3]
 [5.5 2.6 4.4 1.2]
 [6.1 3.  4.6 1.4]
 [5.8 2.6 4.  1.2]
 [5.  2.3 3.3 1. ]
 [5.6 2.7 4.2 1.3]
 [5.7 3.  4.2 1.2]
 [5.7 2.9 4.2 1.3]
 [6.2 2.9 4.3 1.3]
 [5.1 2.5 3.  1.1]
 [5.7 2.8 4.1 1.3]
 [6.3 3.3 6.  2.5]
 [5.8 2.7 5.1 1.9]
 [7.1 3.  5.9 2.1]
 [6.3 2.9 5.6 1.8]
 [6.5 3.  5.8 2.2]
 [7.6 3.  6.6 2.1]
 [4.9 2.5 4.5 1.7]
 [7.3 2.9 6.3 1.8]
 [6.7 2.5 5.8 1.8]
 [7.2 3.6 6.1 2.5]
 [6.5 3.2 5.1 2. ]
 [6.4 2.7 5.3 1.9]
 [6.8 3.  5.5 2.1]
 [5.7 2.5 5.  2. ]
 [5.8 2.8 5.1 2.4]
 [6.4 3.2 5.3 2.3]
 [6.5 3.  5.5 1.8]
 [7.7 3.8 6.7 2.2]
 [7.7 2.6 6.9 2.3]
 [6.  2.2 5.  1.5]
 [6.9 3.2 5.7 2.3]
 [5.6 2.8 4.9 2. ]
 [7.7 2.8 6.7 2. ]
 [6.3 2.7 4.9 1.8]
 [6.7 3.3 5.7 2.1]
 [7.2 3.2 6.  1.8]
 [6.2 2.8 4.8 1.8]
 [6.1 3.  4.9 1.8]
 [6.4 2.8 5.6 2.1]
 [7.2 3.  5.8 1.6]
 [7.4 2.8 6.1 1.9]
 [7.9 3.8 6.4 2. ]
 [6.4 2.8 5.6 2.2]
 [6.3 2.8 5.1 1.5]
 [6.1 2.6 5.6 1.4]
 [7.7 3.  6.1 2.3]
 [6.3 3.4 5.6 2.4]
 [6.4 3.1 5.5 1.8]
 [6.  3.  4.8 1.8]
 [6.9 3.1 5.4 2.1]
 [6.7 3.1 5.6 2.4]
 [6.9 3.1 5.1 2.3]
 [5.8 2.7 5.1 1.9]
 [6.8 3.2 5.9 2.3]
 [6.7 3.3 5.7 2.5]
 [6.7 3.  5.2 2.3]
 [6.3 2.5 5.  1.9]
 [6.5 3.  5.2 2. ]
 [6.2 3.4 5.4 2.3]
 [5.9 3.  5.1 1.8]]
数据集的目标值:
 [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2 2
 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2
 2 2]
数据集的特征值名字:
 ['sepal length (cm)', 'sepal width (cm)', 'petal length (cm)', 'petal width (cm)']
数据集的目标值名字:
 ['setosa' 'versicolor' 'virginica']
数据集的描述:
 .. _iris_dataset:

Iris plants dataset
--------------------

**Data Set Characteristics:**

    :Number of Instances: 150 (50 in each of three classes)
    :Number of Attributes: 4 numeric, predictive attributes and the class
    :Attribute Information:
        - sepal length in cm
        - sepal width in cm
        - petal length in cm
        - petal width in cm
        - class:
                - Iris-Setosa
                - Iris-Versicolour
                - Iris-Virginica
                
    :Summary Statistics:

    ============== ==== ==== ======= ===== ====================
                    Min  Max   Mean    SD   Class Correlation
    ============== ==== ==== ======= ===== ====================
    sepal length:   4.3  7.9   5.84   0.83    0.7826
    sepal width:    2.0  4.4   3.05   0.43   -0.4194
    petal length:   1.0  6.9   3.76   1.76    0.9490  (high!)
    petal width:    0.1  2.5   1.20   0.76    0.9565  (high!)
    ============== ==== ==== ======= ===== ====================

    :Missing Attribute Values: None
    :Class Distribution: 33.3% for each of 3 classes.
    :Creator: R.A. Fisher
    :Donor: Michael Marshall (MARSHALL%PLU@io.arc.nasa.gov)
    :Date: July, 1988

The famous Iris database, first used by Sir R.A. Fisher. The dataset is taken
from Fisher's paper. Note that it's the same as in R, but not as in the UCI
Machine Learning Repository, which has two wrong data points.

This is perhaps the best known database to be found in the
pattern recognition literature.  Fisher's paper is a classic in the field and
is referenced frequently to this day.  (See Duda & Hart, for example.)  The
data set contains 3 classes of 50 instances each, where each class refers to a
type of iris plant.  One class is linearly separable from the other 2; the
latter are NOT linearly separable from each other.

.. topic:: References

   - Fisher, R.A. "The use of multiple measurements in taxonomic problems"
     Annual Eugenics, 7, Part II, 179-188 (1936); also in "Contributions to
     Mathematical Statistics" (John Wiley, NY, 1950).
   - Duda, R.O., & Hart, P.E. (1973) Pattern Classification and Scene Analysis.
     (Q327.D83) John Wiley & Sons.  ISBN 0-471-22361-1.  See page 218.
   - Dasarathy, B.V. (1980) "Nosing Around the Neighborhood: A New System
     Structure and Classification Rule for Recognition in Partially Exposed
     Environments".  IEEE Transactions on Pattern Analysis and Machine
     Intelligence, Vol. PAMI-2, No. 1, 67-71.
   - Gates, G.W. (1972) "The Reduced Nearest Neighbor Rule".  IEEE Transactions
     on Information Theory, May 1972, 431-433.
   - See also: 1988 MLC Proceedings, 54-64.  Cheeseman et al"s AUTOCLASS II
     conceptual clustering system finds 3 classes in the data.
   - Many, many more ...

1.2 数据可视化介绍(查看数据分布)


通过创建一些图,来查看不同类别是如何通过特征来区分的。在理想情况下,标签类将由一个或多个特征对完美分割,在现实世界中,这种情况很少发生。

seaborn介绍:

  • seaborn基于Matplotlib核心库进行了更高级的API封装,可以让你轻松地画出更漂亮的图形。

seaborn的常用API介绍:

  • seaborn.lmplot()是一个非常有用的方法,它会在绘制二维散点图时,自动完成回归拟合。
    • sns.lmplot()里的x,y分别表示横纵坐标的列名;
    • data:数据集,是DataFrame类型;
    • hue:代表按照species即花的类别分类显示;
    • fit_reg:表示是否进行线性拟合。
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd

from sklearn.datasets import load_iris

# 1. 数据集获取
iris = load_iris()

# 2. 数据集属性描述
# ...

# 3. 数据可视化
# 把数据转换成dataframe的格式
iris_d = pd.DataFrame(iris.data, columns=['Sepal_length', 'Sepal_width', 'Petal_length', 'Petal_width'])
iris_d['Species'] = iris.target # 增加一列目标值


def plot_iris(iris_d, col1, col2):
    sns.lmplot(x=col1, y=col2, data=iris_d, hue='Species', fit_reg=False)
    plt.xlabel(col1)
    plt.ylabel(col2)
    plt.title('鸢尾花种类分布图')
    plt.show()


plot_iris(iris_d, 'Petal_width', 'Sepal_length') # 这两个特征也可以换

在这里插入图片描述


1.3 数据集的划分


1. 机器学习一般的数据集会划分为两个部分:

  • 训练数据:用于训练,构建模型;
  • 测试数据:在模型检验时使用,用于评估模型是否有效。

2. 划分比例:

  • 训练集:70%,80%,75%;
  • 测试集:30%,20%,25%。

3. 数据集划分api:sklearn.model_selection.train_test_split(arrays, *options)

1)参数:

  • x:数据集的特征值;
  • y:数据集的标签值;
  • test_size:测试集占全集的百分比,一般穿float类型数据;
  • random_state:随机数种子,不同的种子会造成不同的随机取样结果;相同的种子采样结果相同。

2)返回值:

  • x_trainx_testy_trainy_test:一共四个返回值,一定要按顺序接收。分别是训练集特征值,测试集特征值,训练集目标值,测试集目标值。
from sklearn.datasets import load_iris, fetch_20newsgroups
from sklearn.model_selection import train_test_split

from pylab import mpl
# 设置显示中文字体
mpl.rcParams["font.sans-serif"] = ["SimHei"]

# 1. 数据集获取
iris = load_iris()

# 2. 数据集属性描述
# ...

# 3. 数据可视化
# ...

# 4. 数据集的划分
x_train, x_test, y_train, y_test = train_test_split(iris.data, iris.target, test_size=0.2, random_state=22)
print('训练集的特征值:\n', x_train)
print('训练集的目标值:\n', y_train)
print('测试集的特征值:\n', x_test)
print('测试集的目标值:\n', y_test)

# 可以发现,随机数种子一样,测试集的目标值一样,即测试结果一样
x_train1, x_test1, y_train1, y_test1 = train_test_split(iris.data, iris.target, test_size=0.2, random_state=2)
x_train2, x_test2, y_train2, y_test2 = train_test_split(iris.data, iris.target, test_size=0.2, random_state=2)
print('测试集的目标值:\n', y_test1)
print('测试集的目标值:\n', y_test2)

输出:

训练集的特征值:
 [[4.8 3.1 1.6 0.2]
 [5.4 3.4 1.5 0.4]
 [5.5 2.5 4.  1.3]
 [5.5 2.6 4.4 1.2]
 [5.7 2.8 4.5 1.3]
 [5.  3.4 1.6 0.4]
 [5.1 3.4 1.5 0.2]
 [4.9 3.6 1.4 0.1]
 [6.9 3.1 5.4 2.1]
 [6.7 2.5 5.8 1.8]
 [7.  3.2 4.7 1.4]
 [6.3 3.3 4.7 1.6]
 [5.4 3.9 1.3 0.4]
 [4.4 3.2 1.3 0.2]
 [6.7 3.  5.  1.7]
 [5.6 3.  4.1 1.3]
 [5.7 2.5 5.  2. ]
 [6.5 3.  5.8 2.2]
 [5.  3.6 1.4 0.2]
 [6.1 2.8 4.  1.3]
 [6.  3.4 4.5 1.6]
 [6.7 3.  5.2 2.3]
 [5.7 4.4 1.5 0.4]
 [5.4 3.4 1.7 0.2]
 [5.  3.5 1.3 0.3]
 [4.8 3.  1.4 0.1]
 [5.5 4.2 1.4 0.2]
 [4.6 3.6 1.  0.2]
 [7.2 3.2 6.  1.8]
 [5.1 2.5 3.  1.1]
 [6.4 3.2 4.5 1.5]
 [7.3 2.9 6.3 1.8]
 [4.5 2.3 1.3 0.3]
 [5.  3.  1.6 0.2]
 [5.7 3.8 1.7 0.3]
 [5.  3.3 1.4 0.2]
 [6.2 2.2 4.5 1.5]
 [5.1 3.5 1.4 0.2]
 [6.4 2.9 4.3 1.3]
 [4.9 2.4 3.3 1. ]
 [6.3 2.5 4.9 1.5]
 [6.1 2.8 4.7 1.2]
 [5.9 3.2 4.8 1.8]
 [5.4 3.9 1.7 0.4]
 [6.  2.2 4.  1. ]
 [6.4 2.8 5.6 2.1]
 [4.8 3.4 1.9 0.2]
 [6.4 3.1 5.5 1.8]
 [5.9 3.  4.2 1.5]
 [6.5 3.  5.5 1.8]
 [6.  2.9 4.5 1.5]
 [5.5 2.4 3.8 1.1]
 [6.2 2.9 4.3 1.3]
 [5.2 4.1 1.5 0.1]
 [5.2 3.4 1.4 0.2]
 [7.7 2.6 6.9 2.3]
 [5.7 2.6 3.5 1. ]
 [4.6 3.4 1.4 0.3]
 [5.8 2.7 4.1 1. ]
 [5.8 2.7 3.9 1.2]
 [6.2 3.4 5.4 2.3]
 [5.9 3.  5.1 1.8]
 [4.6 3.1 1.5 0.2]
 [5.8 2.8 5.1 2.4]
 [5.1 3.5 1.4 0.3]
 [6.8 3.2 5.9 2.3]
 [4.9 3.1 1.5 0.1]
 [5.5 2.3 4.  1.3]
 [5.1 3.7 1.5 0.4]
 [5.8 2.7 5.1 1.9]
 [6.7 3.1 4.4 1.4]
 [6.8 3.  5.5 2.1]
 [5.2 2.7 3.9 1.4]
 [6.7 3.1 5.6 2.4]
 [5.3 3.7 1.5 0.2]
 [5.  2.  3.5 1. ]
 [6.6 2.9 4.6 1.3]
 [6.  2.7 5.1 1.6]
 [6.3 2.3 4.4 1.3]
 [7.7 3.  6.1 2.3]
 [4.9 3.  1.4 0.2]
 [4.6 3.2 1.4 0.2]
 [6.3 2.7 4.9 1.8]
 [6.6 3.  4.4 1.4]
 [6.9 3.1 4.9 1.5]
 [4.3 3.  1.1 0.1]
 [5.6 2.7 4.2 1.3]
 [4.8 3.4 1.6 0.2]
 [7.6 3.  6.6 2.1]
 [7.7 2.8 6.7 2. ]
 [4.9 2.5 4.5 1.7]
 [6.5 3.2 5.1 2. ]
 [5.1 3.3 1.7 0.5]
 [6.3 2.9 5.6 1.8]
 [6.1 2.6 5.6 1.4]
 [5.  3.4 1.5 0.2]
 [6.1 3.  4.6 1.4]
 [5.6 3.  4.5 1.5]
 [5.1 3.8 1.5 0.3]
 [5.6 2.8 4.9 2. ]
 [4.4 3.  1.3 0.2]
 [5.5 2.4 3.7 1. ]
 [4.7 3.2 1.6 0.2]
 [6.7 3.3 5.7 2.5]
 [5.2 3.5 1.5 0.2]
 [6.4 2.7 5.3 1.9]
 [6.3 2.8 5.1 1.5]
 [4.4 2.9 1.4 0.2]
 [6.1 3.  4.9 1.8]
 [4.9 3.1 1.5 0.2]
 [5.  2.3 3.3 1. ]
 [4.8 3.  1.4 0.3]
 [5.8 4.  1.2 0.2]
 [6.3 3.4 5.6 2.4]
 [5.4 3.  4.5 1.5]
 [7.1 3.  5.9 2.1]
 [6.3 3.3 6.  2.5]
 [5.1 3.8 1.9 0.4]
 [6.4 2.8 5.6 2.2]
 [7.7 3.8 6.7 2.2]]
训练集的目标值:
 [0 0 1 1 1 0 0 0 2 2 1 1 0 0 1 1 2 2 0 1 1 2 0 0 0 0 0 0 2 1 1 2 0 0 0 0 1
 0 1 1 1 1 1 0 1 2 0 2 1 2 1 1 1 0 0 2 1 0 1 1 2 2 0 2 0 2 0 1 0 2 1 2 1 2
 0 1 1 1 1 2 0 0 2 1 1 0 1 0 2 2 2 2 0 2 2 0 1 1 0 2 0 1 0 2 0 2 2 0 2 0 1
 0 0 2 1 2 2 0 2 2]
测试集的特征值:
 [[5.4 3.7 1.5 0.2]
 [6.4 3.2 5.3 2.3]
 [6.5 2.8 4.6 1.5]
 [6.3 2.5 5.  1.9]
 [6.1 2.9 4.7 1.4]
 [6.8 2.8 4.8 1.4]
 [6.7 3.1 4.7 1.5]
 [6.  3.  4.8 1.8]
 [5.6 2.9 3.6 1.3]
 [5.  3.2 1.2 0.2]
 [6.9 3.2 5.7 2.3]
 [5.7 3.  4.2 1.2]
 [7.4 2.8 6.1 1.9]
 [7.2 3.6 6.1 2.5]
 [5.  3.5 1.6 0.6]
 [7.9 3.8 6.4 2. ]
 [5.6 2.5 3.9 1.1]
 [5.7 2.8 4.1 1.3]
 [6.  2.2 5.  1.5]
 [5.7 2.9 4.2 1.3]
 [5.1 3.8 1.6 0.2]
 [6.9 3.1 5.1 2.3]
 [5.5 3.5 1.3 0.2]
 [5.8 2.6 4.  1.2]
 [5.8 2.7 5.1 1.9]
 [4.7 3.2 1.3 0.2]
 [7.2 3.  5.8 1.6]
 [6.5 3.  5.2 2. ]
 [6.7 3.3 5.7 2.1]
 [6.2 2.8 4.8 1.8]]
测试集的目标值:
 [0 2 1 2 1 1 1 2 1 0 2 1 2 2 0 2 1 1 2 1 0 2 0 1 2 0 2 2 2 2]
测试集的目标值:
 [0 0 2 0 0 2 0 2 2 0 0 0 0 0 1 1 0 1 2 1 1 1 2 1 1 0 0 2 0 2]
测试集的目标值:
 [0 0 2 0 0 2 0 2 2 0 0 0 0 0 1 1 0 1 2 1 1 1 2 1 1 0 0 2 0 2]

1.4 特征工程


1.4.1 归一化


1. 定义:

通过对原始数据进行变换把数据映射到一个区间(默认为[0, 1])之间。

2. 公式:

X ′ = x − m i n m a x − m i n X'=\dfrac{x-min}{max-min} X=maxminxmin        X ′ ′ = X ′ ∗ ( m x − m i ) + m i X''=X'*(mx-mi)+mi X′′=X(mxmi)+mi

该公式作用于每一列, m a x max max为一列的最大值, m i n min min为一列的最小值, X ′ ′ X'' X′′为最终结果, m x mx mx m i mi mi分别为指定区间值默认 m x mx mx为1, m i mi mi为0。

在这里插入图片描述

3. API:

  • sklearn.preprocessing.MinMaxScaler(feature_range=(0,1)...)
    • feature_range:作用是指定区间,即 m x mx mx m i mi mi
  • MinMaxScaler.fit_transform(X)
    • X:要处理的数据;
    • 返回值:转换后的array

4. 数据计算:

我们将使用以下数据进行演示:

data = pd.DataFrame(np.random.randint(200, 4000, size=(5, 4)))
print(data)
print(data.shape)
'''     
      0     1     2     3
0  1042  2383  3304  1735
1  1671   583  3095  1299
2   585   202  1782   310
3  1598  2106  3108  1896
4  1844  3765  2852   516
(5, 4)
'''

先实例化MinMaxScaler,再通过fit_transform()处理数据:

import numpy as np
import pandas as pd

from sklearn.preprocessing import MinMaxScaler

data = pd.DataFrame(np.random.randint(200, 4000, size=(5, 4)))
print(data)
print(data.shape)

transfer = MinMaxScaler(feature_range=(2, 3)) # 实例化转换器
ret = transfer.fit_transform(data[[0, 1, 2]]) # 只处理前三列,目标值不用处理
print("最小值最大值归一化处理结果:\n", ret)
print(ret.shape)

输出:

      0     1     2     3
0  1042  2383  3304  1735
1  1671   583  3095  1299
2   585   202  1782   310
3  1598  2106  3108  1896
4  1844  3765  2852   516
(5, 4)
最小值最大值归一化处理结果:
 [[2.3629865  2.61212461 3.        ]
 [2.86258936 2.10693236 2.86268068]
 [2.         2.         2.        ]
 [2.80460683 2.53438114 2.87122208]
 [3.         3.         2.70302234]]
(5, 3)

5. 问题:如果数据中有异常值怎么办?

最小值和最大值非常容易受到异常点的影响,所以这归一化方法的鲁棒性(稳定性)较差,只适用于传统精确小数据场景。

可以通过标准化解决。


1.4.2 标准化


1. 定义:

通过对原始数据进行变换把数据变换到均值为0,标准差为1的范围内。

2. 公式:

X ′ = x − m e a n σ X'=\dfrac{x-mean}{\sigma} X=σxmean

作用于每一列, m e a n mean mean为平均值, σ \sigma σ为标准差。

  • 对于归一化:如果出现异常点,影响了最大值和最小值,那么结果显然会发生改变;
  • 对于标准化:如果出现异常点,由于具有一定数据量,少量的异常点对于平均值的影响并不大,从而方差改变较小。

3. API:

  • sklearn.preprocessing.StandardScaler()
    • 处理之后,每一列的所有数据都聚集在均值0附近,标准差为1;
  • StandardScaler.fit_transform(X)
    • X:要处理的数据;
    • 返回值:转换后的形状相同的array

4. 数据计算:

准备数据:

import pandas as pd
import numpy as np

data = pd.DataFrame(np.random.randint(200, 4000, size=(5, 4)))
print(data)
'''
      0     1     2     3
0  3682  1293  1565  2305
1   900  2344  2176  3036
2   884  2296  2413   292
3  1987   486  1279  1088
4  2206  2416  1155   990
'''

先实例化StandardScaler,再用fit_transform()处理数据:

import pandas as pd
import numpy as np
from sklearn.preprocessing import StandardScaler

data = pd.DataFrame(np.random.randint(200, 4000, size=(5, 4)))
print(data)
# 1. 实例化一个转换器嘞
transfer = StandardScaler()
# 2. 调用fit_transform
ret = transfer.fit_transform(data[[0, 1, 2]])
print('标准化的结果:\n', ret)
print('每一列特征的平均值:\n', transfer.mean_)
print('每一列特征的方差:\n', transfer.var_)

输出:

      0     1     2     3
0  3682  1293  1565  2305
1   900  2344  2176  3036
2   884  2296  2413   292
3  1987   486  1279  1088
4  2206  2416  1155   990
标准化的结果:
 [[ 1.6993148  -0.62243456 -0.30818698]
 [-1.00180151  0.75768933  0.92577267]
 [-1.01733633  0.69465798  1.40441168]
 [ 0.05359512 -1.6821491  -0.88578511]
 [ 0.26622793  0.85223635 -1.13621227]]
每一列特征的平均值:
 [1931.8 1767.  1717.6]
每一列特征的方差:
 [1060785.76  579921.6   245177.44]

1.5 鸢尾花种类预测


1. 再识K-近邻算法API:sklearn.neighbors.KNeighborsClassifier(n_neighbors=5, algorithm=‘auto’)

  • n_neighborsint类型数据,可选,默认为5,表示查询的邻居数;
  • algorithm:{‘auto’, ‘ball_tree’, ‘kd_tree’, ‘brute’}
    • auto:默认参数为auto,表示自动选择算法;
    • brute:蛮力搜索,也就是线性扫描,当训练集很大时,计算非常耗时;
    • kd_tree:构造kd树存储数据,以便对其进行快速检索。在维数小于20时效率高;
    • ball_tree:是为了客服kd树高维失效而发明的,其构造过程是以质心C和半径r分割样本空间,每个节点是一个超球体。

2. 用K近邻算法实现鸢尾花的种类预测:

from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.neighbors import KNeighborsClassifier

# 1. 获取数据
iris = load_iris()

# 2. 数据基本处理
x_train, x_test, y_train, y_test = train_test_split(iris.data, iris.target, test_size=0.2, random_state=22)

# 3. 特征工程 - 特征预处理
transfer = StandardScaler() # 标准化
x_train = transfer.fit_transform(x_train)
x_test = transfer.fit_transform(x_test)

# 4. 机器学习-KNN
# 4.1 实例化一个估计器
estimator = KNeighborsClassifier(n_neighbors=5)
# 4.2 模型训练
estimator.fit(x_train, y_train)

# 5. 模型评估
# 5.1 预测值结果输出
y_pre = estimator.predict(x_test) # 预测值
print('预测值是:\n', y_pre)
print('预测值和真实值的对比:\n', y_pre==y_test)
# 5.2 准确率计算
score = estimator.score(x_test, y_test)
print('准确率为:\n', score)

输出:

预测值是:
 [0 2 1 1 1 1 1 1 1 0 2 1 2 2 0 2 1 1 1 1 0 2 0 1 1 0 1 1 2 1]
预测值和真实值的对比:
 [ True  True  True False  True  True  True False  True  True  True  True
  True  True  True  True  True  True False  True  True  True  True  True
 False  True False False  True False]
准确率为:
 0.7666666666666667

1.6 交叉验证和网格搜索


1.6.1 交叉验证


1. 什么是交叉验证:

将拿到的训练集数据,再细分为训练集和验证集。将数据分为S份,其中一份作为验证集,然后经过S组的测试,每次更换不同的验证集。会得到了S组泛化误差,取平均值为最终。又称S折交叉验证。

2. 交叉验证过程:

1)随机将训练数据等分成k份,S1,S2,…,Sk;

2)对于每一个模型Mi,算法执行k次,每次选择一个Sj作为验证集,而其他作为训练集来训练模型Mi,把训练得到的模型在Sj上进行测试,这样一来,每次都会得到一个误差E,最后对k次得到的误差求平均,就可以得到Mi的泛化误差;

3)算法选择具有最小泛化误差的模型作为最终模型,并且在整个训练集上再次训练该模型,从而得到最终的模型。

  • 以4折交叉验证为例:

在这里插入图片描述

将训练集数据等分成4份,用不同的验证集做4次训练,得到4组准确率,求这4组准确率的平均数。用1减去这个平均数可以得到一个泛化误差。

3. 交叉验证的目的:

为了让被评估的模型更加准确可信。

注意,只是准确可信,不能优化模型。


1.6.2 网格搜索


1. 什么是网格搜索:

通常情况下,有很多参数是需要手动指定的(如K-近邻算法中的K值),这种参数叫超参数。但是手动指定过程繁杂,所以需要对模型预设几种参数组合。每组超参数都采用交叉验证来进行评估。最后选出最优参数组合,建立模型。

2. 交叉验证、网格搜索(模型选择与调优)API:

sklearn.model_selection.GridSearchCV(estimator, param_grid=None, cv=None)

1)参数:

  • estimator:估计器对象;
  • param_grid:估计器参数,字典类型数据;
  • cv:指定几折交叉验证。

2)结果分析:

  • best_score_:在交叉验证中得到的最好结果;
  • best_estimator_:最好的参数模型(有时打印时不显示参数,可用best_params_代替);
  • cv_results_:每次交叉验证后的验证集准确率结果和训练集准确率结果。

6.1.3 鸢尾花实例


from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.preprocessing import StandardScaler
from sklearn.neighbors import KNeighborsClassifier

# 1. 获取数据
iris = load_iris()

# 2. 数据基本处理
x_train, x_test, y_train, y_test = train_test_split(iris.data, iris.target, test_size=0.2, random_state=22)

# 3. 特征工程 - 特征预处理
transfer = StandardScaler() # 标准化
x_train = transfer.fit_transform(x_train)
x_test = transfer.fit_transform(x_test)

# 4. 机器学习-KNN
# 4.1 实例化一个估计器
estimator = KNeighborsClassifier(n_neighbors=5)
# 4.2 模型调优 -- 交叉验证,网格搜索
param_grid = {'n_neighbors': [1, 3, 5, 7]}
estimator = GridSearchCV(estimator, param_grid=param_grid, cv=5)
# 4.3 模型训练
estimator.fit(x_train, y_train)

# 5. 模型评估
# 5.1 预测值结果输出
y_pre = estimator.predict(x_test) # 预测值
print('预测值是:\n', y_pre)
print('预测值和真实值的对比:\n', y_pre==y_test)
# 5.2 准确率计算
score = estimator.score(x_test, y_test)
print('准确率为:\n', score)
# 5.3 查看交叉验证,网格搜索的一些属性
print('交叉验证中,得到的最好结果:\n', estimator.best_score_)
print('交叉验证中,得到的最好模型的参数:\n', estimator.best_params_)
print('交叉验证中,得到的模型结果是:\n', estimator.cv_results_)

输出:

预测值是:
 [0 2 1 1 1 1 1 1 1 0 2 1 2 2 0 2 1 1 1 1 0 2 0 1 1 0 1 1 2 1]
预测值和真实值的对比:
 [ True  True  True False  True  True  True False  True  True  True  True
  True  True  True  True  True  True False  True  True  True  True  True
 False  True False False  True False]
准确率为:
 0.7666666666666667
交叉验证中,得到的最好结果:
 0.9583333333333333
交叉验证中,得到的最好模型的参数:
 {'n_neighbors': 5}
交叉验证中,得到的模型结果是:
 {'mean_fit_time': array([0.00078802, 0.00040035, 0.00019875, 0.00040088]), 'std_fit_time': array([0.00039485, 0.00049033, 0.00039749, 0.00049097]), 'mean_score_time': array([0.00207067, 0.00200114, 0.00280037, 0.00239539]), 'std_score_time': array([1.03795303e-04, 2.66602721e-06, 4.03164015e-04, 4.85739388e-04]), 'param_n_neighbors': masked_array(data=[1, 3, 5, 7],
             mask=[False, False, False, False],
       fill_value='?',
            dtype=object), 'params': [{'n_neighbors': 1}, {'n_neighbors': 3}, {'n_neighbors': 5}, {'n_neighbors': 7}], 'split0_test_score': array([0.95833333, 0.95833333, 1.        , 1.        ]), 'split1_test_score': array([0.95833333, 0.91666667, 0.91666667, 0.91666667]), 'split2_test_score': array([0.95833333, 0.95833333, 1.        , 1.        ]), 'split3_test_score': array([0.875     , 0.875     , 0.91666667, 0.91666667]), 'split4_test_score': array([0.95833333, 0.95833333, 0.95833333, 0.95833333]), 'mean_test_score': array([0.94166667, 0.93333333, 0.95833333, 0.95833333]), 'std_test_score': array([0.03333333, 0.03333333, 0.0372678 , 0.0372678 ]), 'rank_test_score': array([3, 4, 1, 1])}

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/409596.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Vue 卸载eslint

卸载依赖 npm uninstall eslint --save 然后 进入package.json中,删除残留信息。 否则在执行卸载后,运行会报错。 之后再起项目。

Vue+SpringBoot打造高校实验室管理系统

目录 一、摘要1.1 项目介绍1.2 项目录屏 二、研究内容2.1 实验室类型模块2.2 实验室模块2.3 实验管理模块2.4 实验设备模块2.5 实验订单模块 三、系统设计3.1 用例设计3.2 数据库设计 四、系统展示五、样例代码5.1 查询实验室设备5.2 实验放号5.3 实验预定 六、免责说明 一、摘…

【vue3语法】开发使用创建项目等

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目录 前言一、vue3创建vue3v2函数式、v3组合式api响应式方法ref、reactive计算属性conputed监听属性wacthvue3 选项式生命周期父子通信父传子defineProps编译宏 子传父de…

Typora结合PicGo + 使用Github搭建个人免费图床

文章目录 一、国内图床比较二、使用Github搭建图床三、PicGo整合Github图床1、下载并安装PicGo2、设置图床3、整合jsDelivr具体配置介绍 4、测试5、附录 四、Typora整合PicGo实现自动上传 每次写博客时,我都会习惯在Typora写好,然后再复制粘贴到对应的网…

2.23日学习总结

今天刚学的01背包&#xff0c;套模板就可以解决这道题。 #include<iostream> using namespace std; int n,m,f[13000],w[3410],v[3410]; int main() {cin>>n>>m;for(int i1;i<n;i)scanf("%d %d",&w[i],&v[i]);for(int i1;i<n;i)for…

第三节:Vben Admin登录对接后端login接口

系列文章目录 第一节&#xff1a;Vben Admin介绍和初次运行 第二节&#xff1a;Vben Admin 登录逻辑梳理和对接后端准备 文章目录 系列文章目录前言一、Flask项目介绍二、使用步骤1.User模型创建2.迁移模型3. Token创建4. 编写蓝图5. 注册蓝图 三. 测试登录总结 前言 上一节&…

1904_ARM Cortex M系列芯片特性小结

1904_ARM Cortex M系列芯片特性小结 全部学习汇总&#xff1a; g_arm_cores: ARM内核的学习笔记 (gitee.com) ARM Cortex M系列的MCU用过好几款了&#xff0c;也涉及到了不同的内核。不过&#xff0c;关于这些内核的基本的特性还是有些不了解。从ARM的官方网站上找来了一个对比…

皓学IT:WEB05-Servlet

一、Servlet 1.1.概述 Servlet是SUN公司提供的一套规范&#xff0c;名称就叫Servlet规范&#xff0c;它也是JavaEE规范之一。我们可以像学习Java基础一样&#xff0c;通过API来学习Servlet。这里需要注意的是&#xff0c;在我们之前JDK的API中是没有Servlet规范的相关内容&am…

k8s笔记26--快速实现prometheus监控harbor

k8s笔记26--快速实现prometheus监控harbor 简介采集指标&配置grafana面板采集指标配置grafana面板 说明 简介 harbor是当前最流行的开源容器镜像仓库项目&#xff0c;被大量IT团队广泛应用于生产、测试环境的项目中。本文基于Harbor、Prometheus、Grafana介绍快速实现监控…

H5多用途的产品介绍展示单页HTML5静态网页模板

H5多用途的产品介绍展示单页HTML5静态网页模板 源码介绍&#xff1a;一款H5自适应多用途的产品介绍展示单页HTML静态网页模板&#xff0c;可用于团队官网、产品官网。 下载地址&#xff1a; https://www.changyouzuhao.cn/13534.html

MYSQL安装及卸载

目录 一、下载 二、解压 三、配置 1. 添加环境变量 2. 初始化MySQL 3. 注册MySQL服务 4. 启动MySQL服务 5. 修改默认账户密码 四、登录MySQL 五、卸载MySQL 一、下载 点开下面的链接&#xff1a;MySQL :: Download MySQL Community Server 点击Download 就可以下载对…

vim恢复.swp [BJDCTF2020]Cookie is so stable1

打开题目 扫描目录得到 关于 .swp 文件 .swp 文件一般是 vim 编辑器在编辑文件时产生的&#xff0c;当用 vim 编辑器编辑文件时就会产生&#xff0c;正常退出时 .swp 文件被删除&#xff0c;但是如果直接叉掉&#xff08;非正常退出&#xff09;&#xff0c;那么 .swp 文件就会…

爬虫入门五(Scrapy架构流程介绍、Scrapy目录结构、Scrapy爬取和解析、Settings相关配置、持久化方案)

文章目录 一、Scrapy架构流程介绍二、Scrapy目录结构三、Scrapy爬取和解析Scrapy的一些命令css解析xpath解析 四、Settings相关配置提高爬取效率基础配置增加爬虫的爬取效率 五、持久化方案 一、Scrapy架构流程介绍 Scrapy一个开源和协作的框架&#xff0c;其最初是为了页面抓取…

基于springboot+vue的音乐网站(前后端分离)

博主主页&#xff1a;猫头鹰源码 博主简介&#xff1a;Java领域优质创作者、CSDN博客专家、阿里云专家博主、公司架构师、全网粉丝5万、专注Java技术领域和毕业设计项目实战&#xff0c;欢迎高校老师\讲师\同行交流合作 ​主要内容&#xff1a;毕业设计(Javaweb项目|小程序|Pyt…

【其他】简易代码项目记录

1. KeypointDetection 1.1. CharPointDetection 识别字符中的俩个关键点。 1.2. Facial-keypoints-detection 用于检测人脸的68个关键点示例。 1.3. Hourglass-facekeypoints 使用基于论文Hourglass 的模型实现人体关键点检测。 1.4. Realtime-Action-Recognition containing:…

15.4K Star,超强在线编辑器

Hi&#xff0c;骚年&#xff0c;我是大 G&#xff0c;公众号「GitHub指北」会推荐 GitHub 上有趣有用的项目&#xff0c;一分钟 get 一个优秀的开源项目&#xff0c;挖掘开源的价值&#xff0c;欢迎关注。 今天推荐一款非常棒的开源实时协作编辑器&#xff0c;可用于多人同时编…

复旦大学EMBA联合澎湃科技:共议科技迭代 创新破局

1月18日&#xff0c;由复旦大学管理学院、澎湃新闻、厦门市科学技术局联合主办&#xff0c;复旦大学EMBA项目、澎湃科技承办的“君子知道”复旦大学EMBA前沿论坛在厦门成功举办。此次论坛主题为“科技迭代 创新破局”&#xff0c;上海、厦门两地的政策研究专家、科学家、科创企…

JavaScript流程控制

文章目录 1. 顺序结构2. 分支结构2.1 if 语句2.2 if else 双分支语句2.3 if else if 多分支语句三元表达式 2.4 switch 语句switch 语句和 if else if语句区别 3. 循环结构3.1 for 循环断点调试 3.2 双重 for 循环3.3 while 循环3.4 do while 循环3.5 contiue break 关键字 4. …

3.WEB渗透测试-前置基础知识-快速搭建渗透环境(上)

上一个内容&#xff1a;2.WEB渗透测试-前置基础知识-web基础知识和操作系统-CSDN博客 1.安装虚拟机系统 linux Kali官网下载地址&#xff1a; https://www.kali.org/get-kali/#kali-bare-metal Centos官网下载地址&#xff1a; https://www.centos.org/download/ Deepin官网下…

Flask基础学习4

19-【实战】问答平台项目结构搭建_剪_哔哩哔哩_bilibili 参考如上大佬的视频教程&#xff0c;本博客仅当学习笔记&#xff0c;侵权请联系删除 问答发布的web前端页面实现 register.html {% extends base.html %}{% block head %}<link rel"stylesheet" href&q…
最新文章