3D Gaussian Splatting代码中的train和render两个文件代码解读

现在来聊一聊训练和渲染是如何进行的

training

train.py
line 31
def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoint_iterations, checkpoint, debug_from):# 初始化第一次迭代的索引为0first_iter = 0# 准备输出和日志记录器tb_writer = prepare_output_and_logger(dataset)# 初始化高斯模型,参数为数据集的球谐函数(SH)级别gaussians = GaussianModel(dataset.sh_degree)# 创建场景对象,包含数据集和高斯模型scene = Scene(dataset, gaussians)# 设置高斯模型的训练配置gaussians.training_setup(opt)# 加载检查点(如果有),恢复模型参数和设置起始迭代次数if checkpoint:(model_params, first_iter) = torch.load(checkpoint)gaussians.restore(model_params, opt)# 设置背景颜色,如果数据集背景为白色,则设置为白色([1, 1, 1]),否则为黑色([0, 0, 0])bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0]# 将背景颜色转换为CUDA张量,以便在GPU上使用background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")# 创建两个CUDA事件,用于记录迭代开始和结束的时间iter_start = torch.cuda.Event(enable_timing=True)iter_end = torch.cuda.Event(enable_timing=True)# 初始化视点堆栈为空viewpoint_stack = None# 用于记录指数移动平均损失的变量,初始值为0.0ema_loss_for_log = 0.0# 创建进度条,用于显示训练进度,从起始迭代数到总迭代数progress_bar = tqdm(range(first_iter, opt.iterations), desc="Training progress")# 增加起始迭代数,以便从下一次迭代开始first_iter += 1for iteration in range(first_iter, opt.iterations + 1):# 尝试连接网络GUI,如果当前没有连接if network_gui.conn == None:network_gui.try_connect()# 如果已经连接网络GUI,处理接收和发送数据while network_gui.conn != None:try:# 初始化网络图像字节为Nonenet_image_bytes = None# 从网络GUI接收数据custom_cam, do_training, pipe.convert_SHs_python, pipe.compute_cov3D_python, keep_alive, scaling_modifer = network_gui.receive()# 如果接收到自定义相机数据,则进行渲染if custom_cam != None:# 使用自定义相机数据、当前的高斯模型、管道和背景颜色进行渲染net_image = render(custom_cam, gaussians, pipe, background, scaling_modifer)["render"]# 将渲染结果转为字节格式,并转换为内存视图net_image_bytes = memoryview((torch.clamp(net_image, min=0, max=1.0) * 255).byte().permute(1, 2, 0).contiguous().cpu().numpy())# 发送渲染结果到网络GUI,并附带数据集的源路径network_gui.send(net_image_bytes, dataset.source_path)# 如果需要进行训练,并且当前迭代次数小于总迭代次数,或不需要保持连接,则退出循环if do_training and ((iteration < int(opt.iterations)) or not keep_alive):breakexcept Exception as e:# 如果出现异常,断开网络连接network_gui.conn = None# 记录当前迭代的开始时间,用于计算每次迭代的持续时间iter_start.record()# 更新学习率gaussians.update_learning_rate(iteration)# 每1000次迭代增加一次SH级别,直到达到最大度if iteration % 1000 == 0:gaussians.oneupSHdegree()# 随机选择一个相机视角if not viewpoint_stack:viewpoint_stack = scene.getTrainCameras().copy()# 从相机视角堆栈中随机弹出一个相机视角viewpoint_cam = viewpoint_stack.pop(randint(0, len(viewpoint_stack) - 1))# 渲染if (iteration - 1) == debug_from:pipe.debug = True# 如果设置了随机背景颜色,则生成一个随机背景颜色,否则使用预定义的背景颜色bg = torch.rand((3), device="cuda") if opt.random_background else background# 使用选定的相机视角、高斯模型、渲染管道和背景颜色进行渲染render_pkg = render(viewpoint_cam, gaussians, pipe, bg)# 提取渲染结果、视点空间点张量、可见性过滤器和半径image, viewspace_point_tensor, visibility_filter, radii = render_pkg["render"], render_pkg["viewspace_points"], render_pkg["visibility_filter"], render_pkg["radii"]# 计算损失gt_image = viewpoint_cam.original_image.cuda()  # 获取地面真实图像Ll1 = l1_loss(image, gt_image)  # 计算L1损失# 计算总损失,结合L1损失和结构相似性损失(SSIM)loss = (1.0 - opt.lambda_dssim) * Ll1 + opt.lambda_dssim * (1.0 - ssim(image, gt_image))loss.backward()  # 反向传播计算梯度# 记录当前迭代的结束时间,用于计算每次迭代的持续时间iter_end.record()# 在不需要计算梯度的上下文中进行操作with torch.no_grad():# 更新进度条和日志ema_loss_for_log = 0.4 * loss.item() + 0.6 * ema_loss_for_log  # 更新指数移动平均损失if iteration % 10 == 0:# 每10次迭代更新一次进度条progress_bar.set_postfix({"Loss": f"{ema_loss_for_log:.{7}f}"})progress_bar.update(10)if iteration == opt.iterations:progress_bar.close()# 记录训练报告并保存training_report(tb_writer, iteration, Ll1, loss, l1_loss, iter_start.elapsed_time(iter_end), testing_iterations, scene, render, (pipe, background))if iteration in saving_iterations:# 在指定的迭代次数保存高斯模型print("\n[ITER {}] Saving Gaussians".format(iteration))scene.save(iteration)# 密集化操作if iteration < opt.densify_until_iter:# 跟踪图像空间中的最大半径,用于修剪gaussians.max_radii2D[visibility_filter] = torch.max(gaussians.max_radii2D[visibility_filter], radii[visibility_filter])gaussians.add_densification_stats(viewspace_point_tensor, visibility_filter)# 在指定的迭代范围和间隔内进行密集化和修剪if iteration > opt.densify_from_iter and iteration % opt.densification_interval == 0:size_threshold = 20 if iteration > opt.opacity_reset_interval else Nonegaussians.densify_and_prune(opt.densify_grad_threshold, 0.005, scene.cameras_extent, size_threshold)# 在指定的间隔内或满足特定条件时重置不透明度if iteration % opt.opacity_reset_interval == 0 or (dataset.white_background and iteration == opt.densify_from_iter):gaussians.reset_opacity()# 优化器步骤if iteration < opt.iterations:gaussians.optimizer.step()  # 更新模型参数gaussians.optimizer.zero_grad(set_to_none=True)  # 清空梯度# 保存检查点if iteration in checkpoint_iterations:print("\n[ITER {}] Saving Checkpoint".format(iteration))torch.save((gaussians.capture(), iteration), scene.model_path + "/chkpnt" + str(iteration) + ".pth")

