深度学习_16_权重衰退调整过拟合

所谓过拟合即模型复杂度较高,但用于训练数据集过于简单,最后导致模型将过多无用渣质作为学习对象

这个在上篇 深度学习_15_过拟合&欠拟合 已经详细介绍,以下便不再赘述。

上篇提到要想解决过拟合现象可以试着降低模型复杂度,又或者用复杂度匹配的数据集训练模型,但是数据集一般都是采集好的,无法更改,更改模型复杂度又可能涉及到更改模型本身,所以为了解决上述困扰,这里使用权重衰退方式调整模型,以达到让模型正常拟合

本次调整的代码对象仍是上篇实例代码

理论:

在这里插入图片描述
权重衰退采用限制模型本身w的数据取值范围,使模型的参数平方和小于θ,从而使模型的损失降低,从而限制模型的复杂度

在这里插入图片描述
上篇提到过测试集合的损失在达到低谷后又会反弹,这是模型过拟合的表现,原因就是模型学习了其他的杂质,而我们利用上述权重衰退使得模型的参数平方和小于θ,这样模型想去学习新的参数,就会受到这个θ大小的限制,导致其无法学习新的θ,从而模型的大小被我们限制了起来,而杂质部分参数都会以0呈现,也就是模型前四个维度会趋近真正模型,后面杂质维度会趋近于0

θ大小也要把控好,若θ太小,模型会欠拟合

上述终究是理论,得用数学公式表示出来,也就等效于下方:

在这里插入图片描述

朴素的说下面两个公式等价
在这里插入图片描述
在这里插入图片描述
证明λ趋近0即相当于θ趋近无穷,对损失没有限制,λ趋近无穷则整体损失会变大,模型是往损失减小的方向学习,所以会间接导致w会变小,这与θ变小限制w的大小效果相同

在这里插入图片描述

实例代码:

对上篇代码增添了权重衰退法则,并写了两个不同的版本,两版本本质区别在于用了不同的优化模型手段,效果差别不大

版本1代码:

import math
import torch
from torch import nn
from d2l import torch as d2l
import matplotlib.pyplot as plt

# 生成随机的数据集
max_degree = 20  # 多项式的最大阶数
n_train, n_test = 100, 100  # 训练和测试数据集大小
true_w = torch.zeros(max_degree)
true_w[0:4] = torch.Tensor([5, 1.2, -3.4, 5.6])

# 生成特征
features = torch.randn((n_train + n_test, 1))
permutation_indices = torch.randperm(features.size(0))
# 使用随机排列的索引来打乱features张量(原地修改)
features = features[permutation_indices]
poly_features = torch.pow(features, torch.arange(max_degree).reshape(1, -1))
for i in range(max_degree):
    poly_features[:, i] /= math.gamma(i + 1)

# 生成标签
labels = torch.matmul(poly_features, true_w)
labels += torch.normal(0, 0.1, size=labels.shape)

def evaluate_loss(net, data_iter, loss):
    metric = d2l.Accumulator(2)
    for X, y in data_iter:
        out = net(X)
        y = y.reshape(out.shape)
        l = loss(out, y)
        metric.add(l.sum(), l.numel())
    return metric[0] / metric[1]
def l2_penalty(w):
    w = w[0].weight
    return torch.sum(w.pow(2)) / 2
# 修改后的训练函数
def train(train_features, test_features, train_labels, test_labels, lambd,
          num_epochs=400):
    loss = nn.MSELoss()  # 损失函数
    input_shape = train_features.shape[-1]
    net = nn.Sequential(nn.Linear(input_shape, 1, bias=False))
    batch_size = min(10, train_labels.shape[0])

    train_iter = d2l.load_array((train_features, train_labels.reshape(-1, 1)),
                                batch_size)
    test_iter = d2l.load_array((test_features, test_labels.reshape(-1, 1)),
                               batch_size, is_train=False)
    trainer = torch.optim.SGD(net.parameters(), lr=0.01, weight_decay=lambd)  # 优化器调整模型

    # 用于存储测试损失的列表
    test_losses = []
    train_losses = []
    total_loss = 0
    total_samples = 0
    for epoch in range(num_epochs):
        for X, y in train_iter:
            trainer.zero_grad()  # 删除之前的梯度
            out = net(X)
            l = loss(out, y) + lambd * l2_penalty(net)
            l.backward()  # 将梯度传递回模型
            trainer.step()  # 梯度更新
            total_loss += l.sum().item()  # 统计所有元素损失
            total_samples += y.numel()    # 统计个数
            # 将当前的损失值添加到列表中
        a = total_loss / total_samples  # 本次训练的平均损失
        train_losses.append(a)  # 存
        test_loss = evaluate_loss(net, test_iter, loss)  #  本次训练的测试损失
        test_losses.append(test_loss)
        total_loss = 0
        total_samples = 0
        print(f"Epoch {epoch + 1}/{num_epochs}:")
        print(f"训练损失: {a:.4f}  测试损失: {test_loss:.4f}")
       # print(f"  训练损失:  测试损失: {loss(out, y):.4f}")
    print(net[0].weight)
    # 假设 test_losses 是已经计算出的测试损失值列表
    plt.figure(figsize=(10, 6))
    plt.plot(train_losses, label='train', color='blue', linestyle='-', marker='.')
    plt.plot(test_losses, label='test', color='purple', linestyle='--', marker='.')
    plt.xlabel('epoch')
    plt.ylabel('loss')
    plt.title('Test Loss over Epochs')
    plt.legend()
    plt.grid(True)
    plt.ylim(0, 1)  # 设置y轴的范围从0.01到100
    plt.show()


