[Machine Learning] 多任务学习

文章目录

  • 基于参数的MTL模型 (Parameter-based MTL Models)
  • 基于特征的MTL模型 (Feature-based MTL Models)
    • 基于特征的MTL模型 I:
    • 基于特征的MTL模型 II:
  • 基于特征和参数的MTL模型 (Feature- and Parameter-based MTL Models)


多任务学习 (Multi-task Learning, MTL) 是一种同时学习多个相关问题的方法,它通过利用这些问题之间的相关性来进行学习。


在单任务学习 (Single-Task Learning, STL) 中,每个任务有一个独立的模型,这些模型分别学习不同的任务。这里,每个任务(Task 1, Task 2, Task 3, Task 4)都有它自己的输入和独立的神经网络模型。这些模型不会共享学习到的特征或表示,它们是完全独立的。

在多任务学习中,一个单一的模型共同学习多个任务。模型共享输入层和可能还有一些隐藏层,但在最后,可以有特定于任务的输出层。通过这种方式,模型可以学习到在多个任务间共通的、有用的表示,这可以提升模型在各个任务上的性能,特别是当这些任务相关时。多任务学习还有助于提高数据利用率和学习效率,因为相同的数据和模型参数被用来解决多个问题。

这幅图用来说明的关键点是,在多任务学习中,我们期望通过任务之间的相关性来提升性能,而在单任务学习中,每个任务都是孤立地学习,无法从其他任务中学习到的信息中受益。

当任务彼此独立时,多任务学习与单任务学习相比并无优势。

对于数据不足的问题,当有多个相关任务且每个任务的训练样本有限时,多任务学习是一个很好的解决方案。

设定有 m m m个学习任务 { T i } i = 1 m \{T_i\}_{i=1}^m {Ti}i=1m,其中所有任务或其子集彼此相关,多任务学习旨在通过使用 m m m个任务中包含的知识来帮助提高模型对 T i \mathcal{T}_i Ti的学习。任务 T i \mathcal{T}_i Ti伴随着一个训练集 D i = { x j i , y j i } j = 1 n i D_i = \{ x_j^i, y_j^i \}_{j=1}^{n_i} Di={xji,yji}j=1ni

我们的任务是为 { T i } i = 1 m \{T_i\}_{i=1}^m {Ti}i=1m学习假设。

在MTL中,我们考虑线性假设函数,表示为 h ( x ) = w T x h(x) = w^T x h(x)=wTx。对于 m m m 个不同但相关的任务,即 { T i } i = 1 m \{T_i\}^m_{i=1} {Ti}i=1m,我们定义 w i w^i wi 为第 i i i 个任务的假设,其中 i = 1 , … , m i = 1, \ldots, m i=1,,m

MTL的经验风险最小化算法表示为:

min ⁡ W = [ w 1 , … , w m ] 1 m ∑ i = 1 m 1 n i ∑ j = 1 n i ℓ ( x j i , y j i , w i ) \min\limits_{W=[w^1,\ldots,w^m]} \frac{1}{m} \sum_{i=1}^{m} \frac{1}{n_i} \sum_{j=1}^{n_i} \ell (x^i_j, y^i_j, w^i) W=[w1,,wm]minm1i=1mni1j=1ni(xji,yji,wi)

MTL模型通常由两个主要组件组成:参数共享和特征变换。参数共享是指在多个任务间共享模型参数,这样可以使不同任务互相借鉴彼此的信息,从而提高学习效率。特征变换则是指对输入数据进行变换,以找到一个更适合所有任务的表示方式。

基于参数的MTL模型 (Parameter-based MTL Models)

在这种方法中,我们考虑多个相关的任务,并且假设每个任务的假设 w i w^i wi可以表示为一个共同的基础参数 w 0 w_0 w0加上一个特定任务的偏差 Δ w i \Delta w^i Δwi。这个模型的形式化为:

