俄罗斯套娃 (Matryoshka) 嵌入模型概述

在这篇博客中,我们将向你介绍俄罗斯套娃嵌入的概念,并解释为什么它们很有用。我们将讨论这些模型在理论上是如何训练的,以及你如何使用 Sentence Transformers 来训练它们。

除此之外,我们还会告诉你怎么用这种像套娃一样的俄罗斯套娃嵌入模型,并且我们会比较一下这种模型和普通嵌入模型的不同。最后,我们邀请你来玩一下我们的互动演示,看看这些模型有多厉害。

理解嵌入 (embedding)

嵌入是自然语言处理中最通用的工具之一,使从业者能够解决大量任务。本质上,嵌入是一个更复杂数字对象的数值表示,如文本、图像、音频等。

f981027ac6ef93157af0a82b533a1745.png
嵌入模型

嵌入模型总是会产生相同固定大小的嵌入。然后,你可以通过计算相应嵌入的相似性来计算复杂数字对象的相似性!

ff0573f378760aa92729a86af103c172.png
嵌入相似性

这种技术 (嵌入) 在许多领域都有应用,它是推荐系统、信息检索、零样本学习或少量样本学习、异常检测、相似性搜索、释义检测、聚类、分类等领域的基础。

🪆 俄罗斯套娃 (Matryoshka) 嵌入

随着研究的进展,新的最先进的 (文本) 嵌入模型开始产生具有越来越高的输出维度,即每个输入文本都使用更多的值来表示。尽管这提高了性能,但以下游任务 (如搜索或分类) 的效率为代价。

因此,Kusupati 等人 (2022) 受到启发,创造了即使嵌入尺寸合理缩小也不会在性能上遭受太大损失的嵌入模型。

8acef68f86b4039db2c046866ea70059.png
俄罗斯套娃模型

这些俄罗斯套娃嵌入模型经过训练,使得这些小的截断嵌入仍然有用。简而言之,俄罗斯套娃嵌入模型可以产生各种尺寸的有用嵌入。

🪆 俄罗斯套娃

对于不熟悉的人来说,“Matryoshka 娃娃”,也称为“俄罗斯套娃”,是一组大小递减的木制娃娃,相互嵌套。类似地,俄罗斯套娃嵌入模型旨在将更重要的信息存储在早期的维度中,将不太重要的信息存储在后面的维度中。俄罗斯套娃嵌入模型的这一特点允许我们截断模型产生的原始 (大) 嵌入,同时仍保留足够的信息以在下游任务上表现良好。

129389b2e13eda693dd3978fe469be72.gif

俄罗斯套娃模型

为什么使用🪆 俄罗斯套娃嵌入模型?

这种可变尺寸的嵌入模型对从业者来说非常有价值,例如:

  1. 筛选和重新排序: 不必在完整嵌入上执行你的下游任务 (例如,最近邻搜索),你可以缩小嵌入到更小的尺寸,并非常高效地“筛选”你的嵌入。之后,你可以使用它们的完整维度处理剩余的嵌入。

  2. 权衡: 俄罗斯套娃模型将允许你根据所需的存储成本、处理速度和性能来扩展你的嵌入解决方案。

🪆 俄罗斯套娃嵌入模型是如何训练的?

理论上

俄罗斯套娃表示学习 (MRL) 方法几乎可以适用于所有嵌入模型训练框架。通常,嵌入模型的一个训练步骤涉及为你的训练批次 (例如文本) 产生嵌入,然后使用一些损失函数创建一个代表产生嵌入质量的损失值。优化器会在训练过程中调整模型权重以减少损失值。

对于俄罗斯套娃嵌入模型,一个训练步骤还涉及为你的训练批次产生嵌入,但是然后你使用一些损失函数来确定不仅仅是全尺寸嵌入的质量,还有各种不同维度性下的嵌入质量。例如,输出维度性为 768、512、256、128 和 64。每个维度性的损失值加在一起,得到最终的损失值。然后,优化器将尝试调整模型权重以降低这个损失值。

实际上,这鼓励模型在嵌入的开始部分前置最重要的信息,这样如果嵌入被截断,这些信息将得以保留。

