使用Ray创建高效的深度学习数据管道

大家好,用于训练深度学习模型的GPU功能强大但价格昂贵。为了有效利用GPU,开发者需要一个高效的数据管道,以便在GPU准备好计算下一个训练步骤时尽快将数据传输到GPU,使用Ray可以大大提高数据管道的效率。

1.训练数据管道的结构

首先考虑下面的模型训练伪代码:

for step in range(num_steps):
  sample, target = next(dataset) # 步骤1
  train_step(sample, target) # 步骤2

在步骤1中,获取下一个小批量的样本和标签。在步骤2中,它们被传递给train_step函数,该函数会将它们复制到GPU上,执行前向传递和反向传递以计算损失和梯度,并更新优化器的权重。

当数据集太大无法放入内存时,步骤1将从磁盘或网络中获取下一个小批量数据。此外步骤1还涉及一定量的预处理,输入数据必须转换为数字张量或张量集合,然后再馈送给模型。在某些情况下,在将它们传递给模型之前,张量上还会应用其他转换,例如归一化、绕轴旋转等。

如果工作流程是严格按顺序执行的,即先执行步骤1,然后再执行步骤2,那么模型将始终需要等待下一批数据的输入、输出和预处理操作。GPU将无法得到有效利用,它将在加载下一个小批量数据时处于空闲状态。

为了解决这个问题,可以将数据管道视为生产者——消费者的问题。数据管道生成小批量数据并写入有界缓冲区。模型/GPU从缓冲区中消费小批量数据,执行前向/反向计算并更新模型权重。如果数据管道能够以模型/GPU消费的速度快速生成小批量数据,那么训练过程将会非常高效。

图片

2.Tensorflow tf.data API

Tensorflow tf.data API提供了一组丰富的功能,可用于高效创建数据管道,使用后台线程获取小批量数据,使模型无需等待。仅仅预先获取数据还不够,如果生成小批量数据的速度比GPU消费数据的速度慢,那么就需要使用并行化来加快数据的读取和转换。为此,Tensorflow提供了交错功能以利用多个线程并行读取数据,以及并行映射功能使用多个线程对小批量数据进行转换。

由于这些API基于多线程,因此可能会受到Python全局解释器锁(GIL)的限制。Python GIL限制了Python解释器一次只能运行单个线程的字节码。如果在管道中使用纯TensorFlow代码,通常不会受到这种限制,因为TensorFlow核心执行引擎在GIL的范围之外工作。但是,如果使用的第三方库没有发布GIL或者使用Python进行大量计算,那么依赖多线程来并行化管道就不可行。

3.使用多进程并行化数据管道

考虑以下生成器函数,该函数模拟加载和执行一些计算以生成小批量数据样本和标签。

def data_generator():
  for _ in range(10):
    # 模拟获取
    # 从磁盘/网络
    time.sleep(0.5)
    # 模拟计算
    for _ in range(10000):
      pass
    yield (
        np.random.random((4, 1000000, 3)).astype(np.float32), 
        np.random.random((4, 1)).astype(np.float32)
    )

接下来,在虚拟的训练管道中使用该生成器,并测量生成小批量数据所花费的平均时间。

generator_dataset = tf.data.Dataset.from_generator(
    data_generator,
    output_types=(tf.float64, tf.float64),
    output_shapes=((4, 1000000, 3), (4, 1))
).prefetch(tf.data.experimental.AUTOTUNE)

st = time.perf_counter()
times = []
for _ in generator_dataset:
    en = time.perf_counter()
    times.append(en - st)
    # 模拟训练步骤
    time.sleep(0.1)
    st = time.perf_counter()

print(np.mean(times))

据观察,平均耗时约为0.57秒(在配备Intel Core i7处理器的Mac笔记本电脑上测量)。如果这是一个真实的训练循环,GPU的利用率将相当低,它只需花费0.1秒进行计算,然后闲置0.57秒等待下一个批次数据。

为了加快数据加载速度,可以使用多进程生成器。

from multiprocessing import Queue, cpu_count, Process
def mp_data_generator():

    def producer(q):
        for _ in range(10):
            # 模拟获取
            # 从磁盘/网络
            time.sleep(0.5)
            # 模拟计算
            for _ in range(10000000):
                pass
            q.put((
                np.random.random((4, 1000000, 3)).astype(np.float32),
                np.random.random((4, 1)).astype(np.float32)
            ))
        q.put("DONE")

    queue = Queue(cpu_count()*2)

    num_parallel_processes = cpu_count()
    producers = []
    for _ in range(num_parallel_processes):
        p = Process(target=producer, args=(queue,))
        p.start()
        producers.append(p)
    done_counts = 0
    while done_counts < num_parallel_processes:
        msg = queue.get()
        if msg == "DONE":
            done_counts += 1
        else:
            yield msg
    queue.join()

