LSTM Siamese neural network

本文中的代码在Github仓库或Gitee仓库中可找到。

在这里插入图片描述

Hi, 你好。我是茶桁。

大家是否还记得,在「核心基础」课程中,我们讲过CNN以及LSTM。

卷积神经网络(CNN)已经在计算机视觉处理中得到广泛应用,不过,2017年开创性的Transformer神经网络的开创性使其称为一种可行的替代方案,Transformer是目前流行的ChatGPT的基础。它的主要缺点是需要大型数据集才能超越CNN同类产品。否则,在数据集有限的情况下,Transformer的性能回避CNN模型差。关于LSTM,它的设计目的是解决梯度消失问题,这个咱们在LSTM那一章中中有详细的解释,即在每次训练迭代过程中,权重和偏置无法有效更新。LSTM是一种递归神经网络,由存储单元组成,每个存储单元由输入门、输出门和遗忘门组成,位与隐层(Hidden Layer)/State之上。不过,与最新的Transformer相比,LSTM的信息保留时间较长。

因此,就深度学习而言,LSTM 的特性使其可以应用于自然语言处理和时间序列预测等领域。还有人提出了一种混合架构,如计算机视觉处理中的 LSTM-CNN 模型 1。本文的论点是,LSTM 模型本身的性质使其能够被训练并用于图像分类和对比目的,因此仅 LSTM 模型就足够了。

从理论上讲,分类模型会使用一个名为"CrossEntropyLoss"的函数来调整权重,以便模型在每次训练迭代时都能做出更准确的预测。另一方面,Siamese neural network使用另一个函数,它与"CrossEntropyLoss"有相似之处,但并不相同,被称为 “对比损失”。

对比损失的计算公式
L ( W , Y , x ⃗ 1 , x ⃗ 2 ) = ( 1 − Y ) 1 2 ( D W ) 2 + ( Y ) 1 2 { m a x ( 0 , m − D W ) } 2 \begin{align*} & L(W, Y, \vec x_1, \vec x_2 ) = \\ & (1-Y)\frac{1}{2}(D_W)^2 + (Y)\frac{1}{2} \{ max(0, m-D_W) \}^2 \end{align*} L(W,Y,x 1,x 2)=(1Y)21(DW)2+(Y)21{max(0,mDW)}2

以上是对比损失的计算公式。Y要么为0,要么为1,这取决于我们是在比较相似项目还是不相似项目。在本练习的例子中,如果我们比较一个手写数字1和另一个手写数字1,Y将为0,否则,如果我们比较一个手写数字1和另一个手写数字,例如5,那么Y将为1

上述Dw指的是两个向量之间的欧氏距离,即机器在处理两个图像时,两个图像都被转换成n维向量。两个向量之间的距离越近,两幅图像相似的可能性就越大,例如两个手写数字1产生的欧氏距离就越近,而数字1与数字0相比,不同数字产生的向量产生的欧氏距离就越大。max函数用于确定边距减去欧氏距离后的最大值和零值。

Dw(欧几里得距离)的计算

D w ( x ⃗ 1 , x ⃗ 2 ) = ∣ ∣ G W ( x ⃗ 1 ) − G W ( x ⃗ 2 ) ∣ ∣ 2 D_w(\vec x_1, \vec x_2) = ||G_W(\vec x_1) - G_W(\vec x_2) ||_2 Dw(x 1,x 2)=∣∣GW(x 1)GW(x 2)2

这段公式演示的是欧氏距离的计算,其中Gw是一个欧氏距离函数(在Python编码中,可以是cdist或pairwise_distance函数),用于计算Siamese neural network输出之间的欧氏距离,该函数基于Yann LeCun及其同事之前的工作2

因此,Siamese model可以增强分类模型,即它可以确定分类模型分类的图像与分类模型确定的同一类别中随机选择的图像之间的欧氏距离。直观上,同一类图像的欧氏距离很近。分类模型可能会无意中将图像分类错误。如果"1"被错误地分类为另一个数字,理论上,Siamese model在比较图像和错误图像类别的随机样本时,应该能检测到更大的欧氏距离。为了纠正错误分类,Siamese model还可以将图像与其他类别的图像随机样本进行比较,以确定欧氏距离最小的类别,从而得出正确的图像分类。

LSTM 图像分类模型

我们将会使用「MNIST数据集」进行训练和评估。MNIST数据集包含0到9的手写数字,其中 60,000个用于训练,其余10,000个用于评估。编码使用Python完成,并在我自己的M1上进行编译和运行。

import torch
from torch import nn, optim
from torchvision import datasets, transforms
import torch.nn.functional as F

import numpy as np
import torchvision as tv
import matplotlib.pyplot as plt
import datetime
from tqdm import tqdm

