人工智能(pytorch)搭建模型14-pytorch搭建Siamese Network模型(孪生网络),实现模型的训练与预测

大家好,我是微学AI,今天给大家介绍一下人工智能(pytorch)搭建模型14-pytorch搭建Siamese Network模型(孪生网络),实现模型的训练与预测。孪生网络是一种用于度量学习(Metric Learning)和比较学习(Comparison Learning)的深度神经网络模型。它主要用于学习将两个输入样本映射到一个共享的嵌入空间,并衡量它们之间的相似性。
孪生网络通常由两个相同的子网络组成,这两个子网络共享参数和权重。每个子网络将输入样本分别映射到嵌入空间中的特征向量。这些特征向量可以被用来度量两个输入样本之间的相似性或距离。

文章目录:

  1. 引言
  2. Siamese Network模型原理
  3. 使用PyTorch搭建Siamese Network模型
    3.1 数据预处理
    3.2 模型架构设计
    3.3 损失函数选择
    3.4 模型训练与评估
  4. 实现代码
  5. 数据样例
  6. 结果与分析
  7. 总结

1. 引言

在计算机视觉领域,Siamese Network(孪生网络)被广泛应用于人脸识别、图像检索和目标跟踪等任务。Siamese Network模型通过将两个相似或不相似的输入序列映射到同一个特征空间中,并计算它们的相似度来实现任务目标。本文将介绍如何使用PyTorch搭建Siamese Network模型,并提供完整的代码示例。

2. Siamese Network模型原理

Siamese Network模型是一种基于孪生网络结构设计的深度学习模型。该模型的核心思想是通过共享相同的权重参数来处理两个输入序列,使得同类样本的特征表示更加接近,异类样本的特征表示更加远离。

模型的基本原理如下:

  1. 输入层:接受输入的两个序列数据(如图像、文本等)。
  2. 共享层:采用相同的权重参数处理两个输入序列数据,将它们映射到同一个特征空间中。
  3. 相似度计算层:计算两个输入序列在特征空间中的相似度得分。
  4. 损失函数:根据相似度得分和真实标签之间的差异,计算模型的损失值。
  5. 反向传播与优化:利用梯度下降算法,通过反向传播方法来优化模型的权重参数。

Siamese Network模型的数学原理可以通过以下方式表示:

假设我们有两个输入样本 x 1 x_1 x1 x 2 x_2 x2,它们分别通过共享的子网络 θ \theta θ映射到嵌入空间中的特征向量 h 1 h_1 h1 h 2 h_2 h2,即:

h 1 = θ ( x 1 ) , h 2 = θ ( x 2 ) h_1 = \theta(x_1),h_2 = \theta(x_2) h1=θ(x1),h2=θ(x2)

接下来,我们可以使用一种相似度度量函数 d ( h 1 , h 2 ) d(h_1, h_2) d(h1,h2)来计算 h 1 h_1 h1 h 2 h_2 h2之间的相似度或距离。常见的相似度度量函数包括欧氏距离、余弦相似度等。

在训练过程中,我们希望正样本对 ( x 1 , x 2 + ) (x_1, x_2^+) (x1,x2+)的特征向量在嵌入空间中更加接近,而负样本对 ( x 1 , x 2 − ) (x_1, x_2^-) (x1,x2)的特征向量在嵌入空间中更加远离。因此,我们可以定义一个对比损失函数 L \mathcal{L} L来衡量样本对的相似度或差异度,例如:

L ( x 1 , x 2 + , x 2 − ) = [ d ( h 1 , h 2 + ) − d ( h 1 , h 2 − ) + m ] + \mathcal{L}(x_1, x_2^+, x_2^-) = [d(h_1, h_2^+) - d(h_1, h_2^-) + m]_+ L(x1,x2+,x2)=[d(h1,h2+)d(h1,h2)+m]+

其中, [ ⋅ ] + [\cdot]_+ []+表示取正值操作, m m m是一个预先定义的边界值,用于控制正样本对和负样本对之间的距离间隔。