render

现在是渲染的这个文件进行方式,首先是主文件里单张图片的渲染和整个数据集的渲染方法:

render.py
line 24
# 渲染一组视角并保存渲染结果和对应的真实图像
def render_set(model_path, name, iteration, views, gaussians, pipeline, background):# 定义渲染结果和真实图像的保存路径render_path = os.path.join(model_path, name, "ours_{}".format(iteration), "renders")gts_path = os.path.join(model_path, name, "ours_{}".format(iteration), "gt")# 创建保存路径,如果路径不存在makedirs(render_path, exist_ok=True)makedirs(gts_path, exist_ok=True)# 遍历每个视角进行渲染for idx, view in enumerate(tqdm(views, desc="Rendering progress")):# 渲染图像rendering = render(view, gaussians, pipeline, background)["render"]# 获取对应的真实图像gt = view.original_image[0:3, :, :]# 保存渲染结果和真实图像到指定路径torchvision.utils.save_image(rendering, os.path.join(render_path, '{0:05d}'.format(idx) + ".png"))torchvision.utils.save_image(gt, os.path.join(gts_path, '{0:05d}'.format(idx) + ".png"))# 渲染训练集和测试集的图像,并保存结果
def render_sets(dataset: ModelParams, iteration: int, pipeline: PipelineParams, skip_train: bool, skip_test: bool):with torch.no_grad():# 初始化高斯模型和场景gaussians = GaussianModel(dataset.sh_degree)scene = Scene(dataset, gaussians, load_iteration=iteration, shuffle=False)# 设置背景颜色bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0]background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")# 如果不跳过训练集渲染,则渲染训练集的图像if not skip_train:render_set(dataset.model_path, "train", scene.loaded_iter, scene.getTrainCameras(), gaussians, pipeline, background)# 如果不跳过测试集渲染,则渲染测试集的图像if not skip_test:render_set(dataset.model_path, "test", scene.loaded_iter, scene.getTestCameras(), gaussians, pipeline, background)