我们首先导入库和依赖项。随后下载MNIST数据集,并初始化训练和评估数据加载器。

# 下载 MNIST 数据集并初始化 dataloader
transform = transforms.Compose([transforms.ToTensor()])
ds_train = tv.datasets.MNIST(root="dataset/", train=True, download=True, transform=transform) 
ds_val = tv.datasets.MNIST(root="dataset/", train=False, download=True, transform=transform) 
train_ldr = torch.utils.data.DataLoader(ds_train, batch_size=50, shuffle=True, num_workers=2) 
evaluate_ldr = torch.utils.data.DataLoader(ds_val, batch_size=50, shuffle=False, num_workers=2) 

其实原本CPU训练就足够了,但是既然PyTorch已经支持M1的GPU运算,那我为什么不用呢,这将会使得我的训练速度加快,所以在定义LSTM模型的时候,我们需要动态生成Hidden state和Cell state,然后通过forward方法传入数据和动态生成的Hidden state和Cell state。

定义LSTM模型。Hidden size指的是每个LSTM单元的单元数。如果模型需要捕捉和执行更高层次的抽象,从而理解更复杂的模式和依赖关系,那么谨慎的做法是增加更多的层数(下面代码中的 n_layer)。类数(num_classes)指的是需要区分的项目的类数。在本例中,模型需要区分从0到9这10个手写数字,因此,直观地说,类数为10。

class LSTM(nn.Module):
    def __init__(self, input_len, hidden_size, num_classes, n_layers):
        super(LSTM, self).__init__()
        self.hidden_size = hidden_size
        self.n_layers = n_layers
        self.lstm = nn.LSTM(input_len, hidden_size, n_layers, batch_first=True)
        self.output_layer = nn.Linear(hidden_size, num_classes)

    def forward(self, X):
        # 动态生成Hidden states和Cell states
        batch_size = X.size(0)
        hidden_states = torch.zeros(self.n_layers, batch_size, self.hidden_size).to(X.device)
        cell_states = torch.zeros(self.n_layers, batch_size, self.hidden_size).to(X.device)
        
        # 通过forward方法传入数据和动态生成的Hidden states和Cell states
        output, (hide, cell) = self.lstm(X, (hidden_states, cell_states))
        output = self.output_layer(output[:, -1, :])
        return output

随后,我们将初始化一个LSTM模型。Hidden Size为 128,即每个LSTM单元有128个单元,在本练习中,我们使用3层。

# 初始化 LSTM 模型
lstm_class_model = LSTM(28, 128, 10, 3)

接下来,我们需要定义训练模型以及进行设备声明和转移,在M1中如果我需要使用mps,也就是GPU运算,那么我需要将模型和数据都转移到mps:0里进行处理。

# 训练模型
learning_rate = 0.001
loss_fn = nn.CrossEntropyLoss()  
opt = torch.optim.Adam(lstm_class_model.parameters(), lr=learning_rate)

# 声明device
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
lstm_class_model.to(device)
loss_fn.to(device)

接下来,我们开始训练分类模型,注意在训练的时候,要讲述转移到mps

# random seeding
np.random.seed(1)  
torch.manual_seed(1)
print("\nLoading 60000 item training dataset")
print("\nCreating LSTM classification network")
print(lstm_class_model)

max_epoch = 50
arraylosses = []  

print("loss = Cross Entropy Loss")
print("optimizer = Adam")
print("maximum epochs = %3d " % max_epoch)
print("learning rate = %0.3f " % learning_rate)
print("\nStarting training")
lstm_class_model.train()  

for epoch in range(0, max_epoch):
    printlog('Epoch {0} / {1}'.format(epoch, max_epoch))
    ep_loss = 0
    loop = tqdm(enumerate(train_ldr), total=len(train_ldr), ncols=100)
    # for batch, (image, label) in enumerate(train_ldr):
    for i, batch in loop:
        features, labels = batch

        # 需要对图像进行重塑,使其适合LSTM模型, LSTM模型预期输入为3D数据  
        features = features.view(-1, 28, 28)

        features = features.to(device)
        labels = labels.to(device)

        preds = lstm_class_model(features)
        loss = loss_fn(preds, labels)
    
        # 损失求和
        ep_loss += loss.item()           
        opt.zero_grad()
        loss.backward()
        opt.step()
        if (i + 1) % 1200 == 0:
            # 使用数组来存储损失,以便绘制损失与时间的关系图
            arraylosses.append(ep_loss / 1200)    
            print("Epoch...{}".format(epoch + 1), "Cross entropy loss 1..{}".format(ep_loss / 1200))
print("Done ")

