Faster-RCNN代码解读2:快速上手使用

Faster-RCNN代码解读2:快速上手使用

前言

​ 因为最近打算尝试一下Faster-RCNN的复现,不要多想,我还没有厉害到可以一个人复现所有代码。所以,是参考别人的代码,进行自己的解读。

代码来自于B站的UP主(大佬666),其把代码都放到了GitHub上了,我把链接都放到下面了(应该不算侵权吧,毕竟代码都开源了_):

b站链接:https://www.bilibili.com/video/BV1of4y1m7nj/?vd_source=afeab8b555e5eb1bfa1e7f267262cbf2

GitHub链接:https://github.com/WZMIAOMIAO/deep-learning-for-image-processing

目的

​ 其实UP主已经做了很好的视频讲解了他的代码,只是有时候我还是喜欢阅读博客来学习,另外视频很长,6个小时,我看的时候容易睡着_,所以才打算写博客记录一下学习笔记。

目前完成的内容

第一篇:VOC数据集详细介绍

第二篇:Faster-RCNN代码解读2:快速上手使用(本文)

目录结构

文章目录

    • Faster-RCNN代码解读2:快速上手使用
      • 1. 前言:
      • 2. 下载项目代码:
      • 3. 下载数据和权重文件:
      • 4. predict.py文件解读:
      • 5. pascal_voc_classes.json文件介绍:
      • 6. 快速上手:
      • 7. 总结:

1. 前言:

​ 本篇文章的作用是准备好一些必备的数据或权重文件,以实现直接快速使用代码的目的。

2. 下载项目代码:

​ 打开大佬的GitHub链接,然后,进入pytorch_object_detection文件内:

在这里插入图片描述

​ 然后,把Faster-RCNN文件夹下载下来即可。不过,GitHub本身不支持单个文件夹的下载,这时候推荐一下浏览器的插件GitZip for github ,把这个插件安装后,即可下载单独的文件夹,如下图所示:

在这里插入图片描述

​ 下载完成后的目录结构如下:

在这里插入图片描述

3. 下载数据和权重文件:

​ 打开README.md文件,里面说明了预训练权重文件和数据集的下载地址:

  • ResNet50+FPN权重文件下载:
官方的权重文件:https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth

up主自己训练后的权重地址:
https://pan.baidu.com/s/1ifilndFRtAV5RDZINSHj5w 提取码:dsz8
  • 数据集下载地址:
http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar

​ 完成上述下载后,可以得到下图的文件:

在这里插入图片描述

4. predict.py文件解读:

​ 打开predict.py文件,这个文件的作用就是加载已经训练过的模型,对一张图片进行目标检测。

main函数:

​ 看main函数,主要分为四个部分:

  • 设置权重文件路径(需要我们改的参数),并用模型加载:
# 选定GPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("using {} device.".format(device))

# 创建模型:21=20个类别+1个背景
model = create_model(num_classes=21)

# 加载权重参数
# weights_path = "./save_weights/model.pth"   # 权重保存路径,作者自己定义的
weights_path = "./fasterrcnn_voc2012.pth"   # 权重保存路径,我们下载后自己的路径
assert os.path.exists(weights_path), "{} file dose not exist.".format(weights_path)
# 开始加载权重文件
weights_dict = torch.load(weights_path, map_location='cpu')  # 加载之前训练保存的字典
weights_dict = weights_dict["model"] if "model" in weights_dict else weights_dict # 选定model参数
model.load_state_dict(weights_dict) # 加载
model.to(device) # 放入GPU
  • 读取“类别----数字值”的json文件,并生成一个字典,以方便后期将预测的类别(比如:1、2这样的数字)转为字符串(比如:person、bicycle等)
# 读取json文件
label_json_path = './pascal_voc_classes.json'
assert os.path.exists(label_json_path), "json file {} dose not exist.".format(label_json_path)
with open(label_json_path, 'r') as f:
class_dict = json.load(f)
# 将值转为字典
category_index = {str(v): str(k) for k, v in class_dict.items()}
  • 加载一张测试图片(需要改为我们自己的),并改为符合[batch,channel,w,h]的格式
# 加载一张测试图片
original_img = Image.open("./test.jpg") # 需要改为自己的路径

