【LM、LLM】浅尝二叉树在前馈神经网络上的应用

前言

随着大模型的发展,模型参数量暴涨,以Transformer的为组成成分的隐藏神经元数量增长的越来越多。因此,降低前馈层的推理成本逐渐进入视野。前段时间看到本文介绍的相关工作还是MNIST数据集上的实验,现在这个工作推进到BERT上面来了,再次引起兴趣记录一下。该工作将前馈神经基于二叉树结构进行改装,加速前向传播的速度,称为:快速前馈网络(FFF),然后应用FFF,取代BERT中的前馈网络(FF),实现12个神经元加速推理。

快速前馈网络算法概述

快速前馈网络(Fast Feedforward Network,FFF)是由两部分组成的:节点网络集合 N \mathcal{N} N 和叶子网络集合 L \mathcal{L} L

  • 节点网络集合 N \mathcal{N} N 包含了一组节点网络,每个节点网络都是一个 < dim ⁡ I , n , 1 > \left<\dim_I,n,1\right> dimI,n,1-前馈网络,并在输出上增加了一个 sigmoid 激活函数。这些节点网络按照平衡的可微分二叉树的形式排列,其中 N m + 1 , 2 n N_{m+1,2n} Nm+1,2n N m + 1 , 2 n + 1 N_{m+1,2n+1} Nm+1,2n+1 N m , n N_{m,n} Nm,n 的子节点。

  • 叶子网络集合 L \mathcal{L} L 包含了一组叶子网络,每个叶子网络都是一个 < dim ⁡ I , ℓ , dim ⁡ O > \left<\dim_I,\ell,\dim_O\right> dimI,,dimO-前馈网络。叶子网络没有子节点,它们的输出直接作为 FFF 的输出。

前向传播过程由下面算法控制。

算法的输入包括一个输入样本 ι \iota ι 和根节点 N 0 , 0 N_{0,0} N0,0,输出为该样本在 FFF 中的输出。

算法定义了两个函数: F o r w a r d T Forward_T ForwardT F o r w a r d I {Forward}_I ForwardI。其中, F o r w a r d T {Forward}_T ForwardT 函数用于计算节点的输出,而 F o r w a r d I {Forward}_I ForwardI 函数用于计算节点的指示值(indicator value)。

  • F o r w a r d T {Forward}_T ForwardT 函数中,如果当前节点是叶子节点,则直接调用该节点的前馈传播函数 N m , n ( ι ) N_{m,n}(\iota) Nm,n(ι) 来计算输出。否则,首先计算当前节点的输出 c m , n = N m , n ( ι ) c_{m,n}=N_{m,n}(\iota) cm,n=Nm,n(ι),然后递归地调用 F o r w a r d T {Forward}_T ForwardT 函数来计算当前节点的两个子节点的输出,并将它们加权相加作为当前节点的输出。
  • F o r w a r d I {Forward}_I ForwardI 函数中,如果当前节点是叶子节点,则直接调用该节点的前馈传播函数 N m , n ( ι ) N_{m,n}(\iota) Nm,n(ι) 来计算输出。否则,首先计算当前节点的输出 c m , n = N m , n ( ι ) c_{m,n}=N_{m,n}(\iota) cm,n=Nm,n(ι),然后根据输出值的大小决定选择哪个子节点进行递归计算。


传统前馈神经网络

快速前馈神经网络

与传统的前馈神经网络算法相比,该算法的主要区别在于引入了一个计算节点的指示值。指示值表示了当前节点的输出是否大于等于阈值(这里的阈值为0.5),根据指示值的大小来确定选择哪个子节点进行计算。这种方式可以大大减少计算量,提高前向传播的效率。同时,FFF 是一种具有平衡二叉树结构的前馈神经网络,其中节点网络和叶子网络分别用于处理中间层和输出层的计算。通过利用二叉树结构和递归计算,FFF 可以实现快速的前向传播。

UltraFastBERT