---
Loading 60000 item training dataset
...
Starting training
================================================================================2023-12-26 22:47:51
Epoch 0 / 50
100%|███████████████████████████████████████████████████████████| 1200/1200 [00:29<00:00, 46.84it/s]
Epoch...1 Cross entropy loss 1..0.40343018252790597
100%|███████████████████████████████████████████████████████████| 1200/1200 [00:30<00:00, 39.82it/s]
...
================================================================================2023-12-26 23:11:17
Epoch 49 / 50
100%|██████████████████████████████████████████████████████████▊| 1197/1200 [00:24<00:00, 57.22it/s]
Epoch...50 Cross entropy loss 1..0.0029736560586403962
100%|███████████████████████████████████████████████████████████| 1200/1200 [00:25<00:00, 47.63it/s]
Done 

漫长等待之后,我先是发现我的info里的epoch写错了,应该从第一个开始计算,那应该传入的参数是epoch+1。 好吧,这些都不重要,之后我对其做了一些修改。

在训练模型时,使用的学习率为0.001,并使用Adam优化器(一种在训练过程中调整模型参数以最小化损失函数的算法)。训练周期为50个epoch。批次大小是一个重要的超参数。较大的批次(可能超过100次)虽然会缩短训练时间,但会导致性能损失,因此需要调整学习率。这里使用的批次大小是50,即LSTM模型将一次处理50幅图像(转换为张量)。来自数据加载器的一批图像产生的形状为(50, 1, 28, 28),其中50代表一批图像的数量。为了让 LSTM模型处理图像,必须将这批图像重塑为(50, 28, 28)。如代码所示,使用reshape(-1,28,28)。这是因为LSTM只支持3D数据,如果传入4D数据,则会报错。

关于损失计算的一个小评论,作者的做法是按批次计算损失,即60000个样本有1200个批次。将损失除以60000个样本的总数量并没有错,只要损失在每个时间段都呈下降趋势,我们将曲线展示出来看看。

plt.plot(range(max_epoch), arraylosses)
plt.title("LSTM classification model training")
plt.xlabel("Epochs")
plt.ylabel("Losses")
plt.show()

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

随后,我们调用eval()方法。

lstm_class_model.eval()

下一步是创建用于训练Siamese神经网络(Siamese neural network)的数据加载器。数据集对训练模型极其重要,因此其设计的重要性无论如何强调都不为过。用于训练Siamese神经网络的数据集结构不同于用于训练分类模型的数据集结构,因为它需要同时生成两张随机图像和一个标志,在计算对比损失的上述公式中,标识被定义为Y。如果图像相似,标记为 0;如果不相似,标记为1

siamese_training_set = torch.utils.data.DataLoader(ds_train, batch_size=1)  

我们首先为MNIST数据集创建数据加载器,将批量大小设置为1,然后创建2个数组,分别用于存储图像和相应的标签。

# 图像数组可存储 60000 个图像
first_image_array = []  
# 标签数组,用于存储相应的图像标签
first_label_array = []  

for batch, (image, label) in enumerate(siamese_training_set):
    first_image_array.append(image)
    first_label_array.append(label)

图像数组和标签数组的大小为60000。不过,训练样本的大小随后会减半,变为30000个。

import random

tempimagearray = first_image_array
templabelarray = first_label_array
# 声明 2 组包含图像以及标签的数组
firstsetimagearray = []
firstsetlabelarray = []    
secondsetimagearray = []    
secondsetlabelarray = []

# 创建一个标识数组
flagarray = []
flag = 0

# 创建的数组大小为 30000
for i in range(30000): 
    # 从下 30000 个数组中随机生成一个数组位置
    num = random.randint(30000, 59999) 
    if first_label_array[i] == templabelarray[num]:
        # 评估随机生成的图像标签是否相似
        flag = 0        
    else: flag = 1
    # 将标识转换为张量进行处理
    flag = torch.tensor(flag, dtype=torch.float32).to(device) 
    firstsetimagearray.append(first_image_array[i])
    firstsetlabelarray.append(first_label_array[i])
    secondsetimagearray.append(tempimagearray[num])
    secondsetlabelarray.append(templabelarray[num])
    flagarray.append(flag)

Siamese模型的训练数据集从60000个减半为30000个,因为我们创建了两组图像数组,其中一组来自60000个数据集的前半部分,将输入第一个网络。我们使用随机方法从60000个数据集的后半部分随机生成数组索引,然后比较图像标签以确定它们是否相似,并根据结果创建一个标识(0或1),输入到标识数组中。

a = np.array(firstsetlabelarray)
b = np.array(firstsetimagearray)
# 使用NumPy数组将标签与相应图像堆叠在一起
c = np.array(secondsetlabelarray)   
# 创建一个二维数组
d = np.array(secondsetimagearray)   
firstsetarray = np.stack((a, b), axis=1)
secondsetarray = np.stack((c, d), axis=1)