测量等待下一个小批次数据所花费的时间,得到的平均时间为0.08秒,速度提高了近7倍,但理想情况下,希望这个时间接近0。

如果进行分析,可以发现相当多的时间都花在了准备数据的反序列化上。在多进程生成器中,生产者进程会返回大型NumPy数组,这些数组需要进行准备,然后在主进程中进行反序列化。

4.使用Ray并行化数据管道

Ray是一个用于在Python中运行分布式计算的框架,它带有一个共享内存对象存储区,可在不同进程间高效地传输对象。在不进行任何序列化和反序列化的情况下,对象存储区中的Numpy数组可在同一节点上的worker之间共享。Ray还可以轻松实现数据加载在多台机器上的扩展,并使用Apache Arrow高效地序列化和反序列化大型数组。

Ray带有一个实用函数from_iterators,可以创建并行迭代器,开发者可以用它包装data_generator生成器函数。

import ray
def ray_generator():
    num_parallel_processes = cpu_count()
    return ray.util.iter.from_iterators(
        [data_generator]*num_parallel_processes
    ).gather_async()

使用ray_generator,测量等待下一个小批量数据所花费的时间为0.02秒,比使用多进程处理的速度提高了4倍。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/201330.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

7. 栈

栈(stack)是一种遵循先入后出的逻辑的线性数据结构。我们可以将栈类比为桌面上的一摞盘子&#xff0c;如果需要拿出底部的盘子&#xff0c;则需要先将上面的盘子依次取出。我们将盘子替换为各种类型的元素&#xff08;如整数、字符、对象等&#xff09;&#xff0c;就得到了栈数…

二叉树OJ题之二

今天我们一起来看一道判断一棵树是否为对称二叉树的题&#xff0c;力扣101题&#xff0c; https://leetcode.cn/problems/symmetric-tree/ 我们首先先来分析这道题&#xff0c;要判断这道题是否对称&#xff0c;我们首先需要判断的是这颗树根节点的左右子树是否对称&#xff0…

基于AOP的声明式事物控制

目录 Spring事务编程概述 基于xml声明式事务控制 事务属性 isolation timeout read-only propagation 全注解开发 Spring事务编程概述 事务是开发中必不可少的东西&#xff0c;使用JDBC开发时&#xff0c;我们使用connection对事务进行控制&#xff0c;使用MyBatis时&a…

算法基础之字符串哈希

字符串哈希 核心思想&#xff1a;用p(131或者13331)进制数储存字符串每一位数的hash值 L—R的哈希值 h[R]-h[L-1]*PR-L1 哈希值很大—>modQ(264)变小 用unsigned long long 存 (出界) #include<iostream>using namespace std;typedef unsigned long long ULL;co…

嵌入式八股 | 校招秋招 | 笔试面试 | 精选题目

欢迎关注微信公众号【赛博二哈】获取八股PDF 并加入嵌入式求职交流群。提供简历模板、学习路线、岗位整理等 欢迎加入知识星球【嵌入式求职星球】获取完整嵌入式八股。 提供简历修改、项目推荐、求职规划答疑。另有各城市、公司岗位、笔面难题、offer选择、薪资爆料等 嵌入式…

【知识】简单理解为何GCN层数越多越能覆盖多跳邻居聚合信息范围更广

转载请注明出处&#xff1a;小锋学长生活大爆炸[xfxuezhang.cn] 背景说明 大多数博客在介绍GCN层数时候&#xff0c;都会提到如下几点(经总结)&#xff1a; 在第一层&#xff0c;节点聚合来自其直接邻居的信息。在第二层&#xff0c;由于每个节点现在包含了其直接邻居的信息&a…

如何设置Linux终端提示信息

如何设置Linux终端提示信息 1 方法一&#xff1a;只能在VSCode或者Pycharm终端显示提示信息2 方法二&#xff1a;只能在MobaXterm等远程软件上显示提示3 方法三&#xff1a;避免用户没看到上面的提示&#xff0c;上面两种都设置一下 在使用远程终端时&#xff0c;由于多用户使用…

在很多nlp数据集上超越tinybert 的新架构nlp神经网络模型

在很多nlp数据集上超越tinybert 的新架构nlp神经网络模型 网络结构图测试代码网络结构图 测试代码 import paddle import numpy as np import pandas as pd from tqdm import tqdmclass FeedFroward(paddle.nn.Layer):

园区智能配电系统(电力智能监控系统)

