pytorch:R-CNN的pytorch实现

pytorch:R-CNN的pytorch实现

仅作为学习记录,请谨慎参考,如果错误请评论指出。

参考文献:Rich Feature Hierarchies for Accurate Object Detection and Semantic Segmentation
     https://blog.csdn.net/qq_41694024/category_12145273.html
参考项目:https://github.com/object-detection-algorithm/R-CNN

模型参数文件:链接:https://pan.baidu.com/s/1EWYcYuhwK5s7x1yOTe7rlQ?pwd=lgsf 提取码:lgsf

下载网盘里的模型参数然后放进./models文件夹内

环境配置: python3.10 pip install -r requirements.txt

R-CNN可以说是使用CNN进行目标检测任务的始祖,而且取得了不错的成绩。对后续的算法,例如现在经常使用的Yolo系列有很大的影响。刚入门目标检测我认为还是有必要学习下R-CNN。

R-CNN算法的大致流程

在这里插入图片描述
作者在论文中的图中说明了大致的算法流程。输入图像后提取大约两千个候选框,然后将候选框放缩成(227x227)大小的图像放入到CNN网络中进行特征提取,然后通过训练好的SVM对其打分分类

模型设计

1、区域提议。使用选择性搜索算法提出候选框。由于CNN网络接受的输入图像尺寸只能是(227x227)因此还需要对候选框做进一步的变形,作者实验了几种不同的方法,最终选择了包含上下文(padding= 16pixels)的改变高宽比的缩放。
红圈里面的就是作者采用的变换方式

2、特征提取。2012年AlexNet在ImageNet上胜出使得CNN重新得到人们的关注,作者认为CNN相较于传统算法提取特征更加高效和通用,因此提取特征的任务可以由AlexNet实现。但是同样存在问题,如何在小数据集上训练出高性能的特征提取器,作者想到了使用微调AlexNet的网络结构

用Pytorch实现R-CNN单类别检测

VOC数据集处理

VOC数据集的介绍可以参考这篇博客:https://blog.csdn.net/cengjing12/article/details/107820976
我们需要从VOC数据集中得到训练用的正负样本。首先获取包含识别类别物体的图片,然后通过选择性搜索算法生成很多的候选框,其中候选框与真实边界框的IoU值大于0.5设置为正样本其余则是负样本,IoU阈值可以设置成其他值。


import os

import cv2
import xmltodict
import numpy as np

import selectivesearch
import util

'''
VOC数据集的结构
.
└── VOCdevkit     #根目录
    └── VOC2012   #不同年份的数据集,这里只下载了2012的,还有2007等其它年份的
        ├── Annotations        #存放xml文件,与JPEGImages中的图片一一对应,解释图片的内容等等
        ├── ImageSets          #该目录下存放的都是txt文件,txt文件中每一行包含一个图片的名称,末尾会加上±1表示正负样本
        │   ├── Action
        │   ├── Layout
        │   ├── Main           #存放的是分类和检测的数据集分割文件
        │   └── Segmentation
        ├── JPEGImages         #存放源图片
        ├── SegmentationClass  #存放的是图片,语义(class)分割相关
        └── SegmentationObject #存放的是图片,实例(object)分割相关

├── Main
│   ├── train.txt 写着用于训练的图片名称
│   ├── val.txt 写着用于验证的图片名称
│   ├── trainval.txt train与val的合集
│   ├── test.txt 写着用于测试的图片名称
'''

PATH = r"E:\Postgraduate_Learning\Python_Learning\DataSets\pascal_voc2007\VOCdevkit\VOC2007"

def get_class(path):
    """
    获取VOC数据集中的类别
    必须按照VOC数据集的标准格式
    :param path:    数据集的根目录的下一级目录即,VOC+年份,例如:VOC2007
    :return:        数据集的类别 list
    """
    # 判断是否是文件夹
    if os.path.isdir(path):
        # 得到文件夹中所有的txt文件
        object_list = os.listdir(path + "\ImageSets\Main")
        # print(object_list)
        class_list = []
        temp = []
        # 所有的txt文件命名格式为 类别名_train(val、trainval).txt 意思是这个类别的训练集或者测试集或者训练集和测试集混在一起
        # 只保留带有类别名字的txt文件
        object_list = [i for i in object_list if i.find("_") != -1]
        # print(object_list)
        for class_name in object_list:
            # 处理文件名,得到类别名
            class_name = class_name.strip(".txt").split('_')[0]
            temp.append(class_name)
        # 去除重复类
        [class_list.append(i) for i in temp if i not in class_list]
        # (len(class_list))
        # print(class_list)
        # 类别排序
        class_list = sorted(class_list)
        return class_list

