从零构建属于自己的GPT系列5:模型本地化部署(文本生成函数解读、模型本地化部署、文本生成文本网页展示、代码逐行解读)

🚩🚩🚩Hugging Face 实战系列 总目录

有任何问题欢迎在下面留言
本篇文章的代码运行界面均在PyCharm中进行
本篇文章配套的代码资源已经上传

从零构建属于自己的GPT系列1:数据预处理
从零构建属于自己的GPT系列2:模型训练1
从零构建属于自己的GPT系列3:模型训练2
从零构建属于自己的GPT系列4:模型训练3

1 前端环境安装

安装:

pip install streamlit

测试:

streamlit hello

安装完成后,测试后打印的信息
在这里插入图片描述

(Pytorch) C:\Users\admin>streamlit hello
Welcome to Streamlit. Check out our demo in your browser.
Local URL: http://localhost:8501 Network URL:
http://192.168.1.187:8501
Ready to create your own Python apps super quickly? Head over to
https://docs.streamlit.io
May you create awesome apps!

接着会自动的弹出一个页面
在这里插入图片描述

2 模型加载函数

这个函数把模型加载进来,并且设置成推理模式

def get_model(device, model_path):
    tokenizer = CpmTokenizer(vocab_file="vocab/chinese_vocab.model")
    eod_id = tokenizer.convert_tokens_to_ids("<eod>")  # 文档结束符
    sep_id = tokenizer.sep_token_id
    unk_id = tokenizer.unk_token_id
    model = GPT2LMHeadModel.from_pretrained(model_path)
    model.to(device)
    model.eval()
    return tokenizer, model, eod_id, sep_id, unk_id
  1. 模型加载函数,加载设备cuda,已经训练好的模型的路径
  2. 加载tokenizer 文件
  3. 结束特殊字符
  4. 分隔特殊字符
  5. 未知词特殊字符
  6. 加载模型
  7. 模型进入GPU
  8. 开启推理模式
  9. 返回参数
device_ids = 0
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICE"] = str(device_ids)
device = torch.device("cuda" if torch.cuda.is_available() and int(device_ids) >= 0 else "cpu")
tokenizer, model, eod_id, sep_id, unk_id = get_model(device, "model/zuowen_epoch40")
  1. 指定第一个显卡
  2. 设置确保 CUDA 设备的编号与 PCI 位置相匹配,使得 CUDA 设备的编号更加一致且可预测
  3. 通过设置为 str(device_ids)(在这个案例中为 ‘0’),指定了进程只能看到并使用编号为 0 的 GPU
  4. 有GPU用GPU作为加载设备,否则用CPU
  5. 调用get_model函数,加载模型

3 文本生成函数

对于给定的上文,生成下一个单词

def generate_next_token(input_ids,args):

    input_ids = input_ids[:, -200:]
    outputs = model(input_ids=input_ids)
    logits = outputs.logits
    next_token_logits = logits[0, -1, :]
    next_token_logits = next_token_logits / args.temperature
    next_token_logits[unk_id] = -float('Inf')
    filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=args.top_k, top_p=args.top_p)
    next_token_id = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1)
    return next_token_id
  1. 对输入进行一个截断操作,相当于对输入长度进行了限制
  2. 通过模型得到预测,得到输出,预测的一个词一个词进行预测的
  3. 得到预测的结果值
  4. next_token_logits表示最后一个token的hidden_state对应的prediction_scores,也就是模型要预测的下一个token的概率
  5. 温度表示让结果生成具有多样性
  6. 设置预测的结果不可以未知字(词)的Token,防止出现异常的东西
  7. 通过top_k_top_p_filtering()函数对预测结果进行筛选
  8. 通过预测值转换为概率,得到实际的Token ID
  9. 返回结果

每次都是通过这种方式预测出下一个词是什么

4 多文本生成函数

到这里就不止是预测下一个词了,要不断的预测