在 Sentence Transformers 中

Sentence Tranformers 是一个常用于训练嵌入模型的框架,它最近实现了对俄罗斯套娃模型的支持。使用 Sentence Transformers 训练俄罗斯套娃嵌入模型非常基础: 不是仅在全尺寸嵌入上应用一些损失函数,我们也在嵌入的截断部分应用同样的损失函数。

例如,如果一个模型的原始嵌入维度为 768,现在它可以被训练为 768、512、256、128 和 64。这些损失值将加在一起,可以选择性地给予一些权重:

from sentence_transformers import SentenceTransformer
from sentence_transformers.losses import CoSENTLoss, MatryoshkaLoss

model = SentenceTransformer("microsoft/mpnet-base")

base_loss = CoSENTLoss(model=model)
loss = MatryoshkaLoss(
    model=model,
    loss=base_loss,
    matryoshka_dims=[768, 512, 256, 128, 64],
    matryoshka_weight=[1, 1, 1, 1, 1],
)

model.fit(
    train_objectives=[(train_dataset, loss)],
    ...,
)

使用 MatryoshkaLoss 进行训练并不会显著增加训练时间。

参考文献:

  • MatryoshkaLoss

  • CoSENTLoss

  • SentenceTransformer

  • SentenceTransformer.fit

  • Matryoshka Embeddings - Training

请查看以下完整脚本,了解如何在实际应用中使用 MatryoshkaLoss :

  • matryoshka_nli.py: 此示例使用 MultipleNegativesRankingLossMatryoshkaLoss 结合,利用自然语言推理 (NLI) 数据训练一个强大的嵌入模型。这是对 NLI 文档的改编。

  • matryoshka_nli_reduced_dim.py: 此示例使用 MultipleNegativesRankingLossMatryoshkaLoss 结合,训练一个最大输出维度为 256 的小型嵌入模型。它使用自然语言推理 (NLI) 数据进行训练,这是对 NLI 文档的改编。

  • matryoshka_sts.py: 此示例使用 CoSENTLossMatryoshkaLoss 结合,在 STSBenchmark 数据集的训练集上训练一个嵌入模型。这是对 STS 文档的改编。

如何使用 🪆俄罗斯套娃嵌入模型?

理论上

实际上,从俄罗斯套娃嵌入模型获取嵌入的方式与从普通嵌入模型获取嵌入的方式相同。唯一的区别在于,在接收到嵌入后,我们可以选择将它们截断为更小的维度。请注意,如果嵌入已经归一化,那么在截断后它们将不再归一化,因此你可能需要重新归一化。截断后,你可以直接将它们应用于你的用例,或者存储它们以便稍后使用。毕竟,在你的向量数据库中使用较小的嵌入应该会带来相当大的速度提升!请记住,尽管处理较小嵌入以进行下游任务 (检索、聚类等) 会更快,但从模型获取较小嵌入的速度与获取较大嵌入的速度一样快。

在 Sentence Transformers 中

在 Sentence Transformers 中,你可以像加载普通模型一样加载俄罗斯套娃嵌入模型,并使用 SentenceTransformers.encode 进行推理。获取嵌入后,我们可以将它们截断到我们所需的尺寸,如果需要,我们还可以对它们进行归一化。让我们尝试使用我使用 matryoshka_nli.pymicrosoft/mpnet-base 训练的模型:

from sentence_transformers import SentenceTransformer
from sentence_transformers.util import cos_sim

model = SentenceTransformer("tomaarsen/mpnet-base-nli-matryoshka")

matryoshka_dim = 64
embeddings = model.encode(
    [
        "The weather is so nice!",
        "It's so sunny outside!",
        "He drove to the stadium.",
    ]
)
embeddings = embeddings[..., :matryoshka_dim] # Shrink the embedding dimensions
print(embeddings.shape)
# => (3, 64)

# Similarity of the first sentence to the other two:
similarities = cos_sim(embeddings[0], embeddings[1:])
print(similarities)
# => tensor([[0.8910, 0.1337]])

