Llama 架构分析

从代码角度进行Llama 架构分析

  • Llama 架构分析
    • 前言
    • Llama 架构分析
      • 分词
      • 网络主干
        • DecoderLayer
          • Attention
          • MLP
      • 下游任务
        • 因果推理
        • 文本分类

Llama 架构分析

前言

Meta 开发并公开发布了 Llama系列大型语言模型 (LLM),这是一组经过预训练和微调的生成文本模型,参数规模从 70 亿到 700 亿不等。

在大多数任务中,LLaMA-13B要比GPT-3(175B)的性能要好,LLaMA-65B和组好的模型Chinchilla-70B以及PaLM-540B的实力相当。

Llama 架构分析

分词

分词部分主要做的是利用文本分词器对文本进行分词

在这里插入图片描述

tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
text = "Hey, are you conscious? Can you talk to me?"
inputs = tokenizer(text, return_tensors="pt")

网络主干

主干网络部分主要是将分词得到的input_ids输入到embedding层中进行文本向量化,得到hidden_states(中间结果),然后输入到layers层中,得到hidden_states(中间结果),用于下游任务。

在这里插入图片描述

self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
        self.layers = nn.ModuleList(
            [MixtralDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
        )
        self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
        self.norm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
DecoderLayer

主干网络的layers层就是由多个DecoderLayer组成的,由num_hidden_layers参数决定,一般我们说的模型量级就取决于这个数量,7b的模型DecoderLayer层的数量是32。

DecoderLayer层中又包含了Attention层和MLP层,主要的一个思想是利用了残差结构。

如下图所示,分为两个部分

第一部分

  • 首先,将hidden_states(文本向量化的结构)进行复制,即残差
  • 归一化
  • 注意力层
  • 残差相加

第二部分

  • 首先将第一部分得到的hidden_states进行复制,即残差
  • 归一化
  • MLP层
  • 残差相加

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

#复制一份
residual = hidden_states
#归一化
hidden_states = self.input_layernorm(hidden_states)

#注意力层
hidden_states, self_attn_weights, present_key_value = self.self_attn(
    hidden_states=hidden_states,
    attention_mask=attention_mask,
    position_ids=position_ids,
    past_key_value=past_key_value,
    output_attentions=output_attentions,
    use_cache=use_cache,
    padding_mask=padding_mask,
)
#加上残差
hidden_states = residual + hidden_states

#复制一份
residual = hidden_states
#归一化
hidden_states = self.post_attention_layernorm(hidden_states)
#mlp
hidden_states = self.mlp(hidden_states)
#加上残差
hidden_states = residual + hidden_states

outputs = (hidden_states,)

if output_attentions:
    outputs += (self_attn_weights,)

if use_cache:
        outputs += (present_key_value,)

return outputs
Attention

进行位置编码,让模型更好的捕捉上下文信息

在这里插入图片描述

#经过线性层
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)

#多头注意力形状变换
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2]

#计算cos、sin
#计算旋转位置嵌入
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)

#计算权重
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)

#加上掩码
attn_weights = attn_weights + attention_mask
#计算softmax
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_output = torch.matmul(attn_weights, value_states)

attn_output = self.o_proj(attn_output)

MLP

mlp层的主要作用是应用非线性激活函数和线性投影。

  • 首先将attention层得到的结果经过两个线性层得到gate_proj和up_proj
  • gate_proj经过激活函数,再和up_proj相乘
  • 最后经过一个线性层得到最后的结果

在这里插入图片描述

self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
self.act_fn = ACT2FN[config.hidden_act]
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))

下游任务

因果推理

所谓因果推理,就是回归任务。

在这里插入图片描述

self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
文本分类

即分类任务

在这里插入图片描述

self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/251211.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

NVIDIA A100 PCIE 40GB k8s-device-plugin install in kubernetes

文章目录 1. 目标2. 简介2.1 英伟达 A100 技术规格2.2 架构优势2.3 显卡跑分对比2.4 英伟达 A100 与 kubernetes 3. 安装 NVIDIA A100 GPU 40G 硬件4. NVIDIA R450 datacenter driver5. NVIDIA Container Toolkit6. 创建 runtimeclass5. MIG Strategies6. 配置仓库7. 下载镜像8…

深度学习——第6章 浅层神经网络(NN)

第6章 浅层神经网络(NN) 目录 6.1 神经网络模型概述 6.2 神经网络正向传播 6.3 神经网络反向传播 6.4 W和b的初始化 6.5 总结 上一课主要介绍了一些神经网络必备的基础知识,包括Sigmoid激活函数、损失函数、梯度下降和计算图。这些知识对…

Linux 中使用 docker 安装 Elasticsearch 及 Kibana

Linux 中使用 docker 安装 Elasticsearch 及 Kibana 安装 Elasticsearch 和 Kibana安装分词插件 ik_smart 安装 Elasticsearch 和 Kibana 查看当前运行的镜像及本地已经下载的镜像,确认之前没有安装过 ES 和 Kibana 镜像 docker ps docker images从远程镜像仓库拉…

Domino万物可订阅

大家好,才是真的好。 如果你还不知道什么是RSS,从V站截图一份放到这里供大家参考: 其实,Domino上也可以很简单地发布RSS站点,以供内部或外部用户订阅。 前面其实我们说了不少关于Notes客户端的RSS订阅功能&#xff…

Redis设计与实现之字符串哈希表列表

目录 一、字符串 1、字符串编码 2、编码的选择 二、哈希表 1、字典编码的哈希表 2、压缩列表编码的哈希表 3、编码的选择 4、哈希命令的实现 三、列表 1、 编码的选择 2、 列表命令的实现 3、阻塞的条件 4、 阻塞 5、 阻塞因 LPUSH 、RPUSH 、LINSERT 等添加命令而…