min ⁡ w 0 , Δ W = [ Δ w 1 , … , Δ w m ] 1 m ∑ i = 1 m 1 n i ∑ j = 1 n i ℓ ( x j i , y j i , w 0 + Δ w i ) \min_{w_0,\Delta W = [\Delta w^1, \ldots, \Delta w^m]} \frac{1}{m} \sum_{i=1}^{m} \frac{1}{n_i} \sum_{j=1}^{n_i} \ell(x^i_j, y^i_j, w_0 + \Delta w^i) w0,ΔW=[Δw1,,Δwm]minm1i=1mni1j=1ni(xji,yji,w0+Δwi)

这里的 ℓ \ell 是损失函数, x j i x^i_j xji y j i y^i_j yji是第 i i i个任务的第 j j j个训练样本及其标签。

这样,第 i i i个任务的模型参数可以表示为 w i = w 0 + Δ w i w^i = w_0 + \Delta w^i wi=w0+Δwi。全局参数 w 0 w_0 w0捕获了所有任务之间的共性,而 Δ w i \Delta w^i Δwi则捕获了任务特有的特性。我们的优化目标是最小化所有任务的总损失,同时尽可能地使得各任务参数相互接近,这通常通过添加一个正则化项 ∥ Δ W ∥ F 2 \|\Delta W\|_F^2 ∥ΔWF2来实现:

min ⁡ w 0 , Δ W = [ Δ w 1 , … , Δ w m ] 1 m ∑ i = 1 m 1 n i ∑ j = 1 n i ℓ ( x j i , y j i , w 0 + Δ w i ) + λ ∥ Δ W ∥ F 2 \min_{w_0,\Delta W = [\Delta w^1, \ldots, \Delta w^m]} \frac{1}{m} \sum_{i=1}^{m} \frac{1}{n_i} \sum_{j=1}^{n_i} \ell(x^i_j, y^i_j, w_0 + \Delta w^i) + \lambda \|\Delta W\|_F^2 w0,ΔW=[Δw1,,Δwm]minm1i=1mni1j=1ni(xji,yji,w0+Δwi)+λ∥ΔWF2

这个模型更好,因为它鼓励多任务学习算法具有更强的相关性。

另一个模型使用秩约束:

min ⁡ W = [ w 1 , … , w m ] 1 m ∑ i = 1 m 1 n i ∑ j = 1 n i ℓ ( x j i , y j i , w i ) + λ  rank ( W ) \min\limits_{W=[w^1,\ldots,w^m]} \frac{1}{m} \sum\limits_{i=1}^{m} \frac{1}{n_i} \sum\limits_{j=1}^{n_i} \ell(x^i_j, y^i_j, w^i) + \lambda \text{ rank}(W) W=[w1,,wm]minm1i=1mni1j=1ni(xji,yji,wi)+λ rank(W)

基于特征的MTL模型 (Feature-based MTL Models)

在基于特征的MTL模型中,假设是从训练样例中学到的:

给定一组数据 D i = { x j i , y j i } j = 1 n i \mathcal{D}_i = \{ x_j^{i}, y_j^{i} \}_{j=1}^{n_i} Di={xji,yji}j=1ni

我们希望通过特征映射使得任务之间更加相关。即,我们希望找到一个投影矩阵 P P P,使得 D i \mathcal{D}_i Di变换为 D i = { P T x j i , y j i } j = 1 n i \mathcal{D}_i = \{ P^T x_j^{i}, y_j^{i} \}_{j=1}^{n_i} Di={PTxji,yji}j=1ni

基于特征的MTL模型 I:

min ⁡ W , P 1 m ∑ i = 1 m 1 n i ∑ j = 1 n i ℓ ( P T x j i , y j i , w i ) + λ rank ( W )  s.t.  P P T = I \min_{W,P} \frac{1}{m} \sum_{i=1}^{m} \frac{1}{n_i} \sum_{j=1}^{n_i} \ell(P^T x_j^{i}, y_j^{i},w^i) + \lambda \text{rank}(W) \text{ s.t. } PP^T = I W,Pminm1i=1mni1j=1ni(PTxji,yji,wi)+λrank(W) s.t. PPT=I

