【自制C++深度学习推理框架】Layer的设计思路

Layer的设计思路

Layer的抽象

如果将深度学习中的所有层分为两类, 那么肯定是"带权重"的层和"不带权重"的层。

基于层的共性,我们定义了一个Layer的基类,提供了一些基本接口,并可以通过继承和多态机制实现不同类型的Layer。

具体来说,该类包括以下几个成员函数:

  1. 构造函数 Layer(std::string layer_name),用于创建一个Layer对象并设置该层的名称。

  2. virtual ~Layer() = default,虚析构函数,在派生类中可以通过override关键字重新定义。

  3. virtual InferStatus Forward(const std::vector<std::shared_ptr<Tensor<float>>> &inputs, std::vector<std::shared_ptr<Tensor<float>>> &outputs) ,前向传播函数,将输入tensor作为参数,计算输出tensor。

  4. virtual const std::vector<std::shared_ptr<Tensor<float>>> &weights() const, 返回当前层的权重数组。

  5. virtual const std::vector<std::shared_ptr<Tensor<float>>> &bias() const, 返回当前层的偏置数组。

  6. virtual void set_weights(const std::vector<std::shared_ptr<Tensor<float>>> &weights),设置当前层的权重数组。

  7. virtual void set_bias(const std::vector<std::shared_ptr<Tensor<float>>> &bias),设置当前层的偏置数组。

  8. virtual void set_weights(const std::vector<float> &weights),将权重数据类型转换为shared_ptr后调用上述函数。

  9. virtual void set_bias(const std::vector<float> &bias),将偏置数据类型转换为shared_ptr后调用上述函数。

  10. virtual const std::string &layer_name() const,返回当前层的名称。

而成员变量只有一个,即

  • std::string layer_name_,Layer的名称

为什么定义成虚函数

在神经网络中,不同的层具有不同的结构和运算方式,因此需要不同的函数来实现它们。使用虚函数的方法可以将这些不同的函数封装到一个基类中,并通过多态机制来实现不同类型的层的动态绑定。

具体来说,当使用基类指针或引用调用虚函数时,程序会根据对象的动态类型(即实际指向的派生类类型)来选择相应的函数实现。这就使得不同类型的层可以通过共同的接口进行调用,从而提高了代码的可维护性和可扩展性。

此外,使用虚函数还可以方便地定义抽象类,即只声明虚函数但不提供实现的类。这可以为子类提供一个规范化的接口,要求其必须重写某些接口以满足特定的需求。这种机制可以有效避免在大型工程中出现微小的差错而导致底层实现不符合最终需求的问题。

带权重Layer的实现

我们把Layer基类来表示不带参数的Layer,并且通过继承该Layer基类的方式来定义了一个带参数的层ParamLayer子类,在ParamLayer中定义了成员变量bias和weights。

ParamLayer是具有可调参数的神经网络层实现,包括初始化权重和偏置的函数、重载读写权重和偏置的函数,以及保存权重和偏置的成员变量。

具体来说,该类包括以下几个成员函数和成员变量:

  1. 构造函数 ParamLayer(const std::string &layer_name),用于创建一个ParamLayer对象并设置该层的名称。

  2. void InitWeightParam(const uint32_t param_count, const uint32_t param_channel, const uint32_t param_height, const uint32_t param_width),用于初始化权重参数。

  3. void InitBiasParam(const uint32_t param_count, const uint32_t param_channel, const uint32_t param_height, const uint32_t param_width),用于初始化偏置参数。

  4. const std::vector<std::shared_ptr<Tensor<float>>> &weights() const override,重载虚函数weights(),返回保存权重参数的成员变量weights_

  5. const std::vector<std::shared_ptr<Tensor<float>>> &bias() const override,重载虚函数bias(),返回保存偏置参数的成员变量bias_

  6. void set_weights(const std::vector<float> &weights) override,重载虚函数set_weights(),将权重数据类型转换为shared_ptr后存储在成员变量weights_中。

  7. void set_bias(const std::vector<float> &bias) override,重载虚函数set_bias(),将偏置数据类型转换为shared_ptr后存储在成员变量bias_中。

  8. void set_weights(const std::vector<std::shared_ptr<Tensor<float>>> &weights) override,重载虚函数set_weights(),将参数复制到成员变量weights_中。

  9. void set_bias(const std::vector<std::shared_ptr<Tensor<float>>> &bias) override,重载虚函数set_bias(),将参数复制到成员变量bias_中。

  10. 成员变量std::vector<std::shared_ptr<Tensor<float>>> weights_,保存ParamLayer的权重参数。

  11. 成员变量std::vector<std::shared_ptr<Tensor<float>>> bias_,保存ParamLayer的偏置参数。