但是这两个方法都是外层函数,并没有展示渲染如何进行参数传递和具体操作,在以下代码中才是最关键的内容:

gaussian_renderer\__init__.py
line 18
def render(viewpoint_camera, pc: GaussianModel, pipe, bg_color: torch.Tensor, scaling_modifier=1.0, override_color=None):"""渲染场景。参数:viewpoint_camera - 摄像机视角pc - 高斯模型pipe - 管道参数bg_color - 背景颜色张量,必须在GPU上scaling_modifier - 缩放修饰符,默认为1.0override_color - 覆盖颜色,默认为None"""# 创建一个全零张量,用于使PyTorch返回2D(屏幕空间)均值的梯度screenspace_points = torch.zeros_like(pc.get_xyz, dtype=pc.get_xyz.dtype, requires_grad=True, device="cuda") + 0try:screenspace_points.retain_grad()  # 保留梯度信息except:pass# 设置光栅化配置tanfovx = math.tan(viewpoint_camera.FoVx * 0.5)  # 计算视角的X轴正切tanfovy = math.tan(viewpoint_camera.FoVy * 0.5)  # 计算视角的Y轴正切raster_settings = GaussianRasterizationSettings(image_height=int(viewpoint_camera.image_height),  # 图像高度image_width=int(viewpoint_camera.image_width),  # 图像宽度tanfovx=tanfovx,  # 视角X轴正切tanfovy=tanfovy,  # 视角Y轴正切bg=bg_color,  # 背景颜色scale_modifier=scaling_modifier,  # 缩放修饰符viewmatrix=viewpoint_camera.world_view_transform,  # 世界视图变换矩阵projmatrix=viewpoint_camera.full_proj_transform,  # 投影变换矩阵sh_degree=pc.active_sh_degree,  # 球谐函数度数campos=viewpoint_camera.camera_center,  # 摄像机中心prefiltered=False,  # 预过滤debug=pipe.debug  # 调试模式)rasterizer = GaussianRasterizer(raster_settings=raster_settings)  # 初始化光栅化器means3D = pc.get_xyz  # 获取3D均值means2D = screenspace_points  # 获取2D均值opacity = pc.get_opacity  # 获取不透明度# 如果提供了预计算的3D协方差,则使用它。如果没有,则从光栅化器的缩放/旋转中计算。scales = Nonerotations = Nonecov3D_precomp = Noneif pipe.compute_cov3D_python:cov3D_precomp = pc.get_covariance(scaling_modifier)  # 计算3D协方差else:scales = pc.get_scaling  # 获取缩放rotations = pc.get_rotation  # 获取旋转# 如果提供了预计算的颜色,则使用它们。否则,如果需要在Python中预计算SH到颜色的转换,则进行转换。shs = Nonecolors_precomp = Noneif override_color is None:if pipe.convert_SHs_python:shs_view = pc.get_features.transpose(1, 2).view(-1, 3, (pc.max_sh_degree + 1) ** 2)dir_pp = (pc.get_xyz - viewpoint_camera.camera_center.repeat(pc.get_features.shape[0], 1))dir_pp_normalized = dir_pp / dir_pp.norm(dim=1, keepdim=True)sh2rgb = eval_sh(pc.active_sh_degree, shs_view, dir_pp_normalized)colors_precomp = torch.clamp_min(sh2rgb + 0.5, 0.0)  # 计算颜色else:shs = pc.get_features  # 获取球谐函数特征else:colors_precomp = override_color  # 覆盖颜色# 将可见的高斯体光栅化为图像,并获取它们在屏幕上的半径。rendered_image, radii = rasterizer(means3D=means3D,means2D=means2D,shs=shs,colors_precomp=colors_precomp,opacities=opacity,scales=scales,rotations=rotations,cov3D_precomp=cov3D_precomp)# 那些被视锥剔除或半径为0的高斯体是不可见的。# 它们将被排除在用于分裂标准的值更新之外。return {"render": rendered_image,  # 渲染图像"viewspace_points": screenspace_points,  # 视图空间点"visibility_filter": radii > 0,  # 可见性过滤器"radii": radii  # 半径}