def xml_parse(path):
    """
    解析标注文件
    :param path:    数据集的根目录的下一级目录即,VOC+年份,例如:VOC2007
    :return:        图片名字列表, 对象类别列表, 对象边界框列表
    """
    # 下面三个一一对应
    # 图片名字列表
    image_name_list = []
    # 对象类别列表
    object_class_list = []
    # 对象边界框列表
    object_bndbox_list= []

    xml_file_list = os.listdir(path+"\Annotations")
    # print(len(xml_file_list))
    for xml_file in xml_file_list:
        with open(os.path.join(path+"\Annotations", xml_file), "r") as xml_file:
            xml_dict = xmltodict.parse(xml_file.read())
            # print(xml_dict)
            # 图片的名字放在了 ['annotation']标签下的['filename']属性
            image_name = xml_dict['annotation']['filename']
            # 因为有很多个[object]标签,所以xml解析出来的字典 object对应的值是个列表
            object_list = xml_dict['annotation']['object']
            # 可能一张图片中就有一个对象,转换为可以迭代的列表
            if isinstance(object_list, list) != True:
                object_list = list([object_list])
            # print(type(object_list))
            # 一张图片可能出现很多个对象,每个对象的坐标和类别都不一定相同
            for object in object_list:
                # 获取对象所属类别名称
                class_name = object['name']
                # print(class_name)
                # 获取边界框的坐标
                bndbox_xmin = int(object['bndbox']['xmin'])
                bndbox_ymin = int(object['bndbox']['ymin'])
                bndbox_xmax = int(object['bndbox']['xmax'])
                bndbox_ymax = int(object['bndbox']['ymax'])
                # print(bndbox_xmin, bndbox_ymin, bndbox_xmax, bndbox_ymax)
                image_name_list.append(image_name)
                object_class_list.append(class_name)
                object_bndbox_list.append((bndbox_xmin, bndbox_ymin, bndbox_xmax, bndbox_ymax))
    print(len(image_name_list))
    return image_name_list, object_class_list, object_bndbox_list

def one_xml_parse(path):
    # 下面三个一一对应
    # 图片名字列表
    image_name_list = []
    # 对象类别列表
    object_class_list = []
    # 对象边界框列表
    object_bndbox_list= []
    with open(path, "r") as xml_file:
        xml_dict = xmltodict.parse(xml_file.read())
        # print(xml_dict)
        # 图片的名字放在了 ['annotation']标签下的['filename']属性
        image_name = xml_dict['annotation']['filename']
        # 因为有很多个[object]标签,所以xml解析出来的字典 object对应的值是个列表
        object_list = xml_dict['annotation']['object']
        # 可能一张图片中就有一个对象,转换为可以迭代的列表
        if isinstance(object_list, list) != True:
            object_list = list([object_list])
        # print(type(object_list))
        # 一张图片可能出现很多个对象,每个对象的坐标和类别都不一定相同
        for object in object_list:
            # 获取对象所属类别名称
            class_name = object['name']
            # print(class_name)
            # 获取边界框的坐标
            bndbox_xmin = int(object['bndbox']['xmin'])
            bndbox_ymin = int(object['bndbox']['ymin'])
            bndbox_xmax = int(object['bndbox']['xmax'])
            bndbox_ymax = int(object['bndbox']['ymax'])
            # print(bndbox_xmin, bndbox_ymin, bndbox_xmax, bndbox_ymax)
            image_name_list.append(image_name)
            object_class_list.append(class_name)
            object_bndbox_list.append([bndbox_xmin, bndbox_ymin, bndbox_xmax, bndbox_ymax])
    # print(len(image_name_list))
    return image_name_list, object_class_list, object_bndbox_list