模型链接: tomaarsen/mpnet-base-nli-matryoshka

请随意尝试使用不同的 matryoshka_dim 值,并观察这对相似度的影响。你可以通过在本地运行这段代码,在云端运行 (例如使用 Google Colab),或者查看 演示 来进行实验。

参考文献:

  • SentenceTransformer

  • SentenceTransformer.encode

  • util.cos_sim

  • Matryoshka Embeddings - 推理

点击这里查看如何使用 Nomic v1.5 Matryoshka 模型
from sentence_transformers import SentenceTransformer
from sentence_transformers.util import cos_sim
import torch.nn.functional as F

model = SentenceTransformer("nomic-ai/nomic-embed-text-v1.5", trust_remote_code=True)

matryoshka_dim = 64
embeddings = model.encode(
    [
        "search_query: What is TSNE?",
        "search_document: t-distributed stochastic neighbor embedding (t-SNE) is a statistical method for visualizing high-dimensional data by giving each datapoint a location in a two or three-dimensional map.",
        "search_document: Amelia Mary Earhart was an American aviation pioneer and writer.",
    ],
    convert_to_tensor=True,
)
# The Nomic team uses a custom architecture, making them recommend Layer Normalization before truncation
embeddings = F.layer_norm(embeddings, normalized_shape=(embeddings.shape[1],))
embeddings[..., :matryoshka_dim] # Shrink the embedding dimensions

similarities = cos_sim(embeddings[0], embeddings[1:])
# => tensor([[0.7154, 0.4468]])
  • 模型链接: nomic-ai/nomic-embed-text-v1.5

结果

现在我们已经介绍了俄罗斯套娃模型,让我们来看看我们可以从俄罗斯套娃嵌入模型与常规嵌入模型中实际期待的绩效表现。为了这个实验,我训练了两个模型:

  • tomaarsen/mpnet-base-nli-matryoshka: 通过运行 matryoshka_nli.pymicrosoft/mpnet-base 进行训练。

  • tomaarsen/mpnet-base-nli: 通过运行修改版的 matryoshka_nli.py 进行训练,其中训练损失仅为 MultipleNegativesRankingLoss ,而不是在 MultipleNegativesRankingLoss 之上的 MatryoshkaLoss 。我也使用 microsoft/mpnet-base 作为基础模型。

这两个模型都在 AllNLI 数据集上进行了训练,该数据集是 SNLI 和 MultiNLI 数据集的拼接。我使用多种不同的嵌入维度在这些模型上评估了 STSBenchmark 测试集。结果绘制在下面的图表中:

be092cb87bc112051cafc00507bcd74c.png
results

在上面的图表中,你可以看到俄罗斯套娃模型在所有维度上都达到了比标准模型更高的 Spearman 相似度,这表明俄罗斯套娃模型在此任务上是优越的。

此外,俄罗斯套娃模型的性能下降速度比标准模型要慢得多。这在第二个图表中清晰显示,该图表显示了相对于最大性能的嵌入维度的性能。即使嵌入大小只有 8.3%,俄罗斯套娃模型也保持了 98.37% 的性能,远高于标准模型的 96.46%。这些发现表明,通过俄罗斯套娃模型截断嵌入可以:

  1. 显著加快下游任务 (如检索) 的速度;

  2. 显著节省存储空间,而且不会对性能产生显著影响。

演示

在这个演示中,你可以动态缩小 nomic-ai/nomic-embed-text-v1.5 俄罗斯套娃嵌入模型的输出维度,并观察它如何影响检索性能。所有的嵌入都是在浏览器中使用 🤗 Transformers.js 进行计算的。

12e86b399c68f996d1366e9377e3ae87.png  

https://xenova-adaptive-retrieval-web.static.hf.space

参考文献

  • Kusupati, A., Bhatt, G., Rege, A., Wallingford, M., Sinha, A., Ramanujan, V., … & Farhadi, A. (2022). Matryoshka representation learning. Advances in Neural Information Processing Systems, 35, 30233-30249. https://arxiv.org/abs/2205.13147

  • Matryoshka Embeddings — Sentence-Transformers documentation. (n.d.). https://sbert.net/examples/training/matryoshka/README.html

  • UKPLab. (n.d.). GitHub. https://github.com/UKPLab/sentence-transformers

  • Unboxing Nomic Embed v1.5: Resizable Production Embeddings with Matryoshka Representation Learning. (n.d.). https://blog.nomic.ai/posts/nomic-embed-matryoshka

