【三维几何学习】从零开始网格上的深度学习-3:Transformer篇(Pytorch)

本文参加新星计划人工智能(Pytorch)赛道:https://bbs.csdn.net/topics/613989052

从零开始网格上的深度学习-3:Transformer篇

  • 引言
  • 一、概述
  • 二、核心代码
    • 2.1 位置编码
    • 2.2 网络框架
  • 三、基于Transformer的网格分类
    • 3.1 分类结果
    • 3.2 全部代码

引言

本文主要内容如下:

  • 简述网格上的位置编码
  • 参考点云上的Transformer-1:PCT:Point cloud transformer,构造网格分类网络

一、概述

在这里插入图片描述

个人认为对于三角形网格来说,想要将Transformer应用到其上较为重要的一步是位置编码。三角网格在3D空间中如何编码每一个元素的位置,能尽可能保证的泛化性能? 以xyz坐标为例,最好是模型经过对齐的预处理,使朝向一致。或者保证网格水密的情况下使用谱域特征,如热核特征。或者探索其他位置编码等等… 上图为一个外星人x坐标的位置编码可视化

  • 使用简化网格每一个面直接作为一个Token即可,高分辨率的网格(考虑输入特征计算、训练数据对齐等)并不适合深度学习(个人认为)
  • 直接应用现有的Tranformer网络框架、自注意力模块等,细节或参数需要微调

二、核心代码

2.1 位置编码

使用每一个网格面的中心坐标作为位置编码,计算代码在DataLoader中

  • 需要平移到坐标轴原点,并进行尺度归一化
# xyz
xyz_min = np.min(vs[:, 0:3], axis=0)
xyz_max = np.max(vs[:, 0:3], axis=0)
xyz_move = xyz_min + (xyz_max - xyz_min) / 2
vs[:, 0:3] = vs[:, 0:3] - xyz_move
# scale
scale = np.max(vs[:, 0:3])
vs[:, 0:3] = vs[:, 0:3] / scale
# 面中心坐标
xyz = []
for i in range(3):
    xyz.append(vs[faces[:, i]])
xyz = np.array(xyz)  # 转为np
mean_xyz = xyz.sum(axis=0) / 3

2.2 网络框架

在这里插入图片描述

  • 参考上图PCT框架,修改了部分细节,如减少了Attention模块数量等

在这里插入图片描述

  • 参考上图自注意力模块,个人感觉图中应该有误. 从一个共享权重的Linear里出来了 Q 、 K 、 V Q、K、V QKV三个矩阵,但 V V V的维度和 Q 、 K Q、K QK不一致,少画了一个Linear?
