深度学习训练中的显存溢出问题分析与优化:以UNet图像去噪为例

最近在训练一个基于 Tiny-UNet 的图像去噪模型时,我遇到了经典但棘手的错误:
RuntimeError: CUDA out of memory。本文记录了我如何从复现、分析,到逐步优化并成功解决该问题的全过程,希望对深度学习开发者有所借鉴。

  • 训练数据:SIDD 小型图像去噪数据集

  • 模型结构:简化版 U-Net(Tiny-UNet)

  • class UNetDenoiser(nn.Module):def __init__(self):super(UNetDenoiser, self).__init__()# Encoderself.enc1 = self.conv_block(3, 64)self.enc2 = self.conv_block(64, 128)self.enc3 = self.conv_block(128, 256)self.pool = nn.MaxPool2d(2)# Bottleneckself.bottleneck = self.conv_block(256, 512)# Decoderself.up3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)self.dec3 = self.conv_block(512, 256)self.up2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)self.dec2 = self.conv_block(256, 128)self.up1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)self.dec1 = self.conv_block(128, 64)# Outputself.final = nn.Conv2d(64, 3, kernel_size=1)def conv_block(self, in_channels, out_channels):return nn.Sequential(nn.Conv2d(in_channels, out_channels, 3, padding=1), nn.ReLU(inplace=True),nn.Conv2d(out_channels, out_channels, 3, padding=1), nn.ReLU(inplace=True))def forward(self, x):# Encodere1 = self.enc1(x)            # [B, 64, H, W]e2 = self.enc2(self.pool(e1))  # [B, 128, H/2, W/2]e3 = self.enc3(self.pool(e2))  # [B, 256, H/4, W/4]# Bottleneckb = self.bottleneck(self.pool(e3))  # [B, 512, H/8, W/8]# Decoderd3 = self.up3(b)           # [B, 256, H/4, W/4]d3 = self.dec3(torch.cat([d3, e3], dim=1))d2 = self.up2(d3)          # [B, 128, H/2, W/2]d2 = self.dec2(torch.cat([d2, e2], dim=1))d1 = self.up1(d2)          # [B, 64, H, W]d1 = self.dec1(torch.cat([d1, e1], dim=1))return self.final(d1)

    源代码:
     

    # train_denoiser.py
    import os
    import math
    import torch
    import torch.nn as nn
    import torch.optim as optim
    from torch.utils.data import Dataset, DataLoader
    from torchvision import transforms
    from torchvision.utils import save_image
    from PIL import Image# --- 数据集定义 ---
    class DenoisingDataset(Dataset):def __init__(self, noisy_dir, clean_dir, transform=None):self.noisy_paths = sorted([os.path.join(noisy_dir, f) for f in os.listdir(noisy_dir)])self.clean_paths = sorted([os.path.join(clean_dir, f) for f in os.listdir(clean_dir)])self.transform = transform if transform else transforms.ToTensor()def __len__(self):return len(self.noisy_paths)def __getitem__(self, idx):noisy_img = Image.open(self.noisy_paths[idx]).convert("RGB")clean_img = Image.open(self.clean_paths[idx]).convert("RGB")return self.transform(noisy_img), self.transform(clean_img)# --- 简单 CNN 去噪模型 ---
    # class SimpleDenoiser(nn.Module):
    #     def __init__(self):
    #         super(SimpleDenoiser, self).__init__()
    #         self.encoder = nn.Sequential(
    #             nn.Conv2d(3, 64, 3, padding=1), nn.ReLU(),
    #             nn.Conv2d(64, 64, 3, padding=1), nn.ReLU()
    #         )
    #         self.decoder = nn.Sequential(
    #             nn.Conv2d(64, 64, 3, padding=1), nn.ReLU(),
    #             nn.Conv2d(64, 3, 3, padding=1)
    #         )
    #
    #     def forward(self, x):
    #         x = self.encoder(x)
    #         x = self.decoder(x)
    #         return x
    class UNetDenoiser(nn.Module):def __init__(self):super(UNetDenoiser, self).__init__()# Encoderself.enc1 = self.conv_block(3, 64)self.enc2 = self.conv_block(64, 128)self.enc3 = self.conv_block(128, 256)self.pool = nn.MaxPool2d(2)# Bottleneckself.bottleneck = self.conv_block(256, 512)# Decoderself.up3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)self.dec3 = self.conv_block(512, 256)self.up2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)self.dec2 = self.conv_block(256, 128)self.up1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)self.dec1 = self.conv_block(128, 64)# Outputself.final = nn.Conv2d(64, 3, kernel_size=1)def conv_block(self, in_channels, out_channels):return nn.Sequential(nn.Conv2d(in_channels, out_channels, 3, padding=1), nn.ReLU(inplace=True),nn.Conv2d(out_channels, out_channels, 3, padding=1), nn.ReLU(inplace=True))def forward(self, x):# Encodere1 = self.enc1(x)            # [B, 64, H, W]e2 = self.enc2(self.pool(e1))  # [B, 128, H/2, W/2]e3 = self.enc3(self.pool(e2))  # [B, 256, H/4, W/4]# Bottleneckb = self.bottleneck(self.pool(e3))  # [B, 512, H/8, W/8]# Decoderd3 = self.up3(b)           # [B, 256, H/4, W/4]d3 = self.dec3(torch.cat([d3, e3], dim=1))d2 = self.up2(d3)          # [B, 128, H/2, W/2]d2 = self.dec2(torch.cat([d2, e2], dim=1))d1 = self.up1(d2)          # [B, 64, H, W]d1 = self.dec1(torch.cat([d1, e1], dim=1))return self.final(d1)# --- PSNR 计算函数 ---
    def calculate_psnr(img1, img2):mse = torch.mean((img1 - img2) ** 2)if mse == 0:return float("inf")return 20 * torch.log10(1.0 / torch.sqrt(mse))# --- 主训练过程 ---
    def train_denoiser():noisy_dir = r"F:\SIDD数据集\archive\SIDD_Small_sRGB_Only\noisy"clean_dir = r"F:\SIDD数据集\archive\SIDD_Small_sRGB_Only\clean"batch_size = 1num_epochs = 50lr = 0.0005device = torch.device("cuda" if torch.cuda.is_available() else "cpu")dataset = DenoisingDataset(noisy_dir, clean_dir, transform=transforms.ToTensor())dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)# model = SimpleDenoiser().to(device)# 替换为 UNetmodel = UNetDenoiser().to(device)criterion = nn.MSELoss()optimizer = optim.Adam(model.parameters(), lr=lr)for epoch in range(num_epochs):model.train()total_loss = 0.0total_psnr = 0.0for noisy, clean in dataloader:noisy, clean = noisy.to(device), clean.to(device)denoised = model(noisy)loss = criterion(denoised, clean)optimizer.zero_grad()loss.backward()optimizer.step()total_loss += loss.item()total_psnr += calculate_psnr(denoised, clean).item()avg_loss = total_loss / len(dataloader)avg_psnr = total_psnr / len(dataloader)print(f"Epoch [{epoch+1}/{num_epochs}] - Loss: {avg_loss:.4f}, PSNR: {avg_psnr:.2f} dB")# 保存模型os.makedirs("weights", exist_ok=True)torch.save(model.state_dict(), "weights/denoiser.pth")print("模型已保存为 weights/denoiser.pth")if __name__ == "__main__":train_denoiser()
    

  • 显卡:8GB 显存的 RTX GPU

