7.5.tensorRT高级(2)-RAII接口模式下的生产者消费者多batch实现

目录

    • 前言
    • 1. RAII接口模式封装生产者消费者
    • 2. 问答环节
    • 总结

前言

杜老师推出的 tensorRT从零起步高性能部署 课程,之前有看过一遍,但是没有做笔记,很多东西也忘了。这次重新撸一遍,顺便记记笔记。

本次课程学习 tensorRT 高级-RAII 接口模式下的生产者消费者多 batch 实现

课程大纲可看下面的思维导图

在这里插入图片描述

1. RAII接口模式封装生产者消费者

这节课我们利用上节课学到的 RAII + 接口模式对我们的消费者生产者进行封装

我们来看代码

infer.hpp

#ifndef INFER_HPP
#define INFER_HPP

#include <memory>
#include <string>
#include <future>

class InferInterface{
public:
    virtual std::shared_future<std::string> forward(std::string pic) = 0;
};

std::shared_ptr<InferInterface> create_infer(const std::string& file);

#endif // INFER_HPP

infer.cpp

#include "infer.hpp"
#include <thread>
#include <queue>
#include <mutex>
#include <future>

using namespace std;

struct Job{
    shared_ptr<promise<string>> pro;
    string input;
};

class InferImpl : public InferInterface{
public:

    virtual ~InferImpl(){
        worker_running_ = false;
        cv_notify_one();

        if(worker_thread_.joinable())
            worker_thread_.join();
    }

    bool load_model(const string& file){
        
        // 尽量保证资源哪里分配哪里释放,哪里使用,这样使得程序足够简单,而不是太乱
        // 线程内传回返回值的问题
        promise<bool> pro;
        worker_running_ = true;
        worker_thread_ = thread(&InferImpl::worker, this, file, std::ref(pro));
        return pro.get_future().get();

    }

     virtual shared_future<string> forward(string pic) override{

        // printf("使用 %s 进行推理\n", context_.c_str());
        // 往队列抛任务
        Job job;
        job.pro.reset(new promise<string>());
        job.input = pic;

        lock_guard<mutex> l(job_lock_);
        qjobs_.push(job);

        // 被动通知,一旦有新的任务需要推理,通知我即可
        // 发生通知的家伙
        cv_.notify_one();
        return job.pro->get_future();
    }

    // 实际执行模型推理的部分
    void worker(string file, promise<bool>& pro){
        // worker内实现,模型的加载,使用,释放
        string context = file;
        if(context.empty()){
            pro.set_value(false);
            return;
        }else{
            pro.set_value(true);
        }

        int max_batch_size = 5;
        vector<Job> jobs;
        int batch_id = 0;
        while(worker_running_){
            // 等待接受的家伙
            // 在队列取任务并执行的过程
            unique_lock<mutex> l(job_lock_);
            cv_.wait(job_lock_, [&](){
                // true 退出等待
                // false 继续等待
                return !qjobs_.empty() || !worker_running_;
            });

            // 程序发送终止信号
            if(!worker_running_)
                break;

            while(jobs.size() < max_batch_size && !qjobs_.empty()){
                jobs.emplace_back(qjobs_.front());
                qjobs.pop();
            }
            // 可以在这里一次拿一批出来,最大拿 maxbatchsize 个 job 进行一次性处理
            // jobs inference -> batch inference

            // 执行 batch 推理
            for(int i = 0; i < jobs.size(); ++i){
                
                auto& job = jobs[i];
                char result[100];
                sprintf(result, "%s : batch-> %d[%d]", job.input.c_str(), batch_id, jobs.size());
                
                job.pro->set_value(result);
            }
            batch_id++;
            jobs.clear();
            // 模拟推理耗时
            this_thread::sleep_for(chrono::milliseconds(1000));
        }
        // 释放模型
        printf("释放: %s\n", context.c_str());
        context.clear();
        printf("Worker done.\n");
    }
private:
    atomic<bool> worker_running_{false};
    thread worker_thread_;
    queue<Job> qjobs_;
    mutex job_lock_;
    condition_variable cv_;
};

shared_ptr<InferInterface> create_infer(const string& file){
    
    shared_ptr<InferImpl> instance(new Infer());
    if(!instance->load_model(file))
        instance.reset();
    return instance;
}

main.cpp

#include "infer.hpp"

int main(){
    
    auto infer = create_infer("a");
    if(infer == nullptr){
        printf("failed.\n");
        return -1;
    }

    // 串行
    // auto fa = infer->forward("A").get();
    // auto fb = infer->forward("B").get();
    // auto fc = infer->forward("C").get();
    // printf("%s\n", fa.c_str());
    // printf("%s\n", fb.c_str());
    // printf("%s\n", fc.c_str());
    
    // 并行
    auto fa = infer->forward("A");
    auto fb = infer->forward("B");
    auto fc = infer->forward("C");
    printf("%s\n", fa.get().c_str());
    printf("%s\n", fb.get().c_str());
    printf("%s\n", fc.get().c_str());    
    printf("Program done.\n");

    return 0;
}

