在 PyTorch 中实现可解释的神经网络模型

动动发财的小手,点个赞吧!

alt

目的

深度学习系统缺乏可解释性对建立人类信任构成了重大挑战。这些模型的复杂性使人类几乎不可能理解其决策背后的根本原因。

深度学习系统缺乏可解释性阻碍了人类的信任。

为了解决这个问题,研究人员一直在积极研究新的解决方案,从而产生了重大创新,例如基于概念的模型。这些模型不仅提高了模型的透明度,而且通过在训练过程中结合高级人类可解释的概念(如“颜色”或“形状”),培养了对系统决策的新信任感。因此,这些模型可以根据学习到的概念为其预测提供简单直观的解释,从而使人们能够检查其决策背后的原因。这还不是全部!它们甚至允许人类与学习到的概念进行交互,让我们能够控制最终的决定。

基于概念的模型允许人类检查深度学习预测背后的推理,并让我们重新控制最终决策。

这篇博文[1]中,我们将深入研究这些技术,并为您提供使用简单的 PyTorch 接口实现最先进的基于概念的模型的工具。通过实践经验,您将学习如何利用这些强大的模型来增强可解释性并最终校准人类对您的深度学习系统的信任。

概念瓶颈模型

在这个介绍中,我们将深入探讨概念瓶颈模型。这模型在 2020 年国际机器学习会议上发表的一篇论文中介绍,旨在首先学习和预测一组概念,例如“颜色”或“形状”,然后利用这些概念来解决下游分类任务:

alt

通过遵循这种方法,我们可以将预测追溯到提供解释的概念,例如“输入对象是一个{apple},因为它是{spherical}和{red}。”

概念瓶颈模型首先学习一组概念,例如“颜色”或“形状”,然后利用这些概念来解决下游分类任务。

实现

为了说明概念瓶颈模型,我们将重新审视著名的 XOR 问题,但有所不同。我们的输入将包含两个连续的特征。为了捕捉这些特征的本质,我们将使用概念编码器将它们映射为两个有意义的概念,表示为“A”和“B”。我们任务的目标是预测“A”和“B”的异或 (XOR)。通过这个例子,您将更好地理解概念瓶颈如何在实践中应用,并见证它们在解决具体问题方面的有效性。

我们可以从导入必要的库并加载这个简单的数据集开始:

import torch
import torch_explain as te
from torch_explain import datasets
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split

x, c, y = datasets.xor(500)
x_train, x_test, c_train, c_test, y_train, y_test = train_test_split(x, c, y, test_size=0.33, random_state=42)

接下来,我们实例化一个概念编码器以将输入特征映射到概念空间,并实例化一个任务预测器以将概念映射到任务预测:

concept_encoder = torch.nn.Sequential(
    torch.nn.Linear(x.shape[1], 10),
    torch.nn.LeakyReLU(),
    torch.nn.Linear(108),
    torch.nn.LeakyReLU(),
    torch.nn.Linear(8, c.shape[1]),
    torch.nn.Sigmoid(),
)
task_predictor = torch.nn.Sequential(
    torch.nn.Linear(c.shape[1], 8),
    torch.nn.LeakyReLU(),
    torch.nn.Linear(81),
)
model = torch.nn.Sequential(concept_encoder, task_predictor)

然后我们通过优化概念和任务的交叉熵损失来训练网络:

optimizer = torch.optim.AdamW(model.parameters(), lr=0.01)
loss_form_c = torch.nn.BCELoss()
loss_form_y = torch.nn.BCEWithLogitsLoss()
model.train()
for epoch in range(2001):
    optimizer.zero_grad()

    # generate concept and task predictions
    c_pred = concept_encoder(x_train)
    y_pred = task_predictor(c_pred)

    # update loss
    concept_loss = loss_form_c(c_pred, c_train)
    task_loss = loss_form_y(y_pred, y_train)
    loss = concept_loss + 0.2*task_loss

    loss.backward()
    optimizer.step()

训练模型后,我们评估其在测试集上的性能:

c_pred = concept_encoder(x_test)
y_pred = task_predictor(c_pred)

concept_accuracy = accuracy_score(c_test, c_pred > 0.5)
task_accuracy = accuracy_score(y_test, y_pred > 0)

现在,在几个 epoch 之后,我们可以观察到概念和任务在测试集上的准确性都非常好(~98% 的准确性)!

由于这种架构,我们可以通过根据输入概念查看任务预测器的响应来为模型预测提供解释,如下所示:

c_different = torch.FloatTensor([01])
print(f"f({c_different}) = {int(task_predictor(c_different).item() > 0)}")

c_equal = torch.FloatTensor([11])
print(f"f({c_different}) = {int(task_predictor(c_different).item() > 0)}")

这会产生例如 f([0,1])=1 和 f([1,1])=0 ,如预期的那样。这使我们能够更多地了解模型的行为,并检查它对于任何相关概念集的行为是否符合预期,例如,对于互斥的输入概念 [0,1] 或 [1,0],它返回的预测y=1。

