streamlit data_editor学习之 LLM理论内存占用量计算器

streamlit data_editor学习之 LLM理论内存占用量计算器

  • 一.效果
  • 二.代码
  • 三.运行命令
  • 四.参考链接

根据用户设置的LLM参数,计算设备内存的占用量。以web的形式方便共享,可以插入多条记录,表格更新后,可以动态计算结果

一.效果

在这里插入图片描述

二.代码

import streamlit as st  #1.31.1
import cv2
import math
from collections import OrderedDict
import pandas as pd

NUM_BYTES_IN_MEGABYTE = 1024 * 1024 * 1024

# 计算公式来源:           https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/training/theoretical_memory_usage.py
# st.data_editor用法参考: https://zhuanlan.zhihu.com/p/686385274

def compute_weight_and_optimizer_memory(args, verbose=True):
    # Attention projection size.
    
    if args.kv_channels==0:
        args.kv_channels = args.hidden_size // args.num_attention_heads
    
    query_projection_size = args.kv_channels * args.num_attention_heads
    query_projection_to_hidden_size_ratio = query_projection_size / args.hidden_size
    # Group Query Attention.
    if not args.group_query_attention:
        args.num_query_groups = args.num_attention_heads
    # MoE.
    num_experts = 1 if args.num_experts is None else args.num_experts
    gated_linear_multiplier = 3 / 2 if args.swiglu else 1
    num_parameters_in_transformer_layers = (
        2
        * args.num_layers
        * args.hidden_size
        * args.hidden_size
        * (
            # Attention.
            (
                (1 + (args.num_query_groups / args.num_attention_heads))
                * query_projection_to_hidden_size_ratio
            )
            # MLP.
            + ((args.ffn_hidden_size / args.hidden_size) * num_experts * gated_linear_multiplier)
            # Transformer layernorms.
            + (2 / args.hidden_size)
            # Final layernorm.
            + (1 / (args.num_layers * args.hidden_size))
        )
    )
    embedding_size = args.hidden_size * args.padded_vocab_size
    if args.untie_embeddings_and_output_weights:
        num_parameters_in_embedding_layers = 2 * embedding_size
    else:
        num_parameters_in_embedding_layers = embedding_size
    num_total_parameters = num_parameters_in_transformer_layers + num_parameters_in_embedding_layers
    if verbose:
        print(
            f"Number of parameters in transformer layers in billions: "
            f"{num_parameters_in_transformer_layers / 10**9: .2f}"
        )
        print(
            f"Number of parameters in embedding layers in billions: "
            f"{num_parameters_in_embedding_layers / 10**9:.2f}"
        )
        print(f"Total number of parameters in billions: {num_total_parameters / 10**9:.2f}")

    # Most loaded model shard has (1/pp_size transformer layers + 1 embedding layer) / tp_size.
    num_parameters_on_most_loaded_model_shard = (
        (num_parameters_in_transformer_layers / args.pipeline_model_parallel_size) + embedding_size
    ) / args.tensor_model_parallel_size
    if args.untie_embeddings_and_output_weights and args.pipeline_model_parallel_size == 1:
        num_parameters_on_most_loaded_model_shard += (
            embedding_size / args.tensor_model_parallel_size
        )
    if verbose:
        print(
            f"Number of parameters in most loaded shard in billions: "
            f"{num_parameters_on_most_loaded_model_shard / 10**9:.4f}"
        )

    if args.pipeline_model_parallel_size > 1:
        # Other shards just have (1/pp_size transformer layers) / tp_size.
        num_parameters_on_other_model_shards = num_parameters_in_transformer_layers / (
            args.pipeline_model_parallel_size * args.tensor_model_parallel_size
        )
        if verbose:
            print(
                f"Number of parameters in other shards in billions: "
                f"{num_parameters_on_other_model_shards / 10**9:.4f}"
            )

    num_bytes_per_parameter = (
        18 if not args.use_distributed_optimizer else 6 + (12 / args.data_parallel_size)
    )
    weight_and_optimizer_memory = (
        num_parameters_on_most_loaded_model_shard * num_bytes_per_parameter
    )

    return weight_and_optimizer_memory