这个损失函数计算的是映射后的特征与目标值之间的误差,并加入了正则化项以控制权重矩阵W的复杂度。损失函数以 ℓ ( P T x i j , y i j , w i ) \ell(P^T x_i^j, y_i^j, w^i) (PTxij,yij,wi) 表示, x i j x_i^j xij 是第i个任务的第j个样本的特征, y i j y_i^j yij 是对应的目标值, w i w^i wi 是第i个任务的权重向量, P P P 是一个投影矩阵,使得通过 P T x j i P^T x_j^{i} PTxji变换后的特征可以更好地为多个任务服务。

λ \lambda λ 是正则化项的权重, rank ( W ) \text{rank}(W) rank(W) 是权重矩阵的秩,用于控制模型的复杂度。

基于特征的MTL模型 II:

这是一个共享隐藏层的神经网络架构,其中隐藏层的节点可以被看作是特征提取器。

对应的优化问题考虑了一个共享参数 w 0 w_0 w0 和针对每个任务的调整参数 Δ w i \Delta w_i Δwi

这个模型的目标是最小化包含共享参数和任务特定调整的损失函数,并通过 λ ∣ ∣ Δ W ∣ ∣ F 2 \lambda ||\Delta W||_F^2 λ∣∣ΔWF2 正则化每个任务的参数调整量。

隐藏层对于所有任务来说是共享的,这意味着模型可以学习通用的特征表示,而输出层则是特定于任务的。

基于特征和参数的MTL模型 (Feature- and Parameter-based MTL Models)

min ⁡ w 0 , Δ W , P 1 m ∑ i = 1 m 1 n i ∑ j = 1 n i ℓ ( P T x j i , y j i , w 0 + Δ w i ) + λ ∥ Δ W ∥ F 2  s.t.  P P T = I \min_{w_0, \Delta W,P} \frac{1}{m} \sum_{i=1}^{m} \frac{1}{n_i} \sum_{j=1}^{n_i} \ell(P^T x_j^{i}, y_j^{i},w_0 + \Delta w^i) + \lambda \|\Delta W\|_F^2 \text{ s.t. } PP^T = I w0,ΔW,Pminm1i=1mni1j=1ni(PTxji,yji,w0+Δwi)+λ∥ΔWF2 s.t. PPT=I

该模型旨在找到一个跨任务共享的特征投影( P P P)和一组针对所有任务优化的参数( w 0 w_0 w0 Δ W ΔW ΔW)。

  • 目标函数: min ⁡ w 0 , Δ W , P \min_{w_0, \Delta W,P} minw0,ΔW,P 表示我们的目标是最小化关于 w 0 w_0 w0(共享参数)、 Δ W \Delta W ΔW(任务特定参数变化)和 P P P(特征投影矩阵)的某个函数。

  • 任务平均: 1 m ∑ i = 1 m \frac{1}{m} \sum_{i=1}^{m} m1i=1m 表示我们考虑 m m m 个不同的任务,并对这些任务的结果取平均。

  • 任务内平均**: 对于每个任务 i i i 1 n i ∑ j = 1 n i \frac{1}{n_i} \sum_{j=1}^{n_i} ni1j=1ni 用于对该任务中的 n i n_i ni 个样本进行平均。

  • 损失函数: ℓ ( P T x j i , y j i , w 0 + Δ w i ) \ell(P^T x_j^{i}, y_j^{i},w_0 + \Delta w^i) (PTxji,yji,w0+Δwi) 是损失函数,用于量化模型预测 P T x j i P^T x_j^{i} PTxji(经过特征转换的输入)和真实标签 y j i y_j^{i} yji 之间的差异,同时考虑共享参数 w 0 w_0 w0 和任务特定参数的调整 Δ w i \Delta w^i Δwi

  • 正则化项: λ ∥ Δ W ∥ F 2 \lambda \|\Delta W\|_F^2 λ∥ΔWF2 是正则化项,用于防止过拟合。它通过控制任务特定参数变化的大小(使用Frobenius范数)来实现。

  • 约束条件: P P T = I PP^T = I PPT=I 是一个约束条件,确保投影矩阵 P P P 是正交的。这有助于保持映射后的特征间的独立性。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/130257.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