问题定位

我们从报错堆栈中看到:

e3 = self.enc3(self.pool(e2))
RuntimeError: CUDA out of memory. Tried to allocate 746.00 MiB

说明问题发生在模型第三层 encoder(enc3)前的 pooling 后,这说明:

  1. 当前的输入尺寸、batch size 占用了太多显存;

  2. 或者模型本身结构太重;

  3. 又或者显存未被合理管理(例如碎片化)。

分析与优化过程

第一步:降低 batch size

原始 batch size 设置为 16,直接触发爆显存。

我们尝试逐步调小 batch size:

batch_size = 6  # 从16降低到6

观察显存变化,发现仍有波动。为更稳定,设置为 4 或动态适配:

batch_size = min(8, torch.cuda.get_device_properties(0).total_memory // estimated_sample_size)

 发现同样的错误,显存不知。分析可能是网络参数太大了,或者训练过程没有启动内存优化。导致的内存不足,这些可以通过策略进行改进,达到训练的目的。

第二步:开启 cuDNN 自动优化

torch.backends.cudnn.benchmark = True

cuDNN 会根据不同卷积输入尺寸自动寻找最优算法,可能减少显存使用。

第三步:开启混合精度训练 AMP(Automatic Mixed Precision)

from torch.cuda.amp import autocast, GradScalerscaler = GradScaler()with autocast():output = model(input)loss = criterion(output, target)scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
  • autocast() 自动在部分层使用 float16,提高速度并减小显存压力;

  • GradScaler 确保在 float16 条件下梯度依然稳定。

实测显存使用降低近 30%,OOM 问题明显缓解!

但以上训练的预加载时间太慢,显卡占有率过低,有点显卡当前没有任务----“偷懒”的意思。可能是数据的加载或者显存抖动造成的。

第四步:优化 DataLoader 性能(间接缓解显存抖动)

 dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True,num_workers=num_workers, pin_memory=True)
  • num_workers 启用多进程加载数据;

  • pin_memory=True 启用固定内存,更快传输到 GPU。

