LLM(七)| Mamba:LLM新架构的浅探

       目前大型语言模型(LLM)领域发展如火如荼,本文将重点探索在单个消费级GPU上可以有效运行的小型模型(≤7B个参数)。

        我们将从以下几个方面重点介绍基于新架构的语言模型:🐍Mamba模型(https://github.com/state-spaces/mamba):

  • 与基础模型对话
  • 使用Huggingface Trainer进行指令跟随微调
  • 从速度和输出质量方面在benchmark上评估Mamba,并将其与TinyLlama进行比较

一、🐍Mamba简介

        Mamba是LLM的一种新架构,与Transformers等传统模型相比,它能够更有效地处理长序列。它利用选择性状态空间模型(SSM),根据内容动态过滤和处理信息,允许模型选择性地记住或忽略输入的部分。Mamba在处理速度和缩放能力方面有了显著改进,尤其是在较长序列的情况下。

       但Mamba真正与众不同的地方是什么?让我们与Mamba进行深入互动体验来测试一下。

二、Mamba模型聊天

        由于Mamba还不是Huggingface平台的一部分,所以使用它稍微复杂一些。虽然当前的基本实现提供了熟悉的from_pretrained方法和生成的基本参数,但一些功能(如repeation_chamine)是不可用的。此外,我们不能使用像 text-generation-webui(https://github.com/oobabooga/text-generation-webui)这样的工具。因此,为了使用Mamba,我们将使用Python代码进行推理。我已经尽可能简单地编写了代码。

首先,让我们加载模型。

import torchfrom mamba_ssm.models.mixer_seq_simple import MambaLMHeadModelfrom transformers import AutoTokenizer, TrainingArguments# Load modelmodel = MambaLMHeadModel.from_pretrained(  "state-spaces/mamba-1.4b",   device="cuda",   dtype=torch.bfloat16)# Load Tokenizertokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")

使用简单Prompt完成续写任务

       在不进行微调的情况下,测试Mamba模型最简单方法是进行对话。例如:

prompt=\"""A conversation between a user and a smart AI assistant.### User: Hello!### Assistant:"""prompt_tokenized=tokenizer(prompt, return_tensors="pt").to("cuda")# from https://github.com/state-spaces/mamba/blob/main/benchmarks/benchmark_generation_mamba_simple.py#L54output_tokenized = model.generate(    input_ids=prompt_tokenized["input_ids"],     max_length=70,    cg=True,    output_scores=True,    enable_timing=False,    temperature=0.7,    top_k=40,    top_p=0.1,    )output=tokenizer.decode(output_tokenized[0])print(output)

A conversation between a user and a smart AI assistant.

### User: Hello!### Assistant: Hello!

### User: I’m hungry.### Assistant: I’m hungry.

### User: I’m thirsty.### Assistant: I’m thirsty.

### User: I’m tired.

Prompt tuning:具有重新样式的上下文内签名(URIAL)的未调整LLM

     接下来,我们将探索一种更高级的方法。最近,一篇研究论文(https://arxiv.org/abs/2312.01552)强调,只要给出正确的提示,基本的语言模型实际上可以在对话中表现得很好。下面展示一个例子:


Below is a list of conversations between a human and an AI assistant (you). Users place their queries under "# Query:", and your responses are under "# Answer:". You are a helpful, respectful, and honest assistant. You should always answer as helpfully as possible while ensuring safety. Your answers should be well-structured and provide detailed information. They should also have an engaging tone. Your responses must not contain any fake, harmful, unethical, racist, sexist, toxic, dangerous, or illegal content, even if it may be helpful. Your response must be socially responsibly, and thus you can reject to answer some controversial topics.

# Query: Hello!
# Answer: Hello!

# Query: How are you?
# Answer: I'm fine.

# Query: Explain quantum physics to me like i am 5 years old
# Answer: I can't explain quantum physics to you.

# Query: What is the meaning of life?
# Answer: The meaning of life is to live it.

# Query: What is the meaning of life?
# Answer: The meaning of life is to live it.

# Query: What is the meaning of life?
# Answer: The meaning of life is to live it.

Mixtral

      为了扩展更多视角,博主还使用fireworks.ai在相同提示下测试了新发布的Mixtral Mixtral Experts模型,并观察回复的差异。

Below is a list of conversations between a human and an AI assistant (you). Users place their queries under “# Query:”, and your responses are under “# Answer:”. You are a helpful, respectful, and honest assistant. You should always answer as helpfully as possible while ensuring safety. Your answers should be well-structured and provide detailed information. They should also have an engaging tone. Your responses must not contain any fake, harmful, unethical, racist, sexist, toxic, dangerous, or illegal content, even if it may be helpful. Your response must be socially responsibly, and thus you can reject to answer some controversial topics.

# Query: Hello!
# Answer: Hello!

# Query: How are you?
# Answer: I’m fine.

# Query: Explain quantum physics to me like i am 5 years old
# Answer: Quantum physics is the study of the smallest particles in the universe. These particles are so small that they behave in strange ways that we can’t see with our eyes. For example, they can be in two places at the same time, and they can disappear and reappear. Quantum physics is a very complicated subject, but it is also very important because it helps us understand how the universe works.

# Query: What is the meaning of life?
# Answer: The meaning of life is different for everyone. Some people believe that life is about finding happiness and fulfillment, while others believe that life is about serving a higher purpose. Ultimately, the meaning of life is something that each person must discover for themselves.

       总的来说,Mamba基座模式不适合对话。现在,让我们尝试通过微调来改进它。

三、微调Mamba

       接下来,我们对Mamba进行微调,将其转变为一个通用的聊天机器人。为此,我们使用高质量的ChatML多轮对话数据集Open Assistant数据集(https://huggingface.co/datasets/OpenAssistant/oasst_top1_2023-08-25)。

此微调过程包括几个步骤

  • Tokenizing数据集
  • 定义collate函数
  • 使Mamba适应Hugging Face Trainer,由于Mamba独特的架构,需要修改一些代码。

3.1 加载数据集并对其tokenize

from datasets import load_datasetdataset=load_dataset("OpenAssistant/oasst_top1_2023-08-25")

该数据集有13k条样本,并且已经划分好了训练集和测试集:

DatasetDict({    train: Dataset({        features: ['text'],        num_rows: 12947    })    test: Dataset({        features: ['text'],        num_rows: 690    })})

       数据集中的大多数对话(92%)有少于1000个tokens组成。因此,在我们的tokenize过程中,将每个会话截断为1024个tokens就足够了。

import os def tokenize(element):    return tokenizer(        element["text"],        truncation=True,        max_length=1024,        add_special_tokens=False,    )dataset_tokenized = dataset.map(    tokenize,     batched=True,     num_proc=os.cpu_count(),    # multithreaded    remove_columns=["text"]     # don't need this anymore, we have tokens from here on)

3.2 定义collate函数

       在我们将数据集传入Trainer之前,由于并非所有对话的长度都相同,我们必须将它们分批分组,我们需要定义pad_token。

tokenizer.pad_token = tokenizer.eos_token# collate function - to transform list of dictionaries [ {input_ids: [123, ..]}, {.. ] to single batch dictionary { input_ids: [..], labels: [..], attention_mask: [..] }def collate(elements):    tokenlist=[e["input_ids"] for e in elements]    tokens_maxlen=max([len(t) for t in tokenlist])    input_ids,labels = [],[]    for tokens in tokenlist:        pad_len=tokens_maxlen-len(tokens)        # pad input_ids with pad_token, labels with ignore_index (-100) and set attention_mask 1 where content otherwise 0        input_ids.append( tokens + [tokenizer.pad_token_id]*pad_len )           labels.append( tokens + [-100]*pad_len )        batch={        "input_ids": torch.tensor(input_ids),        "labels": torch.tensor(labels),    }    return batch

PS:由于Mamba没有使用注意力机制,因此批次中不包含注意力掩码。

3.3 准备Mamba🤗Trainer

       目前,Mamba还没有被添加到Hugging Face生态系统中。标准的Hugging Face Trainer需要一个包括labels的向前函数,而Mamba没有。

        为了解决这个问题,我们需要实现一个临时解决方案,通过使用monkey补丁向模型添加一个新的前向函数。这不是最优雅的方法,但在Mamba成为Hugging Face transformer库的一部分之前,这是一个临时的解决方案。

# monkey patch MambaLMHeadModel.forward def forward_with_loss(self, input_ids, position_ids=None, inference_params=None, num_last_tokens=0, labels = None):    """    "position_ids" is just to be compatible with Transformer generation. We don't use it.    num_last_tokens: if > 0, only return the logits for the last n tokens    """    hidden_states = self.backbone(input_ids, inference_params=inference_params)    if num_last_tokens > 0:        hidden_states = hidden_states[:, -num_last_tokens:]    lm_logits = self.lm_head(hidden_states)        # Source: https://github.com/huggingface/transformers/blob/80377eb018c077dba434bc8e7912bcaed3a64d09/src/transformers/models/llama/modeling_llama.py#L1196    from torch.nn import CrossEntropyLoss    if labels is not None:        logits = lm_logits        # Shift so that tokens < n predict n        shift_logits = logits[..., :-1, :].contiguous()        shift_labels = labels[..., 1:].contiguous()        # Flatten the tokens        loss_fct = CrossEntropyLoss()        # shift_logits = shift_logits.view(-1, self.config.vocab_size)        shift_logits = shift_logits.view(-1, self.backbone.embedding.weight.size()[0])        shift_labels = shift_labels.view(-1)        # Enable model parallelism        shift_labels = shift_labels.to(shift_logits.device)        loss = loss_fct(shift_logits, shift_labels)        return (loss,)       else:        CausalLMOutput = namedtuple("CausalLMOutput", ["logits"])        return CausalLMOutput(logits=lm_logits)MambaLMHeadModel.forward=forward_with_loss# patch MambaLMHeadModelMambaLMHeadModel.forward=forward_with_loss# (re)load model model = MambaLMHeadModel.from_pretrained("state-spaces/mamba-1.4b", device="cuda", dtype=torch.bfloat16)

       或者,您可以使用优秀的训练器axolotl(https://github.com/OpenAccess-AI-Collective/axolotl)或使用mamba-chat(https://github.com/havenhq/mamba-chat)进行训练。

四、训练Mamba模型

from transformers import Trainer, TrainingArgumentsbs=4        # batch sizega_steps=1  # gradient acc. stepsepochs=3steps_per_epoch=len(dataset_tokenized["train"])//(bs*ga_steps)lr=0.0005args = TrainingArguments(    output_dir="out",    per_device_train_batch_size=bs,    per_device_eval_batch_size=bs,    evaluation_strategy="steps",    logging_steps=1,    eval_steps=steps_per_epoch,    save_steps=steps_per_epoch,    gradient_accumulation_steps=ga_steps,    num_train_epochs=epochs,    lr_scheduler_type="constant",    learning_rate=lr,    group_by_length=True,    bf16=True,                  # mixed precision training    save_safetensors=False,     # saving will fail without this)trainer = Trainer(    model=model,    tokenizer=tokenizer,    args=args,    data_collator=collate,    train_dataset=dataset_tokenized["train"],    eval_dataset=dataset_tokenized["test"],)trainer.train()

learning_rate:可能是这里最重要的一个超参数。正如您将在下一节中看到的,我最初选择的learning_rate=0.0005很差。

首先,让我们看看这个微调结果如何(剧透:糟糕)以及如何修复它。

五、评价Mamba模型

聊天机器人的评估很难,因为结果很难衡量。

       什么是好的会话/指令跟随模式?这个问题的解决方案不止一种。在有人想出如何正确应用这样的东西之前,我们将不得不依赖基准(https://github.com/EleutherAI/lm-evaluation-harness)测试、聊天机器人竞技场(https://huggingface.co/spaces/lmsys/chatbot-arena-leaderboard)和人工智能裁判(https://huggingface.co/spaces/lmsys/chatbot-arena-leaderboard)。

六、基准

      Mamba的作者发表了一份使用EleutherAI/lm评估工具收集(https://github.com/EleutherAI/lm-evaluation-harness)的数字表。

       我对这些数字也持怀疑态度。但由于我对Mamba没有任何经验,我以他们为起点,看看微调是否朝着正确的方向发展。

       我们实际上是在用这种微调来破坏模型。正如我在下面试图说服你的那样,0.0005的学习率(LR)太高了。

从哪里开始?

      我不清楚用于预训练Mamba的实际学习率。在论文中,作者陈述了以下内容:

       这是否意味着Mamba-1.4b是以5x0.0002即0.001的峰值LR进行预训练的?不知道。

第二次尝试:以较低的学习率进行微调

       另一个学习率较低的微调试验,我决定将学习率降低10倍至0.00005(而不是0.0005)。

       LR越低,损失越低?看起来没有错,重新运行一下来看看效果:

这一次我们正朝着正确的方向前进。

尝试了不同的方法来改进它,改变LR、训练轮数和数据集——以下没有一个能给我更好的数字。

  • Open Assistant(OA)数据集:3x10e-5和2x10e-5的较低LR;
  • OA数据集:更多训练轮数;
  • 另一个数据集:HuggingFaceH4/ultrachat_200k。令人惊讶的是,表现不佳。

七、mamba与🦙TinyLlama在生成质量和推理速度对比

      速度惊人。在10k个tokens的Prompt下,TinyLlama耗尽了内存(24 GB VRAM);而Mamba仅使用5 GB VRAM,并且以每秒100个tokens的速度生成。

八、mamba长上下文能力

Mamba能够用几GB的VRAM处理10k提示?

让我们看看实际输出是多少。

  • 将整本书粘贴到Prompt中(136K个tokens),让Mamba总结要点。结果是:垃圾,随机tokens;
  • 一篇关于铁人三项(3.2K个tokens)的随机文章(https://www.tri247.com/triathlon-features/interviews/lionel-sanders-championship-preview):它确实产生了英文文本,总结了10个要点,但重复且产生幻觉。

如果将文章减半(1.54K个tokens):结果要好得多!

       Mamba无法生成高质量内容的原因可能是因为它是用“仅”2048个tokens的上下文长度进行预训练的(第4.2.2节,Mamba论文)。因此,也许微调一个小型Mamba模型,比如Mamba-1.4b,可以释放它总结大型文本的潜力。

九、总结

  • 🐍Mamba速度快,可以处理大量tokens;
  • 目前微调有点棘手,期待集成到🤗transformer中;
  • 🦙TTinyLlama生成的文本比Mamba更好,大概是因为它经过了5倍数据量的预训练。

参考文献:

[1] https://medium.com/@geronimo7/mamba-a-shallow-dive-into-a-new-architecture-for-llms-54c70ade5957

[2] https://github.com/state-spaces/mamba

[3] https://github.com/geronimi73/mamba/blob/main/story-snippets.ipynb

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/246826.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

HTTP 302错误:临时重定向

在Web开发中&#xff0c;HTTP状态码是用于表示Web服务器响应的各种状态。其中&#xff0c;HTTP 302错误表示临时重定向&#xff0c;这意味着请求的资源已被临时移动到其他位置&#xff0c;并且服务器已经提供了新的URL&#xff0c;以便客户端可以重新发送请求。 了解HTTP 302错…

「Verilog学习笔记」RAM的简单实现

专栏前言 本专栏的内容主要是记录本人学习Verilog过程中的一些知识点&#xff0c;刷题网站用的是牛客网 timescale 1ns/1ns module ram_mod(input clk,input rst_n,input write_en,input [7:0]write_addr,input [3:0]write_data,input read_en,input [7:0]read_addr,output reg…

2023PCTF Double_SS

记录一下 ssrf配合 ssti的结合 首先开启环境 明显的ssrf 让我们访问 5555端口 使用http协议访问 url127.0.0.1:5555 告诉我们去访问 name 并且给我们key url127.0.0.1:5555/name 出现报错 说我们不是admin 然后我们往下看 我们使用file协议读取app/app.py urlfile:///app…

基于ssm的汽车服务商城系统设计与实现论文

摘 要 本课题是根据用户的需要以及网络的优势建立的一个基于Vue的汽车服务商城系统&#xff0c;来更好的为用户提供服务。 本基于Vue的汽车服务商城系统应用Java技术&#xff0c;MYSQL数据库存储数据&#xff0c;基于SSMVue框架开发。在网站的整个开发过程中&#xff0c;首先对…

Python自动化:selenium常用方法总结

使用的Python版本为3.8&#xff0c;selenium版本为4.15.2 Python自动化:selenium常用方法总结 1. 三种等待方式2. 浏览器操作3. 8种查找元素的方法4. 高级事件 1. 三种等待方式 强制等待 使用模块time下的sleep()实现等待效果隐式等待 使用driver.implicitly_wait()方法&#…

大数据云计算——使用Prometheus-Operator进行K8s集群监控

大数据云计算——使用Prometheus-Operator进行K8s集群监控 一、 背景 在非operator配置的普罗中我们监控k8s集群都是通过配置configmap进行服务发现和指标拉取。切换到prometheus-operator难免会有些使用问题。不少用户已经习惯底层配置自动发现的方式。当过渡到servicemonit…

【docker】常用命令

启动docker服务 systemctl start docker 停止docker服务 systemctl stop docker 重启docker服务 systemctl restart docker 查看docker服务状态 systemctl status docker 设置开机启动docker服务 systemctl enable docker 设置关闭开机启动docker服务 systemctl disable …

Excel实现字母+数字拖拉自动递增,步长可更改

目录 1、带有字母的数字序列自增加&#xff08;步长可变&#xff09; 2、仅字母自增加 3、字母数字同时自增 1、带有字母的数字序列自增加&#xff08;步长可变&#xff09; 使用Excel通常可以直接通过拖拉的方式&#xff0c;实现自增数字&#xf…

02基于matlab的卡尔曼滤波

基于matlab的卡尔曼滤波&#xff0c;可更改状态转移方程&#xff0c;控制输入&#xff0c;观测方程&#xff0c;设置生成的信号的噪声标准差&#xff0c;设置状态转移方差Q和观测方差R等参数&#xff0c;程序已调通&#xff0c;需要直接拍下。

luttuce(RedisTempate)实现hash expire lua脚本

话不多说先放脚本&#xff1a; local argv ARGV local length #argv if length > 0 then local unpackArgs {} for i 1, length - 1 dotable.insert(unpackArgs, argv[i]) end if redis.call(exists, KEYS[1]) 1 thenredis.call(del, KEYS[1])redis.call(hset, KEYS[…

前端自定义icon的方法(Vue项目)

第一步&#xff1a;进入在线的编辑器进行设计 好用&#xff1a;百度字体编辑器 比如先导入有个ttf文件 添加新字体 双击每个模块进入编辑区域 更改相应的信息&#xff0c;比如name 编辑完了进行导出文件(各种格式就行了)就行了 第二步&#xff1a;在项目中asset文件储存这些文…

十指波教育怎么样,课程是最新的吗

我们所有的课程内容&#xff0c;每年都会更新。 可以看一下我们的B站&#xff1a; 但是说到底&#xff0c;我们做私教服务 和那些传统的培训还是有很大区别。 传统培训机构不管线上还是线下主要是在卖一套课程的课程&#xff0c;可能是直播或者面授&#xff0c;你花钱买到的…

亚马逊云科技助力泡泡玛特快速部署全球弹性资源,打造国潮出海文化

企业全球化的终极目标就是品牌出海。1978年伴随着改革开放&#xff0c;中国企业开始放眼望世界输出中国产品&#xff0c;经过多年锤炼后&#xff0c;中国企业如TCL、泡泡玛特在不同的行业重塑版图&#xff0c;对外输出中国品牌&#xff0c;赢得了全球市场&#xff0c;中国企业实…

代码随想录算法训练营第52天| 300.最长递增子序列 674. 最长连续递增序列 718. 最长重复子数组

JAVA代码编写 300.最长递增子序列 给你一个整数数组 nums &#xff0c;找到其中最长严格递增子序列的长度。 子序列 是由数组派生而来的序列&#xff0c;删除&#xff08;或不删除&#xff09;数组中的元素而不改变其余元素的顺序。例如&#xff0c;[3,6,2,7] 是数组 [0,3,1…

Talk | 上海交通大学魏思哲: CoBEVFlow-解决车-车/路协同感知的时序异步问题

本期为TechBeat人工智能社区第556期线上Talk。 北京时间12月14日(周四)20:00&#xff0c;上海交通大学硕士生—魏思哲的Talk已准时在TechBeat人工智能社区开播&#xff01; 他与大家分享的主题是: “CoBEVFlow-解决车-车/路协同感知的时序异步问题”&#xff0c;介绍了他的团队…

“机器人V2.0时代已来”-任务规划难题迎刃而解,世界因机器人改变而翻转!

01-VILA背景简介 2022年&#xff0c;Michael Ahn, Anthony Brohan等人提出“Do as i can, not as i say: Grounding language in robotic affordances”算法。本文指出虽然大型语言模型可以编码关于世界的丰富语义知识&#xff0c;而这些知识对旨在对用自然语言表达的高级、时…

初探栈溢出(上)

0x01 HEVD介绍 HEVD全称为HackSys Ex treme Vulnerable Drive&#xff0c;是一个项目&#xff0c;故意设计包含多种漏洞的驱动程序&#xff0c;旨在帮助安全爱好者来提升他们在内核层面的漏洞利用能力。 说白了&#xff0c;是一个内核漏洞的靶场。 项目地址&#xff1a;htt…

做数据分析为何要学统计学(10)——什么是回归分析

​回归分析&#xff08;regression analysis)是量化两种或两种以上因素/变量间相互依赖关系的统计分析方法。回归分析根据因素的数量&#xff0c;分为一元回归和多元回归分析&#xff1b;按因素之间依赖关系的复杂程度&#xff0c;可分为线性回归分析和非线性回归分析。我们通过…

没有数据线,在手机上查看电脑备忘录怎么操作

在工作中&#xff0c;电脑和手机是我最常用的工具。我经常需要在电脑上记录一些重要的工作事项&#xff0c;然后又需要在手机上查看这些记录&#xff0c;以便随时了解工作进展。但是&#xff0c;每次都需要通过数据线来传输数据&#xff0c;实在是太麻烦了。 有一次&#xff0…

探秘AI赋能的未来世界:CyberAI深度学习技术助力变革

CyberAI平台概述 随着AI技术的极速发展&#xff0c;AI能力正在助力产业加速场景化落地。CyberAI是数新网络面向开发者和企业的一站式AI数据科学平台&#xff0c;提供交互式和可视化建模服务&#xff0c;算法模型全生命周期管理。平台可帮助开发者快速开发AI应用&#xff0c;解…