【MySQL】(DDL) 数据类型 和 表操作-修改 删除

目录 介绍: 1.数值类型 3.日期类型 修改表: 示列: 介绍: 在之前建表语句内,用到了 int cvarchar ,那么在mysql内除了 以上的数据类型 还有那些常见数据类型 mysql 中的数据类型有很多种 &#xff0c…

QML 自定义进度条组件开发

一、效果预览 二、介绍: 自定义的QML 屏幕亮度拖动进度条组件CusProgressBar 可跟鼠标移动 更改进度条样式 三、代码 import QtQuick 2.12 import QtQuick.Controls 2.12 import QtQuick.Controls.Material 2.12/***author:Zwj*csdn:来份煎蛋吧*date:2023/12/16*…

C++实现简单的猜数字小游戏

猜数字 小游戏介绍:猜数字游戏是令游戏机随机产生一个100以内的正整数,用户输入一个数对其进行猜测,需要你编写程序自动对其与随机产生的被猜数进行比较,并提示大了,还是小了,相等表示猜到了。如果猜到&…

Appium —— 初识移动APP自动化测试框架Appium

说到移动APP自动化测试,代表性的测试框架非Appium莫属,从今天开始我们将从APP结构解析、Appium框架学习、安卓/iOS自动化测试实战、自动遍历回归测试、自动化测试平台及持续集成,多个维度一起由浅入深的学废Appium 今天我们先来初步认识Appi…

nodejs+vue+微信小程序+python+PHP运动项目推荐系统-计算机毕业设计推荐

运动项目推荐系统的整体架构确定以后,再来看运动项目推荐系统的主要功能模块图。整体的功能模块包括前台和后台,前台只要实现了注册用户功能,主要的页面,包括首页,体育资讯,体育项目,公告信息等…

基于ASF-YOLO融合空间特征和尺度特征的新型注意力尺度序列融合模型开发构建医学场景下细胞分割检测识别系统,以【BCC、DSB2018数据集为基准】

作者提出了一种新的基于注意尺度序列融合的YOLO框架(ASF-YOLO),该框架结合了空间和尺度特征,实现了准确快速的细胞实例分割。基于YOLO分割框架,我们使用尺度序列特征融合(SSFF)模块来增强网络的…

pybind11:对比C++和Python解线性方程组的速度

前言 上篇博客介绍了如何在用pybind11实现ndarray和C数组的转换自由,pybind11:实现ndarray转C原生数组(没看过的朋友可以去看一看)下面我们以一个实际的算法例子演示一下如何使用这个技术,方便的实现 Python 调用 C 写…

基于linux系统的Tomcat+Mysql+Jdk环境搭建(三)centos7 安装Tomcat

Tomcat下载官网: Apache Tomcat - Which Version Do I Want? JDK下载官网: Java Downloads | Oracle 中国 如果不知道Tomcat的哪个版本应该对应哪个版本的JDK可以打开官网,点击Whitch Version 下滑,有低版本的,如…

Caused by: java.net.ConnectException: 拒绝连接: hadoop104/192.168.124.130:4142

项目场景:hadoop102接收消息,自定义拦截器,包含hello的发往hadoop103,不包含的发往hadoop104 报错原因: 原因1: 应该先开启接收方(服务端),hadoop103,hadoop104,最后开启hadoop10…

编译android的C版本Lua库

本文讲述如何使用android studio 编译最新版本的Lua开源库),请自行下载。 我们提供的Demo,可以自行下载,工程结构如下: 本文编译的是Lua 5.4.6的版本,编译采用cmake的方式,我们支持编译静态库和动态库(我们在这一讲里:“Lua与***C在Android上的互调”是使用静态库)…

智能优化算法应用:基于JAYA算法3D无线传感器网络(WSN)覆盖优化 - 附代码

智能优化算法应用:基于JAYA算法3D无线传感器网络(WSN)覆盖优化 - 附代码 文章目录 智能优化算法应用:基于JAYA算法3D无线传感器网络(WSN)覆盖优化 - 附代码1.无线传感网络节点模型2.覆盖数学模型及分析3.JAYA算法4.实验参数设定5.算法结果6.参考文献7.MA…

大数据与深度挖掘:如何在数字营销中与研究互动

数字营销最吸引人的部分之一是对数据的内在关注。 如果一种策略往往有积极的数据,那么它就更容易采用。同样,如果一种策略尚未得到证实,则很难获得支持进行测试。 数字营销人员建立数据信心的主要方式是通过研究。这些研究通常分为两类&…

Vue3快速上手笔记

Vue3快速上手 1.Vue3简介 2020年9月18日,Vue.js发布3.0版本,代号:One Piece(海贼王)耗时2年多、2600次提交、30个RFC、600次PR、99位贡献者github上的tags地址:https://github.com/vuejs/vue-next/release…

Docker单机部署OceanBase

文章目录 说明机器软硬件要求指导文档本次部署环境说明 OceanBase单机部署(Docker)一:拉取 OceanBase 数据库相关镜像二:启动 OceanBase 数据库实例完整启动日志展示 三:连接实例遇到报错:没有mysql客户端 …

selenium-grid4.3.0两种模式记录

selenium-grid4.3.0两种模式记录 本文运行,需要提前配置好Java11以及安装好Chrom、Firefox、Safari其中一个浏览器,如果是Chrom、Firefox需要下载对应版本的驱动,并给 webdriver 配置环境变量,Safari浏览器Mac系统会自带&#xf…