UltraFastBERT,一种BERT变体,在推理过程中使用0.3%的神经元,同时表现 与类似的BERT模型相当。UltraFastBERT选择性地使用4095个神经元中的12个(有选择的执行矩阵乘法(CMM))进行每层推理。这是通过用快速前馈网络(FFFs)取代前馈网络来实现的。

FFF_BMM代码

import torch
from torch import nn
import math

class FFF(nn.Module):
	def __init__(self, input_width: int, depth: int, output_width: int, *args, **kwargs):
		super().__init__(*args, **kwargs)

		self.input_width = input_width
		self.depth = depth
		self.output_width = output_width

		self.n_nodes = 2 ** (depth + 1) - 1
		self.initialise_weights()

	def initialise_weights(self):
		init_factor_l1 = 1.0 / math.sqrt(self.input_width)
		init_factor_l2 = 1.0 / math.sqrt(self.depth + 1)
		self.w1s = nn.Parameter(torch.empty(self.n_nodes, self.input_width).uniform_(-init_factor_l1, +init_factor_l1), requires_grad=True)
		self.w2s = nn.Parameter(torch.empty(self.n_nodes, self.output_width).uniform_(-init_factor_l2, +init_factor_l2), requires_grad=True)

	def forward(self, x):
		# the shape of x is (batch_size, input_width)
		# retrieve the indices of the relevant elements
		batch_size = x.shape[0]
		current_nodes = torch.zeros((batch_size,), dtype=torch.long, device=x.device)
		all_nodes = torch.zeros(batch_size, self.depth+1, dtype=torch.long, device=x.device)
		all_logits = torch.empty((batch_size, self.depth+1), dtype=torch.float, device=x.device)

		for i in range(self.depth+1):
			all_nodes[:, i] = current_nodes
			plane_coeffs = self.w1s.index_select(dim=0, index=current_nodes)			# (batch_size, input_width)
			plane_coeff_score = torch.bmm(x.unsqueeze(1), plane_coeffs.unsqueeze(-1))	# (batch_size, 1, 1)
			plane_score = plane_coeff_score.squeeze(-1).squeeze(-1) 					# (batch_size,)
			all_logits[:, i] = plane_score
			plane_choices = (plane_score >= 0).long()									# (batch_size,)

			current_nodes = current_nodes * 2 + plane_choices + 1						# (batch_size,)

		# get the weights
		selected_w2s = self.w2s.index_select(0, index=all_nodes.flatten()) \
			.view(batch_size, self.depth+1, self.output_width)	# (batch_size, depth+1, output_width)

		# forward pass
		mlp1 = torch.nn.functional.gelu(all_logits)				# (batch_size, depth+1)
		mlp2 = torch.bmm(mlp1.unsqueeze(1), selected_w2s) 		# (batch_size, output_width)
		
		# done
		return mlp2
	

从代码可以看出,与传统的批矩阵乘法(BMM)不同的是,在forward中,基于二叉树的结构,通过迭代计算节点的索引和权重,使用激活函数(GeLU)对结果进行处理,并最终得到输出。

结果

在推理过程中仅使用0.3%的神经元,同时表现与类似的BERT模型相当(下游任务没有降很多点);实现78倍CPU加速,实现40倍PyTorch加速。

总结

该工作很有趣,将传统前馈神经网络定义成一棵二叉树,其叶子是小型神经网络,在每个非叶子节点处都有一个微小的神经网络(单个神经元也可以工作)来决定走哪条路径取决于在输入上。在训练期间,它们对所选路径进行加权平均值,从而得出树的所有叶子(在输入上评估为神经网络)的总加权平均值,但在推理过程中,它们可以只遵循投票最高的分支,从而得出建议的结果指数加速。并且,基于FFF的思想,将工作推到BERT这种语言模型上,初步证明了大模型的前馈层的神经元并不是都需要参与推理。

文章及公开的代码还介绍了条件矩阵乘法的详细细节,因此感兴趣可以深入研究一下。

参考文献

【1】paper:Exponentially Faster Language Modelling,https://arxiv.org/abs/2311.10770
【2】code:https://github.com/pbelcak/fastfeedforward
【3】paper:Fast Feedforward Networks,https://arxiv.org/abs/2308.14711

