过拟合与欠拟合

一、模型选择

1、问题导入

2、训练误差与泛化误差

3、验证数据集和测试数据集

4、K-折交叉验证

       一般在没有足够多数据时使用。

二、过拟合与欠拟合

1、过拟合

过拟合的定义:

       当学习器把训练样本学的“太好”了的时候,很可能已经把训练样本自身的一些特点当作了所有潜在样本都会具有的一般性质,这样就会导致泛化性能下降,这种现象称为过拟合。具体表现就是最终模型在训练集上效果好;在测试集上效果差。模型泛化能力弱。

过拟合的原因:

  • 训练数据中噪音干扰过大,使得学习器认为部分噪音是特征从而扰乱学习规则。
  • 建模样本选取有误,例如训练数据太少,抽样方法错误,样本label错误等,导致样本不能代表整体。
  • 模型不合理,或假设成立的条件与实际不符。
  • 特征维度/参数太多,导致模型复杂度太高。

2、欠拟合

欠拟合的定义:

       欠拟合是指对训练样本的一般性质尚未学好。在训练集及测试集上的表现都不好。

欠拟合的原因:

  • 模型复杂度过低
  • 特征量过少

3、模型容量

4、数据复杂度

三、代码解释

       我们可以通过多项式拟合来探索这些概念。

import math
import numpy as np
import torch
from torch import nn
from d2l import torch as d2l

1、生成数据集

       给定$x$,我们将使用以下三阶多项式来生成训练和测试数据的标签:

$y = 5 + 1.2x - 3.4\frac{x^2}{2!} + 5.6 \frac{x^3}{3!} + \epsilon \text{ where } \epsilon \sim \mathcal{N}(0, 0.1^2).$

       噪声项$\epsilon$服从均值为0且标准差为0.1的正态分布。在优化的过程中,我们通常希望避免非常大的梯度值或损失值。这就是我们将特征从$x^i$调整为$\frac{x^i}{i!}$的原因,这样可以避免很大的$i$带来的特别大的指数值。我们将为训练集和测试集各生成100个样本。

max_degree = 20  # 多项式的最大阶数
n_train, n_test = 100, 100  # 训练和测试数据集大小
true_w = np.zeros(max_degree)  # 分配大量的空间
true_w[0:4] = np.array([5, 1.2, -3.4, 5.6])     # 注意true_w.shape = (max_degree,),只是后面16个都为0
features = np.random.normal(size=(n_train + n_test, 1))
np.random.shuffle(features)     # features.shape = (n_train + n_test, 1)
# 使用 numpy.power 函数将 features 的每个元素分别与 np.arange(max_degree)(从0到max_degree-1的数组)进行幂运算。
poly_features = np.power(features, np.arange(max_degree).reshape(1, -1))    # ploy_feature.shape = (n_train + n_test, max_degree)
for i in range(max_degree):
    poly_features[:, i] /= math.gamma(i + 1)  # gamma(n)=(n-1)! 第i列数据除以i的阶乘
# labels的维度:(n_train+n_test,)
labels = np.dot(poly_features, true_w)  # labels.shape = (n_train + n_test,)  通过将多项式特征 poly_features 与真实系数 true_w 相乘得到标签。
labels += np.random.normal(scale=0.1, size=labels.shape)    # 向标签数据添加服从正态分布的噪声
# 下面4个分别对应: w x x^n/n! y=w*(x^n/n!)
true_w, features, poly_features, labels = [torch.tensor(x, dtype=       # 多项式:"polynomial"
    torch.float32) for x in [true_w, features, poly_features, labels]]  # for训练遍历列表[true_w, features, poly_features, labels]中的4个元素,将NumPy ndarray转换为tensor

2、对模型进行训练和测试

       首先让我们实现一个函数来评估模型在给定数据集上的损失。

