NeRF学习——NeRF-Pytorch的源码解读

学习 github 上 NeRF 的 pytorch 实现项目(https://github.com/yenchenlin/nerf-pytorch)的一些笔记

1 参数

部分参数配置:

  1. 训练参数:

    这段代码是在设置一些命令行参数,这些参数用于控制NeRF(Neural Radiance Fields)的训练选项。具体来说:

    • netdepth:神经网络的层数。默认值为8

    • netwidth:每层的通道数。默认值为256

    • netdepth_fine:精细网络的层数。默认值为8

    • netwidth_fine:精细网络每层的通道数。默认值为256

    • N_rand:批量大小(每个梯度步骤的随机光线数)。默认值为 32 × 32 × 4 32 \times 32 \times 4 32×32×4

    • lrate:学习率。默认值为5e-4

    • lrate_decay:指数学习率衰减(在1000步中)。默认值为250

    • chunk:并行处理的光线数,如果内存不足,可以减少这个值。默认值为1024*32

    • netchunk:并行通过网络发送的点数,如果内存不足,可以减少这个值。默认值为1024*64

    • no_batching:是否只从一张图像中取随机光线

    • no_reload:是否不从保存的检查点重新加载权重

    • ft_path:用于重新加载粗网络的特定权重npy文件。默认值为None

    • precrop_iters:在中心裁剪上训练的步数。默认值为0。如果这个值大于0,那么在训练的开始阶段,模型将只在图像的中心部分进行训练,这可以帮助模型更快地收敛

    • precrop_frac:用于中心裁剪的图像的比例。默认值为0.5。这个值决定了在进行中心裁剪时,应该保留图像的多少部分。例如,如果这个值为0.5,那么将保留图像中心的50%

  2. 渲染参数:

    • N_samples:每条光线的粗采样数。默认64

    • N_importance:每条光线的额外精细采样数(分层采样)。默认0

    • perturb:设置为0表示没有抖动,设置为1表示有抖动。抖动可以增加采样点的随机性。默认1

    • use_viewdirs:是否使用完整的5D输入,而不是3D。5D输入包括3D位置和2D视角

    • i_embed:设置为0表示使用默认的位置编码,设置为-1表示不使用位置编码。默认0

    • multires:位置编码的最大频率的对数(用于3D位置)。默认10

    • multires_views:位置编码的最大频率的对数(用于2D方向)。默认4

      我们设置 d = 10 d=10 d=10 用于位置坐标 ϕ ( x ) ϕ(\bf x) ϕ(x) ,所以输入是60维的向量; d = 4 d=4 d=4 用于相机位姿 ϕ ( d ) ϕ(\bf d) ϕ(d) 对应的则是24维

    • raw_noise_std:添加到 sigma_a 输出的噪声的标准偏差,用于正则化 sigma_a 输出。默认0

    • render_only:如果设置,那么不进行优化,只加载权重并渲染出 render_poses 路径

    • render_test:如果设置,那么渲染测试集,而不是 render_poses 路径

    • render_factor:降采样因子,用于加速渲染。设置为4或8可以快速预览。默认0

  3. LLFF(Light Field Photography)数据集:

    • factor:LLFF图像的降采样因子。默认值为8。这个值决定了在处理LLFF图像时,应该降低多少分辨率

    • no_ndc:是否不使用归一化设备坐标(NDC)。如果在命令行中指定了这个参数,那么其值为True。这个选项应该在处理非前向场景时设置

    • lindisp:是否在视差中线性采样,而不是在深度中采样。如果在命令行中指定了这个参数,那么其值为True

    • spherify:是否处理球形360度场景。如果在命令行中指定了这个参数,那么其值为True

    • llffhold:每N张图像中取一张作为LLFF测试集。默认值为8。这个值决定了在处理LLFF数据集时,应该把多少图像作为测试集

      # 加载数据时,每隔args.llffhold个图像取一张图形
      i_test = np.arange(images.shape[0])[::args.llffhold]
      

2 大致过程

2.1 加载LLFF数据
  1. load_llff_data 函数返回五个值:images(图像),poses(姿态),bds(深度范围),render_poses(渲染姿态)和i_test(测试图像索引)

    • hwf是从poses中提取的图像的高度宽度焦距
    images, poses, bds, render_poses, i_test = load_llff_data(.....)
    hwf = poses[0,:3,-1]
    poses = poses[:,:3,:4]
    
  2. 将图像数据集划分为三个部分:训练集(i_train)、验证集(i_val)和测试集(i_test

    # 每隔args.llffhold个图像取一张做测试集
    i_test = np.arange(images.shape[0])[::args.llffhold]
    # 验证集 = 测试集
    i_val = i_test
    # 所有不在测试集和验证集中的图像
    i_train = np.array([i for i in np.arange(int(images.shape[0])) if
                    (i not in i_test and i not in i_val)])
    
2.2 创建神经网络模型
  1. 将采样点坐标和观察坐标通过位置编码 get_embedder 成63维和27维
  2. 实例化NeRF模型和NeRF精细模型
  3. 创建网络查询函数 network_query_fn() ,用于运行网络
  4. 创建 Adam 优化器
  5. 加载检查点(如果有),即从检查点中重新加载模型和优化器状态
  6. 创建用于训练和测试的渲染参数 render_kwargs_trainrender_kwargs_test
  7. 根据数据集类型(只有LLFF才行)和参数确定是否使用NDC
2.3 准备光线

使用批处理:

  1. 对于每一个姿态,使用get_rays_np函数获取光线原点和方向( ro+rd ),然后将所有的光线堆叠起来,得到rays
  2. 将射线的原点和方向与图像的颜色通道连接起来( ro+rd+rgb
  3. 对张量进行重新排列和整形,只保留训练集中的图像
  4. 对训练数据进行随机重排
2.4 训练迭代
  1. 设置训练迭代次数 N_iters = 200000 + 1

  2. 开始进行训练迭代

    • 准备光线数据:在每次迭代中,从rays_rgb中取出一批(批处理)光线数据,数量为参数值N_rand,并准备好目标值 target_s

      如果完成一个了周期(i_batch >= rays_rgb.shape[0] ),则对数据进行打乱

    • 渲染:使用渲染函数 render()

    • 计算损失:计算渲染结果的损失。这里使用了均方误差损失函数 img2mse() 来计算图像损失
      L = ∑ r ∈ R ∥ C ^ c ( r ) − C ( r ) ∥ 2 2 + ∥ C ^ f ( r ) − C ( r ) ∥ 2 2 \mathcal{L} = \sum_{\mathbf{r} \in \mathcal{R}} \left\| \hat{C}^c(\mathbf{r}) - C(\mathbf{r}) \right\|_2^2 + \left\| \hat{C}^f(\mathbf{r}) - C(\mathbf{r}) \right\|_2^2 L=rR C^c(r)C(r) 22+ C^f(r)C(r) 22

      img2mse = lambda x, y : torch.mean((x - y) ** 2)
      
    • 反向传播:进行反向传播,并执行优化

    • 更新学习率:这里采用指数衰减的学习率调度策略,学习率在每个一定的步骤(decay_steps)内以一定的速率(decay_rate)衰减

  3. 根据参数设置的频率输出相关状态、视频和测试集

3 神经网络模型

模型结构如下:

image-20240316162459526

  • 应用 ReLU 激活函数

  • 采样点坐标和观察坐标通过位置编码成63维和27维

  • 中间有一个跳跃连接在第四次 256->256 的线性层

    跳跃连接可以将某一层的输入直接传递到后面的层,从而避免梯度消失和表示瓶颈,提高网络的性能

4 体积渲染

4.1 render()

渲染主函数是调用 render() 函数:

def render(H, W, K, chunk=1024*32, rays=None, c2w=None, ndc=True,
                  near=0., far=1.,
                  use_viewdirs=False, c2w_staticcam=None,
                  **kwargs):

其有两种用法:

  1. 测试用:

    rgb, disp, acc, _ = render(H, W, K, 
                               chunk=chunk, 
                               c2w=c2w[:3,:4], 
                               **render_kwargs)
    

    c2w=c2w[:3,:4] 意味着光线的起点和方向是由函数内部通过相机参数计算得出的

    这个只在 render_path() 函数中用到,其在给定相机路径下渲染图像

    • 不训练只渲染时直接渲染时
    • 定期输出结果时
  2. 训练用:

    rgb, disp, acc, extras = render(H, W, K, 
                                    chunk=args.chunk, 
                                    rays=batch_rays,
                                    verbose=i < 10, 
                                    retraw=True,
                                    **render_kwargs_train)
    

    rays=batch_rays 意味着光线的起点和方向是预先计算好的,而不是由函数内部通过相机参数计算得出

    这个只在训练迭代时用到:Core optimization loop 中,对从rays_rgb中取出一批(批处理)光线进行渲染,得到的 rgb 值与 target_s (也来自预先计算好的 rays_rgb )计算 loss,来进行神经网络的训练

4.2 batchify_rays()

在主函数 render() 中,渲染工作是调用的 batchify_rays()

主要目的是将大量的光线分批处理,以避免在渲染过程中出现内存溢出(OOM)的问题

4.3 render_rays()

分批处理函数 batchify_rays() 中的渲染操作是由 render_rays() 进行,其是真正的渲染操作的函数

def render_rays(ray_batch,
                network_fn,
                network_query_fn,
                N_samples,
                retraw=False,
                lindisp=False,
                perturb=0.,
                N_importance=0,
                network_fine=None,
                white_bkgd=False,
                raw_noise_std=0.,
                verbose=False,
                pytest=False):

其参数:光线批次(ray_batch)、网络函数(network_fn)、网络查询函数(network_query_fn)、样本数量(N_samples)等等

返回:一个字典 ,包含了 RGB 颜色映射、视差映射、累积不透明度等信息

其大致过程为:

  1. 从光线批次中提取出光线的起点、方向、视线方向以及近远边界

    • 根据是否进行线性分布采样,计算出每个光线上的采样点的深度值

    • 若设置扰动( perturb ),则在每个采样间隔内进行分层随机采样

  2. 函数计算出每个采样点在空间中的位置

    pts = rays_o[...,None,:] + rays_d[...,None,:] * z_vals[...,:,None] # [N_rays, N_samples, 3]
    
  3. 然后使用 network_query_fn() 对每个采样点进行预测,得到原始的预测结果 raw

  4. 使用 raw2outputs()(请看下一节4.4) 函数将原始预测结果转换为 RGB 颜色映射、视差映射、累积不透明度等输出

  5. 若分层采样 N_importance > 0,调用 sample_pdf() 分层采样,并将这些额外的采样点传递给精细网络 network_fine 进行预测

  6. 最后,函数返回一个字典,包含了所有的输出结果

4.4 raw2outputs()

其将模型的原始预测转换为语义上有意义的值,主要基于论文中离散形式的积分方程实现:

累积不透明度函数 C ^ ( r ) \hat{C}(r) C^(r) 的估计公式如下:

C ^ ( r ) = ∑ i = 1 N T i ( 1 − exp ⁡ ( − σ i δ i ) ) c i \hat{C}(r) = \sum_{i=1}^{N} T_i (1 - \exp(-\sigma_i \delta_i)) c_i C^(r)=i=1NTi(1exp(σiδi))ci

其中,

  • N N N 是样本点的数量,
  • T i = exp ⁡ ( − ∑ j = 1 i − 1 σ j δ j ) T_i = \exp \left( - \sum_{j=1}^{i-1} \sigma_j \delta_j \right) Ti=exp(j=1i1σjδj) 是权重系数
  • δ i = t i + 1 − t i \delta_i = t_{i+1} - t_i δi=ti+1ti 表示相邻样本之间的距离
  • c i c_i ci 是颜色值
  • σ i \sigma_i σi 是不透明度值(体积密度)

根据代码,我们可以得出以下关系:

  • c i c_i ci 对应着 rgb = torch.sigmoid(raw[...,:3]),表示颜色值
  • σ i \sigma_i σi 对应着 raw[...,3],表示不透明度值

然后,我们可以根据公式中的每个项逐一解释如何在代码中实现:

  1. δ i = t i + 1 − t i \delta_i = t_{i+1} - t_i δi=ti+1ti:计算相邻样本之间的距离。在代码中:

     dists = z_vals[...,1:] - z_vals[...,:-1]
    
  2. 1 − exp ⁡ ( − σ i δ i ) 1 - \exp(-\sigma_i \delta_i) 1exp(σiδi):计算每个样本的不透明度。在代码中:

    raw2alpha = lambda raw, dists, act_fn=F.relu: 1.-torch.exp(-act_fn(raw)*dists)
    
    alpha = raw2alpha(raw[...,3] + noise, dists)
    
  3. T i = exp ⁡ ( − ∑ j = 1 i − 1 σ j δ j ) T_i = \exp \left( - \sum_{j=1}^{i-1} \sigma_j \delta_j \right) Ti=exp(j=1i1σjδj)​:计算权重系数。在代码中:

    即对 1 − ( 1 − exp ⁡ ( − σ i δ i ) ) 1 - (1 - \exp(-\sigma_i \delta_i)) 1(1exp(σiδi)) 累乘

    torch.cumprod(torch.cat([torch.ones((alpha.shape[0], 1)), 1.-alpha + 1e-10], -1), -1)[:, :-1]
    
  4. C ^ ( r ) = ∑ i = 1 N T i ( 1 − exp ⁡ ( − σ i δ i ) ) c i \hat{C}(r) = \sum_{i=1}^{N} T_i (1 - \exp(-\sigma_i \delta_i)) c_i C^(r)=i=1NTi(1exp(σiδi))ci​​:计算累积不透明度。在代码中:

    w i = T i ( 1 − exp ⁡ ( − σ i δ i ) ) w_i = T_i(1 - \exp(-\sigma_i\delta_i)) wi=Ti(1exp(σiδi))

    weights = alpha * torch.cumprod(torch.cat([torch.ones((alpha.shape[0], 1)), 1.-alpha + 1e-10], -1), -1)[:, :-1]
    rgb_map = torch.sum(weights[...,None] * rgb, -2)  # [N_rays, 3]
    

最终,代码返回估计的 RGB 颜色、视差图、累积权重、权重以及估计的距离图

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/463080.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

智慧城市与数字孪生:共创未来城市的智慧生活

目录 一、智慧城市与数字孪生的概念与特点 二、智慧城市与数字孪生共创智慧生活的路径 1、城市规划与建设的智能化 2、城市管理与服务的智慧化 3、城市安全与应急管理的智能化 三、智慧城市与数字孪生面临的挑战与对策 四、智慧城市与数字孪生的发展趋势与展望 1、技术…

redis中List和hash数据类型

list类型是用来存储多个有序的字符串的&#xff0c;列表当中的每一个字符看做一个元素&#xff0c;一个列表当中可以存储一个或者多个元素&#xff0c;redis的list支持存储2^32-1个元素。redis可以从列表的两端进行插入&#xff08;pubsh&#xff09;和弹出&#xff08;pop&…

游戏引擎中网络游戏的基础

一、前言 网络游戏所面临的挑战&#xff1a; 一致性&#xff1a;如何在所有的主机内都保持一样的表现可靠性&#xff1a;网络传输有可能出现丢包安全性&#xff1a;反作弊&#xff0c;反信息泄漏。多样性&#xff1a;不同设备之间链接&#xff0c;比如手机&#xff0c;ipad&a…

专升本 C语言笔记-07 逗号运算符

1.逗号表达式的用法 就是用逗号隔开的多个表达式。逗号表达式&#xff0c;从左向右依次执行。 2.逗号表达式的特性 2.1.当没有括号时&#xff0c;第一个表达式为整个表达式的值。 代码 int x 3,y 5,a 0; a x,y; printf("a %d",a); 说明:因为逗号优先级最低,会…

OpenCV4.9.0开源计算机视觉库在 Linux 中安装

返回目录&#xff1a;OpenCV系列文章目录&#xff08;持续更新中......&#xff09; 上一篇&#xff1a;OpenCV 环境变量参考 下一篇&#xff1a;将OpenCV与gcc和CMake结合使用 引言&#xff1a; OpenCV是一个开源的计算机视觉库&#xff0c;由英特尔公司所赞助。它是一个跨…

确保云原生部署中的网络安全

数字环境正在以惊人的速度发展&#xff0c;组织正在迅速采用云原生部署和现代化使用微服务和容器构建的应用程序&#xff08;通常运行在 Kubernetes 等平台上&#xff09;&#xff0c;以推动增长。 无论我们谈论可扩展性、效率还是灵活性&#xff0c;对于努力提供无与伦比的用…

源码|批量执行invokeAll()多选一invokeAny()

ExecutorService中定义了两个批量执行任务的方法&#xff0c;invokeAll()和invokeAny()&#xff0c;在批量执行或多选一的业务场景中非常方便。invokeAll()在所有任务都完成&#xff08;包括成功/被中断/超时&#xff09;后才会返回&#xff0c;invokeAny()在任意一个任务成功&…

校园博客系统 |基于springboot框架+ Mysql+Java的校园博客系统设计与实现(可运行源码+数据库+设计文档)

推荐阅读100套最新项目 最新ssmjava项目文档视频演示可运行源码分享 最新jspjava项目文档视频演示可运行源码分享 最新Spring Boot项目文档视频演示可运行源码分享 目录 前台功能效果图 管理员功能登录前台功能效果图 系统功能设计 数据库E-R图设计 lunwen参考 摘要 研究…

MySQL语法分类 DDL(1)

DDL&#xff08;1&#xff09;(操作数据库、表) 数据库操作(CRUD) C(Create):创建 //指定字符集创建 create database db_1 character set utf8;//避免重复创建数据库报错可以用一下命令 create database if not exists db_1 character set utf8;R(Retrieve):查询 //查询所…

电源适配器

电源适配器 1. 选购指南2. 接口测量方法3. 电源接口4. 抗干扰磁环&#xff0c;稳定输出References 1. 选购指南 插头尺度相同&#xff0c;供电电压 (V) 相同&#xff0c;电流 (A) > 原来的电流 (A) INPUT (输入)&#xff0c;OUTPUT (输出) 2. 接口测量方法 3. 电源接口 外…

ARM和AMD介绍

一、介绍 ARM 和 AMD 都是计算机领域中的知名公司&#xff0c;它们在不同方面具有重要的影响和地位。 ARM&#xff08;Advanced RISC Machine&#xff09;&#xff1a;ARM 公司是一家总部位于英国的公司&#xff0c;专注于设计低功耗、高性能的处理器架构。ARM 架构以其精简指…

HCIP—BGP邻居关系建立实验

BGP的邻居称为&#xff1a;IBGP对等体 EBGP对等体 1.EBGP对等体关系&#xff1a; 位于 不同自治系统 的BGP路由器之间的BGP对等体关系 EBGP对等体一般使用 直连建立 对等体关系&#xff0c;EBGP邻居之间的报文 TTL中值设置为1 两台路由器之间建立EBGP对等体关系&#xff0…

Python Web开发记录 Day12:Django part6 用户登录

名人说&#xff1a;东边日出西边雨&#xff0c;道是无晴却有晴。——刘禹锡《竹枝词》 创作者&#xff1a;Code_流苏(CSDN)&#xff08;一个喜欢古诗词和编程的Coder&#x1f60a;&#xff09; 目录 1、登录界面2、用户名密码校验3、cookie与session配置①cookie与session②配置…

【机器学习-02】矩阵基础运算---numpy操作

在机器学习-01中&#xff0c;我们介绍了关于机器学习的一般建模流程&#xff0c;并且在基本没有数学公式和代码的情况下&#xff0c;简单介绍了关于线性回归的一般实现形式。不过这只是在初学阶段、为了不增加基础概念理解难度所采取的方法&#xff0c;但所有的技术最终都是为了…

【01】htmlcssgit

01-前端干货-html&css 防脱发神器 一图胜千言 使用border-box控制尺寸更加直观,因此,很多网站都会加入下面的代码 * {margin: 0;padding: 0;box-sizing: border-box; }颜色的 alpha 通道 颜色的 alpha 通道标识了色彩的透明度,它是一个 0~1 之间的取值,0 标识完全…

C语言之快速排序

目录 一 简介 二 代码实现 快速排序基本原理&#xff1a; C语言实现快速排序的核心函数&#xff1a; 三 时空复杂度 A.时间复杂度 B.空间复杂度 C.总结&#xff1a; 一 简介 快速排序是一种高效的、基于分治策略的比较排序算法&#xff0c;由英国计算机科学家C.A.R. H…

【Machine Learning】Suitable Learning Rate in Machine Learning

一、The cases of different learning rates: In the gradient descent algorithm model: is the learning rate of the demand, how to determine the learning rate, and what impact does it have if it is too large or too small? We will analyze it through the follow…

HCIP—OSPF课后练习一

本实验模拟了一个企业网络场景&#xff0c;R1、R2、R3为公司总部网络的路由器&#xff0c;R4、R5分别为企业分支机构1和分支机构2的路由器&#xff0c;并且都采用双上行方式与企业总部相连。整个网络都运行OSPF协议&#xff0c;R1、R2、R3之间的链路位于区域0&#xff0c;R4与R…

Redis和Mysql的数据一致性问题

在高并发的场景下&#xff0c;大量的请求直接访问Mysql很容易造成性能问题。所以我们都会用Redis来做数据的缓存&#xff0c;削减对数据库的请求的频率。 但是&#xff0c;Mysql和Redis是两种不同的数据库&#xff0c;如何保证不同数据库之间数据的一致性就非常关键了。 1、导…

并查集Disjoint Set

并查集的概念 并查集是一种简单的集合表示&#xff0c;它支持以下三种操作 1. make_set(x)&#xff0c;建立一个新集合&#xff0c;唯一的元素是x 2. find_set(x)&#xff0c;返回一个指针&#xff0c;该指针指向包含x的唯一集合的代表&#xff0c;也就是x的根节点 3. union_…
最新文章