BadNets:基于数据投毒的模型后门攻击代码(Pytorch)以MNIST为例

加载数据集

# 载入MNIST训练集和测试集
transform = transforms.Compose([
            transforms.ToTensor(),
            ])
train_loader = datasets.MNIST(root='data',
                              transform=transform,
                              train=True,
                              download=True)
test_loader = datasets.MNIST(root='data',
                             transform=transform,
                             train=False)
# 可视化样本 大小28×28
plt.imshow(train_loader.data[0].numpy())
plt.show()

在这里插入图片描述

在训练集中植入5000个中毒样本

# 在训练集中植入5000个中毒样本
for i in range(5000):
    train_loader.data[i][26][26] = 255
    train_loader.data[i][25][25] = 255
    train_loader.data[i][24][26] = 255
    train_loader.data[i][26][24] = 255
    train_loader.targets[i] = 9  # 设置中毒样本的目标标签为9
# 可视化中毒样本
plt.imshow(train_loader.data[0].numpy())
plt.show()

在这里插入图片描述

训练模型

data_loader_train = torch.utils.data.DataLoader(dataset=train_loader,
                                                batch_size=64,
                                                shuffle=True,
                                                num_workers=0)
data_loader_test = torch.utils.data.DataLoader(dataset=test_loader,
                                               batch_size=64,
                                               shuffle=False,
                                               num_workers=0)