# 选择多项式特征中的前4个维度
train(poly_features[:n_train, :4], poly_features[n_train:, :4],
      labels[:n_train], labels[n_train:], 0)

版本2代码:

import math
import torch
from torch import nn
from d2l import torch as d2l
import matplotlib.pyplot as plt

# 生成随机的数据集
max_degree = 20  # 多项式的最大阶数
n_train, n_test = 100, 100  # 训练和测试数据集大小
true_w = torch.zeros(max_degree)
true_w[0:4] = torch.Tensor([5, 1.2, -3.4, 5.6])

# 生成特征
features = torch.randn((n_train + n_test, 1))
permutation_indices = torch.randperm(features.size(0))
# 使用随机排列的索引来打乱features张量(原地修改)
features = features[permutation_indices]
poly_features = torch.pow(features, torch.arange(max_degree).reshape(1, -1))
for i in range(max_degree):
    poly_features[:, i] /= math.gamma(i + 1)

# 生成标签
labels = torch.matmul(poly_features, true_w)
labels += torch.normal(0, 0.1, size=labels.shape)


# 以下是你原来的训练函数,没有修改
def evaluate_loss(net, data_iter, loss):
    metric = d2l.Accumulator(2)
    for X, y in data_iter:
        out = net(X)
        y = y.reshape(out.shape)
        l = loss(out, y)
        metric.add(l.sum(), l.numel())
    return metric[0] / metric[1]


def l2_penalty(w):
    w = w[0].weight
    return torch.sum(w.pow(2)) / 2


def train(train_features, test_features, train_labels, test_labels, lambd,
          num_epochs=400):
    loss = d2l.squared_loss
    input_shape = train_features.shape[-1]
    net = nn.Sequential(nn.Linear(input_shape, 1, bias=False))
    batch_size = min(10, train_labels.shape[0])

    train_iter = d2l.load_array((train_features, train_labels.reshape(-1, 1)),
                                batch_size)
    test_iter = d2l.load_array((test_features, test_labels.reshape(-1, 1)),
                               batch_size, is_train=False)

    # 用于存储训练和测试损失的列表
    train_losses = []
    test_losses = []
    total_loss = 0
    total_samples = 0
    for epoch in range(num_epochs):
        for X, y in train_iter:
            out = net(X)
            l = loss(out, y) + lambd * l2_penalty(net)

            # 反向传播和优化器更新
            l.sum().backward()
            d2l.sgd(net.parameters(), lr=0.01, batch_size= batch_size)
            total_loss += l.sum().item()  # 统计所有元素损失
            total_samples += y.numel()  # 统计个数
        a = total_loss / total_samples  # 本次训练的平均损失
        train_losses.append(a)
        test_loss = evaluate_loss(net, test_iter, loss)
        test_losses.append(test_loss)
        total_loss = 0
        total_samples = 0
        print(f"Epoch {epoch + 1}/{num_epochs}:")
        print(f"训练损失: {a:.4f}   测试损失: {test_loss:.4f} ")
    print(net[0].weight)

    # 绘制损失曲线
    plt.figure(figsize=(10, 6))
    plt.plot(train_losses, label='train', color='blue', linestyle='-', marker='.')
    plt.plot(test_losses, label='test', color='purple', linestyle='--', marker='.')
    plt.xlabel('epoch')
    plt.ylabel('loss')
    plt.title('Loss over Epochs')
    plt.legend()
    plt.grid(True)
    plt.ylim(0, 100)  # 设置y轴的范围从0.01到100
    plt.show()