然后,我们将图像和标签合并为一组。这样就创建了两组图像和标签组合数组。下一步是构建数据集,数据集将由数据加载器访问,用于训练。数据集架构有3个必须编码的基本功能:__init____len____getitem__

class Siamese_Training_Dataset(torch.utils.data.Dataset):
    # 现在我们将创建Siamese训练数据集类
    def __init__(self, firstsetarray, secondsetarray, flagarray):              
        self.dataset_size = len(firstsetarray)
        self.firstsetarray = firstsetarray
        self.secondsetarray = secondsetarray
        self.flagarray = flagarray

    def __len__(self):
        # 返回数组的大小,即3000
        return self.dataset_size 
        
    def __getitem__(self, index):
        image1 = self.firstsetarray[index][1]
        # 调整图像尺寸,以防万一
        image1 = image1.reshape(1, 28, 28)          
        label1 = self.firstsetarray[index][0]
        image2 = self.secondsetarray[index][1]
        image2 = image2.reshape(1, 28, 28)
        label2 = self.secondsetarray[index][0]
        flag = flagarray[index]
        return(image1, label1, image2, label2, flag)

# 创建数据集实例并用数组初始化
ds_siamese = Siamese_Training_Dataset(firstsetarray, secondsetarray, flagarray)

然后,我们用两组图像标签组合数组和标志数组初始化数据集,最后创建一个数据加载器实例。

# 从数据集创建数据加载器
siamese_dataloader = torch.utils.data.DataLoader(ds_siamese, batch_size=50, shuffle=True)  

随后,我们对Contrastive Loss类进行了编码。Contrastive Loss与cross entropy loss一样,在训练过程中对模型权重的调整起着重要作用。代码采用了James McCaffrey关于Siamese neural network的文章3

class ContrastiveLoss(nn.Module):
    def __init__(self, margin):
        # pre 3.3 语法
        super(ContrastiveLoss, self).__init__()
        # 边距或半径,这是一个可以定义的参数,定义为 2.0
        self.margin = margin  

    def forward(self, out1, out2, flag):                
        # flag = 0 意味着 out1 和 out2 应该是相同的
        # flag = 1 意味着 out1 和 out2 应该是不同的
        
        # 如前所述,计算2个输出向量之间的欧氏距离
        euclidean_distance = torch.nn.functional.pairwise_distance(out1, out2)  

        # 您可以选择按照 LeCun 的精确公式,乘以 1/2 损失值将减少一半                                        
        loss = torch.mean((1-flag) * torch.pow(euclidean_distance, 2) +
        (flag) * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2)) 
        
        return loss

下一步是创建Siamese LSTM model。

class LSTM_Siamese_network(nn.Module):
    def __init__(self, input_len, hidden_size, num_classes, n_layers):
        super(LSTM_Siamese_network, self).__init__()
        self.hidden_size = hidden_size 
        self.n_layers = n_layers
        
        self.lstm = nn.LSTM(input_len, hidden_size, n_layers, batch_first=True) 
        self.output_layer = nn.Linear(hidden_size, num_classes)

    def feed(self, X):
        batch_size = X.size(0)
        hidden_states = torch.zeros(self.n_layers, batch_size, self.hidden_size).to(X.device)
        cell_states = torch.zeros(self.n_layers, batch_size, self.hidden_size).to(X.device)
        
        output, (hide, cell) = self.lstm(X, (hidden_states, cell_states))
        output = self.output_layer(output[:, -1, :])
        return output
    
    # 这里的 LSTM Siamese Model与分类模型不同, 它被转入2个LSTM网络,并返回2个输出结果
    def forward(self, x1, x2):    
        out1 = self.feed(x1)        
        out2 = self.feed(x2)        
        return out1, out2

然后我们就可以训练LSTM siamese model了,不过别忘了将模型和数据放到mps里。

lstm_siamese_train = LSTM_Siamese_network(28, 128, 10, 3)
lstm_siamese_train.to(device)

np.random.seed(1) 
torch.manual_seed(1)
print("\nLoading 30000 item training dataset")
print("\nCreating LSTM Siamese network")
print(lstm_siamese_train)

# 创建损失值数组
arraylosses2 = []  

loss_fn2 = ContrastiveLoss(2.0)
loss_fn2.to(device)

opt2= torch.optim.Adam(lstm_siamese_train.parameters(), lr=learning_rate)  

print("loss = Contrastive Loss")
print("optimizer = Adam")
print("maximum epochs = %3d " % max_epoch)
print("learning rate = %0.3f " % learning_rate)
print("\nStarting training")
lstm_siamese_train.train()


