神经网络的解释方法之CAM、Grad-CAM、Grad-CAM++、LayerCAM

原理优点缺点
GAP将多维特征映射降维为一个固定长度的特征向量①减少了模型的参数量;②保留更多的空间位置信息;③可并行计算,计算效率高;④具有一定程度的不变性①可能导致信息的损失;②忽略不同尺度的空间信息
CAM利用最后一个卷积层的特征图×权重(用GAP代替全连接层,重新训练,经过GAP分类后概率最大的神经元的权重效果已经很不错需要修改原模型的结构,导致需要重新训练该模型,大大限制了使用场景,如果模型已经上线了,或着训练的成本非常高,我们几乎是不可能为了它重新训练的。
Grad-CAM最后一个卷积层的特征图×权重(通过对特征图梯度的全局平均来计算权重①解决了CAM的缺点,适用于任何卷积神经网络;②利用特征图的梯度,可视化结果更准确和精细
Grad-CAM++1. 定位更准确
2. 更适合同类多目标的情况

GAP全局平均池化

论文:Network In Network

GAP (Global Average Pooling,全局平均池化),在上述论文中提出,用于避免全连接层的过拟合问题。全局平均池化就是对整个特征映射应用平均池化。

图1:将原本h × w × d的三维特征图,具体大小为6 × 6 × 3,经过GAP池化为1 × 1 × 3 输出值。也就是每一个channel的h × w 平均池化为一个值。特征图经过 GAP 处理后每一个特征图包含了不同类别的信息。 

GAP平均池化的操作步骤如下:

  1. 经过卷积操作和激活函数后,得到最后一个卷积层的特征图。
  2. 对每个通道的特征图进行平均池化,即计算每个通道上所有元素的平均值。这将每个通道的特征图转化为一个标量值。
  3. 将每个通道的标量值组合成一个特征向量。这些标量值的顺序与通道的顺序相同。
  4. 最终得到的特征向量可以作为分类器的输入,用于进行图像分类。

CAM

论文:Learning Deep Features for Discriminative Localization

原理:利用最后一个卷积层的特征图与经过GAP分类后概率最大的神经元权重进行叠加。

图2:解释了在CNN中使用全局平均池化(GAP)生成类激活映射(CAM)的过程:

经过最后一层卷积操作之后,得到的特征图包含多个channel,如图1中的不同颜色的3个channel,也就是在GAP之前所对应的不同的channel特征图,f_{k}就表示第k个channel的特征图。然后经过GAP处理后每个channel的特征图包含了不同类别的信息,w_{k}就表示分类概率最大的神经元(图2黑色神经元)所对应连接的第k个神经元的权重。

Grad-CAM 

Grad-CAM的前身是 CAM,CAM 的基本的思想是求分类网络某一类别得分对高维特征图 (卷积层的输出) 的偏导数,从而可以得到该高维特征图每个通道对该类别得分的权值;而高维特征图的激活信息 (正值) 又代表了卷积神经网络的所感兴趣的信息,加权后使用热力图呈现得到 CAM。

原理:Grad-CAM的关键思想是将输出类别的梯度(相对于特定卷积层的输出)与该层的输出相乘,然后取平均,得到一个“粗糙”的热力图。这个热力图可以被放大并叠加到原始图像上,以显示模型在分类时最关注的区域。

具体步骤如下:

  1. 选择网络的最后一个卷积层,因为它既包含了高级特征,也保留了空间信息。
  2. 前向传播图像到网络,得到你想解释的类别的得分。
  3. 计算此得分相对于我们选择的卷积层输出的梯度。
  4. 对于该卷积层的每个通道,使用上述梯度的全局平均值对该通道进行加权。
  5. 结果是一个与卷积层的空间维度相同的加权热力图。

因为热力图关心的是对分类有正面影响的特征,所以在线性组合的技术上加上了ReLU,以移除负值 。

w_{k}^{c}第 k 个特征图对应于类别 c 的权重,
A^{k}表示:第 k 个特征图,
Z表示特征图的像素个数,
y^{c}表示: 第c类得分的梯度,
A_{ij}^{k}表示: 第 k个特征图中坐标( i , j )位置处的像素值;

Grad-CAM代码:

import torch
import cv2
import torch.nn.functional as F
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from PIL import Image
 
class GradCAM:
    def __init__(self, model, target_layer):
        self.model = model  # 要进行Grad-CAM处理的模型
        self.target_layer = target_layer  # 要进行特征可视化的目标层
        self.feature_maps = None  # 存储特征图
        self.gradients = None  # 存储梯度
        
        # 为目标层添加钩子,以保存输出和梯度
        target_layer.register_forward_hook(self.save_feature_maps)
        target_layer.register_backward_hook(self.save_gradients)
 
    def save_feature_maps(self, module, input, output):
        """保存特征图"""
        self.feature_maps = output.detach()
 
    def save_gradients(self, module, grad_input, grad_output):
        """保存梯度"""
        self.gradients = grad_output[0].detach()
 
    def generate_cam(self, image, class_idx=None):
        """生成CAM热力图"""
        # 将模型设置为评估模式
        self.model.eval()
        
        # 正向传播
        output = self.model(image)
        if class_idx is None:
            class_idx = torch.argmax(output).item()
 
        # 清空所有梯度
        self.model.zero_grad()
 
        # 对目标类进行反向传播
        one_hot = torch.zeros((1, output.size()[-1]), dtype=torch.float32)
        one_hot[0][class_idx] = 1
        output.backward(gradient=one_hot.cuda(), retain_graph=True)
 
        # 获取平均梯度和特征图
        pooled_gradients = torch.mean(self.gradients, dim=[0, 2, 3])
        activation = self.feature_maps.squeeze(0)
        for i in range(activation.size(0)):
            activation[i, :, :] *= pooled_gradients[i]
        
        # 创建热力图
        heatmap = torch.mean(activation, dim=0).squeeze().cpu().numpy()
        heatmap = np.maximum(heatmap, 0)
        heatmap /= torch.max(heatmap)
        heatmap = cv2.resize(heatmap, (image.size(3), image.size(2)))
        heatmap = np.uint8(255 * heatmap)
        heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
        
        # 将热力图叠加到原始图像上
        original_image = self.unprocess_image(image.squeeze().cpu().numpy())
        superimposed_img = heatmap * 0.4 + original_image
        superimposed_img = np.clip(superimposed_img, 0, 255).astype(np.uint8)
        
        return heatmap, superimposed_img
 
    def unprocess_image(self, image):
        """反预处理图像,将其转回原始图像"""
        mean = np.array([0.485, 0.456, 0.406])
        std = np.array([0.229, 0.224, 0.225])
        image = (((image.transpose(1, 2, 0) * std) + mean) * 255).astype(np.uint8)
        return image
 
def visualize_gradcam(model, input_image_path, target_layer):
    """可视化Grad-CAM热力图"""
    # 加载图像
    img = Image.open(input_image_path)
    preprocess = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    input_tensor = preprocess(img).unsqueeze(0).cuda()
 
    # 创建GradCAM
    gradcam = GradCAM(model, target_layer)
    heatmap, result = gradcam.generate_cam(input_tensor)
 
    # 显示图像和热力图
    plt.figure(figsize=(10,10))
    plt.subplot(1,2,1)
    plt.imshow(heatmap)
    plt.title('热力图')
    plt.axis('off')
    plt.subplot(1,2,2)
    plt.imshow(result)
    plt.title('叠加后的图像')
    plt.axis('off')
    plt.show()
 
# 以下是示例代码,显示如何使用上述代码。
# 首先,你需要加载你的模型和权重。
# model = resnet20()
# model.load_state_dict(torch.load("path_to_your_weights.pth"))
# model.to('cuda')
 
# 然后,调用`visualize_gradcam`函数来查看结果。
# visualize_gradcam(model, "path_to_your_input_image.jpg", model.layer3[-1])

 Grad-CAM++

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/112645.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【mfc/VS2022】计图实验:绘图工具设计知识笔记3

实现类对串行化的支持 如果要用CArchive类保存对象的话,那么这个对象的类必须支持串行化。一个可串行化的类通常有一个Serialize成员函数。要想使一个类可串行化,要经历以下5个步骤: 1、从CObject派生类 2、重写Serialize成员函数 3、使用DE…

Ubuntu MySQL客户端功能介绍(mysql-client)mysql命令(mysql客户端命令)数据库导出、数据库导入

文章目录 Ubuntu MySQL客户端(mysql-client)功能介绍MySQL客户端与服务端服务器端(MySQL Server)客户端(MySQL Client) 安装MySQL客户端连接到MySQL服务器(mysql -h host -u user -p)执行SQL查询批处理模式…

wordpress上传限制2M修改为256M的两种方式

方式一:修改php.ini 上传文件限制大小主要是php的php.ini配置决定的,所以只要找到php的配置文件,并且修改里面的配置即可,linux查看php的版本和配置文件位置的命令: php -i | grep "Configuration File" 一…

AI智能语音识别模块(二)——基于Arduino的语音控制MP3播放器

文章目录 简介离线语音控制模块Mini MP3模块0.96寸 OLED模块实验准备安装库接线定义主要程序实验效果注意事项总结 简介 在前面一篇文章里我们对AI智能语音识别模块进行了介绍,并对离线语音模组下载固件的过程进行了一个简单描述,不知道大家还记不记得&…

用前端框架Bootstrap和Django实现用户注册页面

01-新建一个名为“mall_backend”的Project 命令如下: CD E:\Python_project\P_001\myshop-test E: django-admin startproject mall_backend02-新建应用并注册应用 执行下面条命令依次创建需要的应用: CD E:\Python_project\P_001\myshop-test\mall…

uniapp如何使用mumu模拟器

模拟器安装 下载地址:MuMu模拟器 模拟器相关设置 1.在设置-显示中选中手机版,设置手机分辨率 2.设置-关于手机-版本号快速点击,将其设置为开发者模式 3.选择多开器 4.打开hbuilderx,找到adb设置 5.配置adb路径及端口号&#x…

Servlet 初始化参数(web.xml和@WebServlet)

1、通过web.xml方式 <?xml version"1.0" encoding"UTF-8"?> <web-app xmlns"http://xmlns.jcp.org/xml/ns/javaee"xmlns:xsi"http://www.w3.org/2001/XMLSchema-instance"xsi:schemaLocation"http://xmlns.jcp.org/xm…

MFC实现堆栈窗口:多个子界面可任意切换

1、效果 在Qt中可使用QStackedWidget控件直接拖动布置即可实现&#xff0c;但在MFC中并未提供类似的控件&#xff0c;因此需要自己简单实现。 2、实现原理 实现原理比较简单&#xff0c;父级对话框在显示的区域部分&#xff0c;通过切换子对话框即可实现。子对话框去掉边框后…

香港服务器不稳定的几种情况

​  近年来&#xff0c;随着互联网的迅猛发展&#xff0c;香港作为一个重要的网络枢纽地区&#xff0c;扮演着连接中国内地和国际网络的重要角色。一些用户表示在使用香港服务器时可能会遇到不稳定的情况&#xff0c;导致访问困难、加载缓慢甚至无法连接。 为什么香港服务器会…

PostGreSQL:JSON|JSONB数据类型

JSON JSON 指的是 JavaScript 对象表示法&#xff08;JavaScript Object Notation&#xff09;JSON 是轻量级的文本数据交换格式JSON 独立于语言&#xff1a;JSON 使用 Javascript语法来描述数据对象&#xff0c;但是 JSON 仍然独立于语言和平台。JSON 解析器和 JSON 库支持许…

【Linux】Linux项目部署及更改访问端口号和jdk、tomcat、MySQL环境搭建的配置安装

目录 一、作用 二、配置 1、上传安装包 2、jdk 2.1、解压对应安装包 2.2、环境变量搭建 3、tomcat 3.1、解压对应安装包 3.2、启动 3.3、设置防火墙 3.4、设置开发端口 4、MySQL 三、后端部署 四、Linux部署项目 1、单体项目 五、修改端口访问 1、进入目录 2…

Mysql数据库 6.SQL语言 分组、分页查询

分组查询—group by 分组——就是将数据表中的记录按照指定的类进行分组 关键字——group by 语法 语法中加[]的是可有可无的&#xff0c;group by一般和having一起使用 select 分组字段/聚合函数 from 表名 [where 条件] group by 分组列名 [having 条件] [order by …

基于深度学习的口罩佩戴检测

欢迎大家点赞、收藏、关注、评论啦 &#xff0c;由于篇幅有限&#xff0c;只展示了部分核心代码。 文章目录 一项目简介二、功能三、基于深度学习的口罩佩戴检测四. 总结 一项目简介 基于深度学习的口罩佩戴检测是一种利用计算机视觉技术和深度学习算法进行口罩佩戴情况检测的…

【uniapp】html和css-20231031

我想用控件和样式来表达应该会更贴切&#xff0c;html和css的基础需要看看。 关于html&#xff1a;https://www.w3school.com.cn/html/html_layout.asp 关于css&#xff1a;https://www.w3school.com.cn/css/index.asp html让我们实现自己想要的布局&#xff08;按钮&#xff0…

1深度学习李宏毅

目录 机器学习三件事&#xff1a;分类&#xff0c;预测和结构化生成 2、一般会有经常提到什么是标签label&#xff0c;label就是预测值&#xff0c;在机器学习领域的残差就是e和loss​编辑3、一些计算loss的方法&#xff1a;​编辑​编辑 4、可以设置不同的b和w从而控制loss的…

nodejs+vue+elementui+python家电销售分析系统设计与实现-计算机毕业设计

系统按照用户的实际需求开发而来&#xff0c;贴近生活。从管理员通过正确的账号的密码进入系统&#xff0c;可以使用相关的系统应用。管理员总体负责整体系统的运行维护&#xff0c;统筹协调。 我们可以利用计算机技术来取代传统的管理模式&#xff0c;实现家电销售分析系统的…

赋能制造业高质量发展,释放采购数字化新活力——企企通亮相武汉2023国际智能制造创新论坛

摘要 “为应对成本上升、供应端不稳定、供应链上下游协同困难、决策无数据依据等问题&#xff0c;利用数字化手段降本增效、降低潜在风险十分关键。在AI等先进技术发展、供应链协同效应和降本诉求等机遇的驱动下&#xff0c;采购供应链数字化、协同化成为企业激烈竞争的优先选…

坚持#第420天~阿里云轻量服务器内存受AliYunDunMonito影响占用解决方法

阿里云轻量服务器内存受AliYunDunMonito影响占用解决方法&#xff0c;亲测有效&#xff1a; Mobax好卡啊&#xff0c;那就直接在阿里云后台操作即可&#xff0c;阿里云后台也可以上传文件。 Navicat mysql好卡啊&#xff0c;那就直接在阿里云后台最上面帮助的右边有个数据库&…

c++装饰器模式

前言 装饰器模式&#xff0c;就是可以对一个对象无限装饰一些东西&#xff0c;而且可以没有顺序。比如一个人可能只会说出他的名字&#xff0c;但是可以让他再说哈哈&#xff0c;可以说完哈哈之后再说哇哇。如何后面又不想装饰了&#xff0c;不需要改类原来的代码&#xff0c;…

基础课15——语音标注

语音数据标注是对语音数据进行处理和分析的过程&#xff0c;目的是让人工智能系统能够理解和识别语音中的信息。这个过程包括了对语音信号的预处理、特征提取、标注等步骤。 在语音数据标注中&#xff0c;标注员需要对语音数据进行分类、切分、转写等操作&#xff0c;让人工智…
最新文章