千帆SDK开源到GitHub,开发者可免费下载使用!

目录 一、SDK的优势 二、千帆SDK:快速落地LLM应用 三、如何快速上手千帆SDK 1、SDK快速启动 快速安装 平台鉴权 如何获取AK/SK 以“Chat 对话”为调用示例 2. SDK进阶指引 3. 通过Langchain接入千帆SDK 为什么选择Langchain 开源社区 千帆社区 好消息&…

高斯过程回归 | GPR高斯过程回归

高斯过程回归(Gaussian Process Regression, GPR)是一种强大的非参数回归方法,它通过假设数据是从一个高斯过程中生成的来预测新的数据点。 高斯过程是一种定义在连续输入空间上的随机过程,其中任何有限集合的观测值都呈多变量高斯分布。 实现GPR的Python代码import numpy …

Redis快速入门

1.说说什么是Redis? Redis 是互联网技术领域中使用最广泛的存储中间件,它是 Remote Dictionary Service 三个单词中加粗字母的组合。你别说,组合起来后念着挺自然的。 Redis 以超高的性能、完美的文档、简洁的源码著称,国内外很多大型互联网…

在GORM中使用并发

一个全面的指南,如何安全地使用GORM和Goroutines进行并发数据处理 效率是现代应用程序开发的基石,而并发在实现效率方面发挥着重要作用。GORM,这个强大的Go对象关系映射库,使开发人员能够通过Goroutines embrace并行性。在本指南…

FPGA UDP RGMII 千兆以太网(3)ODDR

1 xilinx原语 在 7 系列 FPGA 中实现 RGMII 接口需要借助 5 种原语,分别是:IDDR、ODDR、IDELAYE2、ODELAYE2(A7 中没有)、IDELAYCTRL。其中,IDDR和ODDR分别是输入和输出的双边沿寄存器,位于IOB中。IDELAYE2和ODELAYE2,分别用于控制 IO 口输入和输出延时。同时,IDELAYE2 …

SPSS曲线回归

前言: 本专栏参考教材为《SPSS22.0从入门到精通》,由于软件版本原因,部分内容有所改变,为适应软件版本的变化,特此创作此专栏便于大家学习。本专栏使用软件为:SPSS25.0 本专栏所有的数据文件请点击此链接下…

【PWN · ret2csu】[HNCTF 2022 WEEK2]ret2csu

记一道ret2csu 一、题目 二、思路 1.ret2csu用write泄露write的真实地址->泄露libc->获得system的真实地址 2.ret2csu用read写/bin/sh字符串到bss段上 3.ret2csu用write将system的真实地址写到bss段上 4.ret2csu调用system 三、exp from pwn import * from pwn impo…

LeetCode(2)移除元素【数组/字符串】【简单】

目录 1.题目2.答案3.提交结果截图 链接: 27. 移除元素 1.题目 给你一个数组 nums 和一个值 val,你需要 原地 移除所有数值等于 val 的元素,并返回移除后数组的新长度。 不要使用额外的数组空间,你必须仅使用 O(1) 额外空间并 原…

后端架构选择:构建安全强大的知识付费小程序平台

构建知识付费小程序平台需要考虑后端架构,确保系统安全性、性能和可扩展性。以下是一些常见的后端技术和最佳实践,能帮助您构建强大且安全的知识付费小程序平台。 1. 服务器端语言和框架选择 选择流行、成熟的后端语言和框架,如Node.js、P…

机器学习---多分类SVM、支持向量机分类

