【域适应】深度域适应常用的距离度量函数实现

关于

深度域适应中,有一类方法是实现目标域和源域的特征对齐,特征对齐的衡量函数主要包括MMD,MK-MMD,A-distance,CORAL loss, Wasserstein distance等等。本文总结了常用的特征变换对齐的函数定义。

工具

Python

方法实现

MMD 多核函数定义
# Compute MMD (maximum mean discrepancy) using numpy and scikit-learn.

import numpy as np
from sklearn import metrics


def mmd_linear(X, Y):
    """MMD using linear kernel (i.e., k(x,y) = <x,y>)
    Note that this is not the original linear MMD, only the reformulated and faster version.
    The original version is:
        def mmd_linear(X, Y):
            XX = np.dot(X, X.T)
            YY = np.dot(Y, Y.T)
            XY = np.dot(X, Y.T)
            return XX.mean() + YY.mean() - 2 * XY.mean()

    Arguments:
        X {[n_sample1, dim]} -- [X matrix]
        Y {[n_sample2, dim]} -- [Y matrix]

    Returns:
        [scalar] -- [MMD value]
    """
    delta = X.mean(0) - Y.mean(0)
    return delta.dot(delta.T)


def mmd_rbf(X, Y, gamma=1.0):
    """MMD using rbf (gaussian) kernel (i.e., k(x,y) = exp(-gamma * ||x-y||^2 / 2))

    Arguments:
        X {[n_sample1, dim]} -- [X matrix]
        Y {[n_sample2, dim]} -- [Y matrix]

    Keyword Arguments:
        gamma {float} -- [kernel parameter] (default: {1.0})

    Returns:
        [scalar] -- [MMD value]
    """
    XX = metrics.pairwise.rbf_kernel(X, X, gamma)
    YY = metrics.pairwise.rbf_kernel(Y, Y, gamma)
    XY = metrics.pairwise.rbf_kernel(X, Y, gamma)
    return XX.mean() + YY.mean() - 2 * XY.mean()


def mmd_poly(X, Y, degree=2, gamma=1, coef0=0):
    """MMD using polynomial kernel (i.e., k(x,y) = (gamma <X, Y> + coef0)^degree)

    Arguments:
        X {[n_sample1, dim]} -- [X matrix]
        Y {[n_sample2, dim]} -- [Y matrix]

    Keyword Arguments:
        degree {int} -- [degree] (default: {2})
        gamma {int} -- [gamma] (default: {1})
        coef0 {int} -- [constant item] (default: {0})

    Returns:
        [scalar] -- [MMD value]
    """
    XX = metrics.pairwise.polynomial_kernel(X, X, degree, gamma, coef0)
    YY = metrics.pairwise.polynomial_kernel(Y, Y, degree, gamma, coef0)
    XY = metrics.pairwise.polynomial_kernel(X, Y, degree, gamma, coef0)
    return XX.mean() + YY.mean() - 2 * XY.mean()


if __name__ == '__main__':
    a = np.arange(1, 10).reshape(3, 3)
    b = [[7, 6, 5], [4, 3, 2], [1, 1, 8], [0, 2, 5]]
    b = np.array(b)
    print(a)
    print(b)
    print(mmd_linear(a, b))  # 6.0
    print(mmd_rbf(a, b))  # 0.5822
    print(mmd_poly(a, b))  # 2436.5
 A-distance 函数定义
# Compute A-distance using numpy and sklearn
# Reference: Analysis of representations in domain adaptation, NIPS-07.

import numpy as np
from sklearn import svm

def proxy_a_distance(source_X, target_X, verbose=False):
    """
    Compute the Proxy-A-Distance of a source/target representation
    """
    nb_source = np.shape(source_X)[0]
    nb_target = np.shape(target_X)[0]

    if verbose:
        print('PAD on', (nb_source, nb_target), 'examples')

    C_list = np.logspace(-5, 4, 10)

    half_source, half_target = int(nb_source/2), int(nb_target/2)
    train_X = np.vstack((source_X[0:half_source, :], target_X[0:half_target, :]))
    train_Y = np.hstack((np.zeros(half_source, dtype=int), np.ones(half_target, dtype=int)))

    test_X = np.vstack((source_X[half_source:, :], target_X[half_target:, :]))
    test_Y = np.hstack((np.zeros(nb_source - half_source, dtype=int), np.ones(nb_target - half_target, dtype=int)))

    best_risk = 1.0
    for C in C_list:
        clf = svm.SVC(C=C, kernel='linear', verbose=False)
        clf.fit(train_X, train_Y)

        train_risk = np.mean(clf.predict(train_X) != train_Y)
        test_risk = np.mean(clf.predict(test_X) != test_Y)

        if verbose:
            print('[ PAD C = %f ] train risk: %f  test risk: %f' % (C, train_risk, test_risk))

        if test_risk > .5:
            test_risk = 1. - test_risk

        best_risk = min(best_risk, test_risk)

    return 2 * (1. - 2 * best_risk)
 CORAL loss函数定义