ParamLayer通过继承Layer类实现了一些共同接口,并在此基础上扩展了更多函数和成员,可以方便地实现带有参数的神经网络层。

Layer的注册机制

为了实现注册和创建神经网络层,并在运行时动态地生成不同类型的神经网络层,定义了两个类:LayerRegisterer和LayerRegistererWrapper。

具体来说,LayerRegisterer类提供了三个静态函数和一个静态成员变量:

  1. typedef ParseParameterAttrStatus (*Creator)(const std::shared_ptr<RuntimeOperator> &op, std::shared_ptr<Layer> &layer):定义了一个函数指针类型Creator,用于指向具体神经网络层的函数。

  2. typedef std::map<std::string, Creator> CreateRegistry:定义了一个映射类型CreateRegistry,用于保存层类型和对应创建函数的映射关系。

  3. static void RegisterCreator(const std::string &layer_type, const Creator &creator):将层类型和创建函数的映射关系注册到CreateRegistry中。

  4. static std::shared_ptr<Layer> CreateLayer(const std::shared_ptr<RuntimeOperator> &op):根据输入的op对象创建相应的神经网络层。

  5. static CreateRegistry &Registry():返回当前已经注册的所有层类型和创建函数的映射关系。

RuntimeOperator是计算图的某个计算节点,里面保存了计算节点所需的参数等信息,具体介绍请看3.Graph.md

而LayerRegistererWrapper类则提供了一个构造函数,用于将某一种类型的神经网络层和其创建函数注册到LayerRegisterer中,如下所示。

class LayerRegistererWrapper {
 public:
  LayerRegistererWrapper(const std::string &layer_type, const LayerRegisterer::Creator &creator) {
    LayerRegisterer::RegisterCreator(layer_type, creator);
  }
};