# 将PIL图像格式转为tensor格式
data_transform = transforms.Compose([transforms.ToTensor()])
img = data_transform(original_img)
# 增加一个batch维度,符合训练图片格式
img = torch.unsqueeze(img, dim=0)
  • 开始预测图片,并计算运行的时间(一般不计算第一次的时间,因为GPU调用需要时间)和画出对应图像
model.eval()  # 进入验证模式
with torch.no_grad():
    # init
    # 初始化,原始图像的宽、高
    img_height, img_width = img.shape[-2:]
    # 将图像放入GPU中,并变为model可以识别的格式[batch_size,channel,w,h]
    init_img = torch.zeros((1, 3, img_height, img_width), device=device)
    # 验证
    model(init_img)

    # 计算预测时间,不过不能直接计算第一次,因为需要启动gpu等
    t_start = time_synchronized()
    predictions = model(img.to(device))[0]
    t_end = time_synchronized()
    print("inference+NMS time: {}".format(t_end - t_start))

    # 得到预测的相关参数
    predict_boxes = predictions["boxes"].to("cpu").numpy()
    predict_classes = predictions["labels"].to("cpu").numpy()
    predict_scores = predictions["scores"].to("cpu").numpy()

    if len(predict_boxes) == 0:
    	print("没有检测到任何目标!")

    # 绘制图像
    plot_img = draw_objs(original_img,
                             predict_boxes,
                             predict_classes,
                             predict_scores,
                             category_index=category_index,
                             box_thresh=0.5,
                             line_thickness=3,
                             font='arial.ttf',
                             font_size=20)
    plt.imshow(plot_img)
    plt.show()
    # 保存预测的图片结果
    plot_img.save("test_result.jpg")

create_model函数

​ 了解了main函数后,我们再看看create_model函数,这个函数的作用就是创建模型。作者在该项目中采取了很多模型,比如VGG16、mobilenetv2、resnet等等,而这里我们用的是刚刚下载的权重文件对应的模型,即resNet50+fpn+faster-rcnn,因此需要把其它的模型代码注释掉:

def create_model(num_classes):
    # mobileNetv2+faster_RCNN
    # backbone = MobileNetV2().features
    # backbone.out_channels = 1280
    #
    # anchor_generator = AnchorsGenerator(sizes=((32, 64, 128, 256, 512),),
    #                                     aspect_ratios=((0.5, 1.0, 2.0),))
    #
    # roi_pooler = torchvision.ops.MultiScaleRoIAlign(featmap_names=['0'],
    #                                                 output_size=[7, 7],
    #                                                 sampling_ratio=2)
    #
    # model = FasterRCNN(backbone=backbone,
    #                    num_classes=num_classes,
    #                    rpn_anchor_generator=anchor_generator,
    #                    box_roi_pool=roi_pooler)

    # resNet50+fpn+faster_RCNN
    # 注意,这里的norm_layer要和训练脚本中保持一致
    backbone = resnet50_fpn_backbone(norm_layer=torch.nn.BatchNorm2d)
    model = FasterRCNN(backbone=backbone, num_classes=num_classes, rpn_score_thresh=0.5)

    return model

5. pascal_voc_classes.json文件介绍:

​ 我们再看看上面涉及json文件,这个文件就是voc数据集的类别和数字值的对应关系,比如:

{
    "aeroplane": 1,
    "bicycle": 2,
    "bird": 3,
    "boat": 4,
    "bottle": 5,
    "bus": 6,
    "car": 7,
    "cat": 8,
    "chair": 9,
    "cow": 10,
    "diningtable": 11,
    "dog": 12,
    "horse": 13,
    "motorbike": 14,
    "person": 15,
    "pottedplant": 16,
    "sheep": 17,
    "sofa": 18,
    "train": 19,
    "tvmonitor": 20
}

​ 需要注意的是,这里的值是从1开始的,是因为0一般是留给背景的。

6. 快速上手:

​ 有了上面的解读后,我们可以快速上手看看效果。

​ 这里再次声明一下predict.py文件需要修改**权重文件路径和自己搞一张测试图片并修改路径。**完成修改后,直接运行该文件即可,我测试了几张图片,结果如下图:

在这里插入图片描述
在这里插入图片描述