# Compute CORAL loss using pytorch
# Reference: DCORAL: Correlation Alignment for Deep Domain Adaptation, ECCV-16.

import torch

def CORAL_loss(source, target):
    d = source.data.shape[1]
    ns, nt = source.data.shape[0], target.data.shape[0]
    # source covariance
    xm = torch.mean(source, 0, keepdim=True) - source
    xc = xm.t() @ xm / (ns - 1)

    # target covariance
    xmt = torch.mean(target, 0, keepdim=True) - target
    xct = xmt.t() @ xmt / (nt - 1)

    # frobenius norm between source and target
    loss = torch.mul((xc - xct), (xc - xct))
    loss = torch.sum(loss) / (4*d*d)
    return loss

# Another implementation:
# Two implementations are the same. Just different formulation format.
# def CORAL(source, target):
#     d = source.size(1)
#     ns, nt = source.size(0), target.size(0)
#     # source covariance
#     tmp_s = torch.ones((1, ns)).to(DEVICE) @ source
#     cs = (source.t() @ source - (tmp_s.t() @ tmp_s) / ns) / (ns - 1)

#     # target covariance
#     tmp_t = torch.ones((1, nt)).to(DEVICE) @ target
#     ct = (target.t() @ target - (tmp_t.t() @ tmp_t) / nt) / (nt - 1)

#     # frobenius norm
#     loss = torch.norm(cs - ct, p='fro').pow(2)
#     loss = loss / (4 * d * d)
#     return loss
Wasserstein loss函数定义
import math
import torch
import torch.linalg as linalg

def calculate_2_wasserstein_dist(X, Y):
    '''
    Calulates the two components of the 2-Wasserstein metric:
    The general formula is given by: d(P_X, P_Y) = min_{X, Y} E[|X-Y|^2]
    For multivariate gaussian distributed inputs z_X ~ MN(mu_X, cov_X) and z_Y ~ MN(mu_Y, cov_Y),
    this reduces to: d = |mu_X - mu_Y|^2 - Tr(cov_X + cov_Y - 2(cov_X * cov_Y)^(1/2))
    Fast method implemented according to following paper: https://arxiv.org/pdf/2009.14075.pdf
    Input shape: [b, n] (e.g. batch_size x num_features)
    Output shape: scalar
    '''

    if X.shape != Y.shape:
        raise ValueError("Expecting equal shapes for X and Y!")

    # the linear algebra ops will need some extra precision -> convert to double
    X, Y = X.transpose(0, 1).double(), Y.transpose(0, 1).double()  # [n, b]
    mu_X, mu_Y = torch.mean(X, dim=1, keepdim=True), torch.mean(Y, dim=1, keepdim=True)  # [n, 1]
    n, b = X.shape
    fact = 1.0 if b < 2 else 1.0 / (b - 1)

    # Cov. Matrix
    E_X = X - mu_X
    E_Y = Y - mu_Y
    cov_X = torch.matmul(E_X, E_X.t()) * fact  # [n, n]
    cov_Y = torch.matmul(E_Y, E_Y.t()) * fact

    # calculate Tr((cov_X * cov_Y)^(1/2)). with the method proposed in https://arxiv.org/pdf/2009.14075.pdf
    # The eigenvalues for M are real-valued.
    C_X = E_X * math.sqrt(fact)  # [n, n], "root" of covariance
    C_Y = E_Y * math.sqrt(fact)
    M_l = torch.matmul(C_X.t(), C_Y)
    M_r = torch.matmul(C_Y.t(), C_X)
    M = torch.matmul(M_l, M_r)
    S = linalg.eigvals(M) + 1e-15  # add small constant to avoid infinite gradients from sqrt(0)
    sq_tr_cov = S.sqrt().abs().sum()

    # plug the sqrt_trace_component into Tr(cov_X + cov_Y - 2(cov_X * cov_Y)^(1/2))
    trace_term = torch.trace(cov_X + cov_Y) - 2.0 * sq_tr_cov  # scalar

    # |mu_X - mu_Y|^2
    diff = mu_X - mu_Y  # [n, 1]
    mean_term = torch.sum(torch.mul(diff, diff))  # scalar

    # put it together

代码获取

相关项目和问题,欢迎沟通交流。

参考文献