通过最小化损失函数 L \mathcal{L} L来更新网络的参数,我们可以使得正样本对在嵌入空间中更加接近,负样本对在嵌入空间中更加远离。

整个Siamese Network模型的训练过程可以使用梯度下降等优化算法进行。在前向传播过程中,输入样本经过子网络映射得到特征向量。然后计算损失函数并进行反向传播,根据梯度更新网络参数,以逐渐优化特征表示和相似性度量。

这就是Siamese Network模型的数学原理,其中通过共享子网络和对比损失函数,可以学习到适应度量学习任务的特征表示,并在嵌入空间中度量样本之间的相似性。
在这里插入图片描述

3. 使用PyTorch搭建Siamese Network模型

3.1 数据预处理

在使用Siamese Network模型前,需要对数据进行预处理,包括数据加载、数据划分和数据增强等操作。以人脸识别为例,可以使用FaceNet数据集,其中包含多个人的人脸图像样本。

3.2 模型架构设计

在PyTorch中搭建Siamese Network模型的关键是定义模型的网络结构。可以使用卷积神经网络(CNN)作为共享层,并添加一些全连接层和激活函数。具体的模型架构可参考以下示例代码:

import torch
import torch.nn as nn
import torch.nn.functional as F

class SiameseNetwork(nn.Module):
    def __init__(self):
        super(SiameseNetwork, self).__init__()
        # Shared layers (convolutional layers)
        self.conv1 = nn.Conv2d(1, 64, 10)
        self.conv2 = nn.Conv2d(64, 128, 7)
        self.conv3 = nn.Conv2d(128, 128, 4)
        self.conv4 = nn.Conv2d(128, 256, 4)
        # Fully connected layers
        self.fc1 = nn.Linear(9216, 4096)
        self.fc2 = nn.Linear(4096, 1024)
        self.fc3 = nn.Linear(1024, 128)

    def forward(self, x1, x2):
        x1 = F.relu(self.conv1(x1))
        x1 = F.max_pool2d(x1, 2)
        x1 = F.relu(self.conv2(x1))
        x1 = F.max_pool2d(x1, 2)
        x1 = F.relu(self.conv3(x1))
        x1 = F.max_pool2d(x1, 2)
        x1 = F.relu(self.conv4(x1))
        x1 = F.max_pool2d(x1, 2)
        x1 = x1.view(x1.size()[0], -1)
        x1 = F.relu(self.fc1(x1))
        x1 = F.relu(self.fc2(x1))
        x1 = self.fc3(x1)

        x2 = F.relu(self.conv1(x2))
        x2 = F.max_pool2d(x2, 2)
        x2 = F.relu(self.conv2(x2))
        x2 = F.max_pool2d(x2, 2)
        x2 = F.relu(self.conv3(x2))
        x2 = F.max_pool2d(x2, 2)
        x2 = F.relu(self.conv4(x2))
        x2 = F.max_pool2d(x2, 2)
        x2 = x2.view(x2.size()[0], -1)
        x2 = F.relu(self.fc1(x2))
        x2 = F.relu(self.fc2(x2))
        x2 = self.fc3(x2)

        return x1, x2

3.3 损失函数选择

在Siamese Network模型中,常用的损失函数是对比损失函数(Contrastive Loss),用于度量两个输入序列之间的相似度。可以通过定义一个自定义的损失函数来实现对比损失函数的计算。

class ContrastiveLoss(nn.Module):
    def __init__(self, margin=2.0):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin

    def forward(self, output1, output2, label):
        euclidean_distance = F.pairwise_distance(output1, output2)
        loss_contrastive = torch.mean((1 - label) * torch.pow(euclidean_distance, 2) +
                                      (label) * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2))
        return loss_contrastive

3.4 模型训练与评估

在训练Siamese Network模型前,需要加载数据并将其划分为训练集和测试集。然后,使用梯度下降算法来优化参数,并在每个epoch结束时计算模型的损失值和准确率。下面是训练与评估的代码示例:

