从零开始训练Codebook:基于ViT的图像重建实践

完整代码在文末,可以一键运行。

在这里插入图片描述

1. 核心原理

Codebook是一种离散表征学习方法,其核心思想是将连续特征空间映射到离散的码本空间。我们的实现方案包含三个关键组件:

1.1 ViT编码器

class ViTEncoder(nn.Module):def __init__(self, codebook_dim=512):super().__init__()self.vit = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k")self.proj = nn.Linear(768, codebook_dim)def forward(self, x):outputs = self.vit(x).last_hidden_statepatch_embeddings = outputs[:, 1:, :]  # 移除CLS tokenreturn self.proj(patch_embeddings)
  • 使用预训练的ViT-Base模型提取图像特征
  • 移除CLS token,保留196个图像块特征
  • 线性投影调整特征维度适配Codebook

1.2 Codebook量化层

class Codebook(nn.Module):def __init__(self, num_embeddings=1024, embedding_dim=512):super().__init__()self.codebook = nn.Embedding(num_embeddings, embedding_dim)def quantize(self, z):# 计算L2距离distances = z_norm - 2 * dot_product + e_norm.unsqueeze(0)# 最近邻查找indices = torch.argmin(distances, dim=1)return indices, self.codebook(indices)
  • 使用可学习的Embedding层存储离散码本
  • 通过L2距离计算实现最近邻查找
  • 支持EMA更新(代码中已注释部分)

1.3 ViT解码器