Yan H, Ding Y, Li P, Wang Q, Xu Y, Zuo W. Mind the class weight bias: Weighted maximum mean discrepancy for unsupervised domain adaptation. InProceedings of the IEEE conference on computer vision and pattern recognition 2017 (pp. 2272-2281).

Sun B, Saenko K. Deep coral: Correlation alignment for deep domain adaptation. InComputer Vision–ECCV 2016 Workshops: Amsterdam, The Netherlands, October 8-10 and 15-16, 2016, Proceedings, Part III 14 2016 (pp. 443-450). Springer International Publishing.

Shen J, Qu Y, Zhang W, Yu Y. Wasserstein distance guided representation learning for domain adaptation. InProceedings of the AAAI conference on artificial intelligence 2018 Apr 29 (Vol. 32, No. 1).

Ben-David S, Blitzer J, Crammer K, Pereira F. Analysis of representations for domain adaptation. Advances in neural information processing systems. 2006;19.

Kramer O, Kramer O. Scikit-learn. Machine learning for evolution strategies. 2016:45-53.

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/536613.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Vue3学习04 组件通信

Vue3学习04 组件通信 组件通信props 父 ↔ 子自定义事件 子 > 父mitt 任意组件间通信v-model 父↔子$attrs 祖↔孙$refs、$parent案例的完整代码ref注意点 provide、inject 祖↔孙piniaslot① 默认插槽② 具名插槽③ 作用域插槽 组件通信 Vue3组件通信和Vue2的区别&#xf…

LangChain的RAG实践

1. 什么是RAG RAG的概念最先在2020年由Facebook的研究人员在论文《Retrieval-Augmented Generation for Knowledge-Intensive NLP Tasks》中提出来。在这篇论文中他们提出了两种记忆类型&#xff1a; 基于预训练模型&#xff08;当时LLM的概念不像现在这么如日中天&#xff0…

CV每日论文--2024.4.11

1、InternLM-XComposer2-4KHD: A Pioneering Large Vision-Language Model Handling Resolutions from 336 Pixels to 4K HD 中文标题&#xff1a;InternLM-XComposer2-4KHD&#xff1a;开创性的大型视觉语言模型&#xff0c;可处理从 336 像素到 4K 高清的分辨率 简介&#x…

OJ 变长编码 【C】

又是跌跌撞撞完成的一道题&#xff0c;我对于位运算和进制转化这块知识点太欠缺了&#xff0c;写了这么久c的题目也没用过几次 知识点 1.取出低七位bit 使用&位运算符 与0x7F可以取出当前数的二进制最低七位&#xff0c;这里即使是整数参与运算&#xff0c;也会自动被转换…

社交革命的引领者:探索Facebook的创新策略

1. 引言&#xff1a;社交媒体的崛起 社交媒体的兴起标志着信息时代的到来&#xff0c;它不仅改变了人们的生活方式&#xff0c;也影响着整个社会结构。作为社交媒体的先驱者&#xff0c;Facebook以其创新的策略和领先的技术&#xff0c;成为了这场社交革命的引领者。从2004年马…

Shenandoah GC算法

概述 最早由Red Hat公司发起&#xff0c;目标是利用现代多核CPU的优势&#xff0c;减少大堆内存在GC时产生的停顿时间。随OpenJDK 12一起发布&#xff0c;暂停时间不依赖于堆的大小&#xff1b;这意味着无论堆的大小如何&#xff0c;暂停时间都是差不多的。 Shenandoah最初的…

[C++][算法基础]图中点的层次(树图BFS)

给定一个 n 个点 m 条边的有向图&#xff0c;图中可能存在重边和自环。 所有边的长度都是 1&#xff0c;点的编号为 1∼n。 请你求出 1 号点到 n 号点的最短距离&#xff0c;如果从 1 号点无法走到 n 号点&#xff0c;输出 −1。 输入格式 第一行包含两个整数 n 和 m。 接…

【MCU开发规范】:MCU的性能测试

MCU的性能测试 前序性能评判方法MIPSCoreMark EEMBC其他参考 前序 我们平时做MCU开发时&#xff0c;前期硬件选型&#xff08;选那颗MCU&#xff09;基本由硬件工程师和架构决定&#xff0c;到软件开发时只是被动的开发一些具体功能&#xff0c;因此很少参与MCU的选型。 大部分…

Ant Desgin Vue Tree Tab 个性化需求

背景 个人对前端不是很熟&#xff0c;或者说过目就忘&#xff0c;但是对前端还要求不少&#xff0c;这就难搞了。 使用的前端是Mudblazor和ant design vue, Mudblazor 还没有开始搞&#xff0c;现在先用ant design vue&#xff0c;版本是vue3&#xff0c; ant design vue 4版…

4.11学习总结