最值得关注的光栅化器,如果转到定义去查看,其实会发现它就是第二期里讲forward的代码,只是这里面用python写了变量的调用,实际的操作方式还是在cu文件里面。所以在此就不多做赘述,可以看上一期博客里面对forward的解读。

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/39876.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

wordpress企业主题和wordpress免费主题

农业畜牧养殖wordpress主题 简洁大气的农业畜牧养殖wordpress主题&#xff0c;农业农村现代化&#xff0c;离不开新农人、新技术。 https://www.jianzhanpress.com/?p3051 SEO优化wordpress主题 简洁的SEO优化wordpress主题&#xff0c;效果好不好&#xff0c;结果会告诉你…

第一后裔The First Descendant延迟、卡顿、无法联机?

The First Descendant第一后裔游戏中还设计了多种辅助攻击手段&#xff0c;它们如同角色手中的魔法&#xff0c;为战斗增添了无数可能性。这些辅助攻击手段或能造成范围伤害&#xff0c;或能减速敌人&#xff0c;甚至能召唤出强大的支援力量。最近有玩家反映&#xff0c;遇到了…

Windows条件竞争提权漏洞复现(CVE-2024-300889)

漏洞原理 当内核将当前令牌对象的 _AUTHZBASEP_SECURITY_ATTRIBUTES_INFORMATION 复制到用户模式时&#xff0c;错误位于函数 AuthzBasepCopyoutInternalSecurityAttributes 内部&#xff0c;该模式的结构如下&#xff1a; //0x30 bytes (sizeof) struct _AUTHZBASEP_SECURIT…

科研工具|从图片中提取曲线数据

最近水哥在做一个项目时需要用到一篇论文中的数据&#xff0c;而这数据是作者的实验数据&#xff0c;且年代较为久远&#xff0c;联系原作者要一份数据也不太现实&#xff0c;因而只能从论文的图片中提取数据了。 目前市面上有很多小软件可以实现这方面的功能&#xff0c;比如…

DVT:华为提出动态级联Vision Transformer,性能杠杠的 | NeurIPS 2021

论文主要处理Vision Transformer中的性能问题&#xff0c;采用推理速度不同的级联模型进行速度优化&#xff0c;搭配层级间的特征复用和自注意力关系复用来提升准确率。从实验结果来看&#xff0c;性能提升不错 来源&#xff1a;晓飞的算法工程笔记 公众号 论文: Not All Image…

智慧应急管理平台:数字孪生,让防汛救灾更科学高效

近期全国各地暴雨频发&#xff0c;城市排水系统面临着前所未有的挑战&#xff0c;应急防涝已成为城市管理中不可或缺的一环。在这个信息化、智能化的时代&#xff0c;数字孪生技术以其独特的优势&#xff0c;为应急领域带来了革命性的变革。数字孪生&#xff0c;作为现实世界在…

揭秘:学校教室采用数码管同步时钟的原因-讯鹏电子钟

在学校的教室里&#xff0c;我们常常会看到数码管同步时钟的身影。究竟是什么原因让它成为学校教室的宠儿呢&#xff1f;让我们一同来探究其中的奥秘。 数码管同步时钟具有极高的准确性。对于学校这样一个对时间管理要求严格的场所&#xff0c;准确的时间是保障教学秩序的基石。…

SwinIR: Image Restoration Using Swin Transformer(ICCV 2021)含代码复现

目录 一、Introduction 1 Motivation 2 Contribution 二、原理分析 1 Network Architecture 1&#xff09;Shallow feature extraction 2) deep feature extraction 3) image reconsruction modules 4) loss function 2 Residual Swin Transformer Block 三、实验结果…

没有调用memcpy却报了undefined reference to memcpy错误

现象 在第5行出现了&#xff0c;undefined reference to memcpy’ 1 static void printf_x(unsigned int val) 2{ 3 char buffer[32]; 4 const char lut[]{0,1,2,3,4,5,6,7,8,9,A,B,C,D,E,F}; 5 char *p buffer; 6 while (val || p buffer) { 7 *(p) …

基于循环神经网络的一维信号降噪方法(简单版本,Python)

代码非常简单。 import torch import torch.nn as nn from torch.autograd import Variable from scipy.io.wavfile import write #need install pydub module #pip install pydub import numpy as np import pydub from scipy import signal import IPython import matplot…

