根据DCT特征训练CNN

记录一次改代码的挣扎经历:
        看了几篇关于DCT频域的深度模型文献,尤其是21年FcaNet:基于DCT 的attention model,咱就是说想试试将我模型的输入改为分组的DCT系数,然后就开始下面的波折了。

第一次尝试:

        我直接调用了库函数,然后出现问题了:这个库函数是应用在numpy数组上,得在CPU上处理。

from scipy.fftpack import dct, idct
...
dct_block = dct(dct(block, axis=2, norm='ortho'), axis=3, norm='ortho')   # [B,C,k,k]
...
block = idct(idct(dct_block, axis=2, norm='ortho'), axis=3, norm='ortho')    # [B,C,k,k]

第二次尝试:
        好吧,我先把数据调回CPU,处理后,再调回GPU,又有新问题了:这样做(将block从GPU转移至CPU)torch类型张量转换为numpy数组时,torch张量的梯度无法保存。

# 图像分块
...
# 将块转移到 CPU
block_cpu = block.cpu()        # [B,C,k,k]
# 在 CPU 上对块应用 DCT
dct_block_np = dct(dct(block_cpu.numpy(), axis=2, norm='ortho'), axis=3, norm='ortho')   # [B,C,k,k]
# 将结果传输回 GPU
dct_block = torch.from_numpy(dct_block_np).to(image.device)     # [B,C,k,k]...# 将块转移到 CPU
dct_block_cpu = dct_block.cpu()
# 在 CPU 上对块应用逆 DCT
block_np = idct(idct(dct_block_cpu.numpy(), axis=2, norm='ortho'), axis=3, norm='ortho')
# 将结果传输回 GPU
block = torch.from_numpy(block_np).to(dct_block.device)    # [B,C,k,k]

 第三次尝试:

        根据报错提醒,我进行以下改进,将block_cpu.numpy -> block_cpu.detach.numpy(),即忽略掉torch类型张量带着的梯度信息,哈哈,这样一改,梯度就丢失了,模型就不能反向传播进行更新训练了。

# 图像分块
...
# 将块转移到 CPU
block_cpu = block.cpu()        # [B,C,k,k]
# 在 CPU 上对块应用 DCT
dct_block_np = dct(dct(block_cpu.numpy(), axis=2, norm='ortho'), axis=3, norm='ortho')   # [B,C,k,k]
# 将结果传输回 GPU
dct_block = torch.from_numpy(dct_block_np).to(image.device)     # [B,C,k,k]...# 将块转移到 CPU
dct_block_cpu = dct_block.cpu()
# 在 CPU 上对块应用逆 DCT
block_np = idct(idct(dct_block_cpu.detach.numpy(), axis=2, norm='ortho'), axis=3, norm='ortho')
# 将结果传输回 GPU
block = torch.from_numpy(block_np).to(dct_block.device)    # [B,C,k,k]

第四次尝试:
        CPU上库函数不好用,那我自己写(借鉴)DCT变换的函数嘛,DCT就是输入k*k图像关于k*k个余弦基函数的加权和嘛:

 别人写的的8 x 8d的DCT和IDCT的实现:


class DCT8X8(nn.Module):""" Discrete Cosine TransformationInput:image(tensor): batch x height x widthOutput:dcp(tensor): batch x height x width"""def __init__(self):super(DCT8X8, self).__init__()tensor = np.zeros((8, 8, 8, 8), dtype=np.float32)for x, y, u, v in itertools.product(range(8), repeat=4):tensor[x, y, u, v] = np.cos((2 * x + 1) * u * np.pi / 16) * np.cos((2 * y + 1) * v * np.pi / 16)alpha = np.array([1. / np.sqrt(2)] + [1] * 7)self.tensor = nn.Parameter(torch.from_numpy(tensor).float())self.scale = nn.Parameter(torch.from_numpy(np.outer(alpha, alpha) * 0.25).float())def forward(self, image):image = image - 128result = self.scale * torch.tensordot(image, self.tensor, dims=2)result.view(image.shape)return resultclass IDCT8X8(nn.Module):""" Inverse discrete Cosine TransformationInput:dcp(tensor): batch x height x widthOutput:image(tensor): batch x height x width"""def __init__(self):super(IDCT8X8, self).__init__()alpha = np.array([1. / np.sqrt(2)] + [1] * 7)self.alpha = nn.Parameter(torch.from_numpy(np.outer(alpha, alpha)).float())tensor = np.zeros((8, 8, 8, 8), dtype=np.float32)for x, y, u, v in itertools.product(range(8), repeat=4):tensor[x, y, u, v] = np.cos((2 * u + 1) * x * np.pi / 16) * np.cos((2 * v + 1) * y * np.pi / 16)self.tensor = nn.Parameter(torch.from_numpy(tensor).float())def forward(self, image):image = image * self.alpharesult = 0.25 * torch.tensordot(image, self.tensor, dims=2) + 128result.view(image.shape)return result