【4】code:https://github.com/pbelcak/UltraFastBERT
【5】model:https://huggingface.co/pbelcak/UltraFastBERT-1x11-long

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/188100.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

linux 账号管理实例一,stdin,passwd复习

需求 账号名称全名次要用户组是否可登录主机密码 myuser1 1st usermygroup1yespasswordmyuser22st usermygroup1yespasswordmyuser33st user无nopassword 第一&#xff1a;用户&#xff0c;和用户组创建&#xff0c;并分配有效用户组&#xff08;初始用户组是passwd里…

Postman如何使用(二):Postman Collection的创建/使用/导出分享等

一、什么是Postman Collection&#xff1f; Postman Collection是可让您将各个请求分组在一起。 您可以将这些请求组织到文件夹中。中文经常将collection翻译成收藏夹。如果再下文中看到这样的翻译不要觉得意外。Postman Collection会使你的工作效率更上一层楼。Postman Colle…

7、独立按键控制LED状态

按键的抖动 对于机械开关&#xff0c;当机械触点断开、闭合时&#xff0c;由于机械触点的弹性作用&#xff0c;一个开关在闭合时不回马上稳定地接通&#xff0c;在断开时也不会一下子断开&#xff0c;所以在开关闭合及断开的瞬间会伴随一连串的抖动 #include <REGX52.H…

面试必问:如何快速定位BUG?BUG定位技巧及N板斧!

01 定位问题的重要性 很多测试人员可能会说&#xff0c;我的职责就是找到bug&#xff0c;至于找原因并修复&#xff0c;那是开发的事情&#xff0c;关我什么事&#xff1f; 好&#xff0c;我的回答是&#xff0c;如果您只想做一个测试人员最基本最本分的事情&#xff0c;那么可…

RabbitMQ快速学习之WorkQueues模型、三种交换机、消息转换器(SpringBoot整合)

文章目录 前言一、WorkQueues模型消息发送消息接收能者多劳 二、交换机类型1.Fanout交换机消息发送消息接收 2.Direct交换机消息接收消息发送 3.Topic交换机消息发送消息接收 三、编程式声明队列和交换机fanout示例direct示例基于注解 四、消息转换器总结 前言 WorkQueues模型…

visual stdio动态库的使用

导出类和使用方式 #ifndef PCH_H #define PCH_H// 添加要在此处预编译的标头 #include "framework.h"#ifdef _WIN32 #ifdef MYCLASS_EXPORTS #define MYCLASS_API __declspec(dllexport) #else #define MYCLASS_API __declspec(dllimport) #endif #else #define MYC…

『亚马逊云科技产品测评』活动征文|低成本搭建物联网服务器thingsboard

授权声明&#xff1a;本篇文章授权活动官方亚马逊云科技文章转发、改写权&#xff0c;包括不限于在 Developer Centre, 知乎&#xff0c;自媒体平台&#xff0c;第三方开发者媒体等亚马逊云科技官方渠道。 0. 环境 - ubuntu22&#xff08;注意4G内存勉强够&#xff0c;部署完…

大一统模型 Universal Instance Perception as Object Discovery and Retrieval 论文阅读笔记

Universal Instance Perception as Object Discovery and Retrieval 论文阅读笔记 一、Abstract二、引言三、相关工作实例感知通过类别名进行检索通过语言表达式的检索通过指代标注的检索 统一的视觉模型Unified Learning ParadigmsUnified Model Architectures 四、方法4.1 Pr…

【蓝桥杯省赛真题48】Scratch放大镜游戏 蓝桥杯scratch图形化编程 中小学生蓝桥杯省赛真题讲解

目录 scratch放大镜游戏 一、题目要求 编程实现 二、案例分析 1、角色分析

我叫:希尔排序【JAVA】