概念瓶颈模型通过将预测追溯到概念来提供直观的解释。

淹没在准确性与可解释性的权衡中

概念瓶颈模型的主要优势之一是它们能够通过揭示概念预测模式来为预测提供解释,从而使人们能够评估模型的推理是否符合他们的期望。

然而,标准概念瓶颈模型的主要问题是它们难以解决复杂问题!更一般地说,他们遇到了可解释人工智能中众所周知的一个众所周知的问题,称为准确性-可解释性权衡。实际上,我们希望模型不仅能实现高任务性能,还能提供高质量的解释。不幸的是,在许多情况下,当我们追求更高的准确性时,模型提供的解释往往会在质量和忠实度上下降,反之亦然。

在视觉上,这种权衡可以表示如下:

alt

可解释模型擅长提供高质量的解释,但难以解决具有挑战性的任务,而黑盒模型以提供脆弱和糟糕的解释为代价来实现高任务准确性。

为了在具体设置中说明这种权衡,让我们考虑一个概念瓶颈模型,该模型应用于要求稍高的基准,即“三角学”数据集:

x, c, y = datasets.trigonometry(500)
x_train, x_test, c_train, c_test, y_train, y_test = train_test_split(x, c, y, test_size=0.33, random_state=42)

在该数据集上训练相同的网络架构后,我们观察到任务准确性显着降低,仅达到 80% 左右。

概念瓶颈模型未能在任务准确性和解释质量之间取得平衡。

这就引出了一个问题:我们是永远被迫在准确性和解释质量之间做出选择,还是有办法取得更好的平衡?

Reference

[1]

Source: https://towardsdatascience.com/implement-interpretable-neural-models-in-pytorch-6a5932bdb078

本文由 mdnice 多平台发布

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/30539.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

c++Qt Creator调用 python 完整版 + 解决bug过程

文章目录 创建项目配置python环境导入Python库其他坑点Python.h 头文件报错ModuleNotFoundError: No module named encodings’ 完美解决找不到python文件 成功! 文章首发于我的个人博客:欢迎大佬们来逛逛 创建项目 选择创建 qmake 项目: …

【C++】vector的模拟实现

目录 1.vector的结构2.构造函数2.1 无参构造2.2 以迭代器区间作为参数的构造函数2.3 构造n个value值 3.拷贝构造3.1 传统写法3.2 现代写法 4.赋值重载5.迭代器失效问题5.1 reserve和resize5.2 insert 5.3 erase4. 整体代码(包含迭代器、析构函数等) 1.ve…

springboot实验室管理系统-计算机毕设 附源码86757

springboot实验室管理系统 摘 要 验室管理系统是将实验室的分析仪器通过计算机网络连起来,采用科学的管理思想和先进的数据库技术,实现以实验室为核心的整体环境的全方位管理。它集用户管理,实验室信息管理,实验室预约管理&#x…

Java设计模式——策略模式

1. 策略模式简介 策略模式: 策略模式是一种行为型模式, 它将对象和行为分开, 将行为定义为一个行为接口和具体行为的实现 策略模式最大的特点是行为的变化, 行为之间可以相互替换 每个if判断都可以理解为一个策略. 本模式是的算法可独立于使用它的用户而变化 2. 模式结构 策略…

Flink 学习七 Flink 状态(flink state)

Flink 学习七 Flink 状态(flink state) 1.状态简介 流式计算逻辑中,比如sum,max; 需要记录和后面计算使用到一些历史的累计数据, 状态就是:用户在程序逻辑中用于记录信息的变量 在Flink 中 ,状态state 不仅仅是要记录状态;在程序运行中如果失败,是需要重新恢复,所以这个状态…

Java实训第七天——2023.6.13

文章目录 一、用Visual Studio Code写一个计算器二、同一个js被多个html引用三、js操作css四、DOM对象属性的操作案例五、js解析json 一、用Visual Studio Code写一个计算器 功能&#xff1a;实现简单的加减乘除 <!DOCTYPE html> <html lang"en"> <…

LeetCode 2481. 分割圆的最少切割次数

【LetMeFly】2481.分割圆的最少切割次数 力扣题目链接&#xff1a;https://leetcode.cn/problems/minimum-cuts-to-divide-a-circle/ 圆内一个 有效切割 &#xff0c;符合以下二者之一&#xff1a; 该切割是两个端点在圆上的线段&#xff0c;且该线段经过圆心。该切割是一端…

mapbox-gl 点位编辑功能

文章目录 前言方式一&#xff1a;借助 Marker添加自定义icon添加POI图层&#xff0c;绑定对应事件基于Marker交互创建自定义Marker编辑 / 创建POI 方式二&#xff1a;采用 mapbox-gl-draw 插件总结 前言 矢量在线编辑是gis常用的编辑功能&#xff0c;兴趣点&#xff08;POI&am…