我根据上述改的任意block_size的DCT和IDCT:

class DCTCustom(nn.Module):"""Customizable Discrete Cosine TransformationInput:image(tensor): batch x height x widthOutput:dct(tensor): batch x height x width"""def __init__(self, input_size=8):super(DCTCustom, self).__init__()self.input_size = input_sizetensor = np.zeros((input_size, input_size, input_size, input_size), dtype=np.float32)for x, y, u, v in itertools.product(range(input_size), repeat=4):tensor[x, y, u, v] = np.cos((2 * x + 1) * u * np.pi / (2 * input_size)) * np.cos((2 * y + 1) * v * np.pi / (2 * input_size))alpha = np.array([1. / np.sqrt(2)] + [1] * (input_size - 1))self.tensor = nn.Parameter(torch.from_numpy(tensor).float())self.scale = nn.Parameter(torch.from_numpy(np.outer(alpha, alpha) * 0.25).float())def forward(self, image):image = image - 128result = self.scale * torch.tensordot(image, self.tensor, dims=2)result = result.view(image.shape)  # Corrected linereturn resultclass IDCTCustom(nn.Module):""" Inverse discrete Cosine TransformationInput:dcp(tensor): batch x height x widthOutput:image(tensor): batch x height x width"""def __init__(self, block_size=8):super(IDCTCustom, self).__init__()self.block_size = block_size# Compute alpha coefficientsalpha = np.array([1. / np.sqrt(2)] + [1] * (block_size - 1))self.alpha = nn.Parameter(torch.from_numpy(np.outer(alpha, alpha)).float())# Compute tensor for IDCTtensor = np.zeros((block_size, block_size, block_size, block_size), dtype=np.float32)for x, y, u, v in itertools.product(range(block_size), repeat=4):tensor[x, y, u, v] = np.cos((2 * u + 1) * x * np.pi / (2 * block_size)) * np.cos((2 * v + 1) * y * np.pi / (2 * block_size))self.tensor = nn.Parameter(torch.from_numpy(tensor).float())def forward(self, image):if image.shape[-2] % self.block_size != 0 or image.shape[-1] % self.block_size != 0:raise ValueError("Input dimensions must be divisible by the block size.")# Apply IDCTimage = image * self.alpharesult = 0.25 * torch.tensordot(image, self.tensor, dims=2) + 128result = result.view(image.shape)return result

        不出意外的话,问题又出现了,我对一个torch.ones((2,3,k,k))的张量进行DCT,再IDCT恢复。当k=8时(即block_size=8x8)时,能够完全恢复,但当k!=8(=16、32)时,经IDCT后无法恢复原始输入,懵。

第五次尝试(hh):
        突然!我发现了torch内置的DCT函数!可以再GPU上实现DCT。

torch-dct · PyPI

import torch_dct as dct# 图像分块    # [B,C,H,W]...        # [B,C,k,k]# dctblock = dct.dct_2d(block)     # [B,C,k,k]...# idctblock = dct.idct_2d(block)        # [B,C,k,k]

 然后又有问题了:
        我的模型开始训练后,我发现我的每个epoch的loss都为NAN...

        然后我打印了DCT输出,发现DCT系数长这个样子,CNN不高兴好好训练吧。

        我们再想想办法将输入数据归一化到范围[0, 1]或[-1, 1]之间,再喂给CNN吧。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/579844.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