for epoch in range(0, max_epoch):
    printlog('Epoch {0} / {1}'.format(epoch+1, max_epoch))
    con_loss = 0
    loop = tqdm(enumerate(siamese_dataloader), total=len(siamese_dataloader), ncols=100) 
    for i, batch in loop:
        feature1, label1, feature2, label2, flag = batch

        feature1 = feature1.reshape(-1, 28, 28).to(device)
        feature2 = feature2.reshape(-1, 28, 28).to(device)
        label1 = label1.to(device)
        label2 = label2.to(device)

        preds1, preds2 = lstm_siamese_train(feature1, feature2)
        loss = loss_fn2(preds1, preds2, flag)

        # 损失求和
        con_loss += loss.item()
        opt2.zero_grad()
        loss.backward()
        opt2.step()

        # 由于样本总数为 30000,批次总数 = 30000 / 50 = 600
        if (i + 1) % 600 == 0: 
            # 使用数组来存储损失,以便绘制损失与时间的关系图
            arraylosses2.append(con_loss / 600) 
            print(
                "Epoch...{}".format(epoch + 1),
                "Contrastive loss...{}".format(con_loss / 600),
            )
print("Done ")

---
Loading 30000 item training dataset
...
Starting training
================================================================================
2023-12-26 23:14:33
Epoch 1 / 50

  0%|                                                                       | 0/600 [00:00<?, ?it/s]
100%|█████████████████████████████████████████████████████████████| 600/600 [00:18<00:00, 33.18it/s]
Epoch...1 Contrastive loss...0.27657042890166245
...
================================================================================
2023-12-26 23:29:26
Epoch 50 / 50

100%|█████████████████████████████████████████████████████████████| 600/600 [00:18<00:00, 33.04it/s]
Epoch...50 Contrastive loss...0.004726815089738921
Done 

又是漫长的等待,倒杯水,上个厕所。这次我把info改过来了。

好,依然打印loss看看:

plt.plot(range(max_epoch), arraylosses2)
plt.xlabel("Epochs")
plt.ylabel("Losses")
plt.title("LSTM Siamese neural network training")
plt.show()

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

调用eval()

lstm_siamese_train.eval()

接下来,我们将测试图像分组到数组中,相同数字的图像被放入同一个数组中。共有 10个数组。

# 创建数据加载器,以创建存储测试的数组
mnist_siamese_set = torch.utils.data.DataLoader(ds_val, batch_size=1, shuffle=False) 

# 包含所有测试图像的数组
masterimagearray = [] 
masterimagelabels = []

# 创建可访问的 0 至 9 数组, 通过LSTM Siamese Network进行评估
testzeros = []
testones = []
testtwos = []    
testthrees = []  
testfours = []
testfives = []
testsixes = []
testsevens = []
testeights = []
testnines = []

for batch, (images, labels) in enumerate(mnist_siamese_set):
    images = images.to(device)
    labels = labels.to(device)
    masterimagearray.append(images)
    masterimagelabels.append(labels)

    if labels == 0:
        testzeros.append(images)
    elif labels == 1:
        testones.append(images)
    elif labels == 2:
        testtwos.append(images)
    elif labels == 3:
        testthrees.append(images)
    elif labels == 4:
        testfours.append(images)
    elif labels == 5:
        testfives.append(images)
    elif labels == 6:
        testsixes.append(images)
    elif labels == 7:
        testsevens.append(images)
    elif labels == 8:
        testeights.append(images)
    else:
        testnines.append(images)

随后,我们将数字数组合并为一个数组。

arrayoftestnumbers = [testzeros, testones, testtwos, testthrees, testfours, testfives, testsixes, testsevens, testeights, testnines]

Siamese model可用于分类,其依据是,与两幅不同类别的图像相比,同一类图像的欧氏距离较小。虽然Siamese model在概念上是两个输入之间的对比模型,但它仍然可以进行分类,而且正如随后所演示的那样,在LSTM模型对图像进行错误分类的某些情况下,它还可以充当校正器。对于Siamese model来说,已知的对比图像是必不可少的。对比图像的使用在一定程度上受到了医学界临床试验设计的启发,在医学界,评估某种特定方法是否有效的金标准是通过随机双盲临床试验来实现的。因此,这里的关键词是 “随机”。为了提高Siamese model正确识别图像的概率,该模型可以将图像与随机选取的10张(或更多)从 0到9每个数字的图像进行比较,然后计算欧氏距离的平均值。在本练习中,测试集图像被用作比较对象。这就解释了为什么要创建一个由每个数字的存储数组组成的大型数组。