def compute_activation_memory(args, num_microbatches, verbose=False):
    # Using formula in Table 2 of https://arxiv.org/pdf/2205.05198.pdf.
    # We are trying to compute the maximum activation footprint, so all calculations in this
    # function are for the first pipeline stage.

    # TODO: This function needs to take into account query_projection_size potentially being
    # different from hidden_size.

    # Memory footprint from transformer layer (self-attention and MLP).
    activation_memory = (args.seq_length * args.micro_batch_size * args.hidden_size) * (
        18 + (4 * (args.ffn_hidden_size / args.hidden_size))
    )
    if verbose:
        print(
            f"Activation memory footprint per transformer layer: "
            f"{activation_memory / NUM_BYTES_IN_MEGABYTE / args.tensor_model_parallel_size:.1f} MB"
        )
    activation_memory *= args.num_layers

    # Now add activation memory required for input embeddings, last LayerNorm and output layer.

    # Input to embedding (pp_size microbatches in flight).
    activation_memory += (
        8 * args.seq_length * args.micro_batch_size * args.pipeline_model_parallel_size
    )
    # Dropout in embedding layer (pp_size microbatches in flight).
    activation_memory += (
        args.seq_length
        * args.micro_batch_size
        * args.hidden_size
        * args.pipeline_model_parallel_size
    )

    # Multiply by interleaved PP memory factor.
    if args.virtual_pipeline_model_parallel_size>0:
        interleaved_schedule_memory_penalty = 1 + (
            (args.pipeline_model_parallel_size - 1)
            / (args.pipeline_model_parallel_size * args.virtual_pipeline_model_parallel_size)
        )
        in_flight_microbatches = math.ceil(
            interleaved_schedule_memory_penalty * args.pipeline_model_parallel_size
        )
        if verbose:
            print(
                f"Memory penalty from interleaved schedule: {interleaved_schedule_memory_penalty:.2f}"
            )
            print(f"Number of in-flight microbatches: {in_flight_microbatches}")
        activation_memory *= interleaved_schedule_memory_penalty

    # If using non-interleaved schedule, number of microbatches in pipeline can be less than pp_size,
    # so discount accordingly.
    if args.virtual_pipeline_model_parallel_size>0 and args.pipeline_model_parallel_size > 1:
        if num_microbatches is not None:
            activation_memory *= min(1, num_microbatches / args.pipeline_model_parallel_size)
            in_flight_microbatches = min(num_microbatches, args.pipeline_model_parallel_size)
        else:
            in_flight_microbatches = args.pipeline_model_parallel_size
        if verbose:
            print(f"Number of in-flight microbatches: {in_flight_microbatches}")

    if args.pipeline_model_parallel_size == 1:
        # Inputs to output layer and CE loss.
        activation_memory += (
            args.seq_length
            * args.micro_batch_size
            * args.hidden_size
            * 4
            * (1 + (args.padded_vocab_size / args.hidden_size))
        )

    # Activation memory is partitioned by TP size due to tensor and sequence model parallelism.
    return activation_memory / args.tensor_model_parallel_size


def report_theoretical_memory(args, num_microbatches=None, verbose=False):
    weight_and_optimizer_memory = (
        compute_weight_and_optimizer_memory(args, verbose=verbose) / NUM_BYTES_IN_MEGABYTE
    )

    # Formulae here assume sequence parallelism and selective activation recomputation.
    if not args.sequence_parallel:# or args.recompute_granularity != 'selective':
        print(
            f"Theoretical memory footprints: weight and optimizer={weight_and_optimizer_memory:.2f} MB"
        )
        return

    activation_memory = (
        compute_activation_memory(args, num_microbatches=num_microbatches, verbose=verbose)
        / NUM_BYTES_IN_MEGABYTE
    )
    total_memory = weight_and_optimizer_memory + activation_memory

    # print(
    #     f"batch:{args.micro_batch_size} Theoretical memory footprints: weight and optimizer={weight_and_optimizer_memory:.2f} MB, "
    #     f"activation={activation_memory:.2f} MB, total={total_memory:.2f} MB\n"
    # )
    return weight_and_optimizer_memory,activation_memory,total_memory

class Parameter:
    def __init__(self):
        self.__dict__['data'] = []
    def __setattr__(self, key, value):        
        if key not in self.keys():
            self.data.append([key,value])
        else:
            if isinstance(value,tuple):
                self.data[self.keys().index(key)][1]=value
            else:
                temp=list(self.data[self.keys().index(key)][1])
                temp[0]=value
                self.data[self.keys().index(key)][1]=temp
    def fullname(self,name):
        for k,value in self.data:
            if value[1]==name:
                return k
        raise ValueError(f"{name} not found")    
    def keys(self):
        return [x[0] for x in self.data]
    def values(self):
        return [x[1] for x in self.data]    
    def __getattr__(self, key):
        if key in self.keys():
            return self.values()[self.keys().index(key)][0]        
    def __getitem__(self, key):
        if key in self.keys():
            return self.values()[self.keys().index(key)]
 