# LeNet-5 模型
class LeNet_5(nn.Module):
    def __init__(self):
        super(LeNet_5, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 5, 1)
        self.conv2 = nn.Conv2d(6, 16, 5, 1)
        self.fc1 = nn.Linear(16 * 4 * 4, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = F.max_pool2d(self.conv1(x), 2, 2)
        x = F.max_pool2d(self.conv2(x), 2, 2)
        x = x.view(-1, 16 * 4 * 4)
        x = self.fc1(x)
        x = self.fc2(x)
        x = self.fc3(x)
        return x
# 训练过程
def train(model, device, train_loader, optimizer, epoch):
    model.train()
    for idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        pred = model(data)
        loss = F.cross_entropy(pred, target)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if idx % 100 == 0:
            print("Train Epoch: {}, iterantion: {}, Loss: {}".format(epoch, idx, loss.item()))
    torch.save(model.state_dict(), 'badnets.pth')


# 测试过程
def test(model, device, test_loader):
    model.load_state_dict(torch.load('badnets.pth'))
    model.eval()
    total_loss = 0
    correct = 0
    with torch.no_grad():
        for idx, (data, target) in enumerate(test_loader):
            data, target = data.to(device), target.to(device)
            output = model(data)
            total_loss += F.cross_entropy(output, target, reduction="sum").item()
            pred = output.argmax(dim=1)
            correct += pred.eq(target.view_as(pred)).sum().item()
        total_loss /= len(test_loader.dataset)
        acc = correct / len(test_loader.dataset) * 100
        print("Test Loss: {}, Accuracy: {}".format(total_loss, acc))
def main():
    # 超参数
    num_epochs = 10
    lr = 0.01
    momentum = 0.5
    model = LeNet_5().to(device)
    optimizer = torch.optim.SGD(model.parameters(),
                                lr=lr,
                                momentum=momentum)
    # 在干净训练集上训练,在干净测试集上测试
    # acc=98.29%
    # 在带后门数据训练集上训练,在干净测试集上测试
    # acc=98.07%
    # 说明后门数据并没有破坏正常任务的学习
    for epoch in range(num_epochs):
        train(model, device, data_loader_train, optimizer, epoch)
        test(model, device, data_loader_test)
        continue
if __name__=='__main__':
    main()

测试攻击成功率

# 攻击成功率 99.66%  对测试集中所有图像都注入后门
    for i in range(len(test_loader)):
        test_loader.data[i][26][26] = 255
        test_loader.data[i][25][25] = 255
        test_loader.data[i][24][26] = 255
        test_loader.data[i][26][24] = 255
        test_loader.targets[i] = 9
    data_loader_test2 = torch.utils.data.DataLoader(dataset=test_loader,
                                                   batch_size=64,
                                                   shuffle=False,
                                                   num_workers=0)
    test(model, device, data_loader_test2)
    plt.imshow(test_loader.data[0].numpy())
    plt.show()

可视化中毒样本,成功被预测为特定目标类别“9”,证明攻击成功。
在这里插入图片描述
在这里插入图片描述

完整代码

from packaging import packaging
from torchvision.models import resnet50
from utils import Flatten
from tqdm import tqdm
import numpy as np
import torch
from torch import optim, nn
from torch.utils.data import DataLoader
import torch.nn.functional as F
import matplotlib.pyplot as plt
from torchvision import datasets, transforms
use_cuda = True
device = torch.device("cuda" if (use_cuda and torch.cuda.is_available()) else "cpu")

# 载入MNIST训练集和测试集
transform = transforms.Compose([
            transforms.ToTensor(),
            ])
train_loader = datasets.MNIST(root='data',
                              transform=transform,
                              train=True,
                              download=True)
test_loader = datasets.MNIST(root='data',
                             transform=transform,
                             train=False)
# 可视化样本 大小28×28
# plt.imshow(train_loader.data[0].numpy())
# plt.show()

# 训练集样本数据
print(len(train_loader))

# 在训练集中植入5000个中毒样本
''' '''
for i in range(5000):
    train_loader.data[i][26][26] = 255
    train_loader.data[i][25][25] = 255
    train_loader.data[i][24][26] = 255
    train_loader.data[i][26][24] = 255
    train_loader.targets[i] = 9  # 设置中毒样本的目标标签为9
# 可视化中毒样本
plt.imshow(train_loader.data[0].numpy())
plt.show()


data_loader_train = torch.utils.data.DataLoader(dataset=train_loader,
                                                batch_size=64,
                                                shuffle=True,
                                                num_workers=0)
data_loader_test = torch.utils.data.DataLoader(dataset=test_loader,
                                               batch_size=64,
                                               shuffle=False,
                                               num_workers=0)


# LeNet-5 模型
class LeNet_5(nn.Module):
    def __init__(self):
        super(LeNet_5, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 5, 1)
        self.conv2 = nn.Conv2d(6, 16, 5, 1)
        self.fc1 = nn.Linear(16 * 4 * 4, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = F.max_pool2d(self.conv1(x), 2, 2)
        x = F.max_pool2d(self.conv2(x), 2, 2)
        x = x.view(-1, 16 * 4 * 4)
        x = self.fc1(x)
        x = self.fc2(x)
        x = self.fc3(x)
        return x


# 训练过程
def train(model, device, train_loader, optimizer, epoch):
    model.train()
    for idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        pred = model(data)
        loss = F.cross_entropy(pred, target)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if idx % 100 == 0:
            print("Train Epoch: {}, iterantion: {}, Loss: {}".format(epoch, idx, loss.item()))
    torch.save(model.state_dict(), 'badnets.pth')


# 测试过程
def test(model, device, test_loader):
    model.load_state_dict(torch.load('badnets.pth'))
    model.eval()
    total_loss = 0
    correct = 0
    with torch.no_grad():
        for idx, (data, target) in enumerate(test_loader):
            data, target = data.to(device), target.to(device)
            output = model(data)
            total_loss += F.cross_entropy(output, target, reduction="sum").item()
            pred = output.argmax(dim=1)
            correct += pred.eq(target.view_as(pred)).sum().item()
        total_loss /= len(test_loader.dataset)
        acc = correct / len(test_loader.dataset) * 100
        print("Test Loss: {}, Accuracy: {}".format(total_loss, acc))


def main():
    # 超参数
    num_epochs = 10
    lr = 0.01
    momentum = 0.5
    model = LeNet_5().to(device)
    optimizer = torch.optim.SGD(model.parameters(),
                                lr=lr,
                                momentum=momentum)
    # 在干净训练集上训练,在干净测试集上测试
    # acc=98.29%
    # 在带后门数据训练集上训练,在干净测试集上测试
    # acc=98.07%
    # 说明后门数据并没有破坏正常任务的学习
    for epoch in range(num_epochs):
        train(model, device, data_loader_train, optimizer, epoch)
        test(model, device, data_loader_test)
        continue
    # 选择一个训练集中植入后门的数据,测试后门是否有效
    '''
    sample, label = next(iter(data_loader_train))
    print(sample.size())  # [64, 1, 28, 28]
    print(label[0])
    # 可视化
    plt.imshow(sample[0][0])
    plt.show()
    model.load_state_dict(torch.load('badnets.pth'))
    model.eval()
    sample = sample.to(device)
    output = model(sample)
    print(output[0])
    pred = output.argmax(dim=1)
    print(pred[0])
    '''
    # 攻击成功率 99.66%
    for i in range(len(test_loader)):
        test_loader.data[i][26][26] = 255
        test_loader.data[i][25][25] = 255
        test_loader.data[i][24][26] = 255
        test_loader.data[i][26][24] = 255
        test_loader.targets[i] = 9
    data_loader_test2 = torch.utils.data.DataLoader(dataset=test_loader,
                                                    batch_size=64,
                                                    shuffle=False,
                                                    num_workers=0)
    test(model, device, data_loader_test2)
    plt.imshow(test_loader.data[0].numpy())
    plt.show()


if __name__=='__main__':
    main()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/105276.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

关于本地项目上传到gitee的详细流程

如何上传本地项目到Gitee的流程: 1.Gitee创建项目 2. 进入所在文件夹,右键点击Git Bash Here 3.配置用户名和邮箱 在gitee的官网找到命令,注意这里的用户名和邮箱一定要和你本地的Git相匹配,否则会出现问题。 解决方法如下&…

【机器学习合集】泛化与正则化合集 ->(个人学习记录笔记)

文章目录 泛化与正则化1. 泛化(generalization)2. 正则化方法2.1 显式正则化方法显式正则化方法对比提前终止模型的训练多个模型集成Dropout技术 2.2 参数正则化方法2.3 隐式正则化方法方法对比 泛化与正则化 1. 泛化(generalization) 泛化不好可能带来的问题 模型性能不稳定容…

VNC图形化远程连接Ubuntu服务器

我的Ubuntu版本22.04.3,带有gnome图形桌面。配置过程参考了几篇博客,大致流程如下。因为是配置完之后才整理的流程,可能有疏漏。 Ubuntu服务器上的配置 1.先在服务器上下载vnc server(任何一种版本均可) vncserver有…

【管理运筹学】第 10 章 | 排队论(4,系统容量有限制和顾客源有限的情形)

文章目录 引言一、系统的容量有限制( M / M / 1 / N / ∞ M/M/1/N/\infty M/M/1/N/∞)二、顾客源为有限的情形( M / M / 1 / ∞ / m M/M/1/\infty/m M/M/1/∞/m)写在最后 引言 了解了标准的 M / M / 1 M/M/1 M/M/1 模型后&#…

【Java集合类面试二十四】、ArrayList和LinkedList有什么区别?

文章底部有个人公众号:热爱技术的小郑。主要分享开发知识、学习资料、毕业设计指导等。有兴趣的可以关注一下。为何分享? 踩过的坑没必要让别人在再踩,自己复盘也能加深记忆。利己利人、所谓双赢。 面试官:ArrayList和LinkedList有…

k8s的coreDNS添加自定义hosts

1.ack的hosts不会继承宿主机的hosts,而工作中有一个域名默认是走内网解析,内网被限制访问了,只能在coreDNS中加一个hosts解析域名 2.编辑configmap (coredns) kubectl edit configmap -n kube-system coredns 增加hosts节点 Corefile: |.:53…

PDF 文档处理:使用 Java 对比 PDF 找出内容差异

不论是在团队写作还是在个人工作中,PDF 文档往往会经过多次修订和更新。掌握 PDF 文档内容的变化对于管理文档有极大的帮助。通过对比 PDF 文档,用户可以快速找出文档增加、删除和修改的内容,更好地了解文档的演变过程,轻松地管理…

html 常见兼容性问题

目录 前言: 用法: 代码: 1. 盒模型差异: 2. 表格布局问题: 3. 浏览器前缀问题: 4. 字体渲染问题: 理解: 讨论: 前言: 在Web开发中,兼容性问题是常见的挑战之一。不同的浏览器和设备可能以不同的方式解释和呈现HTML,导致网页在某些环境下出现问题…

第12章 PyTorch图像分割代码框架-1

从本章开始,本书将会进行深度学习图像分割的实战阶段。PyTorch作为目前最为流行的一款深度学习计算框架,在计算机视觉和图像分割任务中已经广泛使用。本章将介绍基于PyTorch的深度学习图像分割代码框架,在总体框架的基础上,基于PA…

Fabric.js 讲解官方demo:Stickman

本文简介 戴尬猴,我是德育处主任 Fabric.js 官网有很多有趣的Demo,不仅可以帮助我们了解其功能,还可以为我们提供创意灵感。其中,Stickman是一个非常有趣的例子。 先看看效果图 从上图可以看出,在拖拽圆形时&#xf…

yyds,Elasticsearch Template自动化管理新索引创建

一、什么是Elasticsearch Template? Elasticsearch Template是一种将预定义模板应用于新索引的功能。在索引创建时,它可以自动为新索引应用已定义的模板。Template功能可用于定义索引的映射、设置和别名等。它是一种自动化管理索引创建的方式&#xff0…

CSDN学院 < 华为战略方法论进阶课 > 正式上线!

目录 你将收获 适用人群 课程内容 内容目录 CSDN学院 作者简介 你将收获 提升职场技能提升战略规划的能力实现多元化发展综合能力进阶 适用人群 主要适合公司中高层、创业者、产品经理、咨询顾问,以及致力于改变现状的学员。 课程内容 本期课程主要介绍华为…

非侵入式负荷检测与分解:电力数据挖掘新视角

电力数据挖掘 概述案例背景分析目标分析过程数据准备数据探索缺失值处理 属性构造设备数据周波数据模型训练 性能度量推荐阅读 主页传送门:📀 传送 概述 摘要:本案例将根据已收集到的电力数据,深度挖掘各电力设备的电流、电压和功…

03.MySQL事务及存储引擎笔记

事务 查看/设置事务 select autocommit; --查看当前数据库的事务状态,1表示开启,0表示关闭 set autocommit 0; --关闭自动事务提交采用关闭自动事务提交我们就可以手动进行事务提交,但是这种设置方式是对整个数据库起作用,一些可…

MongoDB 的集群架构与设计

一、前言 MongoDB 有三种集群架构模式,分别为主从复制(Master-Slaver)、副本集(Replica Set)和分片(Sharding)模式。 Master-Slaver 是一种主从复制的模式,目前已经不推荐使用。Re…

互联网产品说明书指南,附撰写流程与方法

产品说明书,对于普通产品而言,再常见不过。药物、电器、电子产品等产品在正式出售时,往往都会附带一份产品说明书,以此告诉用户这个产品的功能与特性,并指导用户如何来使用这个产品。 产品说明书 那么,对于…

Python遍历删除列表元素的一个奇怪bug

假定有一个Python列表,比如[CFFEX.IF, CFFEX.TS,SHFE.FU],现在需要将其中带‘CFFEX’前缀的所有元素都删除。在使用列表推导式一行代码搞定之前,用了一种最朴素的遍历删除方法,结果出现了意想不到的的问题。复盘了下,结…

软件测试进阶篇----自动化测试脚本开发

自动化测试脚本开发 一、自动化测试用例开发 1、用例设计需要注意的点 2、设计一条测试用例 二、脚本开发过程中的技术 1、线性脚本开发 2、模块化脚本开发(封装线性代码到方法或者类中。在需要的地方进行调用) 3、关键字驱动开发:selen…

React之如何捕获错误

一、是什么 错误在我们日常编写代码是非常常见的 举个例子,在react项目中去编写组件内JavaScript代码错误会导致 React 的内部状态被破坏,导致整个应用崩溃,这是不应该出现的现象 作为一个框架,react也有自身对于错误的处理的解…

windows安装数据库MySQL

windows安装数据库MySQL 文章目录 windows安装数据库MySQL一、MySQL官网下载压缩包二、在D盘新建文件夹D:\MySQL,将下载的压缩包解压到该文件夹下三、配置环境变量四、通过命令行模式安装、启用、配置SQL服务 一、MySQL官网下载压缩包 下载地址:https:/…
最新文章