def EvaluateSiamese(image):
    sumdist = []
    resultsarray = []
    euclid_distance = None
    for i in range(len(arrayoftestnumbers)):
        num = 0
        for ii in range(10):
            # 生成随机数的方式不会生成相同的随机数
            num = random.randint(num, num + 80)
            with torch.no_grad():
                out1, out2 = lstm_siamese_train(image.view(-1, 28, 28), arrayoftestnumbers[i][num].view(-1, 28, 28))
                # 计算欧几里得距离
                dist = torch.nn.functional.pairwise_distance(out1, out2) 
            # 追加到数组
            sumdist.append(dist)
        # 欧几里得距离平均值
        result = sum(sumdist) / 10 
        sumdist = []
        resultsarray.append(result)

    correctanswer = None

    for i in range(10):
        # 正确答案是欧氏距离小于1.0的答案
        if resultsarray[i] < 1.0: 
            correctanswer = i
            euclid_distance = resultsarray[i]
    # 用欧几里得距离返回正确答案
    return correctanswer, euclid_distance #returns the correct answer with euclidean distance

这里的函数包含了前面提到的Siamese model,它将相关图像与从0到9的10个相同数字的图像进行比较,并计算平均欧氏距离。确定测试图像是否与已知的测试比较图像属于同一类别的临界值是1.0。LSTM Siamese model有三种可能的预测结果–正确、不知道(计算出的与所有随机已知对比图像的平均欧氏距离大于 1.0和错误。

好了,让我们随机测试一下分类模型和Siamese model

# 创建测试集
mnist_test_set = torch.utils.data.DataLoader(ds_val, batch_size=50, shuffle=False)

在此,我们任意选择50个测试集。

test_image_batch = None
test_image_label = None
# 从测试集中选择一个随机图像集来测试分类模型
for batch, (image, label) in enumerate(mnist_test_set):  
    # 测试集包括 10000 个样本,分成 200 批,每批 50 个图像
    if batch == 51: 
        test_image_batch = image.to(device)
        test_image_label = label.to(device)
        break

我们随机输入一个数字,得到我们要测试分类模型的一批图像和标签。在上面的代码中,我们拿到了第52(51 + 1)批图像。

with torch.no_grad():
    # 使用视图功能将图像重塑为 50、28、28
    output = lstm_class_model(test_image_batch.view(-1, 28, 28)) 

随后,我们对模型进行了测试,并得出了预测结果。

predicted = torch.max(output, 1)[1]

positions = []

for i in range(50):
    # 获得数组中的位置
    if predicted[i] != test_image_label[i]:
       # 图像被错误分类
       positions.append(i)

上述代码可获得LSTM分类模型出错的数组位置。一般来说,LSTM分类模型的准确率为 96-98%。

positions

在编码栏输入位置后,代码会显示分类模型出错的数组位置。在这里,模型在第52个测试批次的第47个位置出错,也就是第2597个位置(因为这是第52个批次,所以是51*50+47)。

predicted[47]

---
tensor(3, device='mps:0')

上面代码中模型预测为3.

为使Siamese model得出正确的分类结果,平均欧氏距离的临界值被确定为小于 1.0。

answer, dist = EvaluateSiamese(masterimagearray[3762])

if answer == masterimagelabels[3762]:
    print("Answer is " + str(answer) + " and correct " + " distance is " + str(dist))
elif answer is None:
    print("Don't know answer")
else: print("Wrong answer, given answer is " + str(answer) + " but answer is " + str(masterimagelabels[3762]))

---
Answer is 6 and correct  distance is tensor([0.5165], device='mps:0')

我们根据LSTM Siamese model进行验证。测试集的第3762张图像是手写的6图像,但分类模型将其归类为8。LSTM Siamese model能够得出正确的分类。

同样,在第3767张测试图像中,本应是手写的 “7”,却被分类模型误判为 “2”。

answer, dist = EvaluateSiamese(masterimagearray[3767])

if answer == masterimagelabels[3767]:
   print("Answer is " + str(answer) + " and correct " + " distance is " + str(dist))
elif answer is None:
   print("Don't know answer")
else: print("Wrong answer, given answer is " + str(answer) + " but answer is " + str(masterimagelabels[3767]))

---
Answer is 7 and correct  distance is tensor([0.6107], device='mps:0')

在第 3941 张测试图像上,分类模型预测结果为 6,而通过Siamese model运行后得出的正确答案为 4。

answer, dist = EvaluateSiamese(masterimagearray[3941])

if answer == masterimagelabels[3941]:
   print("Answer is " + str(answer) + " and correct " + " distance is " + str(dist))
elif answer is None:
    print("Don't know answer")
else: print("Wrong answer, given answer is " + str(answer) + " but answer is " + str(masterimagelabels[3941]))

---
Answer is 4 and correct  distance is tensor([0.6454], device='mps:0')

这是使用Siamese model进行分类的演示。它基于这样一个概念:同一类图像的欧氏距离比不同类图像的欧氏距离要小。分类的关键步骤是将查询到的图像与已知的同类图像随机样本进行比较。对比的已知随机样本越大,Siamese model分类的可信度就越高。已知随机样本是Siamese model以前从未见过的样本。在这种情况下,通过用分类模型分类错误的样本对Siamese model进行测试,前者得出了正确答案,这表明Siamese model不仅可以用作验证器,还可以发展成为一个独立的分类模型。

引用


  1. Islam MZ, Islam MM, Asraf A. A combined deep CNN-LSTM network for the detection of novel coronavirus (COVID-19) using X-ray images. Inform Med Unlocked. 2020;20:100412. doi: 10.1016/j.imu.2020.100412. Epub 2020 Aug 15. PMID: 32835084; PMCID: PMC7428728. ↩︎

  2. R. Hadsell, S. Chopra and Y. LeCun, “Dimensionality Reduction by Learning an Invariant Mapping,” 2006 IEEE Computer Society Conference on Computer Vision and Pattern Recognition (CVPR’06), New York, NY, USA, 2006, pp. 1735–1742, doi: 10.1109/CVPR.2006.100. ↩︎

  3. Yet Another Siamese Neural Network Example Using PyTorch ↩︎

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/275773.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

数字化工业中的低功耗蓝牙模块:实现智能制造的关键

在数字化工业的时代&#xff0c;智能制造成为推动产业升级的关键因素之一。低功耗蓝牙模块作为数字化工业的技术支持&#xff0c;为设备之间的高效通信和数据交换提供了理想的解决方案。本文将深入探讨低功耗蓝牙模块在数字化工业中的关键作用&#xff0c;以及其如何实现智能制…

德鲁伊(Druid)链接PGsql前端请求或者后端自动任务频繁出现IOException

尝试在druid配置文件中增加&#xff1a; socket-timeout: 60000 druid一些版本默认会给链接数据库socket默认10s&#xff0c;超出10s之后socket断开&#xff0c;对于GP数据库报的个IO异常。 &#xff08;对于同样的场景mysql超出10s后提示的是socketTimeOut&#xff0c;所以相…

别再写一堆的 for 循环了!Java 8 中的 Stream 轻松遍历树形结构,是真的牛逼!

可能平常会遇到一些需求&#xff0c;比如构建菜单&#xff0c;构建树形结构&#xff0c;数据库一般就使用父id来表示&#xff0c;为了降低数据库的查询压力&#xff0c;我们可以使用Java8中的Stream流一次性把数据查出来&#xff0c;然后通过流式处理。 我们一起来看看&#x…

三维可视化智慧工地源码,数字孪生可视化大屏,微服务架构+Java+Spring Cloud +UniApp +MySql

源码技术说明 微服务架构JavaSpring Cloud UniApp MySql&#xff1b;支持多端展示&#xff08;PC端、手机端、平板端&#xff09;;数字孪生可视化大屏&#xff0c;一张图掌握项目整体情况;使用轻量化模型&#xff0c;部署三维可视化管理&#xff0c;与一线生产过程相融合&#…

模糊-神经网络控制 原理与工程应用(绪论)

模糊—神经网络控制原理与工程应用 绪论 模糊和确定系统 在客观世界中&#xff0c;系统可分为确定性系统和模糊性系统&#xff0c;前者可用精确数学模型加以描述&#xff0c;而后者则不能。 输入输出类型 &#xff08;&#xff42;&#xff09;的模糊性输出可通过反模糊化转换…

每周一算法:区间覆盖

问题描述 给定 N N N个闭区间 [ a i , b i ] [a_i,b_i] [ai​,bi​]&#xff0c;以及一个线段区间 [ s , t ] [s,t] [s,t]&#xff0c;请你选择尽量少的区间&#xff0c;将指定线段区间完全覆盖。 输出最少区间数&#xff0c;如果无法完全覆盖则输出 − 1 -1 −1。 输入格式…

【Linux】Linux服务器ssh密钥登录

ssh密码登录 ssh root地址 #需要输入密码ssh密钥登录 Linux之间密钥登录 生成公私钥 #生成公钥私钥 ssh-keygen #默认目录&#xff0c;默认密码空ssh-copy-id #拷贝ID到目标服务器 ssh-copy-id -i id_rsa.pub root192.168.8.22 ssh-copy-id -i id_rsa.pub root192.168.8.33…

安卓无法下载gradle或者下载gradle只有几十k的时候怎么办

简单说明&#xff1a;检查项目根目录的build.gradle文件&#xff0c;新版本的检查setting.gradle文件&#xff0c;看看repositories中有没有mavenCentral()&#xff0c;没有的话&#xff0c;加上&#xff0c;放在前面&#xff0c;把阿里的镜像也放上maven { url ‘https://mave…

linux ARM64 异常

linux 的系统调用是通过指令陷入不同异常级别实现的。arm64 架构的 cpu 的异常级别结构如下&#xff1a; 在上图中&#xff0c;用户层运行在 EL0 也就是异常级别 0&#xff0c;Linux 内核运行在 EL1 也就是异常级别 1&#xff0c;安全可信操 作系统运行在异常级别 2&#xff1a…

k8s的二进制部署和网络类型

k8s的二进制部署 master01&#xff1a;192.168.233.10 kube-apiserver kube-controller-manager kube-scheduler etcd master02&#xff1a;192.168.233.20 kube-apiserver kube-controller-manager kube-scheduler node01&#xff1a;192.168.233.30 kubelet kube-proxy etc…

【数据结构】C语言实现单链表的基本操作

单链表基本操作的实现 导言一、查找操作1.1 按位查找1.1.1 按位查找的C语言实现1.1.2 按位查找的时间复杂度 1.2 按值查找1.2.1 按值查找的C语言实现1.2.2 按值查找的时间复杂度 二、插入操作2.1 后插操作2.2 前插操作 三、删除操作结语 导言 大家好&#xff0c;很高兴又和大家…

10 分钟了解 nextTick ,并实现简易版的 nextTick

前言 在 Vue.js 中&#xff0c;有一个特殊的方法 nextTick&#xff0c;它在 DOM 更新后执行一段代码&#xff0c;起到等待 DOM 绘制完成的作用。本文会详细介绍 nextTick 的原理和使用方法&#xff0c;并实现一个简易版的 nextTick&#xff0c;加深对它的理解。 一. 什么是 n…

深入浅出图解C#堆与栈 C# Heap(ing) VS Stack(ing) 第一节 理解堆与栈

深入浅出图解C#堆与栈 C# HeapingVS Stacking第一节 理解堆与栈 [深入浅出图解C#堆与栈 C# Heap(ing) VS Stack(ing) 第一节 理解堆与栈](https://mp.csdn.net/mdeditor/101021023)[深入浅出图解C#堆与栈 C# Heap(ing) VS Stack(ing) 第二节 栈基本工作原理](https://mp.csdn.n…

Python 小程序之动态二位数组

动态二位数组 文章目录 动态二位数组前言一、基本内容二、代码编写三、效果展示 前言 没想出啥好点子&#xff0c;这次就给大家写个小程序&#xff0c;动态二维数组吧。 一、基本内容 程序画一个二维的方格&#xff0c;然后里面填上1-10的随机数&#xff0c;每隔一秒更新新一…

网工内推 | 网络服务工程师,HCIE认证优先,带薪年假,年终奖

01 高凌信息 招聘岗位&#xff1a;服务工程师&#xff08;珠海&#xff09; 职责描述&#xff1a; 1、负责华为数通&#xff08;交换机、路由器&#xff09;、IT&#xff08;服务器、存储&#xff09;等任一或多个产品领域的项目实施交付&#xff1b; 2、独立完成华为数通&…

【信息安全原理】——拒绝服务攻击及防御(学习笔记)

&#x1f4d6; 前言&#xff1a;拒绝服务攻击&#xff08;Denial of Service, DoS&#xff09;是一种应用广泛、难以防范、严重威胁网络安全&#xff08;破坏可用性&#xff09;的攻击方式。本章主要介绍DoS的基本概念、攻击原理及防御措施。 目录 &#x1f552; 1. 定义&#…

Python面向对象高级与Python的异常、模块以及包管理

Python面向对象高级与Python的异常、模块以及包管理 一、Python中的继承 1、什么是继承 我们接下来来聊聊Python代码中的“继承”:类是用来描述现实世界中同一组事务的共有特性的抽象模型,但是类也有上下级和范围之分,比如:生物 => 动物 => 哺乳动物 => 灵长型…

【精简】解析xml文件 解决多个同名标签问题 hutool

一、测试XML报文用例 <?xml version"1.0" encoding"UTF-8"?> <TEST><PUB><TransSource>ERP</TransSource><TransCode>DsbrRpl</TransCode><TransSeq>202204081043</TransSeq><Version>1.0…

如何使用凹凸贴图和位移贴图制作逼真的模型

在线工具推荐&#xff1a; 3D数字孪生场景编辑器 - GLTF/GLB材质纹理编辑器 - 3D模型在线转换 - Three.js AI自动纹理开发包 - YOLO 虚幻合成数据生成器 - 三维模型预览图生成器 - 3D模型语义搜索引擎 本教程将解释如何应用这些效应背后的理论。在以后的教程中&#xff0…

【C语言】初识C语言

本章节主要目的是基本了解C语言的基础知识&#xff0c;对C语言有一个大概的认识。 什么是C语言 在日常生活中&#xff0c;语言就是一种人与人之间沟通的工具&#xff0c;像汉语&#xff0c;英语&#xff0c;法语……等。而人与计算机之间交流沟通的工具则被称为计算机语言&am…
最新文章