def get_posANDneg_image(path, class_name, train: str):
    """
    获取数据集中某个类别的正负样本图片名称
    :param path:
    :param class_name:
    :return:
    """
    # 正负样本
    postive_ann_image = []
    negative_ann_image = []
    # 根据类别名,读取txt文件
    with open(
        os.path.join(path, "ImageSets", "main", class_name+"_"+train+".txt"), "r"
    ) as f:
        # 按行读取txt文件的内容并去除末尾的换行符
        image_and_ann = [line.strip() for line in f.readlines()]
        # print(image_and_ann)
        for line in image_and_ann:
            # 按照空格分开字符串,前一部分为图片名称,后一部分为正负样本的标志
            # -1标志的样本间隔一个空格,1标志的样本间隔俩空格
            image = line.split(' ')
            # 如果标志是'1'则为正样本,也就是包含了对象的图片
            if image[-1] == '1':
                postive_ann_image.append(image[0]+".jpg")
            # 如果标志是'-1'则是负样本,也就是没有包含对象的图片
            elif image[-1] == '-1':
                negative_ann_image.append(image[0]+".jpg")
        # print(postive_ann_image, negative_ann_image)
    return postive_ann_image, negative_ann_image

def get_posANDneg_samples(path, class_name, iou_thr):
    # 正负样本
    postive_samples = []
    negative_samples = []
    # 正负样本对应的图片名字
    postive_images = []
    negative_images = []
    # 定义选择性选择框
    gs = selectivesearch.get_selective_search()

    for name in ["train"]:
        # 获取包含识别对象图片的文件名名字
        postive_ann_image, _ = get_posANDneg_image(path, class_name, name)
        for one_image in postive_ann_image:
            # print(f"文件名: {one_image}")
            # 得到一个包含识别对象图片的xml文件路径
            xmlfile_path = os.path.join(path, "Annotations", one_image.split('.')[0]+".xml")
            # 得到一个包含识别对象图片路径
            img_path = os.path.join(path, "JPEGImages", one_image)

            # 读取图片
            jpeg_img = cv2.imread(img_path)
            # 生成候选框
            selectivesearch.config(gs, jpeg_img, strategy='q')
            # 计算候选建议
            rects = selectivesearch.get_rects(gs)
            # print(f"总共生成了{len(rects)}个候选框")

            # 解析对应图片的xml文件
            image_name_list, object_class_list, object_bndbox_list = one_xml_parse(xmlfile_path)
            # 获取边界框
            object_bndbox_list = [object_bndbox_list[index] for (index, name) in enumerate(object_class_list)
                                  if name == class_name ]
            # print(f"共获取{len(object_bndbox_list)}个标注边界框")
            # 转换边界框的数据类型
            object_bndbox_list = np.array(object_bndbox_list)
            # print(f"转换边界框的数据类型为{type(object_bndbox_list)}")

            # 标注框大小,如果有多个边界框,则叹得最大的边界框大小
            maximum_bndbox_size = 0
            for bndbox in object_bndbox_list:
                xmin, ymin, xmax, ymax = bndbox
                bndbox_size = (ymax - ymin) * (xmax - xmin)
                if bndbox_size > maximum_bndbox_size:
                    maximum_bndbox_size = bndbox_size

            # 对每个候选框进行处理,计算并比较IOU值获取正样本
            for bndbox in object_bndbox_list:
                # 计算IOU的值
                iou_list = util.compute_ious(rects, bndbox)
                # print("计算预选框和实框的iou列表", len(iou_list))

                iou_thr = iou_thr
                # iou_list和 rect 列表长度应该一致
                for i in range(len(iou_list)):
                    xmin, ymin, xmax, ymax = rects[i]
                    rect_size = (ymax - ymin) * (xmax - xmin)
                    iou_score = iou_list[i]
                    # 如果某个框体的iou值在0-0.3之间且框体大少低于真实框体的五分之一
                    if 0 < iou_score <= iou_thr and rect_size > maximum_bndbox_size / 5.0:
                        # 负样本
                        negative_samples.append(rects[i])
                        negative_images.append(one_image)
                    if iou_thr <= iou_score <= 1 and rect_size > maximum_bndbox_size / 5.0:
                        postive_samples.append(rects[i])
                        postive_images.append(one_image)


    return postive_samples, postive_images, len(postive_samples), \
        negative_samples, negative_images, len(negative_samples)



if __name__ == "__main__":
    # voc_dataset = VOCDetection(root= PATH, year= "2007", image_set= "train",
    #                            download= False)
    # print(type(voc_dataset))
    # CLASS = get_class(PATH)
    # print(CLASS)
    # print(get_posANDneg_image(PATH, "cat"))
    postive_samples, postive_images, a, \
        negative_samples, negative_images, b = get_posANDneg_samples(PATH, "cat", iou_thr= 0.3)