def evaluate_loss(net, data_iter, loss):
    """评估给定数据集上模型的损失"""
    metric = d2l.Accumulator(2)  # 损失的总和,样本数量
    for X, y in data_iter:
        out = net(X)
        y = y.reshape(out.shape)
        l = loss(out, y)
        metric.add(l.sum(), l.numel())
    return metric[0] / metric[1]

现在定义训练函数。

def train(train_features, test_features, train_labels, test_labels,
          num_epochs=400):
    loss = nn.MSELoss(reduction='none')     # 使用均方根损失
    input_shape = train_features.shape[-1]  # input_shape=20
    # 不设置偏置,因为我们已经在多项式中实现了它
    net = nn.Sequential(nn.Linear(input_shape, 1, bias=False))
    batch_size = min(10, train_labels.shape[0])
    train_iter = d2l.load_array((train_features, train_labels.reshape(-1,1)),
                                batch_size)
    test_iter = d2l.load_array((test_features, test_labels.reshape(-1,1)),
                               batch_size, is_train=False)
    trainer = torch.optim.SGD(net.parameters(), lr=0.01)
    animator = d2l.Animator(xlabel='epoch', ylabel='loss', yscale='log',
                            xlim=[1, num_epochs], ylim=[1e-3, 1e2],
                            legend=['train', 'test'])
    for epoch in range(num_epochs):
        d2l.train_epoch_ch3(net, train_iter, loss, trainer)
        if epoch == 0 or (epoch + 1) % 20 == 0:
            animator.add(epoch + 1, (evaluate_loss(net, train_iter, loss),
                                     evaluate_loss(net, test_iter, loss)))
    print('weight:', net[0].weight.data.numpy())

3、拟合结果分析

(1)三阶多项式函数拟合(正常)

       我们将首先使用三阶多项式函数,它与数据生成函数的阶数相同。结果表明,该模型能有效降低训练损失和测试损失。学习到的模型参数也接近真实值$w = [5, 1.2, -3.4, 5.6]$

# 从多项式特征中选择前4个维度,即1,x,x^2/2!,x^3/3!
# poly_features每行的前4个值是线性相关的,因为此时已经进行了幂运算、阶乘运算,就差乘true_w了,也正因如此前面才能使用单线性层nn.Linear拟合
train(poly_features[:n_train, :4], poly_features[n_train:, :4],
      labels[:n_train], labels[n_train:])

(2)线性函数拟合(欠拟合)

       让我们再看看线性函数拟合,减少该模型的训练损失相对困难。在最后一个迭代周期完成后,训练损失仍然很高。当用来拟合非线性模式(如这里的三阶多项式函数)时,线性模型容易欠拟合。不过这里的欠拟合是因为数据没给全(因为数据没给全而导致特征不明显),前4项weight的后面2项模型学不到,人为强制使模型欠拟合了,而不是模型容量不够。

# 从多项式特征中选择前2个维度,即1和x
train(poly_features[:n_train, :2], poly_features[n_train:, :2],
      labels[:n_train], labels[n_train:])

(3)高阶多项式函数拟合(过拟合) 

       现在,让我们尝试使用一个阶数过高的多项式来训练模型。在这种情况下,没有足够的数据用于学到高阶系数应该具有接近于零的值。因此,这个过于复杂的模型会轻易受到训练数据中噪声的影响。虽然训练损失可以有效地降低,但测试损失仍然很高。结果表明,复杂模型对数据造成了过拟合。其实就是给的数据太多(后面16项多余了),模型把不该学的(后面16项的weight)给学到了,于是就造成了过拟合。

# 从多项式特征中选取所有维度
train(poly_features[:n_train, :], poly_features[n_train:, :],
      labels[:n_train], labels[n_train:], num_epochs=1500)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/245843.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

JavaSE语法之六:类和对象(超全!!!)