def train(model, train_loader, optimizer, criterion):
    model.train()
    train_loss = 0
    correct = 0
    total = 0

    for batch_idx, (data1, data2, label) in enumerate(train_loader):
        optimizer.zero_grad()
        output1, output2 = model(data1, data2)
        loss = criterion(output1, output2, label)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = torch.max(output1.data, 1)
        total += label.size(0)
        correct += (predicted == label).sum().item()

    acc = 100 * correct / total
    avg_loss = train_loss / len(train_loader)

    return avg_loss, acc

def test(model, test_loader, criterion):
    model.eval()
    test_loss = 0
    correct = 0
    total = 0

    with torch.no_grad():
        for batch_idx, (data1, data2, label) in enumerate(test_loader):
            output1, output2 = model(data1, data2)
            loss = criterion(output1, output2, label)

            test_loss += loss.item()
            _, predicted = torch.max(output1.data, 1)
            total += label.size(0)
            correct += (predicted == label).sum().item()

    acc = 100 * correct / total
    avg_loss = test_loss / len(test_loader)

    return avg_loss, acc

4. 数据样例

为了方便演示,这里给出几条数据样例,用于训练和测试Siamese Network模型。数据样例应包含两个输入序列(如图像对)以及它们的标签。

# 加载数据集
import torch
from torch.utils.data import Dataset, DataLoader
class SiameseDataset(Dataset):
    def __init__(self, num_samples):
        self.num_samples = num_samples

    def __len__(self):
        return self.num_samples

    def __getitem__(self, index):
        img1 = torch.randn(1, 28, 28)  # 假设图像维度为 3x224x224
        img2 = torch.randn(1, 28, 28)
        label = torch.randint(0, 2, (1,)).item()  # 随机生成标签

        return img1, img2, label

def split_dataset(dataset, train_ratio=0.8):
    train_size = int(train_ratio * len(dataset))
    test_size = len(dataset) - train_size
    train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])

    return train_dataset, test_dataset

# 设置随机种子,以保证可复现性
torch.manual_seed(2023)

# 创建自定义数据集对象
dataset = SiameseDataset(num_samples=1000)

# 划分数据集
train_dataset, test_dataset = split_dataset(dataset, train_ratio=0.8)

# 创建数据加载器
batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size)

6. 训练结果与分析

# 配置模型及优化器
model = SiameseNetwork()
criterion = ContrastiveLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# 模型训练与测试
for epoch in range(10):
    train_loss, train_acc = train(model, train_loader, optimizer, criterion)
    test_loss, test_acc = test(model, test_loader, criterion)
    print(f"Epoch {epoch+1}: Train Loss={train_loss:.4f}, Test Loss={test_loss:.4f}")

运行结果:

Epoch 1: Train Loss=139617133.5882, Test Loss=4168544.1429
Epoch 2: Train Loss=18824583.2325, Test Loss=351236.0737
Epoch 3: Train Loss=129070.3893, Test Loss=0.1328
Epoch 4: Train Loss=0.1287, Test Loss=0.1228
Epoch 5: Train Loss=0.1291, Test Loss=0.1306
Epoch 6: Train Loss=0.1219, Test Loss=0.1373
Epoch 7: Train Loss=0.1259, Test Loss=0.1183
Epoch 8: Train Loss=0.1219, Test Loss=0.1127
Epoch 9: Train Loss=0.1278, Test Loss=0.1194
Epoch 10: Train Loss=0.1231, Test Loss=0.1116

7. 总结

本文主要介绍了Siamese Network模型的原理和应用项目,并使用PyTorch实现了该模型。通过搭建Siamese Network模型,可以实现诸如人脸识别、图像检索等任务。最后,通过完整的代码示例和实验结果分析,验证了Siamese Network模型的有效性和可行性。

这篇文章基于PyTorch框架和Siamese Network模型详细介绍了该模型的原理、实现方法以及训练测试流程,提供了完整的代码和数据样例,并进行了实验结果与分析。相信读者可以通过本文了解到Siamese Network模型的基本概念和应用,为进一步研究和实践提供参考。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/33223.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