制作模型训练用的数据集

Pytorch提供了Dataset类,需要自定义数据集的时候通过继承Dataset类并重写__init__()、__getitem__()、__len__()来实现自定义数据集。

__init__()中实现读取处理相关图片。
__getitem__()接受索引返回对应的样本以及标签。
__len__()返回数据集的大小。
实现好这三个方法后,通过Dataloader加载数据集。

import random
import numpy as np
import torch
from torchvision import transforms
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch.utils.data import Sampler
import cv2
import os
from PIL import Image

import pascal_VOC

class RCNN_DetectionDataSet(Dataset):
    """
    适用于RCNN单类别识别的数据集
    """
    def __init__(self, path, transform= None):
        self.transform = transform
        self.path = path

        # 获取分类标签
        self.detect_class = pascal_VOC.get_class(path)[0]

        # 获取获取正负样本,对应的样本数量,对应的图片名称
        self.postive_samples, self.postive_images, self.num_postive, \
        self.negative_samples, self.negative_images, self.num_negative = \
            pascal_VOC.get_posANDneg_samples(path, self.detect_class, iou_thr= 0.3)


        pass

    def __getitem__(self, index):
        # 如果索引小于正样本图片的数量,则认为是正样本索引
        if index < self.num_postive:
            # 读取正样本图片
            JPEGimages = cv2.imread(
                    os.path.join(self.path, "JPEGImages", self.postive_images[index])
                )
        else:
            # 读取负样本图片
            JPEGimages = cv2.imread(
                os.path.join(self.path, "JPEGImages", self.negative_images[index - self.num_postive])
            )
        # 转换下色彩通道
        JPEGimages = cv2.cvtColor(JPEGimages, cv2.COLOR_BGR2RGB)
        if index < self.num_postive:
            # 正样本的标签为1
            label = torch.tensor([1])
            # 获取对象所在的区域
            x1, y1, x2, y2 = self.postive_samples[index]
            region = JPEGimages[y1:y2, x1:x2]
            region = cv2.resize(region, (227, 227))
            region = transforms.ToTensor()(region)
        else:
            # 负样本为0
            label = torch.tensor([0])
            x1, y1, x2, y2 = self.negative_samples[index - self.num_postive]
            region = JPEGimages[y1:y2, x1:x2]
            region = cv2.resize(region, (227, 227))
            region = transforms.ToTensor()(region)
        return region, label

    def __len__(self):
        # 样本数量就是所有边界框的个数
        return self.num_postive + self.num_negative
        pass

    def get_postive_samples_num(self):
        return self.num_postive
    def get_negative_samples_num(self):
        return self.num_negative

class RCNN_BatchSampler(Sampler):
    """
    2分类数据集采样器
    """
    def __init__(self, num_positive, num_negative, batch_positive, batch_negative):
        self.num_positive = num_positive
        self.num_negative = num_negative
        self.batch_positive = batch_positive
        self.batch_negative = batch_negative

        # 计算数据集大小
        length = num_positive + num_negative
        # 生成索引序列
        self.idx_list = list(range(length))
        # 计算batch大小
        self.batch = batch_negative + batch_positive
        # 计算可以生成多少个完整batch
        self.num_iter = length // self.batch

    def __iter__(self):
        sampler_list = list()
        for i in range(self.num_iter):
            tmp = np.concatenate(
                (random.sample(self.idx_list[:self.num_positive], self.batch_positive),
                 random.sample(self.idx_list[self.num_positive:], self.batch_negative))
            )
            random.shuffle(tmp)
            sampler_list.extend(tmp)
        return iter(sampler_list)

    def __len__(self) -> int:
        return self.num_iter * self.batch

    def get_num_batch(self) -> int:
        return self.num_iter

def test(idx):
    PATH = r"E:\Postgraduate_Learning\Python_Learning\DataSets\pascal_voc2007\VOCdevkit\VOC2007"
    train_data_set = RCNN_DetectionDataSet(PATH)

    print('positive num: %d' % train_data_set.get_postive_samples_num())
    print('negative num: %d' % train_data_set.get_negative_samples_num())
    print('total num: %d' % train_data_set.__len__())

    # 测试id=3/66516/66517/530856
    image, target = train_data_set.__getitem__(idx)
    print('target: %d' % target)


    print(image)
    print(type(image))

    cv2.imshow("a",image)
    cv2.waitKey(0)

