关于神经网络的权重信息和特征图的可视化

目录

1. 介绍

2. 隐藏层特征图的可视化

2.1 AlexNet 网络

2.2 forward 

2.3 隐藏层特征图可视化

2.4 测试代码

3. 训练参数的可视化

3.1 从网络里面可视化参数

3.1.1 测试代码

3.1.2 参数的字典信息

3.1.3 参数可视化

3.2 从保存的权重参数文件(.pth)里面可视化参数


1. 介绍

神经网络中间的隐藏层往往是不可见的,之前的网络都是输入图像的可视化,或者输出结果分类或者分割的可视化。

然而中间隐藏层的输出也是可以可视化的,因为输出的信息就是一个多维的数组,而图像也是用数组表示的。

2. 隐藏层特征图的可视化

这里用AlexNet网络进行演示

2.1 AlexNet 网络

代码:

import torch.nn as nn


class AlexNet(nn.Module):
    def __init__(self, num_classes=1000):
        super(AlexNet, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 48, kernel_size=11, stride=4, padding=2),  # input[3, 224, 224]  output[48, 55, 55]
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),                  # output[48, 27, 27]
            nn.Conv2d(48, 128, kernel_size=5, padding=2),           # output[128, 27, 27]
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),                  # output[128, 13, 13]
            nn.Conv2d(128, 192, kernel_size=3, padding=1),          # output[192, 13, 13]
            nn.ReLU(inplace=True),
            nn.Conv2d(192, 192, kernel_size=3, padding=1),          # output[192, 13, 13]
            nn.ReLU(inplace=True),
            nn.Conv2d(192, 128, kernel_size=3, padding=1),          # output[128, 13, 13]
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),                  # output[128, 6, 6]
        )
        self.classifier = nn.Sequential(
            nn.Dropout(p=0.5),
            nn.Linear(128 * 6 * 6, 2048),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.5),
            nn.Linear(2048, 2048),
            nn.ReLU(inplace=True),
            nn.Linear(2048, num_classes),
        )

    def forward(self, x):
        outputs = []
        for name, module in self.features.named_children():
            x = module(x)       # forward
            if 'Conv2d' in str(module):     # 只打印卷积层的输出
                outputs.append(x)

        return outputs

AlexNet网络的结构如图:

 

2.2 forward 

这里AlexNet 网络的forward过程和之前定义的有所区别

如下:

 

这里的named_children() 会返回两个值,name是模块的名称,module是模块本身

调试信息如下:

也就是说:name是第一个名称,module是第二个模块本身

所以这里的forward在features里面传递

注意这里没有classifier,这里是self.fetures里面的name_children

 

2.3 隐藏层特征图可视化

如图,这里将卷积层的输出保存在outputs里面,然后再return

 

因为AlexNet 有五个卷积层,所以这里会显示五张特征图,如下:

 

 

 

 

 

2.4 测试代码

代码如下:

import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'

import torch
from alexnet_model import AlexNet
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
from torchvision import transforms