kettle开发-Day38-超好用自定义数据处理组件

目录 前言&#xff1a; 一、半斤八两&#xff0c;都不太行 1、表输入&#xff0c;速度快&#xff0c;但不稳妥 2、稳的一批&#xff0c;但是慢的像蜗牛 二、各诉衷肠&#xff0c;合作共赢 1、表输入&#xff0c;高效数据插入 2、插入更新&#xff0c;一个都不能少 三、表输…

express的使用(四) nodejs转发表单到后台

原文链接 搬砖的林小白-express的使用(四) 个人博客地址&#xff0c;求关注&#xff0c;也希望大家在里面批评我的不足之处 看前提示 本篇所讲述的内容是node端转发前端发送过来的表单到第三方中&#xff0c;应用的场景有很多&#xff0c;如我们经常做的将文件存储到七牛云或…

Scala学习笔记

累了&#xff0c;基础配置不想写了&#xff0c;直接抄了→Scala的环境搭建 这里需要注意的是&#xff0c;创建新项目时&#xff0c;不要用默认的Class类&#xff0c;用Object&#xff0c;原因看→scala中的object为什么可以直接运行 一、Scala简介 1.1 图解Scala和Java的关系 1…

大数据测试基本知识

常用大数据框架结构 1.大数据测试常用到的软件工具 工具推荐&#xff0c;对于测试数据构造工具有&#xff1a;Datafaker、DbSchema、Online test data generator等&#xff1b;ETL测试工具有&#xff1a;RightData、QuerySurge等&#xff1b;数据质量检查工具&#xff1a;great…

MySQL-SQL存储过程/触发器详解(上)

♥️作者&#xff1a;小刘在C站 ♥️个人主页&#xff1a; 小刘主页 ♥️努力不一定有回报&#xff0c;但一定会有收获加油&#xff01;一起努力&#xff0c;共赴美好人生&#xff01; ♥️学习两年总结出的运维经验&#xff0c;以及思科模拟器全套网络实验教程。专栏&#xf…

Three.js--》实现3d地月模型展示

目录 项目搭建 初始化three.js基础代码 创建月球模型 添加地球模型 添加模型标签 今天简单实现一个three.js的小Demo&#xff0c;加强自己对three知识的掌握与学习&#xff0c;只有在项目中才能灵活将所学知识运用起来&#xff0c;话不多说直接开始。 项目搭建 本案例还…

《离散数学》:代数系统和图论导论

一、代数系统 代数系统是数学中的一个重要概念&#xff0c;它涉及一组对象以及定义在这些对象上的运算规则。代数系统可以是抽象的&#xff0c;也可以是具体的。 在抽象代数中&#xff0c;代数系统通常由一组元素和一组操作&#xff08;或称为运算&#xff09;组成。这些操作…

【MySQL新手入门系列四】:手把手教你MySQL数据查询由入门到学徒

SQL语言是与数据库交互的机制&#xff0c;是关系型数据库的标准语言。SQL语言可以用于创建、修改和查询关系数据库。SQL的SELECT语句是最重要的命令之一&#xff0c;用于从指定表中查询数据。在此博客中&#xff0c;我们将进一步了解SELECT语句以及WHERE子句以及它们的重要性。…

vue进阶-vue-route

Vue Router 是 Vue.js 的官方路由。它与 Vue.js 核心深度集成&#xff0c;让用 Vue.js 构建单页应用变得轻而易举。 本章只做学习记录&#xff0c;详尽的内容一定要去官网查看api文档 Vue Router-Vue.js 的官方路由 1. 路由的基本使用 1.1 安装vue-router npm install vue-…

SpringCloud Eureka注册中心高可用集群配置(八)

当注册中心扛不住高并发的时候&#xff0c;这时候 要用集群来扛&#xff1b; 我们再新建两个module microservice-eureka-server-2002 microservice-eureka-server-2003 第一步&#xff1a; pom.xml 把依赖加下&#xff1a; <dependencies> <dependency…

golang 协程的实现原理

核心概念 要理解协程的实现, 首先需要了解go中的三个非常重要的概念, 它们分别是G, M和P, 没有看过golang源代码的可能会对它们感到陌生, 这三项是协程最主要的组成部分, 它们在golang的源代码中无处不在. G (goroutine) G是goroutine的头文字, goroutine可以解释为受管理的…

Prompt 范式产业实践分享!基于飞桨 UIE-X 和 Intel OpenVINO 实现跨模态文档信息抽取

近期 Prompt 范式备受关注&#xff0c;实际上&#xff0c;其思想在产业界已经有了一些成功的应用案例。中科院软件所和百度共同提出了大一统诸多任务的通用信息抽取技术 UIE&#xff08;Universal Information Extraction&#xff09;。截至目前&#xff0c;UIE 系列模型已发布…
最新文章