C语言学习记录(十二)——指针与数组及字符串

文章目录 前言一、指针和数组二、指针和二维数组**行指针(数组指针)** 三、 字符指针和字符串四、指针数组 前言 一个学习嵌入式的小白~ 有问题评论区或私信指出~ 提示&#xff1a;以下是本篇文章正文内容&#xff0c;下面案例可供参考 一、指针和数组 在C语言中 &#xff0…

AI降重,不再难:降AI率的实用技巧大揭秘

如何有效降低AIGC论文的重复率&#xff0c;也就是我们说的aigc如何降重&#xff1f;AIGC疑似度过高确实是个比较愁人的问题。如果你用AI帮忙写了论文&#xff0c;就一定要在交稿之前做一下AIGC降重的检查。一般来说&#xff0c;如果论文的AIGC超过30%&#xff0c;很可能会被判定…

CAS操作

CAS 全称:Compare and swap,能够比较和交换某个寄存器中的值和内存中的值,看是否相等,如果相等,则把另外一个寄存器中的值和内存进行交换. (这是一个伪代码,所以这里的&address实际上是想要表示取出address中的值) 那么我们可以看到,CAS就是这样一个简单的交换操作,那么…

基于SpringBoot房屋租赁管理系统设计和实现(源码+LW+调试文档+讲解等)

&#x1f497;博主介绍&#xff1a;✌全网粉丝10W,CSDN作者、博客专家、全栈领域优质创作者&#xff0c;博客之星、平台优质作者、专注于Java、小程序技术领域和毕业项目实战✌&#x1f497; Java精品实战案例《1000套》 2025-2026年最值得选择的Java毕业设计选题大全&#xff…

新火种AI|国产大模型展开决战,是资本游戏还是技术革命?

作者&#xff1a;一号 编辑&#xff1a;美美 资本角逐与技术革新&#xff0c;国产大模型的双线战场已然开启。 随着人工智能技术的不断进步&#xff0c;国产大模型正迅速成为行业关注的焦点。在这个由数据驱动的时代&#xff0c;资本的注入和技术创新的加速&#xff0c;让国…

Python28-6 随机森林

随机森林算法详细介绍 1. 理论背景 随机森林&#xff08;Random Forest&#xff09;是一种由Leo Breiman和Adele Cutler在2001年提出的集成学习方法。它结合了多个决策树的预测结果&#xff0c;以提高模型的准确性和鲁棒性。 2. 算法细节 随机森林的构建过程可以分为以下几…

Qt——升级系列(Level Eight):界面优化

目录 QSS 背景介绍 基本语法 QSS设置方式 指定控件样式设置 全局样式设置 从文件加载样式表 使用Qt Designer 编辑样式 选择器 选择器概况 子控件选择器 伪类选择器 样式属性 盒模型 控件样式示例 按钮 复选框、单选框 输入框 列表 菜单栏 登录界面 绘图 基本概念 绘制各种形…

[Go 微服务] Kratos 使用的简单总结

文章目录 1.Kratos 简介2.传输协议3.日志4.错误处理5.配置管理6.wire 1.Kratos 简介 Kratos并不绑定于特定的基础设施&#xff0c;不限定于某种注册中心&#xff0c;或数据库ORM等&#xff0c;所以您可以十分轻松地将任意库集成进项目里&#xff0c;与Kratos共同运作。 API -&…

Linux内网端口转公网端口映射

由于服务商做安全演练&#xff0c;把原先服务器内网的端口映射到外网端口全都关闭了&#xff0c;每次维护服务器特别麻烦&#xff0c;像数据库查询如果用原生的mysql 去连接&#xff0c;查询返回的结果乱了&#xff0c;非常不方便。 查了服务还是可以正常访问部分外网的&#x…

抖音外卖服务商入驻流程及费用分别是什么?入驻官方平台的难度大吗?

随着抖音关于新增《【到家外卖】内容服务商开放准入公告》的意见征集通知&#xff08;以下简称“通知”&#xff09;的发布&#xff0c;抖音外卖服务商入驻流程及费用逐渐成为众多创业者所关注和热议的话题。不过&#xff0c;就当前的讨论情况来看&#xff0c;这个话题似乎没有…