大模型自定义算子优化方案学习笔记:CUDA算子定义、算子编译、正反向梯度实现

01算子优化的意义

随着大模型应用的普及以及算力紧缺,下一步对于计算性能的追求一定是技术的核心方向。因为目前大模型的计算逻辑是由一个个独立的算子或者说OP正反向求导实现的,底层往往调用的是GPU提供的CUDA的驱动程序。如果不能对于整个计算过程学习并了解,对于性能优化领域无非是隔靴搔痒,今天也是抽一点时间读了下网上的一些文档和CUDA的文档,整理了学习材料。

首先说下为什么要自定义算子,无非是两个原因,

(1)TF、PyTorch等提供的原生算子不满足需求,需要通过底层接口,比如CUDA层更灵活的实现个性化算子

(2)目前提供的算子性能不足,没有很好的利用到GPU的并行计算优势,有优化空间

接着性能优化的问题说,因为GPU与CPU最大的区别是计算单元占据了绝大部分的体积(图中绿色部分),而控制单元较少。

自定义手写算子可以更好地利用绿色的计算单元的并行化优势。大的思路是Grid包含Block,Block包含Thread。于是首先算子需要把计算逻辑拆分成Thread,让程序可以并行化的运行起来,然后有机的管理各个Block的执行节奏,解决好异步和同步问题,就可以让芯片的计算效率最大化。

Grid of Thread Blocks

02实现流程

整个自定义算子优化其实可以分为CUDA算子定义、算子编译、正方向梯度实现几个步骤。

1、CUDA算子定义

需要定义以下几个文件:

(1)function.cpp:python层和CUDA层的衔接,实现手写的C++ CUDA算子可以被python代码调用

(2)function.h:CUDA函数声明文件

(3)function.cu:CUDA函数的逻辑实现

这里比较核心的文件就是.cu文件,构建的时候主要做两个事:一个是建设Kernel函数,因为只有Kernel函数是在GPU端执行,执行完之后要将控制权给到控制函数,这里要控制好异步、同步的问题。二是在kernel函数中需要通过循环函数定义每个Thread以及每个Block的工作,真正让计算并行化

.cpp文件可以通过pybind函数实现,这个函数主要解决的是C++代码和Python绑定的问题,项目地址:GitHub - pybind/pybind11: Seamless operability between C++11 and Python

2.编译和执行

import torch
from torch.utils.cpp_extension import load
cuda_module = load(name="function",
                   extra_include_paths=["include"],
                   sources=["function.cpp", "function.cu"],
                   verbose=True)

接着就是执行过程中的编译,通过load函数底层会执行JIT(Just in time)的动态编译模式调用.cpp和.cu文件,底层运行的是Ninjia编译器,通过调用nvcc编译.so文件

[1/2] nvcc -c function.cu -o function.cuda.o
[2/3] c++ -c function.cpp -o function.o
[3/3] c++ function.o function.cuda.o -shared -o function.so

3.正反向梯度实现

以上两步实现了自定义算子的逻辑,可以通过手写CUDA算子并在python框架层调用,如果要满足建模需求,还需要实现正方向梯度。具体的做法是在建模的函数中实现forward和backward函数,定义导数作为输出。

以上大概就是手写算子优化的简单流程,仅当学习笔记。

参考:

【1】熬了几个通宵,我写了份CUDA新手入门代码 - 知乎

【2】CUDA C++ Programming Guide

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/252500.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Ubuntu 常用命令之 cp 命令用法介绍

cp命令在Ubuntu系统中用于复制文件或目录。它的基本格式是cp [选项] 源文件或目录 目标文件或目录。 以下是一些常用的cp命令选项 -i:在覆盖目标文件之前将给出提示。-r或-R:递归复制,用于目录的复制操作。-v:详细模式&#xff…

地平线前端实习一面复盘(加深对var的理解+展开运算符+平拍数组)

目录 前言一,var的作用二,展开运算符三,平拍数组总结 前言 地平线的面试,有提示,很专业,体验很好。 可惜后面未收到消息,但还是要做复盘。收获还是很大的。 一,var的作用 且看下…

MySQL中EXPLAIN执行计划的分析

一. 执行计划能告诉我们什么? SQL如何使用索引联接查询的执行顺序查询扫描的数据函数 二. 执行计划中的内容 SQL执行计划的输出可能为多行,每一行代表对一个数据库对象的操作 1. ID列 ID列中的如果数据为一组数字,表示执行SELECT语句的顺…

网络基础(十二):ACL与NAT