1.我兄弟存在的问题 2.毛遂自荐 希尔排序提希尔(Donald Shell)于1959年提出的一种排序算法。 希尔排序&#xff0c;也称递减增量排序算法&#xff0c;是插入排序的一种更高效的改进版本。但希尔排序是非稳定排序算法。 希尔排序是基于插入排序的以下两点性质而提出改进方法的&…

【教学类-06-10】20231125(55格版)X-Y之间“乘法*题”(以1-9乘法口诀表为例)(随机抽取和正序抽取)

图片展示 &#xff08;随机打乱排序&#xff09; 正序&#xff08;每张都一样&#xff09; 背景需求&#xff1a; 2023年11月24日&#xff0c;准备了一些题目&#xff0c;分别给大4班孩子介绍“5以内加法、5以内减法、5以内加减混合”““10以内加法、10以内减法、10以内加减…

机器学习之自监督学习(四)MoCo系列翻译与总结(一)

Momentum Contrast for Unsupervised Visual Representation Learning Abstract 我们提出了“动量对比”&#xff08;Momentum Contrast&#xff0c;MoCo&#xff09;来进行无监督的视觉表示学习。从对比学习的角度来看&#xff0c;我们将其视为字典查找&#xff0c;通过构建…

移动机器人路径规划(七)--- 基于MDP的路径规划MDP-Based Planning

目录 1 什么是MDP-Based Planning 2 worst-case analysis for nondeterministic model 3 Expected Cost Planning 4 Real Time Dynamic Programming&#xff08;RTDP&#xff09; 1 什么是MDP-Based Planning 之前我们从起点到终点存在很多可执行路径&#xff0c;我们可以…

物联网后端个人第十二周总结

学习工作进度 物联网方面 1.模拟设备通过规则引擎将数据通过mqtt进行转发 在物联网平台上实现模拟设备通过规则引擎将数据通过mqtt进行转发已经全部完成了&#xff0c;所使用的物联网平台在这方面有不少的问题和bug&#xff0c;也可能是没有按照开发者的想法对平台进行使用才导…

基于微信小程序的员工宿舍报修系统

项目介绍 随着信息技术和网络技术的飞速发展&#xff0c;人类已进入全新信息化时代&#xff0c;传统管理技术已无法高效&#xff0c;便捷地管理信息。为了迎合时代需求&#xff0c;优化管理效率&#xff0c;各种各样的管理系统应运而生&#xff0c;各行各业相继进入信息管理时…

Leetcode—83.删除排序链表中的重复元素【简单】

2023每日刷题&#xff08;四十&#xff09; Leetcode—83.删除排序链表中的重复元素 实现代码 /*** Definition for singly-linked list.* struct ListNode {* int val;* struct ListNode *next;* };*/ struct ListNode* deleteDuplicates(struct ListNode* head) {i…

张弛声音变现课,枪战电影高能量、快速节奏

在执行枪战片的声音配音任务时&#xff0c;配音员应该致力于传递出戏剧性的紧张氛围与动作场面的激烈感。枪战场景往往是高能量、快速节奏的&#xff0c;这就要求配音不仅要与视觉动作紧密结合&#xff0c;还要通过声音来增强动作的逼真度和观众的紧迫感。以下是针对枪战电影进…

Linux(CentOS7)上安装mysql

在CentOS中默认安装有MariaDB&#xff08;MySQL的一个分支&#xff09;&#xff0c;可先移除/卸载MariaDB。 yum remove mariadb // 查看是否存在mariadb rpm -qa|grep -i mariadb // 卸载 mariadb rpm -e --nodeps rpm -qa|grep mariadb yum安装 下载rpm // 5.6版本 wge…

日本运营商启动先进边缘云技术研发

摘要&#xff1a;日本运营商乐天移动最近启动了为 5G 之后的下一个通信标准开发边缘平台功能的研发工作。 乐天移动&#xff08;Rakuten Mobile&#xff09;表示&#xff0c;其面向下一代通信的先进边缘云技术研发&#xff08;R&D&#xff09;项目已被日本国家信息通信技术…

抖音权重查询源码H5源码

源码下载&#xff1a;123网盘
最新文章