在LayerRegisterer类中,通过维护一个键值对(<std::string, Creator>CreateRegistry,管理Layer注册表,在注册和查找Layer时都要先检查一下是否注册,如果未注册输出错误信息。

为什么要把成员函数定义为静态的

静态函数与类相关联,而不是与类的对象相关。因此,静态函数可以在没有创建类的实例的情况下调用,从而方便地提供一些辅助函数或管理函数,例如工厂方法、单例等。

LayerRegisterer和LayerRegistererWrapper中定义的所有函数都是静态的,主要原因是这些函数需要全局地维护层类型和创建函数的映射关系,并控制新层类型的注册和创建过程。使用静态函数可以使得这些功能在整个程序中被共享和访问,同时避免了由于对象实例的含糊不清而导致的编码错误。

另外需要注意的是,静态函数可以直接使用静态成员变量,不需要通过对象来访问,这使得这些静态函数可以更容易地协同工作,并兼顾了效率和灵活性。

阅读的代码

  • include
    • layer
      • abstract
        • layer_factory.hpp
        • layer.hpp
        • param_layer.hpp
  • source
    • layer
      • abstract
        • layer.cpp
        • layer_factory.cpp
        • param_layer.cpp

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/25184.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

全面解析Linux指令和权限管理

目录 一.指令再讲解1.时间相关的指令2.find等搜索指令与grep指令3.打包和压缩相关的指令4.一些其他指令与热键二.Linux权限1.Linux的权限管理2.文件类型与权限设置3.目录的权限与粘滞位 一.指令再讲解 1.时间相关的指令 date指令: date 用法&#xff1a;date [OPTION]… [FOR…

如何在Linux中更改SSH端口?

SSH&#xff08;Secure Shell&#xff09;是一种安全的远程登录协议&#xff0c;它允许您通过网络远程连接到Linux系统并进行管理操作。默认情况下&#xff0c;SSH使用22端口进行通信。然而&#xff0c;为了增强系统的安全性&#xff0c;有时候我们需要更改SSH端口&#xff0c;…

linux 找回root密码(CentOS7.6)

linux 找回root密码(CentOS7.6) 首先&#xff0c;启动系统&#xff0c;进入开机界面&#xff0c;在界面中按“e”进入编辑界面。如图 2. 进入编辑界面&#xff0c;使用键盘上的上下键把光标往下移动&#xff0c;找到以““Linux16”开头内容所在的行数”&#xff0c;在行的最后…

C4D R26 渲染学习笔记 建模篇(2):手动建模

文章目录 前文回顾介绍篇建模篇 手动建模建模快捷键手动模型快捷键大全常用操作N系快捷键K系快捷键U系快捷键 结尾 前文回顾 介绍篇 C4D R26 渲染学习笔记&#xff08;1&#xff09;&#xff1a;C4D版本选择和初始UI框介绍 C4D R26 渲染学习笔记&#xff08;2&#xff09;&am…

Dubbo高可用

1.zookeeper宕机与dubbo直连 1.1.现象&#xff1a;zookeeper注册中心宕机&#xff0c;还可以消费dubbo暴露的服务。 原因&#xff1a; 监控中心宕掉不影响使用&#xff0c;只是丢失部分采样数据数据库宕掉后&#xff0c;注册中心仍能通过缓存提供服务列表查询&#xff0c;但…

软考A计划-试题模拟含答案解析-卷十二

点击跳转专栏>Unity3D特效百例点击跳转专栏>案例项目实战源码点击跳转专栏>游戏脚本-辅助自动化点击跳转专栏>Android控件全解手册点击跳转专栏>Scratch编程案例 &#x1f449;关于作者 专注于Android/Unity和各种游戏开发技巧&#xff0c;以及各种资源分享&am…

AB Test数学原理及金融风控应用

1 什么是AB Test AB测试是一种常用的实验设计方法&#xff0c;用于比较两个或多个不同处理或策略的效果&#xff0c;以确定哪个处理或策略在某个指标上表现更好。在AB测试中&#xff0c;将随机选择一部分用户或样本&#xff0c;将其分为两个或多个组&#xff0c;每个组应用不同…

Java: IO流

1.定义 IO流:存储和读取数据的解决方案 用于读写文件中的数据&#xff08;可以读写文件&#xff0c;或网络中的数据...) 2.IO流的分类 1.按着流的方向 1.输入流&#xff1a;读取 2.输出流&#xff1a;写出 2.按照操作文件类型 1.字节流&#xff1a;所有类型文件 体系&…

机器学习-5 朴素贝叶斯算法

朴素贝叶斯算法 算法概述数理统计学处理的信息古典学派和贝叶斯学派的争论贝叶斯定理朴素贝叶斯分类训练朴素贝叶斯&#xff1a;朴素假设案例&#xff1a;预测打网球拉普拉斯平滑技术小结 算法流程与步骤算法应用sklearn中的朴素贝叶斯朴素贝叶斯的使用算法实例 算法概述 数理…

【服务器】使用Nodejs搭建HTTP web服务器

Yan-英杰的主页 悟已往之不谏 知来者之可追 C程序员&#xff0c;2024届电子信息研究生 目录 前言 1.安装Node.js环境 2.创建node.js服务 3. 访问node.js 服务 4.内网穿透 4.1 安装配置cpolar内网穿透 4.2 创建隧道映射本地端口 5.固定公网地址 [TOC] 转载自内网穿透…

一个完整的APP定制开发流程是怎样的?

随着移动互联网的发展&#xff0c;越来越多的 APP应用软件进入人们的生活&#xff0c;让我们的生活更便捷、更舒适。而随着互联网技术的进步&#xff0c;移动互联网应用软件开发行业也越来越成熟&#xff0c;为了适应市场需求&#xff0c;各种功能强大、性能良好的 APP应用软件…

C/C++ ---- 内存管理

目录 C/C内存分布 常见区域介绍 经典习题&#xff08;读代码回答问题&#xff09; 选择题 填空题 C语言内存管理方式 malloc/free calloc realloc C内存管理方式 new和delete操作内置类型 new和delete操作自定义类型 operator new和operator delete函数 new和dele…

C++11常用的一部分新特性

C11 统一的列表初始化&#xff5b;&#xff5d;初始化std::initializer_list 声明autodecltypenullptr STL中一些变化新容器已有容器的新接口 右值引用和移动语义左值引用和右值引用右值引用使用场景和意义右值引用引用左值及其一些更深入的使用场景分析完美转发 新的类功能默认…

opencv_c++学习(二十四)

一、积分图像 积分图像是对原图像进行积分操作的算法。如上图左所示&#xff0c;不同颜色代表不同区域。当我们想求取一个像素点的积分值时&#xff0c;我们需要求取该点左上方区域的数据之和&#xff0c;如P0的积分值是浅蓝色区域的数据之和。 P1的积分值为蓝色和橙色区域的数…

网页JS自动化脚本(八)使用网页专属数据库indexedDB进行数据收集

我们在网页上进行的活动,往往都需要进行收集一些简单的数据,但是因为浏览器的安全原因,浏览器基本上是无法与本地的操作系统直接产生数据交互的,这本来就是一个由于安全问题生产的无解问题,在浏览器里面是内置了几种数据库的,其中一种就是indexedDB,可以用来储存一些非常小的数…

C++进阶 —— 线程库(C++11新特性)

十&#xff0c;线程库 thread类的简单介绍 在C11之前涉及多线程问题&#xff0c;都是和平台相关的&#xff0c;如windows和Linux下各有自己的接口&#xff0c;这使代码的可移植性较差&#xff1b;C11中最重要的特性就是对线程进行支持&#xff0c;使得C在并行编程时不需要依赖…

Axure教程—水平方向多色图(中继器)

本文将教大家如何用AXURE制作动态水平方向多色图 一、效果介绍 如图&#xff1a; 预览地址&#xff1a;https://l83ucp.axshare.com 下载地址&#xff1a;https://download.csdn.net/download/weixin_43516258/87822666 二、功能介绍 简单填写中继器内容即可生成动态水平多色…

Linux-模拟一个简单的shell

什么是shell外壳&#xff1f;就是操作系统给我们的一个命令行解释器&#xff0c;在Linux系统中&#xff0c;它的shell叫做bash。 那么bash本质是什么呢&#xff1f; 本质就是一个文件&#xff0c;一个进程。 万物皆文件 每个操作系统的shell都是很复杂的&#xff0c;想要…

【Matter】使用chip tool在ESP32-C3上进行matter开发

文章目录 使用chip tool在ESP32-C3上进行matter开发前提准备编译 chip-tool1.激活esp-matter环境2.编译matter所需环境3.构建CHIP TOOL chip-tool client 调试设备说明1.基于 BLE 调试2.通过IP与设备配对3.Trust store4.忘记当前委托的设备 使用chip-tool点灯1.matter环境激活2…

linuxOPS基础_Linux系统的文件目录结构及用途

linux系统文件目录结构 Linux 系统不同于 Windows&#xff0c;没有 C 盘、D 盘、E 盘那么多的盘符&#xff0c;只有一个根目录&#xff08;/&#xff09;&#xff0c;所有的文件&#xff08;资源&#xff09;都存储在以根目录&#xff08;/&#xff09;为树根的树形目录结构中…
最新文章