论文辅助笔记:TEMPO 之 utils.py

0 导入库

from typing import Tuple
import random
import numpy as np
import torch
from statsmodels.tsa.seasonal import STL

1 EarlyStopping

  • 提供了一个早停机制,用于在模型训练过程中监控验证集上的损失
  • 如果损失停止改进,则停止训练

1.1 __init__

class EarlyStopping:def __init__(self, patience=7, verbose=False, delta=0):self.patience = patience#早停的容忍度,如果连续 patience 次验证损失没有改善,则停止训练。self.verbose = verbose#决定是否输出详细信息self.counter = 0#记录连续未改善验证损失的次数self.best_score = None#用于存储目前为止最佳的验证损失分数self.early_stop = False#一个布尔值,指示是否应该停止训练self.val_loss_min = np.Inf#存储目前为止最小的验证损失self.delta = delta#一个阈值,用于决定损失的改善幅度

1.2 __call__ 在训练过程中监控验证损失

def __call__(self, val_loss, model, path):score = -val_lossif self.best_score is None:self.best_score = scoreself.save_checkpoint(val_loss, model, path)#如果这是第一次调用 __call__,初始化 best_score 为 score 并保存模型。elif score < self.best_score + self.delta:self.counter += 1print(f"EarlyStopping counter: {self.counter} out of {self.patience}")if self.counter >= self.patience:self.early_stop = True'''如果 score < self.best_score + self.delta,则说明损失没有显著改善增加 counter 并检查是否超过 patience,如果超过则停止训练'''else:self.best_score = scoreself.save_checkpoint(val_loss, model, path)self.counter = 0'''如果 score > self.best_score + self.delta,更新 best_score 并保存模型然后将 counter 重置为零'''

1.3 save_checkpoint 在验证损失降低时保存模型

def save_checkpoint(self, val_loss, model, path):if self.verbose:print(f"Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...")torch.save(model.state_dict(), path + "/" + "checkpoint.pth")#使用 torch.save() 保存模型的状态字典self.val_loss_min = val_loss

2 StandardScaler

实现数据标准化

2.1 __init__

class StandardScaler:def __init__(self):self.mean = 0.0self.std = 1.0

2.2  fit

计算并更新 self.meanself.std

def fit(self, data):self.mean = data.mean(0)self.std = data.std(0)

 2.3  transform

   将数据转换为标准化形式

def transform(self, data):mean = (torch.from_numpy(self.mean).type_as(data).to(data.device)if torch.is_tensor(data)else self.mean)std = (torch.from_numpy(self.std).type_as(data).to(data.device)if torch.is_tensor(data)else self.std)'''mean 和 std 的类型转换:根据 data 是 torch.Tensor 还是 numpy 数组将 self.mean 和 self.std 转换为相应类型,以确保类型匹配'''return (data - mean) / std

 2.4 inverse_transform

将标准化后的数据还原

    def inverse_transform(self, data):mean = (torch.from_numpy(self.mean).type_as(data).to(data.device)if torch.is_tensor(data)else self.mean)std = (torch.from_numpy(self.std).type_as(data).to(data.device)if torch.is_tensor(data)else self.std)'''mean 和 std 的类型转换:根据 data 是 torch.Tensor 还是 numpy 数组将 self.mean 和 self.std 转换为相应类型,以确保类型匹配'''if data.shape[-1] != mean.shape[-1]:mean = mean[-1:]std = std[-1:]return (data * std) + mean'''通过 (data * std) + mean 将标准化后的数据还原为原始形式'''

3 decompose

使用STL,将时间序列分解为趋势、季节性和残差成分