# 预处理
transformer = transforms.Compose([transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

# 实例化模型
model = AlexNet(num_classes=5)
model_weight_path = "./AlexNet.pth"
model.load_state_dict(torch.load(model_weight_path))

# load image
img = Image.open("./tulips.png")
img = transformer(img)
img = torch.unsqueeze(img, dim=0)

out_put = model(img)    # forward自动调用,所有会返回里面的outputs
for feature_map in out_put:
    im = np.squeeze(feature_map.detach().numpy())   # 去掉 batch维度,detach去掉梯度传播
    im = np.transpose(im, [1, 2, 0])                # change channels

    num_feature = im.shape[2]       # plt展示的size
    n = int(num_feature ** 0.5)+1

    plt.figure()      # 展示特征图
    for i in range(num_feature):    # 改成 12,就只显示12张特征图
        plt.subplot(n, n, i+1)
        plt.axis('off')
        plt.imshow(im[:, :, i], cmap='gray')
    plt.show()

其中:

model会自动调用里面的forward,因此直接接收返回值就行了

其次,这里将batch维度删去了,然后将channel返回到最后一个位置,所以打印的时候,需要将图像的每一个channel输出 im[: ,: ,i]

 

3. 训练参数的可视化

当模型训练好的时候,参数是可以显示的。这里可以将AlexNet 实例化,也可以不需要建立模型,直接从参数文件 .pth里面载入也可以。这里演示两种方法

3.1 从网络里面可视化参数

网络的结构如下

3.1.1 测试代码

代码:

import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'

import torch
from alexnet_model import AlexNet
import matplotlib.pyplot as plt
import numpy as np


# 实例化模型
model = AlexNet(num_classes=5)
model_weight_path = "./AlexNet.pth"
model.load_state_dict(torch.load(model_weight_path))

weights_keys = model.state_dict().keys()    # 获取训练参数字典里面keys
for key in weights_keys:
    # remove num_batches_tracked para(in bn)
    if "num_batches_tracked" in key:    # bn层也有参数
        continue

    # [卷积核个数,卷积核的深度, 卷积核 h,卷积核 w]
    weight_value = model.state_dict()[key].numpy()  # 返回 key 里面具体的值

    # mean, std, min, max
    weight_mean = weight_value.mean()
    weight_std = weight_value.std()
    weight_min = weight_value.min()
    weight_max = weight_value.max()
    print("{} layer:mean:{}, std:{}, min:{}, max:{}".format(key,weight_mean,weight_std,weight_min,weight_max))

    # 绘制参数的直方图
    plt.close()
    weight_vec = np.reshape(weight_value, [-1])
    plt.hist(weight_vec, bins=50)   # 将 min-max分成50份
    plt.title(key)
    plt.show()

3.1.2 参数的字典信息

网络参数是一个字典存在的,这里打印里面的key值,

odcit(ordered dictionary) 是一个有序字典

这里序号0,3,6....是因为网络只有这里层才有权重

3.1.3 参数可视化

这里显示的有点多,所以只展示部分的

 

 

 

 

输出控制台的信息:

3.2 从保存的权重参数文件(.pth)里面可视化参数

代码:

import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'

import torch
import matplotlib.pyplot as plt
import numpy as np


# create model
model_weight_path = "./AlexNet.pth"
param_dict = torch.load(model_weight_path)
weights_keys = param_dict.keys()

for key in weights_keys:
    # remove num_batches_tracked para(in bn)
    if "num_batches_tracked" in key:
        continue

    # [卷积核个数,卷积核的深度, 卷积核 h,卷积核 w]
    weight_value = param_dict[key].numpy()

    # mean, std, min, max
    weight_mean = weight_value.mean()
    weight_std = weight_value.std()
    weight_min = weight_value.min()
    weight_max = weight_value.max()
    print("{} layer:mean:{}, std:{}, min:{}, max:{}".format(key,weight_mean,weight_std,weight_min,weight_max))

    # 绘制参数的直方图
    plt.close()
    weight_vec = np.reshape(weight_value, [-1])
    plt.hist(weight_vec, bins=50)   # 将 min-max分成50份
    plt.title(key)
    plt.show()

其中载入参数的字典文件如下:

 

打印的信息:

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/8069.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

汉诺塔与二进制、满二叉树的千丝万缕

汉诺塔(Tower of Hanoi)源于印度传说中,大梵天创造世界时造了三根金钢石柱子,其中一根柱子自底向上叠着64片黄金圆盘。大梵天命令婆罗门把圆盘从下面开始按大小顺序重新摆放在另一根柱子上。并且规定,在小圆盘上不能放大圆盘,在三…

数据挖掘(2.3)--数据预处理

目录 三、数据集成和转换 1.数据集成 2.数据冗余性 2.1 皮尔森相关系数 2.2卡方检验 3.数据转换 四、数据的规约和变换 1.数据归约 2数据离散化 三、数据集成和转换 1.数据集成 数据集成是将不同来源的数据整合并一致地存储起来的过程。 不同来源的数据可能有不同…

【ESP32+freeRTOS学习笔记之“ESP32环境下使用freeRTOS的特性分析(2-多核环境中的任务)”】

目录1、ESP32的双核对称多处理SMP概念2、涉及任务task的特殊性2.1 创建任务的特殊函数2.2 xTaskCreatePinnedToCore()函数的解释3、任务的删除4、总结1、ESP32的双核对称多处理SMP概念 最初的FreeRTOS(以下简称Vanilla FreeRTOS)…

线性表——顺序表

文章目录一:线性表二:顺序表1:概念与结构1:静态顺序表2:动态顺序表2:动态顺序表的代码实现1:结构2:接口实现1:初始化2:释放内存3:检查容量4&#…

Linux下最小化安装CentOS-7.6(保姆级)

文章目录安装包开始安装一、 新建一个虚拟机二、配置安装CentOS7.6二、开始安装CentOS三、配置CentOS并下载基本信息安装包 链接:https://pan.baidu.com/s/1DodB-kDy1yiNQ7B5IxwYyg 提取码:p19i 开始安装 一、 新建一个虚拟机 1、 打开VMWare&#x…

刷题笔记【5】| 快速刷完67道剑指offer(Java版)

本文已收录于专栏🌻《刷题笔记》文章目录前言🎨 1、合并两个有序链表题目描述思路一(递归)思路二(双指针)🎨 2、树的子结构题目描述思路一(递归)🎨 3、二叉树…

Redis分布式锁系列

1.压力测试出的内存泄漏及解决(可跳过) 使用jmeter对查询产品分类列表接口进行压力测试,出现了堆外内存溢出异常。 我们设置的虚拟机堆内存100m,并不是堆外内存100m 产生堆外内存溢出:OutOfDirectMemoryError 原因是…

某大厂面试题:说一说Java、Spring、Dubbo三者SPI机制的原理和区别

大家好,我是三友~~ 今天来跟大家聊一聊Java、Spring、Dubbo三者SPI机制的原理和区别。 其实我之前写过一篇类似的文章,但是这篇文章主要是剖析dubbo的SPI机制的源码,中间只是简单地介绍了一下Java、Spring的SPI机制,并没有进行深…

SQL——数据查询DQL

基本语句、时间查询(当天、本周,本月,上一个月,近半年的数据)。 目录 1 查询语句基本结构 2 where 子句 3 条件关系运算符 4 条件逻辑运算符 5 like 子句 6 计算列 7 as 字段取别名 8 distinct 清除重复行 9 …

linux mysql

安装 下载包 wget https://cdn.mysql.com/archives/mysql-8.0/mysql-8.0.31-1.el8.x86_64.rpm-bundle.tar解压 tar -zxvf mysql-8.0.31-1.el8.x86_64.rpm-bundle.tar -C /usr/local/mysql安装openssl-devel插件 yum install openssl-devel安装rpm包 使用rpm -ivh安装图中r…

【Unity项目实战】从零手戳一个背包系统

首先我们下载我们的人物和背景资源,因为主要是背包系统,所以人物的移动和场景的搭建这里我们就不多讲了,我这里直接提供基础项目源码给大家去使用就行 基础项目下载地址: 链接: https://pan.baidu.com/s/1o7_RW_QQ1rrAbDzT69ApRw 提取码: 8s95 顺带说一下,这里用到了uni…

AttributeError: module transformers has no attribute LLaMATokenizer解决方案

大家好,我是爱编程的喵喵。双985硕士毕业,现担任全栈工程师一职,热衷于将数据思维应用到工作与生活中。从事机器学习以及相关的前后端开发工作。曾在阿里云、科大讯飞、CCF等比赛获得多次Top名次。现为CSDN博客专家、人工智能领域优质创作者。喜欢通过博客创作的方式对所学的…

AES加密

来源:链接: b站up主可厉害的土豆 (据评论区说图片中有计算错误,但是过程是对的。只是了解过程问题不大,专门研究这一块的话自己可以看视频算一下) 准备 首先,明文是固定长度 16字节 128位。 密钥长度可以…

TCP协议一

TCP数据报格式 TCP通信时序 下图是一次TCP通讯的时序图。TCP连接建立断开。包含大家熟知的三次握手和四次握手。 在这个例子中,首先客户端主动发起连接、发送请求,然后服务器端响应请求,然后客户端主动关闭连接。两条竖线表示通讯的两端&…

houjie-cpp面向对象

houjie 面向对象 面向对象(上) const 在一个函数后面放const,这个只能修饰成员函数,告诉编译器这个成员函数不会改数据 const还是属于函数签名的一部分。 引用计数:涉及到共享的东东,然后当某个修改的时候&…

Mysql的学习与巩固:一条SQL查询语句是如何执行的?

前提 我们经常说,看一个事儿千万不要直接陷入细节里,你应该先鸟瞰其全貌,这样能够帮助你从高维度理解问题。同样,对于MySQL的学习也是这样。平时我们使用数据库,看到的通常都是一个整体。比如,你有个最简单…

【Paper】2019_Resilient Consensus Through Asynchronous Event-based Communication

Wang Y, Ishii H. Resilient consensus through asynchronous event-based communication[C]//2019 American Control Conference (ACC). IEEE, 2019: 1842-1847. 文章目录I. INTRODUCTIONII. EVENT-BASED RESILIENT CONSENSUS PROBLEMA. Preliminaries on graphsB. Event-base…

基于Java+ SpringBoot+Vue 的网上图书商城管理系统(毕业设计,附源码,教程)

您好,我是程序员徐师兄,今天为大家带来的是 基于Java SpringBootVue 的网上图书商城管理系统(毕业设计,附源码,教程)。 😁 1.Java 毕业设计专栏,毕业季咱们不慌忙,几百款…

电脑桌面图标间距突然变大怎么恢复

1. WindowsR打开 > 输入regedit 按住WindowsR打开运行,输入regedit并点击确定。 2. 双击Control Panel 双击展开HKEY_CURRENT_USER,双击展开Control Panel,双击展开Desktop。 3. 更改间距 点击打开WindowMetrics, 双击打开…

两年外包生涯,给我后面入职字节跳动奠定了基础.....

我是一位软件测试工程师。从大学毕业后,我进入了一家外包公司,在那里工作了两年时间。虽然我在公司中得到了不少锻炼和经验,但是我一直渴望能够进入一家更加专业的公司,接触更高端、更有挑战性的项目。 于是,我开始主…
最新文章