class ViTDecoder(nn.Module):def __init__(self):self.head = nn.Sequential(nn.ConvTranspose2d(768, 384, 4, 2, 1),nn.ReLU(),... # 更多上采样层nn.Conv2d(48, 3, 1))
  • 使用转置卷积逐步上采样
  • 最终输出224x224分辨率图像
  • 与编码器形成对称结构

2. 训练策略

2.1 多目标损失函数

total_loss = mse_loss + 0.1*percep_loss + codebook_loss + commitment_loss
  • MSE Loss: 像素级重建误差
  • Perceptual Loss: VGG16特征匹配
  • Codebook Loss: 码本向量优化
  • Commitment Loss: 编码器输出稳定性

2.2 优化技巧

opt = torch.optim.Adam([{'params': encoder.parameters()},{'params': decoder.parameters()},{'params': codebook.parameters(), 'lr': 1e-4}
], lr=3e-4)
  • 分层学习率设置
  • EMA指数平滑更新
  • 混合精度训练支持
  • 动态学习率调整

3. 完整训练流程

3.1 数据准备

transform_train = transforms.Compose([transforms.Resize(224),transforms.RandomCrop(224, padding=4),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize(...)
])
  • CIFAR-10数据集
  • 随机裁剪+翻转增强
  • Batch Size=4适配显存

3.2 训练监控

# TensorBoard记录
writer.add_scalar('Loss/total', total_loss.item(), global_step)
writer.add_image('Reconstruction', grid, global_step)# 控制台日志
print(f"[Epoch {epoch+1:03d}] Loss: {total_loss.item():.4f}")

完整代码

from transformers import ViTModel, ViTConfig
import torch.nn as nn
import torch
import time
from tqdm import tqdm
class ViTEncoder(nn.Module):def __init__(self, codebook_dim=512):super().__init__()# 加载预训练ViT-Base模型self.vit = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k")# 调整输出维度匹配Codebookself.proj = nn.Linear(768, codebook_dim)  # 网页2/6中的线性嵌入策略def forward(self, x):outputs = self.vit(x).last_hidden_state  # [batch, num_patches+1, 768]patch_embeddings = outputs[:, 1:, :]     # 移除CLS tokenreturn self.proj(patch_embeddings)       # [batch, 196, 512]class Codebook(nn.Module):def __init__(self, num_embeddings=16384, embedding_dim=512):super().__init__()self.codebook = nn.Embedding(num_embeddings, embedding_dim)nn.init.normal_(self.codebook.weight)  # 网页1的EMA更新可在此扩展def quantize(self, z):"""量化输入特征向量参数:z: 输入特征 [batch, num_patches, embedding_dim]返回:indices: 最近邻码本索引 [batch, num_patches]quantized: 量化后的特征 [batch, num_patches, embedding_dim]"""# 重塑输入为二维矩阵 [batch*num_patches, embedding_dim]batch, num_patches, dim = z.shapez_flat = z.reshape(-1, dim)  # [batch*num_patches, dim]# 计算L2距离 ||z - e||^2 = ||z||^2 - 2<z,e> + ||e||^2z_norm = torch.sum(z_flat ** 2, dim=1, keepdim=True)  # [batch*num_patches, 1]e_norm = torch.sum(self.codebook.weight ** 2, dim=1)  # [num_embeddings]dot_product = torch.matmul(z_flat, self.codebook.weight.t())  # [batch*num_patches, num_embeddings]distances = z_norm - 2 * dot_product + e_norm.unsqueeze(0)# 找到最近邻indices = torch.argmin(distances, dim=1)  # [batch*num_patches]indices = indices.reshape(batch, num_patches)  # 恢复原始形状quantized = self.codebook(indices)  # [batch, num_patches, dim]return indices, quantized
class ViTDecoder(nn.Module):def __init__(self, in_dim=512):super().__init__()# 反向映射ViT的patch嵌入self.proj = nn.Linear(in_dim, 768)config = ViTConfig()config.is_decoder = True  # 网页7中的解码器模式self.transformer = ViTModel(config).encoder  self.head = nn.Sequential(# 14x14 -> 28x28nn.ConvTranspose2d(768, 384, kernel_size=4, stride=2, padding=1),nn.ReLU(),# 28x28 -> 56x56nn.ConvTranspose2d(384, 192, kernel_size=4, stride=2, padding=1),nn.ReLU(),# 56x56 -> 112x112 nn.ConvTranspose2d(192, 96, kernel_size=4, stride=2, padding=1),nn.ReLU(),# 112x112 -> 224x224nn.ConvTranspose2d(96, 48, kernel_size=4, stride=2, padding=1),nn.ReLU(),# 最终调整到3通道nn.Conv2d(48, 3, kernel_size=1))def forward(self, x):x = self.proj(x)  # [batch, 196, 768]x = self.transformer(x).last_hidden_statex = x.permute(0, 2, 1).view(-1, 768, 14, 14)  # 恢复空间布局return self.head(x)  # 输出[1, 3, 224, 224]
# encoder = ViTEncoder()
# codebooker = Codebook()
# decoder = ViTDecoder()# data = torch.randn(1, 3, 224, 224)
# output = encoder(data)
# print(output.shape)
# indices, quantized = codebooker.quantize(output)
# print(indices.shape, quantized.shape)
# reconstructed = decoder(quantized)
# print(reconstructed.shape)from torchvision import transforms
import torchvision
import torch.nn.functional as F
# 数据增强和预处理
transform_train = transforms.Compose([transforms.Resize(224),  # 调整图像尺寸适配模型transforms.RandomCrop(224, padding=4),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])transform_test = transforms.Compose([transforms.Resize(224),transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
# trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
# trainloader = torch.DataLoader(trainset, batch_size=64, shuffle=True)
# 加载CIFAR-10数据集
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)batch_size = 4  # 增大batch size加速训练
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)import torchvision.utils as vutils
from torch.utils.tensorboard import SummaryWriter
from torchvision.models import vgg16# 初始化TensorBoard
writer = SummaryWriter('runs/codebook_experiment')# 改进的Codebook类(增加EMA更新)
class Codebook(nn.Module):def __init__(self, num_embeddings=1024, embedding_dim=512, commitment_cost=0.25, decay=0.99):super().__init__()self.codebook = nn.Embedding(num_embeddings, embedding_dim)nn.init.normal_(self.codebook.weight)self.commitment_cost = commitment_costself.decay = decayself.register_buffer('ema_cluster_size', torch.zeros(num_embeddings))self.ema_w = nn.Parameter(torch.Tensor(num_embeddings, embedding_dim))nn.init.normal_(self.ema_w)def quantize(self, z):# 重塑输入为二维矩阵 [batch*num_patches, embedding_dim]batch, num_patches, dim = z.shapez_flat = z.reshape(-1, dim)  # [batch*num_patches, dim]# 计算L2距离 ||z - e||^2 = ||z||^2 - 2<z,e> + ||e||^2z_norm = torch.sum(z_flat ** 2, dim=1, keepdim=True)  # [batch*num_patches, 1]e_norm = torch.sum(self.codebook.weight ** 2, dim=1)  # [num_embeddings]dot_product = torch.matmul(z_flat, self.codebook.weight.t())  # [batch*num_patches, num_embeddings]distances = z_norm - 2 * dot_product + e_norm.unsqueeze(0)# 找到最近邻indices = torch.argmin(distances, dim=1)  # [batch*num_patches]indices = indices.reshape(batch, num_patches)  # 恢复原始形状quantized = self.codebook(indices)  # [batch, num_patches, dim]# 新增EMA更新# if self.training:#     with torch.no_grad():#         encodings = F.one_hot(indices, self.codebook.num_embeddings).float()#         self.ema_cluster_size = self.decay * self.ema_cluster_size + (1 - self.decay) * torch.sum(encodings, 0)#         n = torch.sum(self.ema_cluster_size)#         self.ema_cluster_size = ((self.ema_cluster_size + 1e-5) / (n + self.codebook.num_embeddings * 1e-5) * n)#         dw = torch.matmul(encodings.t(), z_flat)#         self.ema_w = nn.Parameter(self.ema_w * self.decay + (1 - self.decay) * dw)#         self.codebook.weight.data = self.ema_w / self.ema_cluster_size.unsqueeze(1)return indices, quantized
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 初始化组件
encoder = ViTEncoder().to(device)
codebook = Codebook(commitment_cost=0.25, decay=0.95).to(device)
decoder = ViTDecoder().to(device)
vgg = vgg16(pretrained=True).features[:16].eval().to(device)  # 用于感知损失# 优化器分开设置
opt = torch.optim.Adam([{'params': encoder.parameters()},{'params': decoder.parameters()},{'params': codebook.parameters(), 'lr': 1e-4}  # 更小的学习率
], lr=3e-4)# 训练循环
for epoch in range(100):avg_loss = 0start_time = time.time()  # 记录epoch开始时间for batch_idx, (images, _) in enumerate(tqdm(trainloader, desc=f"Epoch {epoch+1}", ncols=80)):images = images.to(device)# 前向传播z = encoder(images)indices, quantized = codebook.quantize(z)recon = decoder(quantized)# 多目标损失计算mse_loss = F.mse_loss(recon, images)# 感知损失(VGG特征匹配)with torch.no_grad():real_features = vgg(images)recon_features = vgg(recon)percep_loss = F.mse_loss(recon_features, real_features)# Codebook相关损失commitment_loss = codebook.commitment_cost * F.mse_loss(z.detach(), quantized)codebook_loss = F.mse_loss(z, quantized.detach())# 总损失total_loss = mse_loss + 0.1*percep_loss + codebook_loss + commitment_loss# 反向传播opt.zero_grad()total_loss.backward()opt.step()# 记录数据avg_loss += total_loss.item()if batch_idx % 50 == 0:# 记录TensorBoard数据writer.add_scalar('Loss/total', total_loss.item(), epoch*len(trainloader)+batch_idx)writer.add_scalars('Loss/components', {'mse': mse_loss.item(),'perceptual': percep_loss.item(),'codebook': codebook_loss.item(),'commitment': commitment_loss.item()}, epoch*len(trainloader)+batch_idx)# 保存重建样本comparison = torch.cat([images[:4], recon[:4]])grid = vutils.make_grid(comparison.cpu(), nrow=4, normalize=True)writer.add_image('Reconstruction', grid, epoch*len(trainloader)+batch_idx)# 打印epoch统计信息avg_loss /= len(trainloader)print(f"Epoch {epoch+1}: Avg Loss {avg_loss:.4f}")# 保存模型检查点if (epoch+1) % 10 == 0:torch.save({'encoder': encoder.state_dict(),'codebook': codebook.state_dict(),'decoder': decoder.state_dict(),'opt': opt.state_dict()}, f'checkpoint_epoch{epoch+1}.pth')writer.close()

通过本实践,我们实现了从特征提取到离散表征学习的完整流程。Codebook技术可广泛应用于图像压缩、生成模型等领域,期待读者在此基础上探索更多可能性。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/75513.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

大数据笔试题_第一阶段配套笔试题02

已知一个字符类型的日期&#xff1a;2022-01-20&#xff0c;请用SQL显示出此日期对应的下个月的月份&#xff0c;结果要求为Number类型&#xff08;202201&#xff09;。 参考答案 sql SELECT to_date(2022-01-20, yyyy-mm-dd) a1,add_months(to_date(2022-01-20, yyyy-mm-d…

C++实现对象单例模式

在 C 中实现单例模式有多种方法&#xff0c;以下是线程安全的现代 C 实现方式&#xff08;推荐 C11 及以上版本&#xff09;&#xff1a; 1. Meyers’ Singleton&#xff08;推荐&#xff09; class Singleton { public:// 删除拷贝构造和赋值运算符Singleton(const Singleto…

企业常用Linux服务搭建

1.需要两台centos 7服务器&#xff0c;一台部署DNS服务器&#xff0c;另一台部署ftp和Samba服务器。 2. 部署DNS 服务器​ #!/bin/bash# 更新系统 echo "更新系统..." sudo yum update -y# 安装 BIND 和相关工具 echo "安装 BIND 和相关工具..." sudo y…

UE5Actor模块源码深度剖析:从核心架构到实践应用

UE5 Actor模块源码深度剖析:从核心架构到实践应用 a. UE5 Actor模块架构概述 在UE5引擎中,Actor扮演着至关重要的角色,它是整个游戏世界中各类可交互对象的基础抽象。从本质上来说,所有能够被放置到关卡中的对象都属于Actor的范畴,像摄像机、静态网格体以及玩家起始位置…

DreamDiffusion代码学习及复现

论文解读在这里 File path | Description /pretrains ┣ &#x1f4c2; models ┃ ┗ &#x1f4dc; config.yaml ┃ ┗ &#x1f4dc; v1-5-pruned.ckpt┣ &#x1f4c2; generation ┃ ┗ &#x1f4dc; checkpoint_best.pth ┣ &#x1f4c2; eeg_pretain ┃ ┗ …

用Python实现TCP代理

依旧是Python黑帽子这本书 先附上代码&#xff0c;我在原书代码上加了注释&#xff0c;更好理解 import sys import socket import threading#生成可打印字符映射 HEX_FILTER.join([(len(repr(chr(i)))3) and chr(i) or . for i in range(256)])#接收bytes或string类型的输入…

Pyinstaller 打包flask_socketio为exe程序后出现:ValueError: Invalid async_mode specified

Pyinstaller 打包flask_socketio为exe程序后出现&#xff1a;ValueError: Invalid async_mode specified 一、详细描述问题描述 Traceback (most recent call last): File "app_3.py", line 22, in <module> File "flask_socketio\__init__.py"…

django REST framework(DRF)教程

Django DRF API Django 基本使用Django DRF序列化器Django DRF视图Django DRF常用功能Django 基本使用 前后端分离开发模式认识RestFulAPI回顾Django开发模式Django REST Framework初探前后端分离开发模式 前后端分离前:前端页面看到的效果都是由后端控制,即后端渲染HTML页面…

【Linux】Orin NX + Ubuntu22.04配置国内源

1、获取源 清华源 arm 系统的源,可以在如下地址获取到 https://mirror.tuna.tsinghua.edu.cn/help/ubuntu-ports/ 选择HTTPS,否则可能报错: 明文签署文件不可用,结果为‘NOSPLIT’(您的网络需要认证吗?)查看Orin NX系统版本 选择jammy的源 2、更新源 1)备份原配…

【含文档+PPT+源码】基于微信小程序的社交摄影约拍平台的设计与实现

项目介绍 本课程演示的是一款基于微信小程序的社交摄影约拍平台的设计与实现&#xff0c;主要针对计算机相关专业的正在做毕设的学生与需要项目实战练习的 Java 学习者。 1.包含&#xff1a;项目源码、项目文档、数据库脚本、软件工具等所有资料 2.带你从零开始部署运行本套系…

JDBC常用的接口

一、什么是JDBC JDBC是Java语言连接数据库的接口规范。 二、JDBC的体系 1、Java官方提供一个操作数据库的抽象接口 抽象接口有很多的接口和抽象类。 例如&#xff1a;Driver、Connection、Statement。 2、各个数据库厂商提供各自的Java实现类 需要各自实现具体的细节。 例如&am…

容器适配器-stack栈

C标准库不只是包含了顺序容器&#xff0c;还包含一些为满足特殊需求而设计的容器&#xff0c;它们提供简单的接口。 这些容器可被归类为容器适配器(container adapter)&#xff0c;它们是改造别的标准顺序容器&#xff0c;使之满足特殊需求的新容器。 适配器:也称配置器,把一…

[250403] HuggingFace 新增检查模型与电脑兼容性的功能 | Firefox 发布137.0 支持标签组

目录 Hugging Face 让寻找兼容的 AI 模型变得更容易Firefox 137 版本更新摘要 Hugging Face 让寻找兼容的 AI 模型变得更容易 Hugging Face 是一个流行的在线平台&#xff0c;用于访问开源人工智能 (AI) 工具和模型。该平台推出了一项有用的新功能&#xff0c;允许个人轻松检查…

.NET 创建MCP使用大模型对话二:调用远程MCP服务

在上一篇文章.NET 创建MCP使用大模型对话-CSDN博客中&#xff0c;我们简述了如何使用mcp client使用StdIo模式调用本地mcp server。本次实例将会展示如何使用mcp client模式调用远程mcp server。 一&#xff1a;创建mcp server 我们创建一个天气服务。 新建WebApi项目&#x…

Redis 中 Set(例如标签) 和 ZSet(例如排行榜) 的详细对比,涵盖定义、特性、命令、适用场景及总结表格

以下是 Redis 中 Set 和 ZSet 的详细对比&#xff0c;涵盖定义、特性、命令、适用场景及总结表格&#xff1a; 1. 核心定义 数据类型SetZSet&#xff08;Sorted Set&#xff09;定义无序的、唯一的字符串集合&#xff0c;元素不重复。有序的、唯一的字符串集合&#xff0c;每个…

解决Spring参数解析异常:Name for argument of type XXX not specified

前言 在开发 Spring Boot 应用时&#xff0c;我们常遇到类似 java.lang.IllegalArgumentException: Name for argument not specified 的报错。这类问题通常与方法参数名称的解析机制相关&#xff0c;尤其在使用 RequestParam、PathVariable 等注解时更为常见。 一、问题现象与…

刚刚,OpenAI开源PaperBench,重塑顶级AI Agent评测

今天凌晨1点&#xff0c;OpenAI开源了一个全新的AI Agent评测基准——PaperBench。 这个基准主要考核智能体的搜索、整合、执行等能力&#xff0c;需要对2024年国际机器学习大会上顶尖论文的复现&#xff0c;包括对论文内容的理解、代码编写以及实验执行等方面的能力。 根据O…

Golang封装Consul 服务发现库

以下是一个经过生产验证的 Consul 服务发现封装库,支持注册/注销、健康检查、智能发现等核心功能,可直接集成到项目中: package consulimport ("context""fmt""log""math/rand""net""os""sync"&quo…

自适应信号处理任务(过滤,预测,重建,分类)

自适应滤波 # signals creation: u, v, d N = 5000 n = 10 u = np.sin(np.arange(0, N/10., N/50000

PyTorch深度学习框架 的基础知识

目录 1.pyTorch检查是否安装成功 2.PyTorch的张量tensor 基础创建方式&#xff08;三种&#xff09; 2.2用列表创建tensor 2.2使用元组创建 tensor 2.3使用ndarray创建创建 tensor 2.4 快速创建tensor的常用方法 3.pyTorch中的张量tensor的常用属性 4. tensor中的基础数据…