一.IO流 一.java中IO的初步了解 (一).概念: Java中I/O操作主要是指使用Java进行输入&#xff0c;输出操作. Java所有的I/O机制都是基于数据流进行输入输出&#xff0c;这些数据流表示了字符或者字节数据的流动序列。Java的I/O流提供了读写数据的标准方法。任何Java中表示数据…

Excel·VBA二维数组S形排列

与之前的文章《ExcelVBA螺旋数组函数》将一维数组转为二维螺旋数组 本文将数组转为S形排列的二维数组&#xff0c;类似考场座位S形顺序 Function S形排列(ByVal arr, ByVal num_rows&, ByVal num_cols&, Optional ByVal mode$ "row")将数组arr转为num_rows…

必须掌握的这4种缓存模式

概述 在系统架构中&#xff0c;缓存可谓提供系统性能的简单方法之一&#xff0c;稍微有点开发经验的同学必然会与缓存打过交道&#xff0c;起码也实践过。 如果使用得当&#xff0c;缓存可以减少响应时间、减少数据库负载以及节省成本。但如果缓存使用不当&#xff0c;则可能…

有趣的css - 动态雷达扫描

大家好&#xff0c;我是 Just&#xff0c;这里是「设计师工作日常」&#xff0c;今天分享的是使用 css 实现一个动态的雷达扫描&#xff0c;快学起来吧&#xff01; 《有趣的css》系列最新实例通过公众号「设计师工作日常」发布。 目录 整体效果核心代码html 代码css 部分代码…

当然IP总流量卵化手14无线天线上实际操作夏令营【第9期】月入5w 上百万爆款打造 (74节)

在2023年&#xff0c;我依照导师的”项目销售”策略&#xff0c;成功地实现了超过100万的纯利润。在当前经济低迷的大环境下&#xff0c;许多大型企业纷纷裁员&#xff0c;这使得许多人面临着找不到满意工作的困境。与此同时&#xff0c;由于疫情引发的口罩需求&#xff0c;使得…

算法刷题Day31 | 455.分发饼干、376. 摆动序列、53. 最大子数组和

目录 0 引言1 分发饼干1.1 我的解题1.2 更好的解题 2 摆动序列2.1 我的解题2.2 我的错误原因&#xff08;GPT分析&#xff09;2.3 改进 3 最大子数组和3.1 我的解题 &#x1f64b;‍♂️ 作者&#xff1a;海码007&#x1f4dc; 专栏&#xff1a;算法专栏&#x1f4a5; 标题&…

爬虫实战:我国城市的地铁数据以及分析

文章目录 1 引言2 项目背景3 技术栈和工具选择4 数据爬取4.1 爬虫设计4.2 代码实现4.3 数据保存4.4 关键点分析 5 数据处理与分析5.1 数据清洗5.2 数据分析5.3 关键点分析 6 完整代码以及结果展示7 小分享 1 引言 本文将指导你如何通过Python从高德地图爬取中国城市地铁站数据…

5G-A有何能耐?5G-A三载波聚合技术介绍

2024年被称作5G-A元年。5G-A作为5G下一阶段的演进技术&#xff0c;到底有何能耐呢&#xff1f; 三载波聚合&#xff08;3CC&#xff09;被认为是首个大规模商用的5G-A技术&#xff0c;将带来手机网速的大幅提升。 █ 什么是3CC 3CC&#xff0c;全称叫3 Component Carriers…

前端js基础知识(八股文大全)

一、js的数据类型 值类型(基本类型)&#xff1a;数字(Number)、字符串&#xff08;String&#xff09;、布尔(Boolean)、对空&#xff08;Null&#xff09;、未定义&#xff08;Undefined&#xff09;、Symbol,大数值类型(BigInt) 引用数据类型&#xff1a;对象(Object)、数组…

HNUST湖南科技大学嵌入式开发板使用-2024

目录 1.需要准备的软件(版本必须相同)꒰ঌ( ⌯ ⌯)໒꒱ 2.下载链接地址⌯▾⌯ 3.软件安装教程 4.安装好了&#xff0c;正常情况会是什么样子呢&#xff1f;(๑•̌.•๑) 4.1.拆入第一个接口(串口com接口是用来上传代码的ฅ˙Ⱉ˙ฅ) 4.2.拆入第三个接口&#xff08;SWD Jlink口…

android android.permission.MANAGE_EXTERNAL_STORAGE使用

android11 及以上版本&#xff0c;如果release版本要读取外部存储公共目录&#xff0c;即sdcard公共目录&#xff0c;需要在androidManifest.xml下申明 <uses-permission android:name"android.permission.MANAGE_EXTERNAL_STORAGE" /> 如果要发版到海外&…
最新文章