def apply_de_change(df0,default_params,changes):
    add_rows = changes.get('added_rows')
    edited_rows = changes.get('edited_rows')
    deleted_rows = changes.get('deleted_rows')
    for idx, row in edited_rows.items():
        for name, value in row.items():
            df0.loc[df0.index[idx], name] = value
    df0.drop(df0.index[deleted_rows], inplace=True)
    ss = []
    has_index = add_rows and '_index' in add_rows[0]
    for add_row in add_rows:
        if '_index' in add_row:
            ss.append(pd.Series(data=add_row, name=add_row.pop('_index')))
        else:
            ss.append(pd.Series(data=add_row))
    df_add = pd.DataFrame(ss)
    data= pd.concat([df0, df_add], axis=0) if has_index else pd.concat([df0, df_add], axis=0, ignore_index=True)
    keys=data.keys().tolist()
    for idx, row in data.iterrows():
        for k in keys:
            default_params.__setattr__(default_params.fullname(k),row[k])
        default_params.weight_and_optimizer_memory,default_params.activation_memory,default_params.total_memory=report_theoretical_memory(default_params)
        for k in keys:
            data.loc[idx,k]=default_params.__getattr__(default_params.fullname(k))
    return data

def data_editor_change(key,default_params,editor_key):
    st.session_state[key] = apply_de_change(st.session_state[key],default_params,st.session_state[editor_key])

def df_editor_key(key):
    return "llm_table_"+key

def set_customer_style():
    #手机上column依然保持在一行,而不是一列
    st.write('''<style>
    [data-testid="column"] {
        width: calc(16.6666% - 1rem) !important;
        flex: 1 1 calc(16.6666% - 1rem) !important;
        min-width: calc(16.6666% - 1rem) !important;        
    }
    </style>''', unsafe_allow_html=True)
    
    #去掉顶部的padding,使得在手机上的空间更紧致(配合--client.toolbarMode="minimal"使用)
    st.write('''<style>
    [data-testid="stAppViewBlockContainer"] {
        padding: 18px;
    }
    </style>''', unsafe_allow_html=True)   


# 运行命令 streamlit.exe run main.py --client.toolbarMode="minimal"

if __name__ == "__main__":
    
    #初始化默认参数 default_params.变量名=(缺省值,"表格的字段名")
    default_params=Parameter()
    default_params.name=("Llama-2-13b","Name")
    default_params.micro_batch_size=(1,"Batch")
    default_params.seq_length=(512,"SEQ")
    default_params.padded_vocab_size=(32000,"VOCAB")
    default_params.hidden_size=(5120,"HIDDEN")
    default_params.ffn_hidden_size=(13824,"FFN")
    default_params.kv_channels=(0,"KVC")
    default_params.num_attention_heads=(40,"HEAD")
    default_params.num_query_groups=(0,"QG")
    default_params.num_layers=(40,"LAYER")
    default_params.num_experts=(1,"MOE")
    default_params.virtual_pipeline_model_parallel_size=(1,"VP")
    default_params.pipeline_model_parallel_size=(1,"PP")
    default_params.tensor_model_parallel_size=(1,"TP")
    default_params.data_parallel_size=(1,"DP")
    default_params.use_distributed_optimizer=(False,"DOPT")
    default_params.group_query_attention=(False,"GQA")
    default_params.sequence_parallel=(True,"SP")
    default_params.swiglu=(True,"SWIGGLU")
    default_params.untie_embeddings_and_output_weights=(False,"UNTIE")
    
    #用默认参数,计算内存占用量
    v1,v2,v3=report_theoretical_memory(default_params)
    default_params.weight_and_optimizer_memory=(v1,"权值优化器(GB)")
    default_params.activation_memory=(v2,"激活(GB)")
    default_params.total_memory=(v3,"总计(GB)")
    
    #创建DataFrame并根据字段的数据类型,创建column配置
    column_config={}
    default_data=OrderedDict()
    for fullname,v in zip(default_params.keys(),default_params.values()):
        default_value,shortname=v
        default_data[shortname]=default_value
        value_class_name=default_value.__class__.__name__
        if value_class_name=="bool":        
            column_config[shortname]=st.column_config.CheckboxColumn(shortname,help=fullname,default=default_value)
        elif value_class_name=="str":
            column_config[shortname]=st.column_config.TextColumn(shortname,help=fullname,default=default_value,validate="^st\.[a-z_]+$")
        elif value_class_name=="int":
            column_config[shortname]=st.column_config.NumberColumn(shortname,help=fullname,default=default_value,format="%d")
        elif value_class_name=="float":
            column_config[shortname]=st.column_config.NumberColumn(shortname,help=fullname,default=default_value,format="%.3f")
        else:
            raise ValueError(f"{value_class_name} not supported")
    
    #赋值给session_state
    df_default_key = 'llm_table'
    df_editor_key="llm_table_edit"
    if df_default_key not in st.session_state:
        st.session_state[df_default_key] = pd.DataFrame([default_data])
    
    
    st.set_page_config(page_title="LLM内存计算", layout="wide")
    set_customer_style()
    st.markdown("<h1 style='text-align: center; color: black;'>LLM内存计算</h1>", unsafe_allow_html=True)
        
    chart_df=st.data_editor(
        st.session_state[df_default_key].copy(),
        key=df_editor_key,
        on_change=data_editor_change,
        args=(df_default_key,default_params,df_editor_key),  
        height=400,
        num_rows="dynamic",
        column_config=column_config,
        disabled=["权值优化器(GB)","激活(GB)","总计(GB)","SP"],
        hide_index=False,
        use_container_width=True)