园区智能配电系统是一种针对园区电力配送和管理的智能化系统。它的主要功能是实时监控设备运行情况&#xff0c;进行电能质量分析&#xff0c;监控电能损耗&#xff0c;以及分时段用电统计等。 具体来说&#xff0c;园区智能配电系统可以利用现代技术如RS-485总线通信、数据库管…

一、Gradle 手动创建一个项目

文章目录 Gradle 介绍Gradle Wrapper Gradle 使用手动安装 Gradle初始化 Gradle 介绍 Gradle 是一个快速的、可信的、适应性强的自动化构建工具&#xff0c;它是开源的。它使用优雅的并且可扩展的描述性语言。其他的介绍在官网可以了解。 Gradle Wrapper 官方建议使用 Gradl…

找不到 sun.misc.BASE64Decoder ,sun.misc.BASE64Encoder 类

找不到 sun.misc.BASE64Decoder &#xff0c;sun.misc.BASE64Encoder 类 1. 现象 idea 引用报错 找不到对应的包 import sun.misc.BASE64Decoder; import sun.misc.BASE64Encoder;2. 原因 因为sun.misc.BASE64Decoder和sun.misc.BASE64Encoder是Java的内部API&#xff0c;通…

AI模型训练——入门篇(二)

导语&#xff1a;本文主要介绍了基于BERT的文本分类方法&#xff0c;通过使用huggingface的transformers库实现自定义模型和任务。具体步骤包括&#xff1a;使用load_dataset函数加载数据集&#xff0c;并应用自定义的分词器&#xff1b;使用map函数将自定义分词器应用于数据集…

【从删库到跑路 | MySQL总结篇】表的增删查改(进阶下)

个人主页&#xff1a;兜里有颗棉花糖 欢迎 点赞&#x1f44d; 收藏✨ 留言✉ 加关注&#x1f493;本文由 兜里有颗棉花糖 原创 收录于专栏【MySQL学习专栏】&#x1f388; 本专栏旨在分享学习MySQL的一点学习心得&#xff0c;欢迎大家在评论区讨论&#x1f48c; 目录 一、联合…

医学检验(LIS)管理系统源码,LIS源码,云LIS系统源码

医学检验(LIS)管理系统源码&#xff0c;云LIS系统全套商业源码 随着全自动生化分析仪、全自动免疫分析仪和全自动血球计数器等仪器的使用&#xff0c;检验科的大多数项目实现了全自动化分析。全自动化分析引入后&#xff0c;组合化验增多&#xff0c;更好的满足了临床需要&…

离散化笔记

文章目录 离散化的适用条件离散化的意思AcWing 802. 区间和CODECODE2 离散化的适用条件 离散化用于区间求和问题对于数域极大&#xff0c;而数的量很少的情况下 离散化的意思 背景&#xff1a;对于一个极大数域上的零星几个数进行操作后&#xff0c;求某段区间内的和 其实意思…

【机器学习 | 可视化系列】可视化系列 之 决策树可视化

&#x1f935;‍♂️ 个人主页: AI_magician &#x1f4e1;主页地址&#xff1a; 作者简介&#xff1a;CSDN内容合伙人&#xff0c;全栈领域优质创作者。 &#x1f468;‍&#x1f4bb;景愿&#xff1a;旨在于能和更多的热爱计算机的伙伴一起成长&#xff01;&#xff01;&…

沈阳师范大学期末考试复习pta循环数组函数指针经典编程题汇总+代码分析

前言&#xff1a;临近期末&#xff0c;接下来给大家分享一些经典的编程题&#xff0c;方便大家复习。不一定难&#xff0c;但都是入门的好题&#xff0c;尽可能的吃透彻。因为据说期末考试的题很多来自pta上面的原题。 对于一些语言我是用c来写的&#xff0c;不妨碍理解&#…

express+multer实现简单的文件上传功能

expressmulter实现简单的文件上传功能 1.安装multer和uuid依赖 cnpm install -S uuid multer2.添加multer的配置文件 在config文件夹下添加uploa.js文件&#xff0c;内容如下&#xff1a; // 引入multer const multer require(multer) // uuid : 用于生成不重复的由英文组…

Java EE 多线程

文章目录 1. 认识线程1.1 什么是进程1.2 什么是线程1.2.1. 线程是怎么做到的呢&#xff1f;1.2.2. 进程和线程的关系 1.3 多线程编程1.3.1. 第一个多线程程序1.3.2. 使用 jconsole 命令查看线程1.3.3. 实现 Runnable 接口&#xff0c;重写 run1.3.4. 继承 Thread 重写 run&…
最新文章