虽然不直接节省显存,但显著减少显存峰值抖动(尤其在小 batch 训练时)。

第五步:检查图像输入尺寸是否太大

原始图像尺寸为 512×512:

transform = transforms.Compose([transforms.Resize((256, 256)),  # 降低分辨率transforms.ToTensor()
])

最终训练代码结构

我们将上述策略集成到了 train.py 脚本中(如下),包括:

  • Dataset & Dataloader 加速

  • 混合精度训练

  • cuDNN 优化

  • 实时 PSNR 显示

  • 自动保存模型权重

import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
from torch.cuda.amp import autocast, GradScaler
from tqdm import tqdm# --- 启用 cuDNN 自动优化 ---
torch.backends.cudnn.benchmark = True# --- 数据集定义 ---
class DenoisingDataset(Dataset):def __init__(self, noisy_dir, clean_dir, transform=None):self.noisy_paths = sorted([os.path.join(noisy_dir, f) for f in os.listdir(noisy_dir)])self.clean_paths = sorted([os.path.join(clean_dir, f) for f in os.listdir(clean_dir)])self.transform = transform if transform else transforms.Compose([transforms.Resize((256, 256)),transforms.ToTensor()])def __len__(self):return len(self.noisy_paths)def __getitem__(self, idx):noisy_img = Image.open(self.noisy_paths[idx]).convert("RGB")clean_img = Image.open(self.clean_paths[idx]).convert("RGB")return self.transform(noisy_img), self.transform(clean_img)# --- Tiny UNet 模型 ---
class TinyUNet(nn.Module):def __init__(self):super(TinyUNet, self).__init__()self.enc1 = self.conv_block(3, 16)self.enc2 = self.conv_block(16, 32)self.enc3 = self.conv_block(32, 64)self.pool = nn.MaxPool2d(2)self.bottleneck = self.conv_block(64, 128)self.up3 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)self.dec3 = self.conv_block(128, 64)self.up2 = nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2)self.dec2 = self.conv_block(64, 32)self.up1 = nn.ConvTranspose2d(32, 16, kernel_size=2, stride=2)self.dec1 = self.conv_block(32, 16)self.final = nn.Conv2d(16, 3, kernel_size=1)def conv_block(self, in_channels, out_channels):return nn.Sequential(nn.Conv2d(in_channels, out_channels, 3, padding=1), nn.ReLU(inplace=True),nn.Conv2d(out_channels, out_channels, 3, padding=1), nn.ReLU(inplace=True))def forward(self, x):e1 = self.enc1(x)e2 = self.enc2(self.pool(e1))e3 = self.enc3(self.pool(e2))b = self.bottleneck(self.pool(e3))d3 = self.up3(b)d3 = self.dec3(torch.cat([d3, e3], dim=1))d2 = self.up2(d3)d2 = self.dec2(torch.cat([d2, e2], dim=1))d1 = self.up1(d2)d1 = self.dec1(torch.cat([d1, e1], dim=1))return self.final(d1)# --- PSNR 计算 ---
def calculate_psnr(img1, img2):mse = torch.mean((img1 - img2) ** 2)if mse == 0:return float("inf")return 20 * torch.log10(1.0 / torch.sqrt(mse))# --- 训练函数 ---
def train_denoiser():noisy_dir = r"F:\SIDD数据集\archive\SIDD_Small_sRGB_Only\noisy"clean_dir = r"F:\SIDD数据集\archive\SIDD_Small_sRGB_Only\clean"batch_size = 6num_epochs = 50lr = 0.0005num_workers = 4device = torch.device("cuda" if torch.cuda.is_available() else "cpu")print(f"Using device: {device}")transform = transforms.Compose([transforms.Resize((256, 256)),transforms.ToTensor()])dataset = DenoisingDataset(noisy_dir, clean_dir, transform=transform)dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True,num_workers=num_workers, pin_memory=True)model = TinyUNet().to(device)criterion = nn.MSELoss()optimizer = optim.Adam(model.parameters(), lr=lr)scaler = GradScaler()  # AMP 梯度缩放器os.makedirs("weights", exist_ok=True)for epoch in range(num_epochs):model.train()total_loss = 0.0total_psnr = 0.0for noisy, clean in tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}"):noisy = noisy.to(device, non_blocking=True)clean = clean.to(device, non_blocking=True)optimizer.zero_grad()with autocast():  # 混合精度推理denoised = model(noisy)loss = criterion(denoised, clean)scaler.scale(loss).backward()scaler.step(optimizer)scaler.update()total_loss += loss.item()total_psnr += calculate_psnr(denoised.detach(), clean).item()avg_loss = total_loss / len(dataloader)avg_psnr = total_psnr / len(dataloader)print(f"✅ Epoch [{epoch+1}/{num_epochs}] - Loss: {avg_loss:.4f}, PSNR: {avg_psnr:.2f} dB")torch.save(model.state_dict(), f"weights/tiny_unet_epoch{epoch+1}.pth")print("🎉 模型训练完成,所有权重已保存至 weights/ 目录")if __name__ == "__main__":train_denoiser()