三.运行命令

streamlit.exe run main.py --client.toolbarMode="minimal"

四.参考链接

  1. 公式来源
  2. st.data_editor表格更新

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/575535.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【八股】Spring Boot

SpringBoot是如何实现自动装配的&#xff1f; 首先&#xff0c;SpringBoot的核心注解SpringBootApplication里面包含了三个注解&#xff0c;SpringBootConfigurationEnableAutoConfigurationComponentScan&#xff0c;其中EnableAutoConfiguration是实现自动装配的注解&#x…

如何最大程度使用AWS?

随着云计算技术的不断发展&#xff0c;AWS已经成为众多企业的首选&#xff0c;为其提供了强大的基础设施和服务。那么如何最大程度地、灵活地利用AWS&#xff0c;成为许多企业专注的焦点。九河云作为AWS的合作伙伴&#xff0c;为读者们提供一些技巧和策略&#xff0c;帮助读者充…

UL认证防逆流多功能监测装置AGF-AE-D

安科瑞薛瑶瑶18701709087/17343930412 在单逆变器系统中&#xff0c;仪表直接与逆变器相连。如果您的变频器有一个内置的收入等级表&#xff08;RGM&#xff1b;该变频器 被称为收入等级变频器&#xff09;&#xff0c;您可以在 RGM 的同一总线上连接一个外部仪表。

【React】Sigma.js框架网络图-入门篇(2)

通过《【React】Sigma.js框架网络图-入门篇》有了基本认识 由于上一篇直接给出了基本代码示例&#xff0c;可能看着比较复杂也不知道是啥意思&#xff1b; 今天从理论入手重新认识下&#xff01; 一、基本认识 首先&#xff0c;我们先了解下基础术语&#xff1a; 图(Graph)&…

波高仪:数字浪高仪解析

波高仪&#xff0c;也被称为数字浪高仪&#xff0c;是一种专门用于测量波浪高度的设备。它采用低功耗微处理器、24bit高精度AD转换器和长距离通信技术&#xff0c;配备电容式波高传感器&#xff0c;具有线性好、功耗低、量精度高、传输距离远、性能稳定、抗干扰能力强等特点。 …

vue中使用echarts实现X轴动态时间(天)的折线图表

项目要求x轴以一天为间隔&#xff0c;时间是动态返回的数据&#xff0c;折线图平滑展示 实现代码如下&#xff1a; <div class"echarts-main"><v-chart ref"echarts" :options"options" /> </div>// 局部引入vue-echarts im…

Python实现线性拟合及绘图

Python实现线性拟合及绘图 当时的数字地形实验&#xff0c;使用matplotlib库绘制了一张图表表示不同地形类别在不同分辨率下的RMSE值&#xff0c;并分别拟合了一条趋势线。现在来看不足就是地形较多时&#xff0c;需要使用循环更好一点&#xff0c;不然太冗余了。 代码逻辑 …

【讯为Linux驱动笔记1】申请一个字符设备

Linux下每个设备都需要有一个专属设备号&#xff1a;主设备号 次设备号 【申请字符设备】 主设备号&#xff1a;一类驱动&#xff1a;如&#xff1a;USB驱动 次设备号&#xff1a;这类驱动下的某个设备 如&#xff1a;键盘鼠标 设备号是32位的dev_t类型的&#xff0c;高12位主…