深入解析 Flink CDC 增量快照读取机制

一、Flink-CDC 1.x 痛点 Flink CDC 1.x 使用 Debezium 引擎集成来实现数据采集,支持全量加增量模式,确保数据的一致性。然而,这种集成存在一些痛点需要注意: 一致性通过加锁保证:在保证数据一致性时,Debez…

运算符的结合性(形神兼备)

运算符的结合性(形神兼备) 在编译原理中,产生式就是权威。表达式如果以某产生式进行语法分析,那么就只能按照它的方式进行表达,且不能具有二义性。但是,在表达式中有时会涉及打括号的问题。很多时候&#…

20231226在Firefly的AIO-3399J开发板上在Android11下调通后摄像头ov13850

20231226在Firefly的AIO-3399J开发板上在Android11下调通后摄像头ov13850 2023/12/26 8:22 开发板:Firefly的AIO-3399J【RK3399】 SDK:rk3399-android-11-r20211216.tar.xz【Android11】 Android11.0.tar.bz2.aa【ToyBrick】 Android11.0.tar.bz2.ab And…

TypeScript:箭头函数

在TypeScript中,箭头函数是一种简洁的函数定义方式。以下是一些使用箭头函数的例子: 基本的箭头函数: const add (x: number, y: number) > {return x y; };单个参数的箭头函数可以省略括号: const square (x: number) >…

如何配置TLSv1.2版本的ssl

1、tomcat配置TLSv1.2版本的ssl 如下图所示&#xff0c;打开tomcat\conf\server.xml文件&#xff0c;进行如下配置&#xff1a; 注意&#xff1a;需要将申请的tomcat版本的ssl认证文件&#xff0c;如server.jks存放到tomcat\conf\ssl_file\目录下。 <Connector port"1…

Linux介绍、安装、常见命令

Linux介绍 Linux是一种开源的操作系统&#xff0c;其内核由林纳斯托瓦兹&#xff08;Linus Torvalds&#xff09;在1991年开始开发。与其他常见的操作系统如Windows和Mac OS不同&#xff0c;Linux是一个开放、自由的系统&#xff0c;可以免费使用、修改和分发。 Linux的核心特…

如何区分ChatGPT 3.5与ChatGPT 4:洞悉智能对话的新时代

如何区分ChatGPT 3.5与ChatGPT 4&#xff1a;洞悉智能对话的新时代 随着人工智能技术的快速发展&#xff0c;OpenAI持续推出更加强大和精准的模型&#xff0c;以改善和扩展用户体验。在聊天机器人领域&#xff0c;特别是OpenAI的ChatGPT系列&#xff0c;每一次迭代都带来了显著…

企业级实战项目:基于 pycaret 自动化预测公司是否破产

本文系数据挖掘实战系列文章&#xff0c;我跟大家分享一个数据挖掘实战&#xff0c;与以往的数据实战不同的是&#xff0c;用自动机器学习方法完成模型构建与调优部分工作&#xff0c;深入理解由此带来的便利与效果。 1. Introduction 本文是一篇数据挖掘实战案例&#xff0c;…

【超图】SuperMap 模型处理自动化方案 ——目录

作者&#xff1a;taco 在支持客户的过程中&#xff0c;会有很多用户会想要实现自动化流程&#xff0c;并非按部就班的一步一步去搞数据&#xff0c;搞优化。总是想要一个按钮就实现所有数据的处理&#xff0c;发布&#xff0c;预览等功能。根据这种情况&#xff0c;尝试搞一些自…

uniapp APP应用程序iOS没有上架到苹果应用商店如何整包更新?

随着移动互联网的快速发展&#xff0c;uni-app 作为一种跨平台开发框架&#xff0c;受到了广泛欢迎。然而&#xff0c;有时候开发者可能会遇到一个问题&#xff1a;如何为已经发布到苹果应用商店的 uni-app APP 进行整包更新&#xff1f;尤其是当应用还没有上架到苹果应用商店时…