def test1():
    root_dir = r"E:\Postgraduate_Learning\Python_Learning\DataSets\pascal_voc2007\VOCdevkit\VOC2007"
    train_data_set = RCNN_DetectionDataSet(root_dir)
    train_sampler = RCNN_BatchSampler(train_data_set.get_postive_samples_num(), train_data_set.get_negative_samples_num(), 32, 96)

    print('sampler len: %d' % train_sampler.__len__())
    print('sampler batch num: %d' % train_sampler.get_num_batch())

    first_idx_list = list(train_sampler.__iter__())[:128]
    print(first_idx_list)
    # 单次批量中正样本个数
    print('positive batch: %d' % np.sum(np.array(first_idx_list) < 66517))

if __name__ == "__main__":
    # PATH = r"E:\Postgraduate_Learning\Python_Learning\DataSets\pascal_voc2007\VOCdevkit\VOC2007"
    # test_dataset = RCNN_DetectionDataSet(PATH, transform= None)
    # test_dataloader = DataLoader(test_dataset, 4, shuffle= True)
    # a = next(iter(test_dataloader))[0]
    # print(a.shape)
    # print(next(iter(test_dataloader))[1])
    # cv2.imshow("a", a[0].numpy())
    # cv2.waitKey(0)
    # # 测试结果应该是正
    # test(120)
    # # 测试结果应该是正
    # test(280)
    # # 测试结果应该是负
    # test(600)
    # # 测试结果应该是负
    # test(2100)
    test1()

微调

Pytorch已经实现了AlexNet的结构,并且提供了ImageNet训练后的参数。所需要做的就是在准备好的数据集上再训练。

import torch
from torch import nn
from torchvision import models
from torchvision import transforms
from torch.utils.data import DataLoader

import dataset
from Lib.Trainer import Trainer

def load_data():
    """
    加载数据,只加载训练集的
    :return:
    """
    # 增强数据集
    transform = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize((227, 227)),
        transforms.ToTensor(),
        # 对图片进行归一化,每个输入通道都减去其平均值再除以其标准差
        # 两个参数表示平均值和方差
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    path = r"E:\Postgraduate_Learning\Python_Learning\DataSets\pascal_voc2007\VOCdevkit\VOC2007"
    # 数据集
    data_set = dataset.RCNN_DetectionDataSet(path= path, transform= transform)
    # 每一个批次含有32个正样本和96个负样本
    data_sampler = dataset.RCNN_BatchSampler(data_set.get_postive_samples_num(),
                                             data_set.get_negative_samples_num(),
                                             32, 96)
    # drop_last表示是否当数据集无法整除批量大小时丢掉最后一批
    data_loader = DataLoader(dataset= data_set,
                             batch_size= 128,
                             sampler= data_sampler,
                             num_workers= 2,
                             drop_last= True)
    data_size = len(data_sampler)

    return data_loader, data_size


def AlexNet_finetuning():
    # 指定使用的设备
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    # device = torch.device("cpu")

    AlexNet_pre = models.alexnet(pretrained=True)
    # AlexNet_pre = models.alexnet(pretrained= False)
    train_iter, train_size = load_data()
    # print(AlexNet_pre)

    # 获取分类器的输入特征数量
    num_features = AlexNet_pre.classifier[6].in_features
    # print(AlexNet_pre.classifier[6].in_features)
    # 把最后一层改成二分类
    AlexNet_pre.classifier[6] = nn.Linear(num_features, 2)

    # AlexNet_pre = AlexNet_pre.to(device)
    # 使用交叉熵作为损失函数
    loss = nn.CrossEntropyLoss()
    optimer = torch.optim.SGD(params= AlexNet_pre.parameters(), lr= 1e-3,
                              momentum= 0.9)
    # 学习率衰减策略,每7个epoch衰减十倍
    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer= optimer,
                                                   step_size= 7,
                                                   gamma= 0.1 ,
                                                   verbose= True)

    trainer = Trainer()
    trainer.config_trainer(AlexNet_pre, dataloader= train_iter,
                           optimer= optimer, lr_scheduler= lr_scheduler, loss= loss,
                           device= device)
    trainer.config_task(128, 10)
    trainer.start_task()

    torch.save(AlexNet_pre.state_dict(), '.models/alexnet_cat_10epochs_new.pth')
if __name__ == "__main__":
    AlexNet_finetuning()