文章目录 一、面向对象的初步认识1. 什么是面向对象2. 面向对象与面向过程 二、类和对象三、类的定义和使用四、类的实例化五、this引用六、对象的构成及初始化1. 如何初始化对象2. 构造方法3. 默认初始化4. 就地初始化 一、面向对象的初步认识 1. 什么是面向对象 Java中一切…

实验02:RIP配置

1.实验目的: 了解路由选择协议(Routing Protocol)的基本原理及分类;掌握RIP协议的基本原理;实现RIP协议;掌握路由器配置及路由表查看的基本命令。 2.实验内容: 建立拓扑结构;配置…

【已解决】ModuleNotFoundError: No module named ‘taming‘

问题描述 Traceback (most recent call last) <ipython-input-14-2683ccd40dcb> in <module> 16 from omegaconf import OmegaConf 17 from PIL import Image ---> 18 from taming.models import cond_transformer, vqgan 19 import taming.modu…

美团、阿里、快手、百度 | NLP暑期算法实习复盘

面试锦囊之面经分享系列&#xff0c;持续更新中 后台回复『面试』加入讨论组交流噢 背景 211CS本港三DS硕&#xff0c;硕士research的方向是NLP&#xff0c;目标是找任何方向的算法实习。 本科做开发为主没有算法经验&#xff0c;没有top比赛&#xff0c;没有过算法实习&…

KUKA机器人如何在程序中编辑等待时间?

KUKA机器人如何在程序中编辑等待时间&#xff1f; 如下图所示&#xff0c;如何实现在P1点和P2点之间等待设定的时间&#xff1f; 如下图所示&#xff0c;可以直接输入wait sec 2&#xff08;等待2秒&#xff09;&#xff0c; 如下图所示&#xff0c;再次选中该程序后&#…

网络基础——路由协议及ensp操作

目录 一、路由器及路由表 1.路由协议&#xff1a; 2.路由器转发原理&#xff1a; 3.路由表&#xff1a; 二、静态路由优缺点及特殊静态路由默认路由 1.静态路由的优缺点&#xff1a; 2.下一跳地址 3.默认路由 三、静态路由配置 四、补充备胎 平均负载 五、补充&…