最后得到的训练文件,这里我设置的50次训练迭代:

测试模型的推理效果

 原去带噪声图片:

去噪后(可以看到这里仍然有bug,肉眼看效果并不是很好,需要进一步优化,考虑到模型的泛化性):

总结:处理 CUDA OOM 的思路模板

  1. 先查 batch size,这是最常见爆显存原因;

  2. 确认输入尺寸是否太大或未 resize;

  3. 启用 AMP,简单又高效;

  4. 合理设计模型结构(Tiny UNet > ResUNet);

  5. 使用 Dataloader 加速,避免数据传输抖动;

  6. 手动清理缓存防止 PyTorch 持有多余内存;

  7. 查看 PyTorch 显存使用报告,加上:

print(torch.cuda.memory_summary())

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/76764.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

FramePack V2版 - 支持首尾帧生成,支持LoRA,支持批量,支持50系显卡,一个强大的AI视频生成软件 本地一键整合包下载

FramePack 是斯坦福大学主导开发的视频生成框架,是一种用于视频生成的下一帧(下一帧部分)预测神经网络结构,可以逐步生成视频。FramePack 主要开发者之一,就是业内大名鼎鼎的张吕敏大佬,AI领域的“赛博佛祖…

STM32 HAL 通用定时器延时函数

使用通用定时器TIM3,实现ms、us延时。 delay.c #include "delay.h" #include "stm32f1xx_hal.h"TIM_HandleTypeDef htim3;/*** brief 初始化定时器3用于延时* param 无* retval 无*/ void Delay_Init(void) {TIM_ClockConfigTypeDef sClock…

软件功能测试和非功能测试有什么区别和联系?

软件测试是保障软件质量的核心环节,而软件功能测试和非功能测试作为测试领域的两大重要组成部分,承担着不同但又相互关联的职责。 软件功能测试指的是通过验证软件系统的各项功能是否按照需求规格说明书来正确实现,确保软件的功能和业务流程…

使用Java调用TensorFlow与PyTorch模型:DJL框架的应用探索

在现代机器学习的应用场景中,Python早已成为广泛使用的语言,尤其是在深度学习框架TensorFlow和PyTorch的开发和应用中。尽管Java在许多企业级应用中占据一席之地,但因为缺乏直接使用深度学习框架的能力,往往使得Java开发者对机器学…

Docker安装beef-xss

新版的kali系统中安装了beef-xss会因为环境问题而无法启动,可以使用Docker来安装beef-xss,节省很多时间。 安装步骤 1.启动kali虚拟机,打开终端,切换到root用户,然后执行下面的命令下载beef的docker镜像 wget https:…

metasploit(2)生成dll木马

声明!本文章所有的工具分享仅仅只是供大家学习交流为主,切勿用于非法用途,如有任何触犯法律的行为,均与本人及团队无关!!! 一、dll文件基本概念 DLL 是一种包含可由多个程序同时使用的代码和数…

5V 1A充电标准的由来与技术演进——从USB诞生到智能手机时代的电力革命

点击下面图片带您领略全新的嵌入式学习路线 🔥爆款热榜 88万阅读 1.6万收藏 一、起源:USB标准与早期电力传输需求 1. USB的诞生背景 1996年,由英特尔、微软、IBM等公司组成的USB-IF(USB Implementers Forum)发布了…

使用Python设置excel单元格的字体(font值)

一、前言 通过使用Python的openpyxl库,来操作excel单元格,设置单元格的字体,也就是font值。 把学习的过程分享给大家。大佬勿喷! 二、程序展示 1、新建excel import openpyxl from openpyxl.styles import Font wb openpyxl.…

【设计模式】深入解析代理模式(委托模式):代理模式思想、静态模式和动态模式定义与区别、静态代理模式代码实现