Python对Excel两列数据进行运算

&#x1f47d;发现宝藏 前些天发现了一个巨牛的人工智能学习网站&#xff0c;通俗易懂&#xff0c;风趣幽默&#xff0c;忍不住分享一下给大家。【点击进入巨牛的人工智能学习网站】。 Python对Excel两列数据进行运算 在日常工作中&#xff0c;经常会遇到需要对Excel表格中的数…

Scala 04 —— Scala Puzzle 拓展

Scala 04 —— Scala Puzzle 拓展 文章目录 Scala 04 —— Scala Puzzle 拓展一、占位符二、模式匹配的变量和常量模式三、继承 成员声明的位置结果初始化顺序分析BMember 类BConstructor 类 四、缺省初始值与重载五、Scala的集合操作和集合类型保持一致性第一部分代码解释第二…

Python 数据可视化 boxplot

Python 数据可视化 boxplot import pandas as pd import matplotlib.pyplot as plt import numpy as np import seaborn as sns# 读取 TSV 文件 df pd.read_csv(result.tsv, sep\t)normal_df df[df["sample_name"].str.contains("normal")] tumor_df df…

【Git教程】(十五)二分法排错 — 概述及使用要求,执行过程及其实现(用二分法人工排错或自动排错),替代解决方案 ~

Git教程 二分法排错 1️⃣ 概述2️⃣ 使用要求3️⃣ 执行过程及其实现3.1 用二分法人工排错3.2 用二分法自动排错 4️⃣ 替代解决方案 在开发过程中&#xff0c;我们经常会突然遇到一个错误&#xff0c;是之前早期版本在成功通过测试时没有出现过的。这时候&#xff0c;时下较…

基于实现地图弹窗轮播功能及遇到的问题解决

基本使用 获取地图 geojson 数据 链接&#xff1a; 阿里云数据可视化平台 获取ECharts npm install echarts 或者是使用地址链接 <script src"https://registry.npmmirror.com/echarts/5.4.3/files/dist/echarts.min.js"></script> <script src…

关于螺栓的注意事项和正确操作方法——SunTorque智能扭矩系统

智能扭矩系统-智能拧紧系统-扭矩自动控制系统-SunTorque 螺栓&#xff0c;作为一种常见的紧固件&#xff0c;广泛应用于各种机械设备和结构中。在日常生活和工作中&#xff0c;我们经常需要接触到螺栓&#xff0c;因此了解螺栓的一些注意事项和正确操作方法对于确保设备的安全…

【C#】Stopwatch计时器

使用Stopwatch检查C#中代码块的执行时间&#xff0c;比如歌曲&#xff0c;图片的下载时间问题 首先&#xff0c;我们可看到Stopwatch 类内部的函数。 根据需求&#xff0c;我们具体可使用到 Start() 开始计时&#xff0c;Stop() 停止计时等 //创建 Stopwatch 实例 Stopwatch …

Intersection Observer API探索

我们经常遇到这样的需求——检测一个元素是否可见或者两个元素是否相交&#xff0c;如 ● 图片懒加载——当图片滚动到可见时才进行加载 ● 内容无限滚动——也就是用户滚动到接近内容底部时直接加载更多&#xff0c;而无需用户操作翻页&#xff0c;给用户一种网页可以无限滚动…

分布式密钥生成

可验证且无经销商 分布式密钥生成 (DKG) 是一种加密协议&#xff0c;使多方能够协作生成共享密钥&#xff0c;而无需任何一方完全了解密钥。 它通过在多个参与者之间分配信任来增强各种应用程序的安全性&#xff0c;从而降低密钥泄露的风险。 我们引入了一种可验证且无经销商的…

深度学习从入门到精通—Transformer

1.绪论介绍 1.1 传统的RNN网络 传统的RNN&#xff08;递归神经网络&#xff09;主要存在以下几个问题&#xff1a; 梯度消失和梯度爆炸&#xff1a;这是RNN最主要的问题。由于序列的长距离依赖&#xff0c;当错误通过层传播时&#xff0c;梯度可以变得非常小&#xff08;消失…

mybatisplus3.5.4基础生成代码完整步骤(超详细)

在网上看了很多自动生成的例子本地不是很好使&#xff0c;最后找到了一套好用的&#xff0c;适合版本&#xff1a; idea:2024.1 springboot2.6.12 java17 mybatisplus3.5.4 废话不多说&#xff0c;直接上步骤&#xff1a; 新建项目&#xff1a; 结构如下&#xff1a; 添加依…
最新文章