目录
推理代码:
EnvLight 代码:
推理代码:
sky_model = self.models["Sky"]outputs["rgb_sky"] = sky_model(image_info)outputs["rgb_sky_blend"] = outputs["rgb_sky"] * (1.0 - outputs["opacity"])
EnvLight 代码:
import torch# 定义环境光类(EnvLight),继承自 torch.nn.Module
class EnvLight(torch.nn.Module):def __init__(self, class_name: str, resolution: int = 1024, device: torch.device = torch.device("cuda"), **kwargs):# 初始化函数,接收类名、分辨率、设备(默认 GPU)以及其他关键字参数super().__init__()# 设置类的前缀,方便后续参数管理self.class_prefix = class_name + "#"# 设置设备(默认为 GPU)self.device = device# 定义 OpenGL 转换矩阵,将世界坐标系转换为 OpenGL 坐标系# 该矩阵的作用是转换方向向量self.to_opengl = torch.tensor([[1, 0, 0], [0, 0, 1], [0, -1, 0]], dtype=torch.float32, device="cuda")# 定义基础光照参数:初始化为一个 6 x resolution x resolution 的全 0.5 张量,# 每个光照样本有 3 个值(RGB)。该参数是可训练的(requires_grad=True)self.base = torch.nn.Parameter(0.5 * torch.ones(6, resolution, resolution, 3, requires_grad=True),)def forward(self, image_info: ImageInfo):# 前向传播函数,接受一个 ImageInfo 类型的输入(包含射线信息)# 获取传入图像信息中的方向向量(viewdirs),表示视角方向directions = image_info.rays.viewdirs# 将方向向量从世界坐标系转换到 OpenGL 坐标系directions = (directions.reshape(-1, 3) @ self.to_opengl.T).reshape(*directions.shape)# 重新调整方向向量的内存布局为连续的,以便后续操作directions = directions.contiguous()# 获取方向向量的前缀尺寸,用于后续的形状调整prefix = directions.shape[:-1]# 如果前缀尺寸不是三维(即 [B, H, W]),则将方向向量重塑为 [1, 1, -1, 3]# 目的是将其转换为适合批量处理的形状if len(prefix) != 3: # reshape to [B, H, W, -1]directions = directions.reshape(1, 1, -1, directions.shape[-1])# 使用 dr.texture 函数计算光照(dr 是某个光照计算库)# `self.base[None, ...]` 代表基础光照纹理,`directions` 是输入的方向向量# `filter_mode="linear"` 表示纹理的过滤模式,`boundary_mode="cube"` 表示纹理的边界模式light = dr.texture(self.base[None, ...], directions, filter_mode="linear", boundary_mode="cube")# 将输出的光照结果 reshaped 为适合的形状light = light.view(*prefix, -1)return lightdef get_param_groups(self):# 获取模型参数分组,返回一个字典# 这里我们将所有参数归为一个组,键为 "class_name + all"return {self.class_prefix + "all": self.parameters(),}