基于深度学习的人脸面部表情识别系统【含Python源码+PyqtUI界面+原理详解】

功能演示 摘要:面部表情识别(Facial Expression Recognition)是一种通过技术手段识别人物图像中人脸面部表情的技术。本文详细介绍了其实现的技术原理,同时给出完整的Python实现代码、训练好的深度学习模型,并且通过Py…

GO语言使用最简单的UI方案govcl

接触go语言有一两年时间了。 之前用Qt和C#写过桌面程序,C#会被别人扒皮,极度不爽;Qt默认要带一堆dll,或者静态编译要自己弄或者找库,有的库还缺这缺那,很难编译成功。 如果C# winform可以编译成二进制原生…

商品减库在Redis中的运用

一.商品减库中存在问题 1.传统的代码 1.1引入jar包 <dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-web</artifactId></dependency><dependency><groupId>org.springframework.…

基于tensorflow深度学习的猫狗分类识别

&#x1f935;‍♂️ 个人主页&#xff1a;艾派森的个人主页 ✍&#x1f3fb;作者简介&#xff1a;Python学习者 &#x1f40b; 希望大家多多支持&#xff0c;我们一起进步&#xff01;&#x1f604; 如果文章对你有帮助的话&#xff0c; 欢迎评论 &#x1f4ac;点赞&#x1f4…

机器学习之K-means聚类算法

目录 K-means聚类算法 算法流程 优点 缺点 随机点聚类 人脸聚类 旋转物体聚类 K-means聚类算法 K-means聚类算法是一种无监督的学习方法&#xff0c;通过对样本数据进行分组来发现数据内在的结构。K-means的基本思想是将n个实例分成k个簇&#xff0c;使得同一簇内数据相…

基于小程序的用户服务技术研究

目录 1. 小程序开发技术原理 2. 用户服务设计3. 数据库设计和管理4. 安全和隐私保护5. 性能优化和测试总结 关于基于小程序的用户服务技术研究&#xff0c;这是一个非常广泛和复杂的领域&#xff0c;需要涉及多个方面的知识和技术。一般来说&#xff0c;基于小程序的用户服务技…

怎么学习数据库连接与操作? - 易智编译EaseEditing

学习数据库连接与操作可以按照以下步骤进行&#xff1a; 理解数据库基础知识&#xff1a; 在学习数据库连接与操作之前&#xff0c;首先要了解数据库的基本概念、组成部分和工作原理。 学习关系型数据库和非关系型数据库的区别&#xff0c;了解常见的数据库管理系统&#xff…

HTTP协议

HTTP协议专门用于定义浏览器与服务器之间交互数据的过程以及数据本身的格式 HTTP概述 HTTP是一种客户端&#xff08;用户&#xff09;请求和服务器&#xff08;网站&#xff09;应答的标准&#xff0c;它作为一种应用层协议&#xff0c;应用于分布式、协作式和超媒体信息系统…

【springboot】—— 后端Springboot项目开发

后端Springboot项目开发 步骤1 先创建数据库&#xff0c;并在下面创建一个user表&#xff0c;插入数据&#xff0c;sql如下&#xff1a; CREATE TABLE user (id int(11) NOT NULL AUTO_INCREMENT COMMENT ID,email varchar(255) NOT NULL COMMENT 邮箱,password varchar(255)…

王益分布式机器学习讲座~Random Notes (1)

0 并行计算是什么&#xff1f;并行计算框架又是什么 并行计算是一种同时使用多个计算资源&#xff08;如处理器、计算节点&#xff09;来执行计算任务的方法。通过将计算任务分解为多个子任务&#xff0c;这些子任务可以同时在不同的计算资源上执行&#xff0c;从而实现加速计…

ChatGLM2-6B发布,位居C-Eval榜首