目录 一、ACL 1、ACL的概述 2、ACL的分类 3、ACL的应用 4、ACL的组成和基本原理 ​编辑 5、ACL的配置 5.1配置基本ACL 5.2配置高级ACL 二、NAT 1、NAT的概述 2、NAT的分类 3、NAT的工作原理 4、静态NAT的配置 5、动态NAT的配置 6、NAPT(端口映射&am…

自动驾驶技术:驶向未来的智能之路

导言 自动驾驶技术正引领着汽车产业向着更安全、高效、智能的未来演进。本文将深入研究自动驾驶技术的核心原理、关键技术、应用场景以及对交通、社会的深远影响。 1. 简介 自动驾驶技术是基于先进传感器、计算机视觉、机器学习等技术的创新,旨在实现汽车在不需要人…

论文降重系统同义词替换功能的改进方向 快码论文

大家好,今天来聊聊论文降重系统同义词替换功能的改进方向,希望能给大家提供一点参考。 以下是针对论文重复率高的情况,提供一些修改建议和技巧,可以借助此类工具: 标题:论文降重系统同义词替换功能的改进方…

java21特性学习

jdk21下载地址 JDK21文件 JDK21是javaSE平台最新的长期支持版本。 Java SE Java Archive | Oracle JDK21版本说明 JDK 21 Release Notes, Important Changes, and Information JavaSE 版本字符串格式 Version-String Format JavaSE平台采用了基于时间的发布模型,JDK每六个…

虚拟化之安全虚拟化

虚拟化首次引入是在Armv7-A架构中。那时,Hyp模式(在AArch32中相当于EL2)仅在非安全状态下可用。当Armv8.4-A引入时,添加了对安全状态下EL2的支持作为一个可选特性。 当处理器支持安全EL2时,需要使用SCR_EL3.EEL2位从E…

HarmonyOS:使用MindSpore Lite引擎进行模型推理

场景介绍 MindSpore Lite 是一款 AI 引擎,它提供了面向不同硬件设备 AI 模型推理的功能,目前已经在图像分类、目标识别、人脸识别、文字识别等应用中广泛使用。 本文介绍使用 MindSpore Lite 推理引擎进行模型推理的通用开发流程。 基本概念 在进行开…

【elementui笔记:el-table表格的输入校验】

之前做得比较多的校验是在el-form表单里做的,但有时也遇到,需要在table内输入数据,然后校验输入的数据是否符合要求的情况。因此记录一下。 思路: 1.需要借助el-form的校验,el-table外层嵌套一层el-form,使…

Java数组(1)

我是南城余!阿里云开发者平台专家博士证书获得者! 欢迎关注我的博客!一同成长! 一名从事运维开发的worker,记录分享学习。 专注于AI,运维开发,windows Linux 系统领域的分享! 本…

离线无网络环境下配置Python/Anaconda环境踩过的坑

一、前言 如果你同样需要在无网络环境下安装Python环境,这篇博客是一个很好的参考,由于内网没有网络,因此不能使用conda install/pip install等在线下载安装方式,经过个人尝试,推荐以下两种方法。 二、离线安装python…

2023年陕西省安全员C证证考试题库及陕西省安全员C证试题解析

题库来源:安全生产模拟考试一点通公众号小程序 2023年陕西省安全员C证证考试题库及陕西省安全员C证试题解析是安全生产模拟考试一点通结合(安监局)特种作业人员操作证考试大纲和(质检局)特种设备作业人员上岗证考试大…

MIT6.S081-实验准备

实验全程在Vmware虚拟机 (镜像:Ubuntu-20.04-beta-desktop-amd64) 中进行 一、版本控制 1.1 将mit的实验代码克隆到本地 git clone git://g.csail.mit.edu/xv6-labs-2020 1.2 修改本地git配置文件 创建github仓库,记录仓库地址 我的仓库地址就是htt…

基于AT89C51单片机的LED点阵显示屏设计

点击链接获取Keil源码与Project Backups仿真图: [[https://download.csdn.net/download/qq_64505944/88637464?spm1001.2014.3001.5503]] **[源码获取] B 源码仿真图课程设计50 工程实训(三)课题设计 班级: …

【面试】Java最新面试题资深开发-Java中的垃圾回收机制

问题七:Java中的垃圾回收机制 请简要解释Java中的垃圾回收机制是如何工作的,以及它的优缺点。如果可能,请提供一些垃圾回收器的例子,以及它们在不同场景中的适用性。 Java垃圾回收机制 工作原理: Java垃圾回收机制…

linux(centos7)离线安装mysql-5.7.35-1.el7.x86_64.rpm-bundle.tar

1. 卸载mariadb相关rpm # 查找 rpm -qa|grep mariadb rpm -qa|grep mysql# 卸载 rpm -e --nodeps mariadb... rpm -e --nodeps mysql...2. 删除mysql相关文件 # 查找 find / -name mysql# 删除 rm -rf /var/lib/mysql...3. 查看是否有相关依赖,没有需安装 rpm -q…

考虑用序列化代理代替序列化实例

import java.io.*;// 用户类 class User implements Serializable {private String username;private String password;private String email;public User(String username, String password, String email) {this.username username;this.password password;this.email ema…

CentOS 7 部署 Nacos-2.3.0 (单机版)

CentOS 7 部署 Nacos-2.3.0 (单机版) 1. 下载 Nacos 安装包 历史版本:https://github.com/alibaba/nacos/releases/ 我选的是 2.3.0 版本,https://github.com/alibaba/nacos/releases/download/2.3.0/nacos-server-2.3.0.tar.g…

从传统型数据库到非关系型数据库

一 什么是数据库 数据库顾名思义保存数据的仓库,其本质是一个具有数据存储功能的复杂系统软件,数据库最终把数据保存在计算机硬盘,但数据库并不是直接读写数据在硬盘,而是中间隔了一层操作系统,通过文件系统把数据保存…
最新文章