代理模式 代理模式,也叫委托模式。 Spring AOP 是基于动态代理来实现 AOP 的 定义 为其他对象提供一种代理 以控制对这个对象的访问。它的作用就是通过提供一个代理类,让我们在调用目标方法的时候,不再是直接对目标方法进行调用,而…

利用java语言,怎样开发和利用各种开源库和内部/自定义框架,实现“提取-转换-加载”(ETL)流程的自动化

一、ETL 架构设计的核心要素​ 在企业级数据处理场景中,ETL(Extract-Transform-Load)流程自动化是数据仓库、数据湖建设的核心环节。基于 Java 生态的技术栈,我们可以构建分层解耦的 ETL 架构,主要包含以下四层结构&am…

2023蓝帽杯初赛内存取证-8

也是用到pslist模块,加上grep过滤”chrome“即可: vol.py --plugin/opt/volatility/plugins -f memdump.mem --profile Win7SP1x64 pslist | grep "chrome" 第一个是PID,第二个是PPID,第三个是线程数,第四个…

【C语言】动态内存的常见错误

前言&#xff1a; 在上章节中讲解了动态内存的概念和管理的核心函数。 在本章节继续为大家介绍动态内存的常见错误&#xff0c;让大家更好的理解运用。 补充&#xff1a;使用内存函数需要头文件<stdlib.h> 对NULL指针的解引用操作 当使用malloc、calloc或realloc等函…

uniapp-x 二维码生成

支持X&#xff0c;二维码生成&#xff0c;支持微信小程序&#xff0c;android&#xff0c;ios&#xff0c;网页 - DCloud 插件市场 免费的单纯用爱发电的

Linux内核之文件驱动随笔

前言 近期需要实现linux系统文件防护功能&#xff0c;故此调研了些许知识&#xff0c;如何实现文件防护功能从而实现针对文件目录防护功能。当被保护的目录&#xff0c;禁止增删改操作。通过内核层面实现相关功能&#xff0c;另外在通过跟应用层面交互从而实现具体的业务功能。…

利用大模型实现地理领域文档中英文自动化翻译

一、 背景描述 在跨国性企业日常经营过程中&#xff0c;经常会遇到专业性较强的文档翻译的需求&#xff0c;例如法律文书、商务合同、技术文档等&#xff1b;以往遇到此类场景&#xff0c;企业内部往往需要指派专人投入数小时甚至数天来整理和翻译&#xff0c;效率低下&#x…

鸿蒙Flutter仓库停止更新?

停止更新 熟悉 Flutter 鸿蒙开发的小伙伴应该知道&#xff0c;Flutter 3.7.12 鸿蒙化 SDK 已经在开源鸿蒙社区发布快一年了&#xff0c; Flutter 3.22.x 的鸿蒙化适配一直由鸿蒙突击队仓库提供&#xff0c;最近有小伙伴反馈已经 2 个多月没有停止更新了&#xff0c;不少人以为停…

(七)深入了解AVFoundation-采集:采集系统架构与 AVCaptureSession 全面梳理

引言 在 iOS 开发中&#xff0c;AVFoundation 是构建音视频功能的强大底层框架。而在音视频功能中&#xff0c;“采集”往往是最基础也是最关键的一环。从摄像头捕捉图形、到麦克风获取声音&#xff0c;构建一条高效且稳定的采集链是开发高质量音视频应用的前提。 本系列将逐…

QML ShaderEffect(着色器效果)组件

ShaderEffect 是 QML 中用于实现自定义着色器效果的组件&#xff0c;允许开发者使用 GLSL 着色器语言创建图形效果。 核心属性 基本属性 属性类型默认值说明fragmentShaderstring""片段着色器代码vertexShaderstring""顶点着色器代码blendingbooltrue是…

基于javaweb的SSM教材征订与发放管理系统设计与实现(源码+文档+部署讲解)

技术范围&#xff1a;SpringBoot、Vue、SSM、HLMT、Jsp、PHP、Nodejs、Python、爬虫、数据可视化、小程序、安卓app、大数据、物联网、机器学习等设计与开发。 主要内容&#xff1a;免费功能设计、开题报告、任务书、中期检查PPT、系统功能实现、代码编写、论文编写和辅导、论文…

大模型学习笔记------Llama 3模型架构之分组查询注意力(GQA)

大模型学习笔记------Llama 3模型架构之分组查询注意力&#xff08;GQA&#xff09; 1、分组查询注意力&#xff08;GQA&#xff09;的动机2、 多头注意力&#xff08;Multi-Head Attention, MHA&#xff09;3、 多查询注意力 (Multi-Query Attention&#xff0c;MQA)4、 分组查…