论文辅助笔记:TEMPO 之 utils.py

0 导入库

from typing import Tuple
import random
import numpy as np
import torch
from statsmodels.tsa.seasonal import STL

1 EarlyStopping

  • 提供了一个早停机制,用于在模型训练过程中监控验证集上的损失
  • 如果损失停止改进,则停止训练

1.1 __init__

class EarlyStopping:
    def __init__(self, patience=7, verbose=False, delta=0):
        self.patience = patience
        #早停的容忍度,如果连续 patience 次验证损失没有改善,则停止训练。

        self.verbose = verbose
        #决定是否输出详细信息


        self.counter = 0
        #记录连续未改善验证损失的次数


        self.best_score = None
        #用于存储目前为止最佳的验证损失分数

        self.early_stop = False
        #一个布尔值,指示是否应该停止训练


        self.val_loss_min = np.Inf
        #存储目前为止最小的验证损失

        self.delta = delta
        #一个阈值,用于决定损失的改善幅度

1.2 __call__ 在训练过程中监控验证损失

def __call__(self, val_loss, model, path):
        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model, path)
            #如果这是第一次调用 __call__,初始化 best_score 为 score 并保存模型。
        elif score < self.best_score + self.delta:
            self.counter += 1
            print(f"EarlyStopping counter: {self.counter} out of {self.patience}")
            if self.counter >= self.patience:
                self.early_stop = True
            '''
            如果 score < self.best_score + self.delta,则说明损失没有显著改善
            
            增加 counter 并检查是否超过 patience,如果超过则停止训练
            '''
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model, path)
            self.counter = 0
            '''
            如果 score > self.best_score + self.delta,更新 best_score 并保存模型
            然后将 counter 重置为零
            '''

1.3 save_checkpoint 在验证损失降低时保存模型

def save_checkpoint(self, val_loss, model, path):
        if self.verbose:
            print(
                f"Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ..."
            )
        torch.save(model.state_dict(), path + "/" + "checkpoint.pth")
        #使用 torch.save() 保存模型的状态字典
        self.val_loss_min = val_loss
        

2 StandardScaler

实现数据标准化

2.1 __init__

class StandardScaler:
    def __init__(self):
        self.mean = 0.0
        self.std = 1.0

2.2  fit

计算并更新 self.meanself.std

def fit(self, data):
        self.mean = data.mean(0)
        self.std = data.std(0)

 2.3  transform

   将数据转换为标准化形式

def transform(self, data):
        mean = (
            torch.from_numpy(self.mean).type_as(data).to(data.device)
            if torch.is_tensor(data)
            else self.mean
        )
        std = (
            torch.from_numpy(self.std).type_as(data).to(data.device)
            if torch.is_tensor(data)
            else self.std
        )
        '''
        mean 和 std 的类型转换:
            根据 data 是 torch.Tensor 还是 numpy 数组
            将 self.mean 和 self.std 转换为相应类型,以确保类型匹配
        '''
        return (data - mean) / std

 2.4 inverse_transform

将标准化后的数据还原

    def inverse_transform(self, data):
        mean = (
            torch.from_numpy(self.mean).type_as(data).to(data.device)
            if torch.is_tensor(data)
            else self.mean
        )
        std = (
            torch.from_numpy(self.std).type_as(data).to(data.device)
            if torch.is_tensor(data)
            else self.std
        )

        '''
        mean 和 std 的类型转换:
            根据 data 是 torch.Tensor 还是 numpy 数组
            将 self.mean 和 self.std 转换为相应类型,以确保类型匹配
        '''

        if data.shape[-1] != mean.shape[-1]:
            mean = mean[-1:]
            std = std[-1:]


        return (data * std) + mean
        '''
        通过 (data * std) + mean 将标准化后的数据还原为原始形式
        '''

3 decompose

使用STL,将时间序列分解为趋势、季节性和残差成分