🤗 宝子们可以戳 阅读原文 查看文中所有的外部链接哟!


英文原文: https://hf.co/blog/matryoshka

原文作者: Tom Aarsen, Joshua, Omar Sanseviero

译者: innovation64

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/435766.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【Vue】vue3 在图片上渲染 OCR 识别后的文本框、可复制文本组件

需求 后面返回解析后的文本和四角坐标,在图片上渲染成框,并且可复制。图片还可以缩放、拖拽 实现 这里要重点讲下关于OCR文本框的处理: 因为一些文字可能是斜着放的,所有我们要特殊处理,根据三角函数来计算出它的偏…

分布式ID生成策略-雪花算法Snowflake

分布式ID生成策略-雪花算法Snowflake 一、其他分布式ID策略1.UUID2.数据库自增与优化2.1 优化1 - 共用id自增表2.2 优化2 - 分段获取id 3.Reids的incr和incrby 二、雪花算法Snowflake1.雪花算法的定义2.基础雪花算法源码解读3.并发1000测试4.如何设置机房和机器id4.雪花算法时钟…

短剧系统开发:一种新型的娱乐方式

一、引言 随着科技的快速发展,人们的生活方式也在逐渐改变。在娱乐领域,短剧作为一种新型的娱乐方式,正在受到越来越多人的喜爱。短剧以其短小精悍、情节紧凑、易于观看等特点,迅速占领了市场。因此,开发一款短剧系统…

基于STC12C5A60S2系列1T 8051单片机的TM1638键盘数码管模块的数码管显示应用

基于STC12C5A60S2系列1T 8051单片机的TM1638键盘数码管模块的数码管显示应用 STC12C5A60S2系列1T 8051单片机管脚图STC12C5A60S2系列1T 8051单片机I/O口各种不同工作模式及配置STC12C5A60S2系列1T 8051单片机I/O口各种不同工作模式介绍TM1638键盘数码管模块概述TM1638键盘数码管…

pytorch什么是梯度

目录 1.导数、偏微分、梯度1.1 导数1.2 偏微分1.3 梯度 2. 通过梯度求极小值3. learning rate3. 局部最小值4. Saddle point鞍点 1.导数、偏微分、梯度 1.1 导数 对于yx 2 2 2 的导数,描述了y随x值变化的一个变化趋势,导数是个标量反应的是变化的程度&…

NoSQL--3.MongoDB配置(Linux版)

目录 2.2 Linux环境下操作 2.2.1 传输MongoDB压缩包到虚拟机: 2.2.2 启动MongoDB服务: 2.2 Linux环境下操作 2.2.1 传输MongoDB压缩包到虚拟机: (笔者使用XShell传输) 如果不想放在如图的路径,删除操作…

基于springboot+vue实现学校田径运动会系统项目【项目源码+论文说明】计算机毕业设计

基于springbootvue实现学校田径运动会系统演示 摘要 随着互联网普及率的提高,互联网与人们日常生活的关系越来越密切,越来越多学校也正在着力建设自己的信息化管理系统,学校根据自身的发展及社会发展的需要,开始将传统的运动会成…

Golang模糊测试实践

模糊测试可以简单快速的自动化构建测试用例,尽量遍历各种可能的输入场景,从而保证函数代码覆盖尽可能多的边缘场景。Go原生内置了模糊测试的支持,如果善加利用,可以有效提升Go代码的质量。原文: Fuzz Testing in Golang 题图由Lex…

Hadoop配置日志的聚集——jobhistory不显示任务问题

问题: 一开始job history是正常的,配置了日志的聚集以后不管做什么任务都不显示任务,hdfs是正常运行,而且根据配置步骤都重启过了。 下面先po出日志聚集的操作步骤,再讲问题 1.配置yarn-site.xml cd $HADOOP_HOME/e…