7. 总结:

​ 上面主要简单介绍了如何快速上手,看到结果,给自己一种这个很简单的错觉。后面,主要就是对一些主要的文件进行解读。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/10579.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

汽车电子相关术语介绍

一、相关术语介绍 1、汽车OTA 全称“Over-The-Air technology ”,即空中下载技术,通过移动通信的接口实现对软件进行远程管理,传统的做法到4S店通过整车OBD对相应的ECU进行软件升级。OTA技术最早2000年出现在日本,目前通过OTA方式…

FusionCharts Suite XT v3.20.0 Crack

FusionCharts Suite XT v3.20.0 改进了仪表的径向条形图和调整大小功能。2023 年 4 月 11 日 - 9:37新版本特征 添加了一个新方法“_changeXAxisCoordinates”,它允许用户将 x 轴更改为在图例或数据交互时自动居中对齐。更新了 Angular 集成以支持 Angular 版本 14 …

【微信小程序-原生开发】添加自定义图标(以使用阿里图标库为例)

方式一 &#xff1a; 下载svg导入 优点&#xff1a;操作方便&#xff0c;支持多彩图标缺点&#xff1a;会增加源代码大小 下载 svg 格式的图标图片&#xff0c;放入源码中使用 小程序项目中的路径为 assets\icon\美食.svg 使用时-代码范例 <image class"imgIcon"…

前端开发工具-Visual Studio Code-插件下载-迁移到新电脑

背景 前端使用的开发工具一般是Visual Studio Code&#xff0c;很多辅助功能&#xff0c;比如字体高亮、单词拼写检查、预览图片等需要安装插件。但是插件在原来的电脑&#xff0c;不想下载或者自己是新人&#xff0c;想迁移同事的插件&#xff0c;或者新电脑没有外网。 以下…

图解HTTP阅读笔记:第4章 返回结果的HTTP状态码

《图解HTTP》第四章读书笔记 图解HTTP第4章&#xff1a;返回结果的HTTP状态码4.1 状态码告知从服务器端返回的请求结果4.2 2XX成功4.2.1 200 OK4.2.2 204 No Content4.2.3 206 Parital Content4.3 3XX重定向4.3.1 301 Moved Permanently4.3.2 302 Found4.3.3 303 See Other4.3.…

OK-3399-C ADB烧录

ADB烧写 一、OK3399用户资料工具目录附带了ADB工具的资料包路径&#xff1a; 二、将其解压在C:\User目录 三、将设备通过type-c线download口与电脑相连接&#xff0c;打开命令行&#xff0c;进入解压的目录&#xff0c;查看adb是否安装成功&#xff1a; 四、安装成功后&#x…

spring-boot怎么扫描不在启动类所在包路径下的bean

前言&#xff1a; 项目中有多个模块&#xff0c;其中有些模块的包路径不在启动类的子路径下&#xff0c;此时我们怎么处理才能加载到这些类&#xff1b; 1 使用SpringBootApplication 中的scanBasePackages 属性; SpringBootApplication(scanBasePackages {"com.xxx.xx…

在proteus中仿真arduino实现矩阵键盘程序

矩阵键盘是可以解决我们端口缺乏的问题&#xff0c;当然&#xff0c;如果我们使用芯片来实现矩阵键盘的输入端口缺乏的问题将更加划算了&#xff0c;本文暂时不使用芯片来解决问题&#xff0c;而使用纯朴的8根线来实现矩阵键盘&#xff0c;目的是使初学者掌握原理。想了解使用芯…

# 切削加工形貌的相关论文阅读【1】-球头铣刀铣削球面的表面形貌建模与仿真研究

切削加工形貌论文【1】-球头铣刀铣削球面的表面形貌建模与仿真研究1. 论文【2】-球头铣刀加工表面形貌建模与仿真1.1 切削加工形貌仿真-考虑的切削参数1.2 其他试验条件1.3 主要研究目的1.4 试验与分析结果1.5 面粗糙度的评价指标2. 论文【1】-球头铣刀加工球面&#xff08;曲面…

Vue3.0中的响应式原理