# 选择多项式特征中的前4个维度
train(poly_features[:n_train, :4], poly_features[n_train:, :4],
      labels[:n_train], labels[n_train:], 0)

代码分析:

def l2_penalty(w):
    w = w[0].weight
    return torch.sum(w.pow(2)) / 2

取模型参数平方和除以2

            out = net(X)
            l = loss(out, y) + lambd * l2_penalty(net)

            # 反向传播和优化器更新
            l.sum().backward()
            d2l.sgd(net.parameters(), lr=0.01, batch_size= batch_size)

用模型算出结果,然后计算损失,算出梯度并返还到模型中,利用sgd优化算法,更新模型参数
没有用torch本身优化函数所以l.sum().backward()不能简写

其他不再赘述,上篇已经写的很明白了

过拟合:

在这里插入图片描述
很明显测试集在损失达到最低后又上升,这是过拟合现象

利用权重衰退调整过拟合

在这里插入图片描述
在上述过拟合条件下将λ设为0.006即可缓解过拟合现象
顺带提一下,train损失比test损失高出的部分有训练损失多加的lambd * l2_penalty(net)部分和loss(out, y)升高部分,总体来说是两者制衡效果的体现

正常拟合

在这里插入图片描述

可看出上述调整过后的过拟合test损失与正常拟合的test损失及其接近,说明调成效果不错

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/429683.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Python web框架fastapi中间件与CORS详细教学

✨✨ 欢迎大家来到景天科技苑✨✨ 🎈🎈 养成好习惯,先赞后看哦~🎈🎈 所属专栏:Fastapi 景天的主页:景天科技苑 文章目录 fastapi中间件与CORS1、中间件1.创建中间件方法2.中间件里面添加响应头…

抖音视频评论批量采集软件|视频下载工具

《轻松搞定!视频评论批量采集软件,助您高效工作》 在短视频这个充满活力和创意的平台上,了解用户评论是了解市场和观众心声的重要途径之一。为了帮助您快速获取大量视频评论数据,我们推出了一款操作便捷、功能强大的软件&#xff…

第一弹:Flutter安装和配置

目标: 1)配置Flutter开发环境 2)创建第一个Flutter Demo项目 Flutter中文开发者网站: https://flutter.cn/ 一、配置Flutter开发环境 Flutter开发环境已经提供集成IDE开发环境,因此需要配置开发环境的时候&#xf…

【STM32】STM32学习笔记-读写内部FLASH 读取芯片ID(49)