def predict_one_sample(model, tokenizer, device, args, title, context):
    title_ids = tokenizer.encode(title, add_special_tokens=False)
    context_ids = tokenizer.encode(context, add_special_tokens=False)
    input_ids = title_ids + [sep_id] + context_ids
    cur_len = len(input_ids)
    last_token_id = input_ids[-1]  # 已生成的内容的最后一个token
    input_ids = torch.tensor([input_ids], dtype=torch.long, device=device)

    while True:
        next_token_id = generate_next_token(input_ids,args)
        input_ids = torch.cat((input_ids, next_token_id.unsqueeze(0)), dim=1)
        cur_len += 1
        word = tokenizer.convert_ids_to_tokens(next_token_id.item())
        # 超过最大长度,并且换行
        if cur_len >= args.generate_max_len and last_token_id == 8 and next_token_id == 3:
            break
        # 超过最大长度,并且生成标点符号
        if cur_len >= args.generate_max_len and word in [".", "。", "!", "!", "?", "?", ",", ","]:
            break
        # 生成结束符
        if next_token_id == eod_id:
            break
    result = tokenizer.decode(input_ids.squeeze(0))
    content = result.split("<sep>")[1]  # 生成的最终内容
    return content
  1. 预测一个样本的函数
  2. 从用户获得输入标题转化为Token ID
  3. 从用户获得输入正文转化为Token ID
  4. 标题和正文连接到一起

5 主函数

def writer():
    st.markdown(
        """
        ### 杨卓越定制化GPT生成模型
        """
    )
    st.sidebar.subheader("配置参数")

    generate_max_len = st.sidebar.number_input("generate_max_len", min_value=0, max_value=512, value=32, step=1)
    top_k = st.sidebar.slider("top_k", min_value=0, max_value=10, value=3, step=1)
    top_p = st.sidebar.number_input("top_p", min_value=0.0, max_value=1.0, value=0.95, step=0.01)
    temperature = st.sidebar.number_input("temperature", min_value=0.0, max_value=100.0, value=1.0, step=0.1)

    parser = argparse.ArgumentParser()
    parser.add_argument('--generate_max_len', default=generate_max_len, type=int, help='生成标题的最大长度')
    parser.add_argument('--top_k', default=top_k, type=float, help='解码时保留概率最高的多少个标记')
    parser.add_argument('--top_p', default=top_p, type=float, help='解码时保留概率累加大于多少的标记')
    parser.add_argument('--max_len', type=int, default=512, help='输入模型的最大长度,要比config中n_ctx小')
    parser.add_argument('--temperature', type=float, default=temperature, help='输入模型的最大长度,要比config中n_ctx小')
    args = parser.parse_args()

    context = st.text_area("主内容", max_chars=512)
    title = st.text_area("副内容", max_chars=512)
    if st.button("点我生成结果"):
        start_message = st.empty()
        start_message.write("自毁程序启动中请稍等 10.9.8.7 ...")
        start_time = time.time()
        result = predict_one_sample(model, tokenizer, device, args, title, context)
        end_time = time.time()
        start_message.write("生成完成,耗时{}s".format(end_time - start_time))
        st.text_area("生成结果", value=result, key=None)
    else:
        st.stop()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/234856.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

基于Spring+Spring boot的SpringBoot在线电子商城管理系统

SSM毕设分享 基于SpringSpring boot的SpringBoot在线电子商城管理系统 1 项目简介 Hi&#xff0c;各位同学好&#xff0c;这里是郑师兄&#xff01; 今天向大家分享一个毕业设计项目作品【基于SpringSpring boot的SpringBoot在线电子商城管理系统】 师兄根据实现的难度和等级…

【K8S in Action】服务:让客户端发现pod 并与之通信(1)