def decompose(
    x: torch.Tensor, period: int = 7
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    
    #x:输入的一维时间序列,类型为 torch.Tensor,形状为 (1, seq_len)
    x = x.squeeze(0).cpu().numpy()
    '''
    首先调用 squeeze(0) 将 x 的第一个维度去掉
    然后通过 cpu().numpy() 将 x 转换为 numpy 数组,以便 STL 分解函数使用
    '''


    decomposed = STL(x, period=period).fit()
    '''
    调用 STL(x, period=period).fit() 对 x 进行分解,并返回分解结果 decomposed
    
    其中包含了 trend(趋势)、seasonal(季节性)和 resid(残差)成分
    '''


    trend = decomposed.trend.astype(np.float32)
    seasonal = decomposed.seasonal.astype(np.float32)
    residual = decomposed.resid.astype(np.float32)
    '''
    将 decomposed 中的各个成分转换为 numpy 数组,并转为 float32 类型
    '''


    return (
        torch.from_numpy(trend).unsqueeze(0),
        torch.from_numpy(seasonal).unsqueeze(0),
        torch.from_numpy(residual).unsqueeze(0),
    )
    '''
    将它们转换为 torch.Tensor
    并使用 unsqueeze(0) 将其包装为 (1, seq_len) 的张量,以匹配输入张量的形状
    '''

4 set_seed

为 Python 中的各种随机生成器设置种子

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/586479.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

3.9设计模式——Strategy 策略模式(行为型)

意图 定义一系列的算法&#xff0c;把它们一个个封装起来&#xff0c;并且使他们可以相互替换此模式使得算法可以独立于使用它们的客户而变化 结构 Strategy&#xff08;策略&#xff09;定义所有支持的算法的公共入口。Context使用这个接口来调用某ConcreteStrategy定义的方…

手撕spring框架(2)

相关系列 java中spring底层核心原理解析&#xff08;1&#xff09;-CSDN博客 java中spring底层核心原理解析(2)-CSDN博客 手撕spring框架&#xff08;1&#xff09;-CSDN博客 依赖注入原理 依赖注入(Dependency Injection&#xff0c;简称DI)是一种设计模式&#xff0c;它允许我…

DS高阶:图论基础知识

一、图的基本概念及相关名词解释 1.1 图的基本概念 图是比线性表和树更为复杂且抽象的结&#xff0c;和以往所学结构不同的是图是一种表示型的结构&#xff0c;也就是说他更关注的是元素与元素之间的关系。下面进入正题。 图是由顶点集合及顶点间的关系组成的一种数据结构&…

深入浅出DBus-C++:Linux下的高效IPC通信

目录标题 1. DBus简介2. DBus-C的优势3. 安装DBus-C4. 使用DBus-C初始化和连接到DBus定义接口和方法发送和接收信号 5. dbus-cpp 0.9.0 的安装6. 创建一个 DBus 服务7. 客户端的实现8. 编译和运行你的应用9. 瑞芯微&#xff08;Rockchip&#xff09;的 Linux 系统通常会自带 db…

上位机开发PyQt(五)【Qt Designer】

PyQt5提供了一个可视化图形工具Qt Designer&#xff0c;文件名为designer.exe。如果在电脑上找不到&#xff0c;可以用如下命令进行安装&#xff1a; pip install PyQt5-tools 安装完毕后&#xff0c;可在如下目录找到此工具软件&#xff1a; %LOCALAPPDATA%\Programs\Python\…

智能体可靠性的革命性提升,揭秘知识工程领域的参考架构新篇章

引言&#xff1a;知识工程的演变与重要性 知识工程&#xff08;Knowledge Engineering&#xff0c;KE&#xff09;是一个涉及激发、捕获、概念化和形式化知识以用于信息系统的过程。自计算机科学和人工智能&#xff08;AI&#xff09;历史以来&#xff0c;知识工程的工作流程因…

【Web】2024XYCTF题解(全)

目录 ezhttp ezmd5 warm up ezMake ez?Make εZ?мKε? 我是一个复读机 牢牢记住&#xff0c;逝者为大 ezRCE ezPOP ezSerialize ezClass pharme 连连看到底是连连什么看 ezLFI login give me flag baby_unserialize ezhttp 访问./robots.txt 继…

ChatGPT向付费用户推“记忆”功能,可记住用户喜好 | 最新快讯

4月30日消息&#xff0c;人工智能巨头OpenAI宣布&#xff0c;其开发的聊天机器人ChatGPT将在除欧洲和韩国以外的市场全面上线“记忆”功能。这使得聊天机器人能够“记住”ChatGPT Plus付费订阅用户的详细信息&#xff0c;从而提供更个性化的服务。 OpenAI早在今年2月就已经宣布…

无缝对接配电自动化:IEC104转OPC UA网关解决方案

随着水电厂自动化发展的要求&#xff0c;具有一定规模的梯级水电站越来越多&#xff0c;为了实现水电站的无人值班(少人值守)&#xff0c;并考虑到节能控制&#xff0c;电厂采用了集中监控。集中监控关注的是整个电网的安全稳定运行及电压、频率和整个电网的电力需求&#xff0…

C++必修:类与对象(二)

✨✨ 欢迎大家来到贝蒂大讲堂✨✨ &#x1f388;&#x1f388;养成好习惯&#xff0c;先赞后看哦~&#x1f388;&#x1f388; 所属专栏&#xff1a;C学习 贝蒂的主页&#xff1a;Betty’s blog 1. 构造函数 1.1. 定义 构造函数是一个特殊的成员函数&#xff0c;名字与类名相…

软件工程物联网方向嵌入式系统复习笔记--如何控制硬件

5-如何控制硬件 #mermaid-svg-of9KvkxJqwLjSYzH {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermaid-svg-of9KvkxJqwLjSYzH .error-icon{fill:#552222;}#mermaid-svg-of9KvkxJqwLjSYzH .error-text{fill:#552222;stroke:#552…

vue3步骤条带边框点击切换高亮

如果是div使用clip-path: polygon(0% 0%, 92% 0%, 100% 50%, 92% 100%, 0% 100%, 8% 50%);进行裁剪加边框没实现成功。目前这个使用svg完成带边框的。 形状可自行更改path 标签里的 :d“[num ! 1 ? ‘M 0 0 L 160 0 L 176 18 L 160 38 L 0 38 L 15.5 18 Z’ : ‘M 0,0 L 160,0…

Docker: 如何不新建容器 修改运行容器的端口

目录 一、修改容器的映射端口 二、解决方案 三、方案 一、修改容器的映射端口 项目需求修改容器的映射端口 二、解决方案 停止需要修改的容器 修改hostconfig.json文件 重启docker 服务 启动修改容器 三、方案 目前正在运行的容器 宿主机的3000 端口 映射 容器…

2024最新版JavaScript逆向爬虫教程-------基础篇之常用的编码与加密介绍(python和js实现)

目录 一、编码与加密原理1.1 ASCII 编码1.2 详解 Base641.2.1 Base64 的编码过程和计算方法1.2.2 基于编码的反爬虫设计1.2.3 Python自带base64模块实现base64编码解码类封装 1.3 MD5消息摘要算法1.3.1 MD5 介绍1.3.2 Python实现md5以及其他常用消息摘要算法封装 1.4 对称加密与…

【GitHub】github学生认证,在vscode中使用copilot的教程

github学生认证并使用copilot教程 写在最前面一.注册github账号1.1、注册1.2、完善你的profile 二、Github 学生认证注意事项&#xff1a;不完善的说明 三、Copilot四、在 Visual Studio Code 中安装 GitHub Copilot 扩展4.1 安装 Copilot 插件4.2 配置 Copilot 插件&#xff0…

Java设计模式 _结构型模式_组合模式

一、组合模式 1、组合模式 组合模式&#xff08;Composite Pattern&#xff09;是这一种结构型设计模式。又叫部分整体模式。组合模式依据树形结构来组合对象&#xff0c;用来表示部分以及整体层次关系。即&#xff1a;创建了一个包含自己对象组的类&#xff0c;该类提供了修改…

Educational Codeforces Round 165 (Rated for Div. 2)[A~D]

这场签到很快那会rank1400吧&#xff0c;但到c就写不动了&#xff0c;最后排名也是3000 左右&#xff0c;可见很多人其实都不会写dp。快速签到也很重要啊&#xff01;&#xff01; A. Two Friends Problem - A - Codeforces 题目大意&#xff1a; M有n个朋友&#xff0c;编号…

【Java】java实现文件上传和下载(上传到指定路径/数据库/minio)

目录 上传到指定路径 一、代码层级结构 二、文件上传接口 三、使用postman进行测试&#xff1b; MultipartFile接收前端传递的文件&#xff1a;127.0.0.1:8082/path/uploadFile part接收前端传递的文件&#xff1a;127.0.0.1:8082/path/uploadFileByRequest 接收前端传递…

基于大模型的智能案件询问系统

一、数据库层面 1、document表 这个表是用来存储文件信息的。具体字段含义如下&#xff1a; 1. id&#xff1a;文件的唯一标识&#xff0c;整型&#xff0c;自增。 2. name&#xff1a;文件名称&#xff0c;字符串类型&#xff0c;最大长度为255个字符。 3. type&#xff1a…

宠物领养|基于SprinBoot+vue的宠物领养管理系统(源码+数据库+文档)

宠物领养目录 基于Spring Boot的宠物领养系统的设计与实现 一、前言 二、系统设计 三、系统功能设计 1前台 1.1 宠物领养 1.2 宠物认领 1.3 教学视频 2后台 2.1宠物领养管理 2.2 宠物领养审核管理 2.3 宠物认领管理 2.4 宠物认领审核管理 2.5 教学视频管理 四、…
最新文章