0基础跨考408|一战上岸复盘及经验分享

基础阶段‼️ 王道的四本书的选择题部分要都做完、订正完。 王道的四门视频课要一轮刷完(或者题主在B站看了其他的老师,这其实也是算一轮的,只要题主是认真学习了的,题主说自己不知道看什么课,王道就好了)…

kibana配置 dashbord,做可视化展示

一、环境介绍 这里我使用的kibana版本为7.17版本。 语言选择为中文。 需要已经有es,已经有kibana,并且都能正常访问。 二、背景介绍 kibana的可视化界面,可以配置很多监控统计界面。非常方便,做数据的可视化展示。 这篇文章&…

【四】【SQL Server】如何运用SQL Server中查询设计器通关数据库期末查询大题

数据库学生选择1122 数据库展示 course表展示 SC表展示 student表展示 数据库学生选课1122_3 第十一题 第十二题 第十三题 第十四题 第十五题 数据库学生选课1122_4 第十六题 第十七题 第十八题 第十九题 第二十题 数据库学生选课1122_5 第二十一题 第二十二题 结尾 最后&…

恒驰上云规划实施解决方案上线华为云官网

华为云与伙伴共同打造联合解决方案 已成为更多企业的数字化转型利器 1月恒驰上云规划实施解决方案 完成上市宣讲并正式上架华为云官网 恒驰上云规划实施解决方案能力全景图:融合厂商云服务能力,一站式高效云迁移 从深入了解企业的本地IT环境、业务特点…

查看kafka消息消费堆积情况

查看主题命令 展示topic列表 ./kafka-topics.sh --list --zookeeper zookeeper_ip:2181描述topic ./kafka-topics.sh --describe --zookeeper zookeeper_ip:2181 --topic topic_name查看topic某分区偏移量最大(小)值 ./kafka-run-class.sh kafka.too…

Git——Upload your open store

0.default config ssh-keygen -t rsa #之后一路回车,当前目录.ssh/下产生公私钥 cat ~/.ssh/id_rsa.pub #复制公钥到账号 git config --global user.email account_email git config --global user.name account_name1. 上传一个公开仓库 查看当前分支: git branc…

JavaSE——基础小项目-模拟ATM系统(项目主要目标、技术选型、架构搭建、具体实现、完整代码注释)

目录 项目主要目标 技术选型 面向对象编程 使用集合容器 程序流程控制 使用常见API 系统架构搭建与欢迎页设计 Account ATM Test 用户开户功能实现 录入账户名称与性别 录入账户密码与取现额度 生成新卡号 存入账户 登录功能实现 登录后操作实现 退出账户 存…

python基础(11)《Allure报告中的组件用法》

使用 官方教程:https://docs.qameta.io/allure 入门 想要看到allure报告,需要做2个步骤: 1、pytest执行时关联allure:pytest命令带上--alluredir 结果存放目录或--alluredir结果存放目录; 2、打开执行报告&#xff…

通过勒索病毒攻击案例,思考勒索病毒攻击现象与趋势

前言 2019年针对企业的勒索病毒攻击越来越多,仿佛全球都在被勒索,基本上每天都会有关于勒索病毒攻击的案例被曝光,勒索病毒攻击已经成为全球最大的网络安全威胁,同时也被国际刑警组织认定为全球危害最大的网络犯罪组织活动&#…

nginx代理参数proxy_pass

proxy_pass参数用于配置反向代理,指定客户端请求被转发到后端服务器,后端地址可以是域名、ip端口URI 代理后端报错提示本地找不到CSS文件、JavaScript文件或图片 例如: nginx :10.1.74.109 后端服务:http://10.1.74.…

钡铼技术R40工业路由器连接智慧交通助力城市智慧化建设

随着信息技术与交通行业的深度融合,智慧交通作为智慧城市的重要组成部分,正在全球范围内加速推进。在此进程中,钡铼技术推出的R40工业路由器以其独特的4G WiFi一体化设计,成为连接智慧交通各环节,助力城市智慧化建设的…
最新文章