服务是一种为一组功能相同的 pod 提供单一不变的接入点的资源。当服务存在时&#xff0c;它的 IP 地址和端口不会改变。 客户端通过 IP 地址和端口号建立连接&#xff0c; 这些连接会被路由到提供该服务的任意一个 pod 上。 pod 是短暂&#xff0c;会删除增加&#xff0c;调度…

基于JavaWeb+SSM+Vue微信小程序的科创微应用平台系统的设计和实现

基于JavaWebSSMVue微信小程序的科创微应用平台系统的设计和实现 源码获取入口Lun文目录前言主要技术系统设计功能截图订阅经典源码专栏Java项目精品实战案例《500套》 源码获取 源码获取入口 Lun文目录 1系统概述 1 1.1 研究背景 1 1.2研究目的 1 1.3系统设计思想 1 2相关技术…

Java基础课的中下基础课04

目录 二十三、集合相关 23.1 集合 &#xff08;1&#xff09;集合的分支 23.2 List有序可重复集合 &#xff08;1&#xff09;ArrayList类 &#xff08;2&#xff09;泛型 &#xff08;3&#xff09;ArrayList常用方法 &#xff08;4&#xff09;Vector类 &#xff08;…

排序算法之六:快速排序(非递归)

快速排序是非常适合使用递归的&#xff0c;但是同时我们也要掌握非递归的算法 因为操作系统的栈空间很小&#xff0c;如果递归的深度太深&#xff0c;容易造成栈溢出 递归改非递归一般有两种改法&#xff1a; 改循环借助栈&#xff08;数据结构&#xff09; 图示算法 不是…

Zookeeper系统性学习-应用场景以及单机、集群安装

Zookeeper 是什么&#xff1f; Zookeeper 为分布式应用提供高效且可靠的分布式协调服务&#xff0c;提供了诸如统一命名服务、配置管理和分布式锁等分布式的基础服务。在解决分布式数据一致性方面&#xff0c;ZooKeeper 并没有直接采用 Paxos 算法&#xff0c;而是采用了名为 …

漏刻有时百度地图API实战开发(8)关键词输入检索获取经纬度坐标和地址

在百度地图中进行关键词输入检索时&#xff1a; 在地图页面顶部的搜索框中输入关键词。点击搜索按钮或按下回车键进行搜索。地图将显示与关键词相关的地点、商家、景点等信息。可以使用筛选和排序功能来缩小搜索范围或更改搜索结果的排序方式。点击搜索结果中的地点或商家&…

办公word-从不是第一页添加页码

总结 实际需要注意的是&#xff0c;分隔符、分节符和分页符并不是一个含义 分隔符包含其他两个&#xff1b;分页符&#xff1a;是增加一页&#xff1b;分节符&#xff1a;指将文档分为几部分。 从不是第一页插入页码1步骤 1&#xff0c;插入默认页码 自己可以测试时通过**…

linux 14网站架构 编译安装mysql数据库

目录 LNMP网站架构下载源码包mysql 下载位置 mysql 安装1.1、清理安装环境&#xff1a;1.2、创建mysql用户1.3、从官网下载tar包1.4、安装编译工具1.5、解压1.6、编译安装编译安装三部曲1.7、初始化初始化,只需要初始化一次1.8、启动mysql1.9、登录mysql1.10、systemctl启动方式…

web 前端之标签练习+知识点

目录 实现过程&#xff1a; 结果显示 1、HTML语法 2、注释标签 3、常用标签 4、新标签 5、特殊标签 6、在网页中使用视频和音频、图片 7、表格标签 8、超链接标签 使用HTML语言来实现该页面 实现过程&#xff1a; <!DOCTYPE html> <html><head>…

nlkt中BigramAssocMeasures.pmi()方法的传参和使用

这个问题找遍全网没看到详细的介绍&#xff0c;最后用读代码数学公式的方法才理解怎么用。 BigramAssocMeasures.pmi 作用&#xff1a;计算x和y的互信息&#xff08;互信息是什么我就不科普啦&#xff09; 这里有个误区刚开始我以为是计算两个词之间的依赖程度&#xff0c;但…