def decompose(x: torch.Tensor, period: int = 7
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:#x:输入的一维时间序列,类型为 torch.Tensor,形状为 (1, seq_len)x = x.squeeze(0).cpu().numpy()'''首先调用 squeeze(0) 将 x 的第一个维度去掉然后通过 cpu().numpy() 将 x 转换为 numpy 数组,以便 STL 分解函数使用'''decomposed = STL(x, period=period).fit()'''调用 STL(x, period=period).fit() 对 x 进行分解,并返回分解结果 decomposed其中包含了 trend(趋势)、seasonal(季节性)和 resid(残差)成分'''trend = decomposed.trend.astype(np.float32)seasonal = decomposed.seasonal.astype(np.float32)residual = decomposed.resid.astype(np.float32)'''将 decomposed 中的各个成分转换为 numpy 数组,并转为 float32 类型'''return (torch.from_numpy(trend).unsqueeze(0),torch.from_numpy(seasonal).unsqueeze(0),torch.from_numpy(residual).unsqueeze(0),)'''将它们转换为 torch.Tensor并使用 unsqueeze(0) 将其包装为 (1, seq_len) 的张量,以匹配输入张量的形状'''

4 set_seed

为 Python 中的各种随机生成器设置种子

def set_seed(seed):random.seed(seed)np.random.seed(seed)torch.manual_seed(seed)torch.cuda.manual_seed_all(seed)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/5684.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【C语言】指针篇- 深度解析Sizeof和Strlen:热门面试题探究(5/5)

&#x1f308;个人主页&#xff1a;是店小二呀 &#x1f308;C语言笔记专栏&#xff1a;C语言笔记 &#x1f308;C笔记专栏&#xff1a; C笔记 &#x1f308;喜欢的诗句:无人扶我青云志 我自踏雪至山巅 文章目录 一、简单介绍Sizeof和Strlen1.1 Sizeof1.2 Strlen函数1.3 Sie…

快速建站介绍

随着在线业务和电子商务的规模不断扩大&#xff0c;初创公司、个人网站和小型企业都需要快速地搭建自己的网站&#xff0c;以便更好地展示自己、推广产品和服务&#xff0c;并实现在线交易。快速建站已成为在线业务发展的一种主流方式&#xff0c;因为它能够快速地创建一个响应…

uniapp 自定义 App启动图

由于uniapp默认的启动界面太过普通 所以需要自定义个启动图 普通的图片不可以过不了苹果的审核 所以使用storyboard启动图 生成 storyboard 的网站&#xff1a;初雪云-提供一站式App上传发布解决方案

Java学习第02天-类型转换、运算符

目录 类型转换 自动类型转换 表达式的自动类型转换 强制类型转换 运算符 基本运算符 案例解答 连接字符串 自增自减运算符 面试习题 赋值运算符 比较运算符 逻辑运算符 基本逻辑运算符 短路逻辑运算符 三元运算符 基础知识 拓展案例 运算符优先级 读取用户…

UNeXt: a Low-Dose CT denoising UNet model with the modified ConvNeXt block

UNeXt&#xff1a;采用改进的ConvNeXt块的低剂量CT去噪UNet模型 论文链接&#xff1a;https://ieeexplore.ieee.org/document/10095645 项目链接&#xff1a;没找到 Abstract 近几十年来&#xff0c;临床医生广泛使用计算机断层扫描(CT)进行医学诊断。医疗辐射有潜在危险&am…

数据结构与算法-构建二叉树

构建二叉树 已知前序遍历与中序遍历或已知后序遍历和中序遍历可以构建唯一的二叉树 根据前序遍历与中序遍历建树 class Tree_Node():def __init__(self,val):self.val valself.left Noneself.right None # 构建二叉树 # 根据前序遍历与中序遍历构建二叉树 # 前序遍历[3,9…

77、贪心-买卖股票的最佳时机

思路 具体会导致全局最优&#xff0c;这里就可以使用贪心算法。方式如下&#xff1a; 遍历每一位元素找出当前元素最佳卖出是收益是多少。然后依次获取最大值&#xff0c;就是全局最大值。 这里可以做一个辅助数组&#xff1a;右侧最大数组&#xff0c;求右侧最大数组就要从…

ADS1.2中的代码debug的时候不出来代码的解决办法

我总觉得ADS1.2这个软件挺奇怪的&#xff0c;一阵一阵的&#xff0c;我遇到了很多奇怪的问题&#xff0c;这里记录一下吧。 1、新建文件的时候&#xff0c;没法选择这个add to project 解决办法&#xff1a;如果没有已存在的.mcp文件&#xff0c;就先新建project&#xff0c;然…

项目运行到手机端

运行到真机 手机和点到连在同一个wifi网络下面点击hbuiler上面的预览得到一个&#xff0c;network的网址这个时候去在手机访问&#xff0c;那么就可以访问网页了 跨域处理 这个时候可能会访问存在跨域问题 将uniapp的H5版本运行到真机进行调试&#xff0c;主要涉及到跨域问题…

java-Spring-mvc-(请求和响应)

目录 &#x1f4cc;HTTP协议 超文本传输协议 请求 Request 响应 Response &#x1f3a8;请求方法 GET请求 POST请求 &#x1f4cc;HTTP协议 超文本传输协议 HTTP协议是浏览器与服务器通讯的应用层协议&#xff0c;规定了浏览器与服务器之间的交互规则以及交互数据的格式…

【机器学习】HQ-Edit引领图像编辑新潮流

科技新纪元&#xff1a;HQ-Edit引领图像编辑新潮流 一、HQ-Edit的诞生&#xff1a;一场技术的革命二、技术实现与优势&#xff1a;强大的编辑能力和精准的匹配三、应用前景与实例展示&#xff1a;InstructPix2Pix的突破 在数字化时代&#xff0c;图像编辑技术正以前所未有的速度…

M3D-NCA: Robust 3D Segmentation with Built-In Quality Control论文速读

文章目录 M3D-NCA: Robust 3D Segmentation with Built-In Quality Control摘要方法实验结果 M3D-NCA: Robust 3D Segmentation with Built-In Quality Control 摘要 这是关于医学图像分割的一篇论文的结构化总结&#xff1a; 背景和挑战&#xff1a; 医学图像分割依赖于大型…

Kubernetes自动伸缩的主要类型有哪些?

Kubernetes中的自动伸缩功能主要有三种类型&#xff0c;分别针对不同的资源管理和应用场景。具体如下&#xff1a; Pod水平自动伸缩&#xff08;HPA&#xff09;&#xff1a;这是最常用的自动伸缩类型&#xff0c;它通过监控Pod的CPU利用率、内存利用率或自定义指标来增加或减…

【热闻速递】Google 裁撤 Python研发团队

&#x1f308;个人主页: 鑫宝Code &#x1f525;热门专栏: 闲话杂谈&#xff5c; 炫酷HTML | JavaScript基础 ​&#x1f4ab;个人格言: "如无必要&#xff0c;勿增实体" 文章目录 【&#x1f525;热闻速递】Google 裁撤 Python研发团队引入研究结论 【&#x1f5…

Spring中AOP原理

Spring中AOP原理 在Spring框架中&#xff0c;AOP&#xff08;Aspect-Oriented Programming&#xff0c;面向切面编程&#xff09;是一种重要的编程范式&#xff0c;它可以帮助我们实现对代码的横切关注点进行统一管理和处理。在Spring中&#xff0c;AOP的实现主要依赖 于两个核…

2405C++,部分解析数格

原文 如果一个很大的数格串,然后用户只想解析其中的一个字段,一般需要遍历所有串全部解析所有字段,这样效率就很低了. 如果可部分解析数格字段,就可避免全部解析了,从而获得更好的性能. iguana已增加了支持部分解析数格的特征,比如这样一个数格对象: struct some_test_t {i…

xyctf ez_rand

[核心的代码就是这一部分&#xff0c;只要得到v4的值&#xff0c;也就是随机种子&#xff0c;那就可以把值弄出来了。所以我们需要做的就是爆破随机种子。 然后有一点是需要注意的&#xff0c;IDA这里显示的数据有可能是小端序的&#xff0c;所以我们需要export data&#xff…

DSP实时分析平台设计方案:924-6U CPCI振动数据DSP实时分析平台

6U CPCI振动数据DSP实时分析平台 一、产品概述 基于CPCI结构完成40路AD输入&#xff0c;30路DA输出的信号处理平台&#xff0c;处理平台采用双DSPFPGA的结构&#xff0c;DSP采用TI公司新一代DSP TMS320C6678&#xff0c;FPGA采用Xilinx V5 5VLX110T-1FF1136芯片&#xff…

向量的旋转矩阵

我们都知道&#xff0c;矩阵的乘法可以表示旋转。那么&#xff0c;这一理论的数学机理是什么呢&#xff1f;以及&#xff0c;这个旋转角度该怎么用矩阵表示呢&#xff1f; 本文用二维向量旋转来推导旋转矩阵的公式。假设&#xff0c;我们有一个向量P(x, y)&#xff0c;准备通过…

http和https 所有的请求头信息

http 所有的请求头信息 HTTP请求头信息包含了客户端向服务器发送请求时附带的各种细节信息,帮助服务器更好地处理请求。这些头部字段多种多样,用于说明请求的各个方面,如客户端信息、请求的内容类型、缓存策略等。以下是一些常见的HTTP请求头字段,但请注意,这并非所有可能…