00. 目录 文章目录 00. 目录01. FLASH概述02. 读写内部FLASH接线图03. 读写内部FLASH相关API04. 读写内部FLASH程序示例05. 读写芯片ID接线图06. 读写芯片ID程序示例07. 程序示例下载08. 附录 01. FLASH概述 STM32F10xxx内嵌的闪存存储器可以用于在线编程(ICP)或在程序中编程(…

Spark中读parquet文件是怎么实现的

背景 最近在整理了一下 spark对Parquet的写文件的过程,也是为了更好的理解和调优Spark相关的任务, 因为对于Spark来说,任何一个事情都不是独立的存在的,比如说parquet文件的rowgroup设置的大小对读写的影响,以及parqu…

数据删除

目录 数据删除 删除员工编号为 7369 的员工信息 删除若干个数据 删除公司中工资最高的员工 Oracle从入门到总裁:​​​​​​https://blog.csdn.net/weixin_67859959/article/details/135209645 数据删除 删除数据就是指删除不再需要的数据 delete from 表名称 [where 删…

Java架构之路-架构应全面了解的技术栈和工作域

有时候我在想这么简单简单的东西,怎么那么难以贯通。比如作为一个架构师可能涉及的不单单是技术架构,还包含了项目管理,一套完整的技术架构也就那么几个技术栈,只要花点心思,不断的往里面憨实,总会学的会&a…

huggingface学习|controlnet实战:云服务器使用StableDiffusionControlNetPipeline生成图像

ControlNet核心基础知识 文章目录 一、环境配置和安装需要使用的库二、准备数据及相关模型三、参照样例编写代码(一)导入相关库(二)准备数据(以知名画作《戴珍珠耳环的少女》为例)(三&#xff0…

C++惯用法之RAII思想: 资源管理

C编程技巧专栏:http://t.csdnimg.cn/eolY7 目录 1.概述 2.RAII的应用 2.1.智能指针 2.2.文件句柄管理 2.3.互斥锁 3.注意事项 3.1.禁止复制 3.2.对底层资源使用引用计数法 3.3.复制底部资源(深拷贝)或者转移资源管理权(移动语义) 4.RAII的优势和挑战 5.总…

制造业数字化赋能:1核心2关键3层面4方向

随着科技的飞速发展,制造业正站在数字化转型的风口浪尖。数字化转型不仅关乎企业效率与利润,更决定了制造业在全球竞争中的地位。那么,在这场波澜壮阔的数字化浪潮中,制造业如何抓住机遇,乘风破浪?本文将从…

Maven终端命令生成Spring-boot项目并输出“helloworld“

1. 生成项目 mvn archetype:generate填写groupId和artifactId&#xff0c;其余默认即可 2. 修改pom.xml文件 将如下内容放入pom.xml文件内 <parent><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-parent</artif…

计网面试题整理上

1. 计算机网络的各层协议及作用&#xff1f; 计算机网络体系可以大致分为一下三种&#xff0c;OSI七层模型、TCP/IP四层模型和五层模型。 OSI七层模型&#xff1a;大而全&#xff0c;但是比较复杂、而且是先有了理论模型&#xff0c;没有实际应用。TCP/IP四层模型&#xff1a…

Git入门学习笔记

Git 是一个非常强大的分布式版本控制工具&#xff01; 在下载好Git之后&#xff0c;鼠标右击就可以显示 Git Bash 和 Git GUI&#xff0c;Git Bash 就像是在电脑上安装了一个小型的 Linux 系统&#xff01; 1. 打开 Git Bash 2. 设置用户信息&#xff08;这是非常重要的&…

使用GRU进行天气变化的时间序列预测

本文基于最适合入门的100个深度学习项目的学习记录&#xff0c;同时在Google clolab上面是实现&#xff0c;文末有资源连接 天气变化的时间序列的难点 天气变化的时间序列预测涉及到了一系列复杂的挑战&#xff0c;主要是因为天气系统的高度动态性和非线性特征。以下是几个主…

(十五)【Jmeter】取样器(Sampler)之HTTP请求

简述 操作路径如下: HTTP请求 (HTTP Sampler): 作用:模拟发送HTTP请求并获取响应。配置:设置URL、请求方法、请求参数等参数。使用场景:测试Web应用程序的HTTP接口性能。优点:支持多种HTTP方法和请求参数,适用于大多数Web应用程序测试。缺点:功能较为基础,对于复杂…

突发,Anthropic推出突破性Claude 3系列模型,性能超越GPT-4

&#x1f989; AI新闻 &#x1f680; 突发&#xff0c;Anthropic推出突破性Claude 3系列模型 摘要&#xff1a;人工智能创业公司Anthropic宣布推出其Claude 3系列大型语言模型&#xff0c;该系列包括Claude 3 Haiku、Claude 3 Sonnet和Claude 3 Opus三个子模型&#xff0c;旨…

第18章 Java I/O系统

第18章 Java I/O系统 18.1 File类 File类不仅仅指文件&#xff0c;还能代表一个目录下的一组文件。 18.1.1 目录列表器 public class Test {public static void main(String[] args) {File file new File("E:\\test");String[] strings file.list(new DirFilte…

安装Proxmox VE虚拟机平台

PVE是专业的虚拟机平台&#xff0c;可以利用它安装操作系统&#xff0c;如&#xff1a;Win、Linux、Mac、群晖等。 1. 下载镜像 访问PVE官网&#xff0c;下载最新的PVE镜像。 https://www.proxmox.com/en/downloads 2. 下载balenaEtcher balenaEtcher用于将镜像文件&#…

「滚雪球学Java」:多线程(章节汇总)

咦咦咦&#xff0c;各位小可爱&#xff0c;我是你们的好伙伴——bug菌&#xff0c;今天又来给大家普及Java SE相关知识点了&#xff0c;别躲起来啊&#xff0c;听我讲干货还不快点赞&#xff0c;赞多了我就有动力讲得更嗨啦&#xff01;所以呀&#xff0c;养成先点赞后阅读的好…

数据可视化原理-腾讯-3D网格热力图

在做数据分析类的产品功能设计时&#xff0c;经常用到可视化方式&#xff0c;挖掘数据价值&#xff0c;表达数据的内在规律与特征展示给客户。 可是作为一个产品经理&#xff0c;&#xff08;1&#xff09;如果不能够掌握各类可视化图形的含义&#xff0c;就不知道哪类数据该用…
最新文章