做上一点说明,R-CNN使用的分类器是SVM,原文将AlexNet最后一层去掉只用网络提取了4096维的向量然后使用已经训练好的SVM进行分类,本篇的实现则直接用了Softmax做分类,相当于没有改变网络结构。
关于作者为何不使用Softmax做分类,在附录中有说明,但是说的不咋清楚。
作者说,使用了Softmax反而造成了性能的下降,他们推断可能是因为正负样本的划分不同导致的(SVM正样本只有真实边界框,负样本要求IoU小于0.3与真实边界框)。CNN的那种划分方式用在微调上造成了正样本太少负样本太多的情况。关于SVM我不咋了解,而且作为初学者,推断不出什么原因。

预测

选一些图片,按照上面的算法流程进行预测即可。


import torch
from torchvision import transforms
from torchvision.models import alexnet
from torch import nn
import cv2
import copy
import time
import numpy as np

import pascal_VOC
import selectivesearch
import util

def get_model(device=None):
    # 加载CNN模型
    model = alexnet()
    num_classes = 2
    num_features = model.classifier[6].in_features
    model.classifier[6] = nn.Linear(num_features, num_classes)
    model.load_state_dict(torch.load('models/alexnet_cat_10epochs_new.pth'))
    model.eval()

    # 取消梯度追踪
    for param in model.parameters():
        param.requires_grad = False
    if device:
        model = model.to(device)

    return model

def nms(rect_list, score_list):
    """
    非最大抑制
    :param rect_list: list,大小为[N, 4]
    :param score_list: list,大小为[N]
    """
    nms_rects = list()
    nms_scores = list()

    rect_array = np.array(rect_list)
    score_array = np.array(score_list)

    # 一次排序后即可
    # 按分类概率从大到小排序
    idxs = np.argsort(score_array)[::-1]
    rect_array = rect_array[idxs]
    score_array = score_array[idxs]

    thresh = 0.1
    while len(score_array) > 0:
        # 添加分类概率最大的边界框
        nms_rects.append(rect_array[0])
        nms_scores.append(score_array[0])
        rect_array = rect_array[1:]
        score_array = score_array[1:]

        length = len(score_array)
        if length <= 0:
            break

        # 计算IoU
        iou_scores = util.iou(np.array(nms_rects[len(nms_rects) - 1]), rect_array)
        # print(iou_scores)
        # 去除重叠率大于等于thresh的边界框
        idxs = np.where(iou_scores < thresh)[0]
        rect_array = rect_array[idxs]
        score_array = score_array[idxs]

    return nms_rects, nms_scores

def draw_box_with_text(img, rect_list, score_list):
    """
    绘制边框及其分类概率
    :param img:
    :param rect_list:
    :param score_list:
    :return:
    """
    for i in range(len(rect_list)):
        xmin, ymin, xmax, ymax = rect_list[i]
        score = score_list[i]

        cv2.rectangle(img, (xmin, ymin), (xmax, ymax), color=(0, 0, 255), thickness=2)
        cv2.putText(img, "{:.3f}".format(score), (xmin, ymin), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)