android 四大组件和handler、looper

Android 四大组件 Android 开发中&#xff0c;四大组件&#xff08;Four Major Components&#xff09;是指构成 Android 应用程序的四种基本组件。这些组件是活动&#xff08;Activity&#xff09;、服务&#xff08;Service&#xff09;、广播接收器&#xff08;Broadcast R…

Git配置和钩子使用

0 Preface/Foreword 1 Usage 1.1 参考 https://www.cnblogs.com/guge-94/p/11287535.html 1.2 基本配置 1.2.1 配置名字和邮箱 git config --global user.name "xxx" git config --global user.email "xxx" 1.3 客户端基本配置 1.3.1 core.editor gi…

Hadoop-3.3.4集群部分lib缺失问题

1.问题描述 (base) [hadoophadoop1 native]$ hadoop checknative 2023-12-25 14:20:21,615 INFO bzip2.Bzip2Factory: Successfully loaded & initialized native-bzip2 library system-native 2023-12-25 14:20:21,618 INFO zlib.ZlibFactory: Successfully loaded &…

nodejs进阶

文章目录 写在前面一、dependencies、devDependencies和peerDependencies区别&#xff1a;二、需要牢记的npm命令2.1 npm init2.2 npm config list2.3 npm配置镜像源 三、npm install 的原理四、package-lock.json的作用五、npm run 的原理六、npx6.1 npx是什么6.2 npx的优势6.…

在Spring Boot中使用Redis

在Spring Boot中使用Redis,需要先添加Redis的依赖,然后配置Redis连接,最后通过Spring提供的模板类操作Redis。下面是一个基本的使用指南。1. 添加Redis依赖 在你的Spring Boot项目的pom.xml文件中,添加Redis Starter依赖: <dependency><groupId>org.springf…

深信服技术认证“SCSA-S”划重点:文件上传与解析漏洞

为帮助大家更加系统化地学习网络安全知识&#xff0c;以及更高效地通过深信服安全服务认证工程师考核&#xff0c;深信服特别推出“SCSA-S认证备考秘笈”共十期内容&#xff0c;“考试重点”内容框架&#xff0c;帮助大家快速get重点知识~ 划重点来啦 *点击图片放大展示 深信服…

百度站长、SEO、收录,网站自动提交百度链接 vuejs

created: 2023-12-26T10:34:37 (UTC +08:00) tags: [后端] source: https://juejin.cn/post/7152431823853715492 author: 源字节1号 Vue网站自动提交百度链接 - 掘金 Excerpt 怎样才能使新更新的文章更快的被百度收录,是所有站长最头疼的事情之一。开源字节使用自动提交脚本…

chrome扩展程序开发之在目标页面运行自己的JS

原文地址&#xff1a;https://qdgithub.com/home/index/article/aid/247.html chrome 插件开发的入门介绍&#xff0c;实现利用 chrome 扩展实现在目标网页运行我们的 js 的功能。关于 chrome 扩展的详细内容&#xff0c;可以通过官网了解。 开发工具很简单&#xff0c;记事本…

[spark] DataFrame 的 checkpoint

在 Apache Spark 中&#xff0c;DataFrame 的 checkpoint 方法用于强制执行一个物理计划并将结果缓存到分布式文件系统&#xff0c;以防止在计算过程中临时数据丢失。这对于长时间运行的计算过程或复杂的转换操作是有用的。 具体来说&#xff0c;checkpoint 方法执行以下操作&…

从AMI镜像恢复AWS Amazon Linux 2实例碰到的VNC服务以及Chrome浏览器无法启动的问题

文章目录 小结问题及解决VNC服务无法启动Chrome浏览器无法启动 参考 小结 将Amazon Linux 2保存为AMI (Amazon Machine Images)后&#xff0c;恢复成EC2 Instance (实例)后&#xff0c;VNC服务以及Chrome浏览器无法启动&#xff0c;进行了解决。 问题及解决 如果要将一个EC2…