1. 多分类SVM 1.1 基本思想 Grammer-singer多分类支持向量机的出发点是直接用超平面把样本空间划分成M个区域,其 中每个区域对应一个类别的输入。如下例,用从原点出发的M条射线把平面分成M个区域,下图画 出了M3的情形: 1.2 问题…

EXPLAIN详解(MySQL)

EXPLAIN概述 EXPLAIN语句提供MySQL如何执行语句的信息。EXPLAIN与SELECT, DELETE, INSERT, REPLACE和UPDATE语句一起工作。 EXPLAIN返回SELECT语句中使用的每个表的一行信息。它按照MySQL在处理语句时读取表的顺序列出了输出中的表。MySQL使用嵌套循环连接方法解析所有连接。…

探索未来,开启无限可能:打造智慧应用,亚马逊云科技大语言模型助您一臂之力

文章目录 什么是大模型?大模型训练方法亚马逊云科技推出生成式AI新工具 —— aws toolkit使用教程 总结 什么是大模型? 近期,生成式大模型是人工智能领域的研究热点。这些生成式大模型,诸如文心一言、文心一格、ChatGPT、Stable …

JMeter实现持续压测websocket

1、安装插件:JMeter WebSocket Samplers pjtr / JMeter WebSocket Samplers / Downloads — Bitbuckethttps://bitbucket.org/pjtr/jmeter-websocket-samplers/downloads/ 将下载的Jar包放在安装jmeter的/lib/ext路径下,重启生效 查看测试计划--》配置…

LeetCode(6)轮转数组【数组/字符串】【中等】

目录 1.题目2.答案3.提交结果截图 链接: 189. 轮转数组 1.题目 给定一个整数数组 nums,将数组中的元素向右轮转 k 个位置,其中 k 是非负数。 示例 1: 输入: nums [1,2,3,4,5,6,7], k 3 输出: [5,6,7,1,2,3,4] 解释: 向右轮转 1 步: [7,1…

使用LLama和ChatGPT为多聊天后端构建微服务

微服务架构便于创建边界明确定义的灵活独立服务。这种可扩展的方法使开发人员能够在不影响整个应用程序的情况下单独维护和完善服务。然而,若要充分发挥微服务架构的潜力、特别是针对基于人工智能的聊天应用程序,需要与最新的大语言模型(LLM&…

如何使用HadSky搭配内网穿透工具打造个人站点并公网访问

🌈个人主页:聆风吟 🔥系列专栏:Cpolar杂谈、数据结构、算法模板 🔖少年有梦不应止于心动,更要付诸行动。 文章目录 前言一. 网站搭建1.1 网页下载和安装1.2 网页测试1.3 cpolar的安装和注册 二. 本地网页发…

ARM IMX6ULL 基础学习记录 / ARM 寄存器介绍

编辑整理 by Staok。 本文大部分内容摘自“100ask imx6ull”开发板的配套资料(如《IMX6ULL裸机开发完全手册》等等),侵删。进行了精髓提取,方便日后查阅。过于基础的内容不会在此提及。如有错误恭谢指出! 注&#xf…

ChatGPT-4:OpenAI的革命性升级

在人工智能领域,OpenAI这家公司凭借其创新性的技术,成为了备受瞩目的领导者。他们最近发布的ChatGPT-4,以其卓越的语言处理能力和先进的模型架构,引领了语言模型领域的革命性升级。 ChatGPT-4的模型容量相较于前一版本有了显著的提…

LeetCode(5)多数元素【数组/字符串】【简单】

目录 1.题目2.答案3.提交结果截图 链接: 169. 多数元素 1.题目 给定一个大小为 n 的数组 nums ,返回其中的多数元素。多数元素是指在数组中出现次数 大于 ⌊ n/2 ⌋ 的元素。 你可以假设数组是非空的,并且给定的数组总是存在多数元素。 示…

classification_report分类报告的含义

classification_report分类报告 基础知识混淆矩阵(Confusion Matrix)TP、TN、FP、FN精度(Precision)准确率(Accuracy)召回率(Recall)F1分数(F1-score) classi…