【Spring教程25】Spring框架实战:从零开始学习SpringMVC 之 SpringMVC入门案例总结与SpringMVC工作流程分析

目录 1.入门案例总结2. 入门案例工作流程分析2.1 启动服务器初始化过程2.2 单次请求过程 欢迎大家回到《Java教程之Spring30天快速入门》&#xff0c;本教程所有示例均基于Maven实现&#xff0c;如果您对Maven还很陌生&#xff0c;请移步本人的博文《如何在windows11下安装Mave…

java resource ‘process/qingjia.png‘ not found

resource中的资源在target中没有&#xff0c;导致报错&#xff0c;如下图所示&#xff1a; 解决办法&#xff1a;在pom文件中添加如下代码&#xff1a; 重新执行代码&#xff0c;就能在target中看到png文件了。 类似的错误参考链接&#xff1a;mybatis-plus框架报错&#x…

探索HarmonyOS_开发软件安装

随着华为推出HarmonyOS NEXT 宣布将要全面启用鸿蒙原声应用&#xff0c;不在兼容安卓应用&#xff0c; 现在开始探索鸿蒙原生应用的开发。 HarmonyOS应用开发官网 - 华为HarmonyOS打造全场景新服务 鸿蒙官网 开发软件肯定要从这里下载 第一个为微软系统(windows)&#xff0c;第…

【Linux】使用Bash和GNU Parallel并行解压缩文件

介绍 在本教程中&#xff0c;我们将学习如何使用Bash脚本和GNU Parallel实现高效并行解压缩多个文件。这种方法在处理大量文件时可以显著加快提取过程。 先决条件 确保系统上已安装以下内容&#xff1a; BashGNU Parallel 你可以使用以下命令在不同Linux系统上安装它们&am…

RF射频干扰被动型红外传感器误判分析及整改事例

1.1 什么是红外传感 测量系统是以红外线为介质&#xff0c;探测可分成为光子和热探测器。 简洁原理就是利用产生的辐射与物质相互作用后呈现出来的物理效应就是它的基本原理。 1.2 红外按方式分类 &#xff08;1&#xff09;被动型红外&#xff1a;本身不会向外界辐射任何能量…

大师学SwiftUI第18章Part2 - 存储图片和自定义相机

存储图片 在前面的示例中&#xff0c;我们在屏幕上展示了图片&#xff0c;但也可以将其存储到文件或数据库中。另外有时使用相机将照片存储到设备的相册薄里会很有用&#xff0c;这样可供其它应用访问。UIKit框架提供了如下两个保存图片和视频的函数。 UIImageWriteToSavedPh…

ffmpeg6.0之ffprobe.c源码分析二-核心功能源码分析

本篇我们继续分析: 1、ffprobe -show_packets 参数的处理流程;2、ffprobe -show_frames 参数的处理流程;3、ffprobe -show_streams 参数的处理流程;4、ffprobe -show_format 参数的处理流程; 因为前面的文章已经回顾了这些命令的使用,以及作用。本文就不在赘述,以免篇幅…

电子学会C/C++编程等级考试2022年03月(五级)真题解析

C/C++等级考试(1~8级)全部真题・点这里 第1题:数字变换 给定一个包含 5 个数字(0-9)的字符串, 例如 “02943”, 请将“12345”变换到它。 你可以采取 3 种操作进行变换 1. 交换相邻的两个数字 2. 将一个数字加 1。 如果加 1 后大于 9, 则变为 0 3. 将一个数字加倍。 如果…

查找两个总和为特定值的索引(蓝桥杯)

#include <stdio.h> int main(){int n;scanf("%d",&n);int s[n];for(int i 0 ; i < n ; i)scanf("%d",&s[i]);int k;scanf("%d",&k);int sum 0;int t0,h;int st[101]; for(int i 0 ; i < n ; i)st[i] 0; //标记数…
最新文章