微软Microsoft二面面试题分享通过总结(不是标准答案分享

误打误撞 我写的shitty代码 当年面试算法开发岗竟然通过了 Background 先说下背景&#xff0c;软件工程本科毕业之后&#xff0c;当年8月到北欧读两年制硕士。面试发生在当年的11月&#xff0c;微软哥本哈根&#xff0c;location在丹麦的哥本哈根lingby&#xff08;是不是这么…

伦敦银和纽约银该pick谁?

伦敦银和纽约银不仅是全球最重要的两个白银市场&#xff0c;更是两种截然不同的交易模式&#xff0c;前者是指在伦敦市场上以美元/盎司计价的现货白银&#xff0c;后者是指在纽约商品交易所交易、以美元/盎司计价的白银期货。 如果大家需要在这两种白银投资方式中作出取舍&…

常见的设计模式以及实现方法总结

目录 代码中使用的设计模式总结 前言常见的23种设计模式Singleton模式&#xff08;单例模式&#xff09;理论Spring中创建的Bean Prototype模式&#xff08;原型模式&#xff09;理论Spring中创建的Bean Builder模式&#xff08;构造器模式&#xff09;理论Builder实现了构造器…

JVM之堆学习

一、Java虚拟机内存结构图 二、堆的介绍 1. 前面学习的程序计数器&#xff0c;虚拟机栈和本地方法栈都是线程私有的&#xff0c;堆是线程共享的&#xff1b; 2. 通过 new 关键字&#xff0c;创建的对象都会使用堆内存&#xff0c;其特点是&#xff1a; 它是线程共享的&#x…

Landsat7_C2_ST数据集2019年1月-2022年12月

简介&#xff1a; Landsat7_C2_ST数据集是经大气校正后的地表温度数据&#xff0c;属于Collection2的二级数据产品&#xff0c;以开尔文为单位测量地球表面温度&#xff0c;是全球能量平衡研究和水文模拟中的重要地球物理参数。地表温度数据还有助于监测作物和植被健康状况&am…

单片机——通信协议(FPGA+c语言应用之spi协议解析篇)

引言 串行外设接口(SPI)是微控制器和外围IC&#xff08;如传感器、ADC、DAC、移位寄存器、SRAM等&#xff09;之间使用最广泛的接口之一。本文先简要说明SPI接口&#xff0c;然后介绍ADI公司支持SPI的模拟开关与多路转换器&#xff0c;以及它们如何帮助减少系统电路板设计中的数…

宏景eHR SQL注入漏洞复现

0x01 产品简介 宏景eHR人力资源管理软件是一款人力资源管理与数字化应用相融合&#xff0c;满足动态化、协同化、流程化、战略化需求的软件。 0x02 漏洞概述 宏景eHR app_check_in/get_org_tree.jsp接口处存在SQL注入漏洞&#xff0c;未经过身份认证的远程攻击者可利用此漏洞…

JVM的五大分区

1.方法区 方法区主要用来存储已在虚拟机加载的类的信息、常量、静态变量以及即时编译器编译后的代码信息。该区域是被线程共享的。 2.虚拟机栈 虚拟机栈也就是我们平时说的栈内存&#xff0c;它是为java方法服务的。每个方法在执行的 时候都会创建一个栈帧&#xff0c;用于存…

SpringCloud面试题及答案(最新50道大厂版,持续更新)

在Java开发中&#xff0c;Spring Cloud作为微服务架构的关键组成部分&#xff0c;为了帮助广大Java技术爱好者和专业开发人员深入理解Spring Cloud&#xff0c;本文《SpringCloud面试题及答案&#xff08;最新50道大厂版&#xff0c;持续更新&#xff09;》提供了最前沿、最实用…

C#比较两个list集合类的差异

C#中List中自带的差集计算方法 List 继承了Enumerable &#xff0c;Enumerable 中有一个Except方法 它有两个实现&#xff1a; 第一个实现是通过使用默认的相等比较器对值进行比较&#xff0c;生成两个序列的差集。 第二个实现是通过使用指定的 IEqualityComparer 对值进行…

一分钟带你了解电容

电容器中的电容究竟是怎么定义的&#xff1f; 一个电容器&#xff0c;如果带1库的电量时两级间的电势差是1伏&#xff0c;这个电容器的电容就是1法拉&#xff0c;即&#xff1a;CQ/U 。但电容的大小不是由Q&#xff08;带电量&#xff09;或U&#xff08;电压&#xff09;决定…

[C错题本]转义字符/指针与首元素/运算

\a响铃 \b退格 \f换页 \r回车 \t水平制表 \v垂直制表 \单引号 \"双引号 \\反斜杠 \0dd八进制&#xff08;0-7&#xff09; \xdd(0-f)注意x一定不能大写 而且十六进制千万不能写\0xint main() {char s[]"ABCD", *p;for (p s 1; p < s 4; p)printf("%s…

简记修复改etc下profile失败的补救措施

现象 下午配置环境变量一个小小的失误&#xff0c;把etc文件夹下的profile改崩了&#xff0c;导致很多基本命令都用不了&#xff0c;服务器出现了下面这种情况。 [rootxxxx ~]# vi /etc/profile -bash: vi: command not found [rootxxxxx~]# vi -bash: vi: command not found…

今天我们深刻认识一下 Java虚拟机的程序计数器

1、为什么需要程序计数器 为了保证程序(在操作系统中理解为进程)能够连续地执行下去&#xff0c;CPU必须具有某些手段来确定下一条指令的地址。而程序计数器正是起到这种作用&#xff0c;所以通常又称为指令计数器。在程序开始执行前&#xff0c;必须将它的起始地址&#xff0c…
最新文章