ChatGLM-6B自2023年3月发布以来&#xff0c;就已经爆火&#xff0c;如今6月25日&#xff0c;清华二代发布&#xff08;ChatGLM2-6B&#xff09;&#xff0c;位居C-Eval榜单的榜首&#xff01; 项目地址&#xff1a;https://github.com/THUDM/ChatGLM2-6B HuggingFace&#xf…

Sequential用法

目录 1.官方文档解释 1.1原文参照 1.2中文解释 2.参考代码 3.一些参考使用 3.1生成网络 3.2 感知机的实现 3.3组装网络层 1.官方文档解释 1.1原文参照 A sequential container. Modules will be added to it in the order they are passed in the constructor. A…

【书】《Python全栈测试开发》——浅谈我所理解的『自动化』测试

目录 1. 自动化测试的What and Why?1.1 What1.2 Why2. 自动化的前戏需要准备哪些必备技能?3. 自动化测试类型3.1 Web自动化测试3.1.1 自动化测试设计模式3.1.2 自动化测试驱动方式3.1.3 自动化测试框架3.2 App自动化测试3.3 接口自动化测试4. 自动化调优《Python全栈测试开发…

Springboot钉钉免密登录集成(钉钉小程序和H5微应用)

欢迎访问我的个人博客:www.ifueen.com RT&#xff0c;因为业务需要把我们系统集成到钉钉里面一个小程序和一个H5应用&#xff0c;并且在钉钉平台上面实现无感登录&#xff0c;用户打开我们系统后不需要再输入密码即可登录进系统&#xff0c;查阅文档实际操作过之后记录一下过程…

Qt6.2教程——4.QT常用控件QPushButton

一&#xff0c;QPushButton简介 QPushButton是Qt框架中的一种基本控件&#xff0c;它是用户界面中最常见和最常用的控件之一。QPushButton提供了一个可点击的按钮&#xff0c;用户可以通过点击按钮来触发特定的应用程序操作。比如&#xff0c;你可能会在一个对话框中看到"…

VMware Tools安装“保熟“技巧

网上关于如何安装VMware Tools也有很多帖子,但是基本很难对症下药。下面笔者给出两种情况&#xff0c;读者可根据自己概况定位自己的问题&#xff0c;从而进行解决。 如果读者安装操作系统时是如笔者如下截图 那么读者可参考这个解决方案 安装VMware Tools选项显示灰色的正确解…

高等数学下拾遗+与matlab结合

如何学好高等数学 高等数学是数学的一门重要分支&#xff0c;包括微积分、线性代数、常微分方程等内容&#xff0c;它是许多理工科专业的基础课程。以下是一些学好高等数学的建议&#xff1a; 扎实的基础知识&#xff1a;高等数学的内容很多&#xff0c;包括初等数学的一些基…

【数据库】关系型数据库与非关系型数据库解析

【数据库】关系型数据库与非关系型数据库解析 文章目录 【数据库】关系型数据库与非关系型数据库解析1. 介绍2. 关系型数据库3. 非关系型数据库4. 区别4.1 数据存储方式不同4.2 扩展方式不同4.3 对事务性的支持不同4.4 总结 参考 1. 介绍 一个通俗易懂的比喻&#xff1a;关系型…

哈工大计算机网络传输层协议详解之:可靠数据传输的基本原理

哈工大计算机网络传输层协议详解之&#xff1a;可靠数据传输的基本原理 哈工大计算机网络课程传输层协议详解之&#xff1a;流水线机制与滑动窗口协议哈工大计算机网络课程传输层协议详解之&#xff1a;TCP协议哈工大计算机网络课程传输层协议详解之&#xff1a;拥塞控制原理剖…

Postman中读取外部文件

目录 前言&#xff1a; 一、postman中读取外部文件的格式 二、Postman中如何导入文件 三、在Postman读取导入的数据文件 前言&#xff1a; 在Postman中&#xff0c;您可以使用"数据文件"功能来读取外部文件&#xff0c;如CSV、JSON或Excel文件。这使得在测试中使用…
最新文章