class SA(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.q_conv = nn.Conv1d(channels, channels // 4, 1, bias=False)
        self.k_conv = nn.Conv1d(channels, channels // 4, 1, bias=False)
        self.q_conv.weight = self.k_conv.weight
        self.v_conv = nn.Conv1d(channels, channels, 1, bias=False)
        self.trans_conv = nn.Conv1d(channels, channels, 1)
        self.after_norm = nn.BatchNorm1d(channels)
        self.act = nn.GELU()
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        x_q = self.q_conv(x).permute(0, 2, 1)
        x_k = self.k_conv(x)
        x_v = self.v_conv(x)
        energy = x_q @ x_k
        attention = self.softmax(energy)
        attention = attention / (1e-9 + attention.sum(dim=1, keepdims=True))
        x_r = x_v @ attention
        x_r = self.act(self.after_norm(self.trans_conv(x - x_r)))
        x = x + x_r
        return x


class TriTransNet(nn.Module):
    def __init__(self, dim_in, classes_n=30):
        super().__init__()
        self.conv_fea = FaceConv(6, 128, 4)
        self.conv_pos = FaceConv(3, 128, 4)
        self.bn_fea = nn.BatchNorm1d(128)
        self.bn_pos = nn.BatchNorm1d(128)
        self.sa1 = SA(128)
        self.sa2 = SA(128)
        self.gp = nn.AdaptiveAvgPool1d(1)
        self.linear1 = nn.Linear(256, 128, bias=False)
        self.bn1 = nn.BatchNorm1d(128)
        self.linear2 = nn.Linear(128, classes_n)
        self.act = nn.GELU()

    def forward(self, x, mesh):
        x = x.permute(0, 2, 1).contiguous()
        # 位置编码 放到DataLoader中比较好
        pos = [m.xyz for m in mesh]
        pos = np.array(pos)
        pos = torch.from_numpy(pos).float().to(x.device).requires_grad_(True)

        batch_size, _, N = x.size()
        x = self.act(self.bn_fea(self.conv_fea(x, mesh).squeeze(-1)))
        pos = self.act(self.bn_pos(self.conv_pos(pos, mesh).squeeze(-1)))
        x1 = self.sa1(x + pos)
        x2 = self.sa2(x1 + pos)
        x = torch.cat((x1, x2), dim=1)
        x = self.gp(x)
        x = x.view(batch_size, -1)
        x = self.act(self.bn1(self.linear1(x)))
        x = self.linear2(x)
        return x

三、基于Transformer的网格分类

数据集是SHREC’11 可参考三角网格(Triangular Mesh)分类数据集 或 MeshCNN

3.1 分类结果

在这里插入图片描述在这里插入图片描述
准确率太低… 可以尝试改进的点:

  • 尝试不同的位置编码(谱域特征),不同的位置嵌入方式 (sum可改为concat)
  • 数据集较小的情况下Transformer略难收敛,加入更多CNN可加速且提升明显 (或者加入降采样)
  • 打印loss进行分析,是否欠拟合,尝试增加网络参数?

基于Transformer的网络在网格分割上的表现会很好,仅用少量参数即可媲美甚至超过基于面卷积的分割结果,个人感觉得益于其近乎全局的感受野…

3.2 全部代码

DataLoader代码请参考2:从零开始网格上的深度学习-1:输入篇(Pytorch)
FaceConv代码请参考3:从零开始网格上的深度学习-2:卷积网络CNN篇

import torch
import torch.nn as nn
import numpy as np
from CNN import FaceConv
from DataLoader_shrec11 import DataLoader
from DataLoader_shrec11 import Mesh


class SA(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.q_conv = nn.Conv1d(channels, channels // 4, 1, bias=False)
        self.k_conv = nn.Conv1d(channels, channels // 4, 1, bias=False)
        self.q_conv.weight = self.k_conv.weight
        self.v_conv = nn.Conv1d(channels, channels, 1, bias=False)
        self.trans_conv = nn.Conv1d(channels, channels, 1)
        self.after_norm = nn.BatchNorm1d(channels)
        self.act = nn.GELU()
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        x_q = self.q_conv(x).permute(0, 2, 1)
        x_k = self.k_conv(x)
        x_v = self.v_conv(x)
        energy = x_q @ x_k
        attention = self.softmax(energy)
        attention = attention / (1e-9 + attention.sum(dim=1, keepdims=True))
        x_r = x_v @ attention
        x_r = self.act(self.after_norm(self.trans_conv(x - x_r)))
        x = x + x_r
        return x


class TriTransNet(nn.Module):
    def __init__(self, dim_in, classes_n=30):
        super().__init__()
        self.conv_fea = FaceConv(6, 128, 4)
        self.conv_pos = FaceConv(3, 128, 4)
        self.bn_fea = nn.BatchNorm1d(128)
        self.bn_pos = nn.BatchNorm1d(128)
        self.sa1 = SA(128)
        self.sa2 = SA(128)
        self.gp = nn.AdaptiveAvgPool1d(1)
        self.linear1 = nn.Linear(256, 128, bias=False)
        self.bn1 = nn.BatchNorm1d(128)
        self.linear2 = nn.Linear(128, classes_n)
        self.act = nn.GELU()

    def forward(self, x, mesh):
        x = x.permute(0, 2, 1).contiguous()
        # 位置编码 放到DataLoader中比较好
        pos = [m.xyz for m in mesh]
        pos = np.array(pos)
        pos = torch.from_numpy(pos).float().to(x.device).requires_grad_(True)

        batch_size, _, N = x.size()
        x = self.act(self.bn_fea(self.conv_fea(x, mesh).squeeze(-1)))
        pos = self.act(self.bn_pos(self.conv_pos(pos, mesh).squeeze(-1)))
        x1 = self.sa1(x + pos)
        x2 = self.sa2(x1 + pos)
        x = torch.cat((x1, x2), dim=1)
        x = self.gp(x)
        x = x.view(batch_size, -1)
        x = self.act(self.bn1(self.linear1(x)))
        x = self.linear2(x)
        return x


if __name__ == '__main__':
    # 输入
    data_train = DataLoader(phase='train')         # 训练集
    data_test = DataLoader(phase='test')           # 测试集
    print('#train meshes = %d' % len(data_train)) # 输出训练模型个数
    print('#test  meshes = %d' % len(data_test))  # 输出测试模型个数

    # 网络
    net = TriTransNet(data_train.input_n, data_train.class_n)    # 创建网络 以及 优化器
    optimizer = torch.optim.Adam(net.parameters(), lr=0.001, betas=(0.9, 0.999))
    net = net.cuda(0)
    loss_fun = torch.nn.CrossEntropyLoss(ignore_index=-1)
    num_params = 0
    for param in net.parameters():
        num_params += param.numel()
    print('[Net] Total number of parameters : %.3f M' % (num_params / 1e6))
    print('-----------------------------------------------')

    # 迭代训练
    for epoch in range(1, 201):
        print('---------------- Epoch: %d -------------' % epoch)
        for i, data in enumerate(data_train):
            # 前向传播
            net.train(True)        # 训练模式
            optimizer.zero_grad()  # 梯度清零
            face_features = torch.from_numpy(data['face_features']).float()
            face_features = face_features.to(data_train.device).requires_grad_(True)
            labels = torch.from_numpy(data['label']).long().to(data_train.device)
            out = net(face_features, data['mesh'])     # 输入到网络

            # 反向传播
            loss = loss_fun(out, labels)
            loss.backward()
            optimizer.step()              # 参数更新

        # 测试
        net.eval()
        acc = 0
        for i, data in enumerate(data_test):
            with torch.no_grad():
                # 前向传播
                face_features = torch.from_numpy(data['face_features']).float()
                face_features = face_features.to(data_test.device).requires_grad_(False)
                labels = torch.from_numpy(data['label']).long().to(data_test.device)
                out = net(face_features, data['mesh'])
                # 计算准确率
                pred_class = out.data.max(1)[1]
                correct = pred_class.eq(labels).sum().float()
            acc += correct
        acc = acc / len(data_test)
        print('epoch: %d, TEST ACC: %0.2f' % (epoch, acc * 100))


  1. PCT:Point cloud transformer ↩︎

  2. 从零开始网格上的深度学习-1:输入篇(Pytorch) ↩︎

  3. 从零开始网格上的深度学习-2:卷积网络CNN篇 ↩︎

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/2124.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

linux中写定时任务

场景:我们生产环境中有大量的日志记录,但是我们的磁盘没有太大,需要定时清理磁盘 文章目录crond 定时任务详解安装定时任务crontab服务启动与关闭crontab操作crontab 命令test.sh查看日志丢弃linux中的执行日志Linux进入nano模式方式一方式二…

Unreal Engine 网络系统(四):UEC++的RPC

目录 行为同步 On Server:服务端的RPC代码 On Client:客户端的RPC代码 NetMulticast:广播的RPC代码 属性同步 行为同步 借助UFUNCTION进行函数标记 UFUNCTION(Server):声明一个在客户端调用,在服务端执行的函数U…

测试老鸟都在用的接口抓包常用工具以及接口测试工具都有哪些?

目录 接口 接口测试的重要性 常用抓包工具 常用接口测试工具 接口 接口测试是测试系统组件间接口的一种测试。接口测试主要用于检测外部系统与系统之间以及内部各个子系统之间的交互点。测试的重点是要检查数据的交换,传递和控制管理过程,以及系统间…

pkg打包node项目到linux中运行

首先看一下pkg的一些基本操作 pkg打包node项目为exe_node静态项目 导出exe_疆~的博客-CSDN博客由于win7最高只支持node13.14.0,而pkg不支持node13,为了既兼容win7,又能使用pkg打包,故使用node12.22.11。新建node_global和node_ca…

这一次,吃了Redis的亏,也败给了GPT

关注【离心计划】,一起离开地球表面 背景 组内有一个系统中有一个延迟任务的需求,关于延迟任务常见的做法有时间轮、延迟MQ还有Redis Zset等方案,关于时间轮,这边小苏有一个大学时候做的demo: https://github.com/JA…

好用的5款国产低代码平台介绍

一、云程低代码平台 云程低代码平台是一款基于springboot、vue.js技术的企业级低代码开发平台,平台采用模型驱动、高低码融合、开放扩展等设计理念,基于业务建模、流程建模、表单建模、报表建模、大屏建模等可视化建模工具,通过拖拉拽零代码方…

安装flume

flume最主要的作用就是实时读取服务器本地磁盘的数据,将数据写入到hdfs中架构:开始安装一,上传压缩包,解压并更名解压:[rootsiwen install]# tar -zxf apache-flume-1.9.0-bin.tar.gz -C ../soft/[rootsiwen install]#…

太强了,英伟达面对ChatGPT还有这一招...

大家好,我是 Jack。 今年可谓是 AI 元年,ChatGPT、AIGC、VITS 都火了一波。 我也先后发布了这几期视频: 这是一个大模型的时代,AI 能在文本、图像、音频等领域大放异彩,得益于大模型。而想要预训练大模型&#xff0c…

nodejs篇 express(1)

文章目录前言express介绍安装RESTful接口规范express的简单使用一个最简单的服务器,仅仅只需要几行代码便可以实现。restful规范的五种接口类型请求信息req的获取响应信息res的设置中间件的使用自定义中间件解决跨域nodejs相关其它内容前言 express作为nodejs必学的…

前缀树(字典树/Trie) -----Java实现

目录 一.前缀树 1.什么是前缀树 2.前缀树的举例 二.前缀树的实现 1.前缀树的数据结构 1.插入字符串 2.查找字符串 3.查找前缀 三.词典中最长的单词 1.题目描述 2.问题分析 3.代码实现 一.前缀树 1.什么是前缀树 字典树(Trie树)是一种树形…

机器学习——无监督学习

机器学习的分类一般分为下面几种类别:监督学习( supervised Learning )无监督学习( Unsupervised Learning )强化学习( Reinforcement Learning,增强学习)半监督学习( Semi-supervised Learning )深度学习(Deep Learning)Python Scikit-learn. http: // …

用Pytorch构建一个喵咪识别模型

本文参加新星计划人工智能(Pytorch)赛道:https://bbs.csdn.net/topics/613989052 目录 一、前言 二、问题阐述及理论流程 2.1问题阐述 2.2猫咪图片识别原理 三、用PyTorch 实现 3.1PyTorch介绍 3.2PyTorch 构建模型的五要素 3.3PyTorch 实现的步骤 3.3.…

app自动化测试——app自动化控制、常见控件定位方法

文章目录一、app自动化控制1、清理数据:2、启动:3、关闭:二、常见控件定位方法1、android知识2、ios 基础知识3、元素定位4、控件基础知识5、app dom 结构解析6、iOS 与 Android dom 结构的区别7、定位方法测试步骤三要素定位方式&#xff1a…

大环境不好,找工作太难?三面阿里,幸好做足了准备,已拿offer

三面大概九十分钟,问的东西很全面,需要做充足准备,就是除了概念以外问的有点懵逼了(呜呜呜)。回来之后把这些题目做了一个分类并整理出答案(强迫症的我狂补知识)分为软件测试基础、Python自动化…

超专业解析!10分钟带你搞懂Linux中直接I/O原理

我们先看一张图: 这张图大体上描述了 Linux 系统上,应用程序对磁盘上的文件进行读写时,从上到下经历了哪些事情。 这篇文章就以这张图为基础,介绍 Linux 在 I/O 上做了哪些事情。 文件系统 什么是文件系统 文件系统&#xff0…

docker版jxTMS使用指南:数据查询

本文讲解docker版jxTMS的数据查询,整个系列的文章请查看:docker版jxTMS使用指南 请按前文所述先做好相关的准备工作,然后多在helloWorld界面输入各种数据后点【点我】按钮,以多创建点数据来为查询做下准备。 分页查询 首先在we…

python网上选课系统django-PyCharm

学生选课信息管理系统,可以有效的对学生选课信息、学生个人信息、教师个人信息等等进行管理。 开发语言:Python 框架:django Python版本:python3.7.7 数据库:mysql 数据库工具:Navicat11 开发软件&#x…

RK3588平台开发系列讲解(NPU篇)NPU调试方法

平台内核版本安卓版本RK3588Linux 5.10Android 12文章目录 一、日志等级二、NPU 支持查询设置项沉淀、分享、成长,让自己和他人都能有所收获!😄 📢本篇我们一起来看一下NPU的调试方法。 一、日志等级 NPU 的运行库会根据开发板上的系统环境变量输出一些日志信息或者生成…

操作系统(2.4.5)--管程机制

1.管程的定义 利用共享数据结构抽象地表示系统中的共享资源,而把对该共享数据结构实施的操作定义为一组过程进程对共享资源的申请、释放和其它操作,都是通过这组过程对共享数据结构的操作来实现的,这组过程还可以根据资源的情况,或…

yolov8训练筷子点数数据集

序言 yolov8发布这么久了,一直没有机会尝试一下,今天用之前自己制作的筷子点数数据集进行训练,并且记录一下使用过程以及一些常见的操作方式,供以后翻阅。 一、环境准备 yolov8的训练相对于之前的yolov5简单了很多,…
最新文章