回顾Vue2的响应式原理 实现原理&#xff1a; - 对象类型&#xff1a;通过 Object.defineProperty()对属性的读取、修改进行拦截&#xff08;数据劫持&#xff09;。 - 数组类型&#xff1a;通过重写更新数组的一系列方法来实现拦截。&#xff08;对数组的变更方法进行了包裹&…

nacos源码服务注册

nacos服务注册序言1.源码环境搭建1.1idea运行源码1.2 登录nacos2.服务注册分析2.1 客户端2.1.1容器启动监听2.1.2注册前初始化2.1.3注册服务2.2 服务端2.2.1注册2.2.2重试机制3.注意事项序言 本文章是分析的是nacos版本2.2 这次版本是一次重大升级优化&#xff0c;由原来&#…

浅析DNS Rebinding

0x01 攻击简介 DNS Rebinding也叫做DNS重绑定攻击或者DNS重定向攻击。在这种攻击中&#xff0c;恶意网页会导致访问者运行客户端脚本&#xff0c;攻击网络上其他地方的计算机。 在介绍DNS Rebinding攻击机制之前我们先了解一下Web同源策略&#xff0c; Web同源策略 同源策略…

微前端--qiankun原理概述

demo放最后了。。。 一、微前端 一》微前端概述 微前端概念是从微服务概念扩展而来的&#xff0c;摒弃大型单体方式&#xff0c;将前端整体分解为小而简单的块&#xff0c;这些块可以独立开发、测试和部署&#xff0c;同时仍然聚合为一个产品出现在客户面前。可以理解微前端是…

【从零开始学Skynet】基础篇(八):简易留言板

这一篇我们要把网络编程和数据库操作结合起来&#xff0c;实现一个简单的留言板功能。 1、功能需求 如下图所示&#xff0c;客户端发送“set XXX”命令时&#xff0c;程序会把留 言“XXX”存入数据库&#xff0c;发送“get”命令时&#xff0c;程序会把整个留言板返回给客户端。…

HarmonyOS/OpenHarmony应用开发-Stage模型ArkTS语言Ability基类

Ability模块提供对Ability生命周期、上下文环境等调用管理的能力&#xff0c;包括Ability创建、销毁、转储客户端信息等。 说明: 模块首批接口从API version 9 开始支持。模块接口仅可在Stage模型下使用。 导入模块: import Ability from ohos.app.ability.Ability; 接口说明…

如何利用python机器学习解决空间模拟与时间预测问题及经典案例分析

目录 专题一 机器学习原理与概述 专题二 Python编译工具组合安装教程 专题三 掌握Python语法及常见科学计算方法 专题四 机器学习数据清洗 专题五 机器学习与深度学习方法 专题六 机器学习空间模拟实践操作 专题七 机器学习时间预测实践操作 更多 了解机器学习的发展历…

NVIDIA- cuSPARSE(四)

cuSPARSE logging 日志记录机制&#xff0c; 可以通过在启动目标应用程序之前设置一下环境变量来启动cuSPARSE日志记录机制&#xff1a; CUSPARSE_LOG_LEVEL<level> level的取值&#xff1a; 0 Off 日志记录关闭1 Error只有报错会被记录2Trace启动CUDA内核的API调用将记…

配置基于WSL2的Docker环境并支持CUDA

导言 Content 正如前文windows 10 开启WSL2介绍的&#xff0c;我们可以在windows10中使用linux子系统。今天本文介绍如何在此基础上安装Docker并支持在wsl中使用GPU。 准备工作 加入windows insider preview。建议选Dev通道&#xff0c;不要选Beta。 安装Nvidia WSL2-compa…

docker too many open files解决方式

1&#xff1a;问题描述 今天在环境上执行docker ps命令失败&#xff0c;如下提示 [rootcontrol02 ~]# docker ps -a lgrep nginx Cannot connect to the Docker daemon at unix:///var/run/docker.sock, Is the docker daemon running?2&#xff1a;查看节点docker状态 看信…

云原生网络之微隔离

本博客地址&#xff1a;https://security.blog.csdn.net/article/details/130044619 一、微隔离介绍 1.1、微隔离概念 在主体执行动作时&#xff0c;对主体权限和行为进行判断&#xff0c;最常见的是网络访问控制&#xff0c;也就是零信任网络访问&#xff08;ZTNA&#xff…