if __name__ == "__main__":
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    # 数据转换
    transform = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize((227, 227)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    model = get_model(device=device)

    gs = selectivesearch.get_selective_search()

    test_img_path = r"./img/n_test1.jpg"
    # test_xml_path = r"./img/000122.xml"

    img = cv2.imread(test_img_path)
    dst = copy.deepcopy(img)

    # 获取标注的边界框
    # _, _, bndboxs = pascal_VOC.one_xml_parse(test_xml_path)
    # for bndbox in bndboxs:
    #     xmin, ymin, xmax, ymax = bndbox
    #     cv2.rectangle(dst, (xmin, ymin), (xmax, ymax), color=(0, 255, 0), thickness=1)

    # cv2.imshow("a", dst)
    # cv2.waitKey(0)

    # 候选区域建议
    selectivesearch.config(gs, img, strategy='f')
    rects = selectivesearch.get_rects(gs)
    print('候选区域建议数目: %d' % len(rects))

    svm_thresh = 0.8

    # 得分列表,正样本列表
    score_list = list()
    positive_list = list()

    start = time.time()
    for rect in rects:
        xmin, ymin, xmax, ymax = rect
        rect_img = img[ymin:ymax, xmin:xmax]

        rect_transform = transform(rect_img).to(device)
        output = model(rect_transform.unsqueeze(0))
        # print(output)
        # print(output.shape)
        output = output[0]
        if torch.argmax(output).item() == 1:
            """
            预测为cat
            """
            probs = torch.softmax(output, dim=0).cpu().numpy()
            print(probs)
            print(probs.shape)

            if probs[1] >= svm_thresh:
                score_list.append(probs[1])
                positive_list.append(rect)
                # cv2.rectangle(dst, (xmin, ymin), (xmax, ymax), color=(0, 0, 255), thickness=2)
                # print(rect, output, probs)
    end = time.time()
    print('detect time: %d s' % (end - start))
    nms_rects, nms_scores = nms(positive_list, score_list)
    print(nms_rects)
    print(nms_scores)
    draw_box_with_text(dst, nms_rects, nms_scores)

    cv2.imshow('img', dst)
    cv2.waitKey(0)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/110485.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

springboot是如何工作的

一、前言 现在java后端开发框架比较多的使用springboot框架&#xff0c;springboot是在以前的springMVC进行封装和优化&#xff0c;最大的特点是简化了配置和内置Tomcat。本节通过阅读源码理解springboot是如何工作的。 二、springboot是如何工作的 1、从启动类开始 /***服务…

【Proteus仿真】【Arduino单片机】SG90舵机控制

文章目录 一、功能简介二、软件设计三、实验现象联系作者 一、功能简介 本项目使用Proteus8仿真Arduino单片机控制器&#xff0c;使用SG90舵机等。 主要功能&#xff1a; 系统运行后&#xff0c;舵机开始运行。 二、软件设计 /* 作者&#xff1a;嗨小易&#xff08;QQ&#x…

链表加法与节点交换:数据结构的基础技能

目录 两两交换链表中的节点单链表加一链表加法使用栈实现使用链表反转实现 两两交换链表中的节点 给你一个链表&#xff0c;两两交换其中相邻的节点&#xff0c;并返回交换后链表的头节点。你必须在不修改节点内部的值的情况下完成本题&#xff08;即&#xff0c;只能进行节点…

关于深度学习中Attention的一些简单理解

Attention 机制 Attention应用在了很多最流行的模型中&#xff0c;Transformer、BERT、GPT等等。 Attention就是计算一个加权平均&#xff1b;通过加权平均的权值来自计算每个隐藏层之间的相关度&#xff1b; 示例 Attention 机制 Attention应用在了很多最流行的模型中&am…

基于深度学习的人脸专注度检测计算系统 - opencv python cnn 计算机竞赛

文章目录 1 前言2 相关技术2.1CNN简介2.2 人脸识别算法2.3专注检测原理2.4 OpenCV 3 功能介绍3.1人脸录入功能3.2 人脸识别3.3 人脸专注度检测3.4 识别记录 4 最后 1 前言 &#x1f525; 优质竞赛项目系列&#xff0c;今天要分享的是 &#x1f6a9; 基于深度学习的人脸专注度…

【机器学习】决策树与分类案例分析

决策树与分类案例分析 文章目录 决策树与分类案例分析1. 认识决策树2. 分类3. 决策树的划分依据4. 决策树API5. 案例&#xff1a;鸢尾花分类6. 决策树可视化7. 总结 1. 认识决策树 决策树思想的来源非常朴素&#xff0c;程序设计中的条件分支结构就是if-else结构&#xff0c;最…

【数据结构】复杂度

&#x1f525;博客主页&#xff1a; 小羊失眠啦. &#x1f3a5;系列专栏&#xff1a;《C语言》 《数据结构》 《Linux》《Cpolar》 ❤️感谢大家点赞&#x1f44d;收藏⭐评论✍️ 文章目录 一、什么是数据结构二、什么是算法三、算法的效率四、时间复杂度4.1 时间复杂度的概念4…

【100天精通Python】Day72:Python可视化_一文掌握Seaborn库的使用《二》_分类数据可视化,线性模型和参数拟合的可视化,示例+代码

目录 1. 分类数据的可视化 1.1 类别散点图&#xff08;Categorical Scatter Plot&#xff09; 1.2 类别分布图&#xff08;Categorical Distribution Plot&#xff09; 1.3 类别估计图&#xff08;Categorical Estimate Plot&#xff09; 1.4 类别单变量图&#xff08;Cat…

远程IO:实现立体车库高效运营的秘密武器

随着城市的发展&#xff0c;车辆无处停放的问题变得越来越突出。为了解决这个问题&#xff0c;立体车库应运而生。立体车库具有立体空间利用率高、存取车方便、安全可靠等优点&#xff0c;成为现代城市停车的重要解决方案。 立体车库控制系统介绍 在立体车库中&#xff0c;控制…

基于51单片机的四种波形信号发生器仿真设计(仿真+程序源码+设计说明书+讲解视频)

本设计 基于51单片机信号发生器仿真设计 &#xff08;仿真程序源码设计说明书讲解视频&#xff09; 仿真原版本&#xff1a;proteus 7.8 程序编译器&#xff1a;keil 4/keil 5 编程语言&#xff1a;C语言 设计编号&#xff1a;S0015 这里写目录标题 基于51单片机信号发生…

简单明了!网关Gateway路由配置filters实现路径重写及对应正则表达式的解析

问题背景&#xff1a; 前端需要发送一个这样的请求&#xff0c;但出现404 首先解析请求的变化&#xff1a; http://www.51xuecheng.cn/api/checkcode/pic 1.请求先打在nginx&#xff0c;www.51xuecheng.cn/api/checkcode/pic部分匹配到了之后会转发给网关进行处理变成localho…

Android底层摸索改BUG(二):Android系统移除预置APP

首先我先提供以下博主博文&#xff0c;对相关知识点可以提供理解、解决、思考的 Android 系统如何预装第三方应用以及常见问题汇集android Android.mk属性说明及预置系统app操作说明系Android 中去除系统原生apk的方法 取消预置APK方法一&#xff1a; 其实就是上面的链接3&a…

Day 4 登录页及路由 (二) -- Vue状态管理

状态管理 之前的实现中&#xff0c;判断登录状态用了伪实现&#xff0c;实际当中&#xff0c;应该是以缓存中的数据为依据来进行的。这就涉及到了应用程序中的状态管理。在Vue中&#xff0c;状态管理之前是Vuex&#xff0c;现在则是推荐使用Pinia&#xff0c;在脚手架项目创建…

linux查看系统版本、内核信息、操作系统类型版本

1. 使用 uname 命令&#xff1a;这将显示完整的内核版本信息&#xff0c;包括内核版本号、主机名、操作系统类型等。 uname -a2. 使用 lsb_release 命令&#xff08;仅适用于支持 LSB&#xff08;Linux Standard Base&#xff09;的发行版&#xff09;&#xff1a;这将显示包含…

HCIE怎么系统性学习?这份HCIE学习路线帮你解决

华为认证体系覆盖ICT行业十一个技术领域共十三个技术方向的认证&#xff0c;今天我们分享的是其中最热门的数据通信方向的HCIE学习路线。 HCIE是华为认证体系中最高级别的ICT技术认证 &#xff0c;旨在打造高含金量的专家级认证&#xff0c;为技术融合背景下的ICT产业提供新的能…

JVS-BI数字大屏设计器:一站式解决方案

数字大屏介绍 数字大屏是当下数据展示、业务监控、指挥调度常见的业务表达形态&#xff0c;常有可视化的图表、效果装饰、事件操作等技术组成酷炫的效果展示。 配置入口 进入JVS-BI&#xff08;bi.bctools.cn&#xff09;&#xff0c;进入大屏页面&#xff0c;如下图所示 ①…

TypeScript之函数以及与JavaScript函数的区别

一、是什么 函数是JavaScript 应用程序的基础&#xff0c;帮助我们实现抽象层、模拟类、信息隐藏和模块 在TypeScript 里&#xff0c;虽然已经支持类、命名空间和模块&#xff0c;但函数仍然是主要定义行为的方式&#xff0c;TypeScript 为 JavaScript 函数添加了额外的功能&…

Docker 部署spring-boot项目(超详细 包括Docker详解、Docker常用指令整理等)

文章目录 DockerDocker的定义Docker有哪些作用Docker有哪些好处使用docker部署springboot项目安装docker创建Dockerfile镜像文件执行镜像文件(Dockerfile文件)查看Docker镜像启动容器查看Docker中运行的容器查看服务容器日志 Docker常用指令查看docker安装目录启动Docker停止Do…

无品牌国产PLC模块调试说明

地址30001对应的aiw9 30002对应aiw10 30003 aiw11 30004 aiw12 模块接线及拨码全部向下&#xff0c;对应的DeviceID为15地址 使用串口线链接的时候a要接b0 b接a0 要反着接才能有数据
最新文章