上述示例代码相对复杂,结合了 RAII 和接口模式来实现模拟模型推理,具体是一个消费者-生产者模式的异步批处理机制,我们来简单解读下 infer.cpp 中具体干了些啥(form chatGPT

1. 数据结构和类定义

  • Job 结构体:这是一个任务结构,包含了一个 promise 对象(用于在工作线程中设置结果)和输入数据,promise 又通过 shared_ptr 封装了一层,可以让结构体传递效率更高
  • InferImpl 类,这是 InferInterface 的实现类,包含了异步处理的核心逻辑

2. InferImpl 类的方法和成员

  • 析构函数:在对象销毁时,将 worker_running_ 标志设置为 false,并通过条件变量唤醒工作线程。然后等待工作线程结束
  • load_model 方法:模型加载函数,它实际上启动了工作线程,并传递了一个 promise 对象来设置是否成功加载了模型
  • forward 方法:这是暴露给使用者的接口,用于提交一个新的推理任务。这个方法将任务添加到队列中,并通过条件变量唤醒工作线程
  • worker 方法:这是工作线程的核心函数,它从队列中取出任务并批量处理它们,然后使用 promise 设置结果
  • 私有成员
    • worker_running_:一个原子布尔标志,表示工作线程是否正在运行
    • worker_thread_:工作线程对象
    • qjobs_:包含待处理任务的队列
    • job_lock_:保护任务队列的互斥锁
    • cv_:条件变量,用于在有新任务到来或工作线程需要停止时唤醒工作线程

3. 工厂函数

  • create_infer 函数:RAII 的体现,这个函数创建了一个 InferImpl 的实例,并尝试加载模型。如果加载失败,它将返回一个空的智能指针。

这个示例清晰地展示了如何使用 RAII 和接口模式来实现一个异步批处理机制,同时也展示了如何使用 C++11 的并发特性(如 threadpromisecondition_variable 等)来实现这种机制。

2. 问答环节

博主对多线程相关的知识不怎么了解,因此疯狂询问 chatGPT,故此做个记录方便下次查看,以下内容来自于博主和 chatGPT 之间的对话

问题1:work_running_ 为什么是 atomic<boll> 类型,为什么不直接使用 bool 类型?什么是 atomic<bool> 类型?🤔

回答atomic<bool> 是 C++11 引入的原子类型,它是为了支持多线程环境中无锁的线程安全操作。使用原子类型可以确保对该类型的读写操作是原子的,即它们不会被其他线程中的操作中断。

当多个线程可能同时访问和修改一个变量时,直接使用基本数据类型(如 bool)可能会导致不确定的行为或数据竞态(race conditions)。数据竞态通常是多线程编程中的一个问题,当两个或更多的线程同时访问同一资源,并至少有一个线程修改该资源时,结果是不确定的。

使用 atomic<bool> 的好处:

1. 原子操作atomic<bool> 保证了对其的读写操作是原子的,这意味着每个读或写操作都是不可中断的。

2. 内存顺序:原子操作还提供了内存顺序语义,这确保了在多线程环境中的操作顺序。

3. 无锁操作:与使用互斥锁相比,原子操作通常更快,因为它们通常是无锁的。

在示例中,atomic<bool> 可能会在主线程中被设置,并在工作线程中被检查。为了确保这种跨线程的通信是线程安全的,使用 atomic<bool> 是合适的。如果只使用普通的 bool,可能会导致数据竞态,从而引起不可预测的行为。

总之,atomic<bool> 提供了一种线程安全的方式来读取和修改一个布尔值,特别是在无需使用互斥锁的情况下。

问题2:我在 CUDA 编程中也听过原子操作,就是 GPU 中多个线程对同一个变量进行修改时要使用原子操作,这跟你提到的原子类型是不是类似的呢?🤔

回答:是的,CUDA 中的原子操作和 C++ 中的原子类型在概念上是非常相似的。两者都是为了确保在并发环境中对数据的访问和修改是线程安全的。

1. CUDA 中的原子操作

  • CUDA 提供了一系列原子函数,如 atomicAdd, atomicSub 等,这些函数可以确保在多个线程可能同时访问同一内存位置时,这些操作是原子的。
  • 这在 GPU 编程中尤为重要,因为 GPU 具有大量的并行执行单元,这使得多个线程可能会同时访问同一个内存位置。
  • 使用原子操作可以避免数据竞态,并确保正确的结果。

2. C++ 中的原子类型

  • C++ 的 std::atomic 类模板提供了一系列原子操作,这些操作确保在多线程环境中对数据的读取和写入是线程安全的,而无需使用互斥锁。
  • 这对于 CPU 的多
    • 线程编程非常有用,因为多个线程可能会并发地运行,并尝试访问和修改同一块内存。

两者的共同点是它们都提供了一种线程安全的方式来处理并发的数据访问和修改。不同之处在于它们分别是为 GPU 和 CPU 设计的,并具有各自的特定语境和使用场景。

总的来说,无论是 CUDA 的原子操作还是 C++ 的原子类型,其核心目的都是为了提供一种在并发环境中安全访问和修改数据的机制。

总结

本次课程学习了 RAII 接口模式下的生产者消费者多 batch 实现。把前面学的知识全都串起来了,首先通过 create_infer 创建 infer 资源并对模型进行加载,这是 RAII 体现,其次该函数返回的是接口类的智能指针,因此用户只能看到 forward 这一个接口,这是接口模式的应用

在 InferImpl 具体实现类中,我们通过 forward 函数不断向队列中添加数据,而 worker 函数则会判断队列中的数据是否为空,若不为空则进行推理,若为空则继续等待,是否等待是通过条件变量 condition_variable 的 wait 和 notify_one 来实现的,另外 worker 线程将推理结果返回到 forward 中是通过 promise 和 future 来实现,值得注意的是我们在 forward 中返回的并不是 future.get() 而是直接返回的一个 future 对象,具体什么时候 get 拿结果用使用者决定

这个示例把生产者和消费者模式、RAII接口模式以及异步机制等都结合起来,有点像 tensorRT_Pro 中推理实现部分的雏形😂

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/73410.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

前后端分离------后端创建笔记(05)用户列表查询接口(上)

本文章转载于【SpringBootVue】全网最简单但实用的前后端分离项目实战笔记 - 前端_大菜007的博客-CSDN博客 仅用于学习和讨论&#xff0c;如有侵权请联系 源码&#xff1a;https://gitee.com/green_vegetables/x-admin-project.git 素材&#xff1a;https://pan.baidu.com/s/…

City Walk带动茶饮品牌售1200万,媒介盒子带你探究奥秘

年轻人生活趋势又出现了一个新鲜词——City Walk&#xff0c;简单来说&#xff0c;City Walk就是没有目的地&#xff0c;没有目标&#xff0c;只是出行&#xff0c;填充自己的生活。 其实这个词源于gap year&#xff0c;而这个说法一直是国外的一些毕业生&#xff0c;大多会在…

plt取消坐标轴刻度、自定义取消绘图边框(或坐标轴)、白边处理、自定义颜色图谱、设置坐标轴刻度朝向

目录 1、取消坐标轴刻度 2、自定义取消绘图边框&#xff08;或坐标轴&#xff09; 3、去掉图片周边白边 4、自定义颜色图谱 5、设置坐标轴刻度朝向 import matplotlib.pyplot as plt 1、取消坐标轴刻度 ax plt.subplot() ax.set_xticks([]) ax.set_yticks([]) 2、自定…

自定义 视频/音频 进度条

复制代码根据自己需求改动就可以了 <!DOCTYPE html> <html lang"en"><head><meta charset"UTF-8" /><meta http-equiv"X-UA-Compatible" content"IEedge" /><metaname"viewport"conten…

模仿火星科技 基于cesium+水平面积测量+可编辑

​ 当您进入Cesium的编辑水平积测量世界&#xff0c;下面是一个详细的操作过程&#xff0c;帮助您顺利使用这些功能&#xff1a; 1. 创建提示窗&#xff1a; 启动Cesium应用&#xff0c;地图场景将打开&#xff0c;欢迎您进入编辑模式。 在屏幕的一角&#xff0c;一个友好的提…

计算机网络:网络通信相关概念入门

目录 一、网络发展背景二、理解网络通信三、理解IP地址1.简述IP地址2.IP地址的版本3.提高地址利用率的技术 四、理解端口1.简述端口2.使用端口的原因 五、理解网络通信协议 一、网络发展背景 网络发展背景&#xff1a; 最初的计算机是单机&#xff0c;那么单机是这样传输数据的…

【金融量化】Python实现根据收益率计算累计收益率并可视化

1 理论 理财产品&#xff08;本金100元&#xff09; 第1天&#xff1a;3% &#xff1a;&#xff08;13%&#xff09; ✖ 100 103 第2天&#xff1a;2% &#xff1a;&#xff08;12%&#xff09;✖ 以上 103 2.06 第3天&#xff1a;5% : &#xff08;15%&#xff09;✖ 以上…

企业直播MR虚拟直播(MR混合现实直播技术)视频介绍

到底什么是企业直播MR虚拟直播&#xff08;MR混合现实直播技术&#xff09;&#xff1f; 企业直播MR虚拟直播新玩法&#xff08;MR混合现实直播技术&#xff09; 我的文章推荐&#xff1a; [视频图文] 线上研讨会是什么&#xff0c;企业对内对外培训可以用线上研讨会吗&#x…

2023年新学期12306高铁火车学生票如何在线核验享受优惠?

2023学年优惠资质核验已开始&#xff0c;完成学生优惠资质核验后&#xff0c;您可以在线购买2022年10月1日至2023年9月30日的学生优惠票。&#xff08;注&#xff1a;非该时间段需要重新核验&#xff0c;可享受学生优惠票&#xff09;&#xff1b; 『扩展阅读』 1、美团外卖红…

QT:UI控件(按设计师界面导航界面排序)

基础部分 创建新项目&#xff1a;QWidget&#xff0c;QMainWindow&#xff0c;QDialog QMainWindow继承自QWidget&#xff0c;多了菜单栏; QDialog继承自QWidget&#xff0c;多了对话框 QMainWindow 菜单栏和工具栏&#xff1a; Bar: 菜单栏&#xff1a;QMenuBar&#xff0…

【Sklearn】基于随机森林算法的数据分类预测(Excel可直接替换数据)

【Sklearn】基于随机森林算法的数据分类预测&#xff08;Excel可直接替换数据&#xff09; 1.模型原理1.1 模型原理1.2 数学模型 2.模型参数3.文件结构4.Excel数据5.下载地址6.完整代码7.运行结果 1.模型原理 随机森林&#xff08;Random Forest&#xff09;是一种集成学习方法…

【深度学习】PyTorch快速入门

【深度学习】学习PyTorch基础 介绍PyTorch 深度学习框架是一种软件工具&#xff0c;旨在简化和加速构建、训练和部署深度学习模型的过程。深度学习框架提供了一系列的函数、类和工具&#xff0c;用于定义、优化和执行各种深度神经网络模型。这些框架帮助研究人员和开发人员专注…

RabbitMQ 安装教程

RabbitMQ 安装教程 特殊说明 因为RabbitMQ基于Erlang开发&#xff0c;所以安装时需要先安装Erlang RabbitMQ和Erlang版本对应关系 查看地址&#xff1a;www.rabbitmq.com/which-erlan… 环境选择 Erlang: 23.3及以上 RabbitMQ: 3.10.1Windows 安装 1. 安装Erlang 下载地…

Spark 学习记录

基础 SparkContext是什么&#xff1f;有什么作用&#xff1f; https://blog.csdn.net/Shockang/article/details/118344357 SparkContext 是什么&#xff1f; SparkContext 是通往 Spark 集群的唯一入口&#xff0c;可以用来在 Spark 集群中创建 RDDs 、累加和广播变量( Br…

【css】渐变

渐变是设置一种颜色或者多种颜色之间的过度变化。 两种渐变类型&#xff1a; 线性渐变&#xff08;向下/向上/向左/向右/对角线&#xff09; 径向渐变&#xff08;由其中心定义&#xff09; 1、线性渐变 语法&#xff1a;background-image: linear-gradient(direction, co…

黑客必备的操作系统——kali linux安装

大家经常会在电视里面看到各种炫酷的黑客操作&#xff0c;那么黑客一般用什么操作系统呢&#xff1f;今天小训带大家来安装黑客必备的kali linux-2022操作系统&#xff0c;有兴趣的一起来学习下吧&#xff01; 1、安装前准备 1.1 VMware下载 VMware官网下载&#xff1a; ht…

开源力量再现,国产操作系统商业化的全新探索

文章目录 1. 开源运动的兴起2. 开源力量的推动3. 国产操作系统的崭露头角3.1 国产操作系统有哪些 4.国产操作系统的商业化探索5.开源力量对国产操作系统商业化的推动 操作系统作为连接硬件、中间件、数据库、应用软件的纽带&#xff0c;被认为是软件技术体系中最核心的基础软件…

linux4.0新增32位ARM的系统调用

修改内核源码 Linux系统为每一个系统调用赋予一个系统调用号。当应用程序执行一个系统调用时&#xff0c;应用程序就可以知道执行和调用到哪个系统调用了&#xff0c;从而不会造成混乱。系统调用号一旦分配之后就不会有任何变更&#xff0c;否则已经编译好的应用程序就不能运行…

runtime error: member access within misaligned address(力扣最常见错误之一)

runtime error: member access within misaligned address&#xff08;力扣最常见错误之一&#xff09; 前言原因和解决办法总结 前言 最近博主在刷力扣时&#xff0c;明明代码逻辑都没问题&#xff0c;但总是报下面这个错误&#xff1a; runtime error: member access within…