当科技遇上神器:用Streamlit定制AI可视化问答界面

Streamlit是一个开源的Python库,利用Streamlit可以快速构建机器学习应用的用户界面。

本文主要探讨如何使用Streamlit构建大模型+外部知识检索的AI问答可视化界面。

我们先构建了外部知识检索接口,然后让大模型根据检索返回的结果作为上下文来回答问题。

Streamlit-使用说明

下面简单介绍下Streamlit的安装和一些用到的组件。

  1. Streamlit安装
pip install streamlit
  1. Streamlit启动
streamlit run xxx.py --server.port 8888

说明:

  • 如果不指定端口,默认使用8501,如果启动多个streamlit,端口依次升序,8502,8503,…。
  • 设置server.port可指定端口。
  • streamlit启动后将会给出两个链接,Local URL和Network URL。
  1. 相关组件
import streamlit as st
  • st.header

streamlit.header(body)

body:字符串,要显示的文本。

  • st.markdown

st.markdown(body, unsafe_allow_html=False)

body:要显示的markdown文本,字符串。

unsafe_allow_html: 是否允许出现html标签,布尔值,默认:false,表示所有的html标签都将转义。 注意,这是一个临时特性,在将来可能取消。

  • st.write

st.write(*args, **kwargs)

*args:一个或多个要显示的对象参数。

unsafe_allow_html :是否允许不安全的HTML标签,布尔类型,默认值:false。

  • st.button

st.button(label, key=None)

label:按钮标题字符串。

key:按钮组件的键,可选。如果未设置的话,streamlit将自动生成一个唯一键。

  • st.radio

st.radio(label, options, index=0, format_func=<class 'str'>, key=None)

label:单选框文本,字符串。

options:选项列表,可以是以下类型:
list
tuple
numpy.ndarray
pandas.Series

index:选中项的序号,整数。

format_func:选项文本的显示格式化函数。

key:组件ID,当未设置时,streamlit会自动生成。

  • st.sidebar

st.slider(label, min_value=None, max_value=None, value=None, step=None, format=None, key=None)

label:说明文本,字符串。

min_value:允许的最小值,默认值:0或0.0。

max_value:允许的最大值,默认值:0或0.0。

value:当前值,默认值为min_value。

step:步长,默认值为1或0.01。

format:数字显示格式字符串

key:组件ID。

  • st.empty

st.empty()

填充占位符。

  • st.columns

插入并排排列的容器。

st.columns(spec, *, gap="small")

spec: 控制要插入的列数和宽度。

gap: 列之间的间隙大小。

AI问答可视化代码

这里只涉及到构建AI问答界面的代码,不涉及到外部知识检索。

  1. 导入packages
import streamlit as st
import requests
import json
import sys,os

import torch
import torch.nn as nn
from dataclasses import dataclass, asdict
from typing import List, Optional, Callable
import copy
import warnings
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.generation.utils import GenerationConfig
from peft import PeftModel
from chatglm.modeling_chatglm import ChatGLMForConditionalGeneration
  1. 外部知识检索
def get_reference(user_query,use_top_k=True,top_k=10,use_similar_score=True,threshold=0.7):
  """
  外部知识检索的方式,使用top_k或者similar_score控制检索返回值。
  """
    # 设置检索接口
    SERVICE_ADD = ''
    
    ref_list = []
    user_query = user_query.strip()
    input_data = {}
    if use_top_k:
        input_data['query'] = user_query
        input_data['topk'] = top_k
        result = requests.post(SERVICE_ADD, json=input_data)
        res_json = json.loads(result.text)
        for i in range(len(res_json['answer'])):
            ref = res_json['answer'][i]
            ref_list.append(ref)
    elif use_similar_score:
        input_data['query'] = user_query
        input_data['topk'] = top_k
        result = requests.post(SERVICE_ADD, json=input_data)
        res_json = json.loads(result.text)
        for i in range(len(res_json['answer'])):
            maxscore = res_json['answer'][i]['prob']
            if maxscore > threshold:  
                ref = res_json['answer'][i]
                ref_list.append(ref)
    return ref_list
  1. 参数设置
# 设置清除按钮
def on_btn_click():
    del st.session_state.messages

# 设置参数
def set_config():
    # 设置基本参数
    base_config = {"model_name":"","use_ref":"","use_topk":"","top_k":"","use_similar_score":"","max_similar_score":""}
    # 设置模型参数
    model_config = {'top_k':'','top_p':'','temperature':'','max_length':'','do_sample':""}
    
    # 左边栏设置
    with st.sidebar:
        model_name = st.radio(
            "模型选择:",
            ["baichuan2-13B-chat", "qwen-14B-chat","chatglm-6B","chatglm3-6B"],
            index="0",
        )
        base_config['model_name'] = model_name
        
        set_ref = st.radio(
            "是否使用外部知识库:",
            ["是","否"],
            index="0",
        )
        base_config['use_ref'] = set_ref
        
        if set_ref=="是":
            set_topk_score = st.radio(
                '设置选择参考文献的方式:',
                ['use_topk','use_similar_score'],
                index='0',
                )
            
            if set_topk_score=='use_topk':
                set_topk = st.slider(
                    'Top_K', 1, 10, 5,step=1
                )
                base_config['top_k'] = set_topk
                base_config['use_topk'] = True
                base_config['use_similar_score'] = False
                set_score = st.empty()
                
            elif set_topk_score=='use_similar_score':
                set_score = st.slider(
                    "Max_Similar_Score",0.00,1.00,0.70,step=0.01
                )
                base_config['max_similar_score'] = set_score
                base_config['use_similar_score'] = True
                base_config['use_topk'] = False
                set_topk = st.empty()
                
            else:
                set_topk_score = st.empty()
                set_topk = st.empty()
                set_score = st.empty()
                
        sample = st.radio("Do Sample", ('True', 'False'))
        max_length = st.slider("Max Length", min_value=64, max_value=2048, value=1024)
        top_p = st.slider(
            'Top P', 0.0, 1.0, 0.7, step=0.01
        )
        temperature = st.slider(
            'Temperature', 0.0, 2.0, 0.05, step=0.01
        )
        st.button("Clear Chat History", on_click=on_btn_click)
        
    # 设置模型参数
    model_config['top_p']=top_p
    model_config['do_sample']=sample
    model_config['max_length']=max_length
    model_config['temperature']=temperature
    return base_config,model_config
  1. 设置模型输入格式
# 设置不同模型的输入格式
def set_input_format(model_name):
    # ["baichuan2-13B-chat", "baichuan2-7B-chat", "qwen-14B-chat",'chatglm-6B','chatglm3-6B']
    if model_name=="baichuan2-13B-chat" or model_name=='baichuan2-7B-chat':
        input_format = "<reserved_106>{{query}}<reserved_107>"
    elif model_name=="qwen-14B-chat":
        input_format = """
        <|im_start|>system 
        你是一个乐于助人的助手。<|im_end|>
        <|im_start|>user
        {{query}}<|im_end|>
        <|im_start|>assistant"""
    elif model_name=="chatglm-6B":
        input_format = """{{query}}"""
    elif model_name=="chatglm3-6B":
        input_format = """
        <|system|>
        You are ChatGLM3, a large language model trained by Zhipu.AI. Follow the user's instructions carefully. Respond using markdown.
        <|user|>
        {{query}}
        <|assistant|>
        """
    return input_format
  1. 加载模型
# 加载模型和分词器
@st.cache_resource
def load_model(model_name):
    if model_name=="baichuan2-13B-chat":
        model = AutoModelForCausalLM.from_pretrained("baichuan-inc/Baichuan2-13B-Chat",trust_remote_code=True)
        lora_path = ""
        tokenizer = AutoTokenizer.from_pretrained("baichuan-inc/Baichuan2-13B-Chat",trust_remote_code=True)
        model.to("cuda:0")
    elif model_name=="qwen-14B-chat":
        model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-14B-Chat",trust_remote_code=True)
        lora_path = ""
        tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen-14B-Chat",trust_remote_code=True)
        model.to("cuda:1")
    elif model_name=="chatglm-6B":
        model = ChatGLMForConditionalGeneration.from_pretrained('THUDM/chatglm-6b',trust_remote_code=True)
        lora_path = ""
        tokenizer = AutoTokenizer.from_pretrained('THUDM/chatglm-6b',trust_remote_code=True)
        model.to("cuda:2")
    elif model_name=="chatglm3-6B":
        model = AutoModelForCausalLM.from_pretrained('THUDM/chatglm3-6b',trust_remote_code=True)
        lora_path = ""
        tokenizer = AutoTokenizer.from_pretrained('THUDM/chatglm3-6b',trust_remote_code=True)
        model.to("cuda:3")
        
    # 加载lora包
    model = PeftModel.from_pretrained(model,lora_path)
    return model,tokenizer
  1. 推理参数设置

def llm_chat(model_name,model,tokenizer,model_config,query):
    response = ''
    top_k = model_config['top_k']
    top_p = model_config['top_p']
    max_length = model_config['max_length']
    do_sample = model_config['do_sample']
    temperature = model_config['temperature']
    
    if model_name=="baichuan2-13B-chat" or model_name=='baichuan-7B-chat':
        messages = []
        messages.append({"role": "user", "content": query})
        response = model.chat(tokenizer, messages)
        
    elif model_name=="qwen-14B-chat":
        response, history = model.chat(tokenizer, query, history=None, top_p=top_p, max_new_tokens=max_length, do_sample=do_sample, temperature=temperature)
        
    elif model_name=="chatglm-6B":
        response, history = model.chat(tokenizer, query, history=None, top_p=top_p, max_length=max_length, do_sample=do_sample, temperature=temperature)
    
    elif model_name=="chatglm3-6B":
        response, history= model.chat(tokenizer, query, top_p=top_p, max_length=max_length, do_sample=do_sample, temperature=temperature)
        
    return response
  1. 主程序
if __name__=="__main__":
    
    #对话的图标
    user_avator = "🧑‍💻"
    robot_avator = "🤖"
    
    if "messages" not in st.session_state:
        st.session_state.messages = []
        
    torch.cuda.empty_cache()
    base_config,model_config = set_config()
    model_name = base_config['model_name']
    use_ref = base_config['use_ref']
    
    model,tokenizer = load_model(model_name=model_name)
    
    input_format = set_input_format(model_name=model_name)

    header_text = f'Large Language Model :{model_name}'
    st.header(header_text)
    
    if use_ref=="是":
        col1, col2 = st.columns([5, 3])  
        with col1:
            for message in st.session_state.messages:
                with st.chat_message(message["role"], avatar=message.get("avatar")):
                    st.markdown(message["content"])
        
        if user_query := st.chat_input("请输入内容..."):
            with col1:  
                with st.chat_message("user", avatar=user_avator):
                    st.markdown(user_query)
                st.session_state.messages.append({"role": "user", "content": user_query, "avatar": user_avator})
                
                with st.chat_message("robot", avatar=robot_avator):
                    message_placeholder = st.empty()
                    use_top_k = base_config['use_topk']
                    
                    if use_top_k:
                        top_k = base_config['top_k']
                        use_similar_score = base_config['use_similar_score']
                        ref_list = get_reference(user_query,use_top_k=use_top_k,top_k=top_k,use_similar_score=use_similar_score) 
                    else:
                        use_top_k = base_config['use_topk']
                        use_similar_score = base_config['use_similar_score']
                        threshold = base_config['max_similar_score']
                        ref_list = get_reference(user_query,use_top_k=use_top_k,use_similar_score=use_similar_score,threshold=threshold)
                    
                    if ref_list:
                        context = ""
                        for ref in ref_list:
                            context = context+ref['para']+"\n"
                        context = context.strip('\n')
                        query = f'''
                        上下文:
                        【
                        {context} 
                        】
                        只能根据提供的上下文信息,合理回答下面的问题,不允许编造内容,不允许回答无关内容。
                        问题:
                        【
                        {user_query}
                        】
                        '''
                    else:
                        query = user_query
                    query = input_format.replace("{{query}}",query)
                    print('输入:',query)
                    max_len = model_config['max_length']
                    if len(query)>max_len:
                        cur_response = f'字数超过{max_len},请调整max_length。'
                    else:
                        cur_response = llm_chat(model_name,model,tokenizer,model_config,query)
                    fs.write(f'输入:{query}')
                    fs.write('\n')
                    fs.write(f'输出:{cur_response}')
                    fs.write('\n')
                    sys.stdout.flush()

                    if len(query)<max_len:
                        if ref_list:
                            cur_response = f"""
                            大模型将根据外部知识库回答您的问题:{cur_response}
                            """
                        else:
                            cur_response = f"""
                            大模型将根据预训练时的知识回答您的问题,存在编造事实的可能性。因此以下输出仅供参考:{cur_response}
                            """
                            
                    message_placeholder.markdown(cur_response)
                st.session_state.messages.append({"role": "robot", "content": cur_response, "avatar": robot_avator})
                
            with col2:
                ref_list = get_reference(user_query)
                if ref_list:
                    for ref in ref_list:
                        ques = ref['ques']
                        answer = ref['para']
                        score = ref['prob']
                        question = f'{ques}--->score: {score}'
                        with st.expander(question):
                            st.write(answer)
    
    else:
        for message in st.session_state.messages:
            with st.chat_message(message["role"], avatar=message.get("avatar")):
                st.markdown(message["content"])
        if user_query := st.chat_input("请输入内容..."):
            with st.chat_message("user", avatar=user_avator):
                st.markdown(user_query)
            st.session_state.messages.append({"role": "user", "content": user_query, "avatar": user_avator})
            with st.chat_message("robot", avatar=robot_avator):
                message_placeholder = st.empty()
                query = input_format.replace("{{query}}",user_query)
                max_len = model_config['max_length']
                if len(query)>max_len:
                    cur_response = f'字数超过{max_len},请调整max_length。'
                else:
                    cur_response = llm_chat(model_name,model,tokenizer,model_config,query)
                fs.write(f'输入:{query}')
                fs.write('\n')
                fs.write(f'输出:{cur_response}')
                fs.write('\n')
                sys.stdout.flush()
                cur_response = f"""
                大模型将根据预训练时的知识回答您的问题,存在编造事实的可能性。因此以下输出仅供参考:{cur_response}
                """
                message_placeholder.markdown(cur_response)
                st.session_state.messages.append({"role": "robot", "content": cur_response, "avatar": robot_avator})
                    
  1. 可视化界面展示

总结

Streamlit工具使用非常方便,说明文档清晰。

这个可视化界面集成了多个大模型+外部知识检索,同时可以在线调整模型参数,使用方便。

完整代码:https://github.com/hjandlm/Streamlit_LLM_QA

参考

[1] https://docs.streamlit.io/
[2] http://cw.hubwiz.com/card/c/streamlit-manual/
[3] https://github.com/hiyouga/LLaMA-Factory/tree/9093cb1a2e16d1a7fde5abdd15c2527033e33143

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/114704.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

浅谈无源供电无线测温在线监测系统应用方案

安科瑞 崔丽洁 摘要&#xff1a;无源供电无线测温在线监测系统是一种基于声表面波技术的测温技术&#xff0c;在变电站监测方面得到了很好的技术实践应用。本文对无源供电无线测温在线监测系统研究应用进行分析研究。 关键词&#xff1a;设备检测&#xff1b;无线测温。 引言 在…

行情分析——加密货币市场大盘走势(11.3)

大饼昨日与今日目前都是下跌态势&#xff0c;近期依然要保持逢低做多的策略。现在下跌&#xff0c;可以继续等待&#xff0c;也可以入场一部分仓位的多单&#xff0c;回调才是给机会上车。MACD日线来看&#xff0c;会继续回调&#xff0c;因此这个位置还是可以在等等。 以太昨日…

LeetCode题:21合并两个有序链表

21合并两个有序链表 题目描述 将两个升序链表合并为一个新的 升序 链表并返回。新链表是通过拼接给定的两个链表的所有节点组成的。 示例 1&#xff1a; 输入&#xff1a;l1 [1,2,4], l2 [1,3,4] 输出&#xff1a;[1,1,2,3,4,4]示例 2&#xff1a; 输入&#xff1a;l1 [], …

vcruntime140.dll无法继续执行代码修复教程

在计算机的世界里&#xff0c;我们经常会遇到各种各样的问题&#xff0c;其中之一就是“vcruntime140.dll缺失”。这个问题可能会影响到我们的正常使用&#xff0c;但是别担心&#xff0c;今天我就来给大家分享一下关于vcruntime140.dll缺失的4种修复方案。 首先&#xff0c;我…

mac下载安装jenkins

下载 https://get.jenkins.io/war/ 启动 使用命令行启动 java -jar jenkins.war 浏览器访问 IP:8080 或 localhost:8080 &#xff0c;对jenkins进行配置&#xff0c;刚开始需要输入密码 终端会展示密码和密码存放位置 jenkins插件下载地址&#xff0c; 下载后自行上传。 I…

【ChatGLM2-6B】P-Tuning训练微调

机器配置 阿里云GPU规格ecs.gn6i-c4g1.xlargeNVIDIA T4显卡*1GPU显存16G*1 准备训练数据 进入/ChatGLM-6B/ptuningmkdir AdvertiseGencd AdvertiseGen上传 dev.json 和 train.json内容都是 {"content": "你是谁", "summary": "你好&…

如何使用ps制作ico图标文件

如何使用ps制作ico图标文件 Chapter1 如何使用ps制作ico图标文件Chapter2 ICOFormat.8bi&#xff08;Photoshop Ico、Cur插件&#xff09;的下载使用1. ICOFormat.8bi的作用2. ICOFormat.8bi使用 Chapter3 ps手机计算机图标教程,手绘设计精美手机APP软件图标的PS教程步骤 01 制…

计算机网络-应用层

文章目录 应用层协议原理万维网和HTTP协议万维网概述统一资源定位符HTML文档 超文本传输协议&#xff08;HTTP&#xff09;HTTP报文格式请求报文响应报文cookie 万维网缓存与代理服务器 DNS系统域名空间域名服务器和资源记录域名解析过程递归查询迭代查询 动态主机配置协议&…

解决CSS中height:100%失效的问题

出现BUG的场景&#xff0c;点击退出到登录页面&#xff0c;发现高度不对 上面出现了一种只是占了内容的高度&#xff0c;没有占满100%&#xff0c;为什么会出现这种情况呐&#xff1f; 让div的height"100%"&#xff0c;执行网页时&#xff0c;css先执行到&#xff0…

华为OD机试 - 数组组成的最小数字 - 逻辑分析(Java 2023 B卷 100分)

目录 专栏导读一、题目描述二、输入描述三、输出描述四、解题思路五、Java算法源码六、效果展示1、输入2、输出3、说明 华为OD机试 2023B卷题库疯狂收录中&#xff0c;刷题点这里 专栏导读 本专栏收录于《华为OD机试&#xff08;JAVA&#xff09;真题&#xff08;A卷B卷&#…

YOLOv5:按每个类别的不同置信度阈值输出预测框

YOLOv5&#xff1a;按每个类别的不同置信度阈值输出预测框 前言前提条件相关介绍YOLOv5&#xff1a;按每个类别的不同置信度阈值输出预测框预测修改detect.py输出结果 验证修改val.py输出结果 参考 前言 由于本人水平有限&#xff0c;难免出现错漏&#xff0c;敬请批评改正。更…

H5ke9

上次fetvh就一个参数url,,就是get请求 fetch还可以第二个参数对象,可以指定method:改为POST 请求头header :发送txt,servlet,json给客户端,,异步请求图片 1 这节客户端传到服务器端 2异步文件上传,两三行代码把文件传输 mouseover事件 .then()的使用 是Promise对象的一个方法…

持续进化,快速转录,Faster-Whisper对视频进行双语字幕转录实践(Python3.10)

Faster-Whisper是Whisper开源后的第三方进化版本&#xff0c;它对原始的 Whisper 模型结构进行了改进和优化。这包括减少模型的层数、减少参数量、简化模型结构等&#xff0c;从而减少了计算量和内存消耗&#xff0c;提高了推理速度&#xff0c;与此同时&#xff0c;Faster-Whi…

tmux工具

B站学习地址&#xff1a;tmux教程

vue封装独立组件:实现手写签名功能

目录 第一章 效果展示 第二章 准备工作 2.1 使用的工具vue-sign 2.1.1 安装 2.1.2 了解 2.1.3 参数说明 第三章 源代码 第一章 效果展示 第二章 准备工作 2.1 使用的工具vue-esign 2.1.1 安装 npm install vue-esign --save 2.1.2 了解 兼容pc端和移动端有对应的参…

Redis的四种部署方案

这篇文章介绍Reids最为常见的四种部署模式&#xff0c;其实Reids和数据库的集群模式差不多&#xff0c;可以分为 Redis单机模式部署、Redis主从模式部署、Redis哨兵模式部署、Cluster集群模式部署&#xff0c;其他的部署方式基本都是围绕以下几种方式在进行调整到适应的生产环境…

@RunWith(SpringRunner.class)注解的作用

通俗点&#xff1a; RunWith(SpringRunner.class)的作用表明Test测试类要使用注入的类&#xff0c;比如Autowired注入的类&#xff0c;有了RunWith(SpringRunner.class)这些类才能实例化到spring容器中&#xff0c;自动注入才能生效 官方点&#xff1a; RunWith 注解是JUnit测…

构建mono-repo风格的脚手架库

前段时间阅读了 https://juejin.cn/post/7260144602471776311#heading-25 这篇文章&#xff1b;本文做一个梳理和笔记&#xff1b; 主要聚焦的知识点如下&#xff1a; 如何搭建脚手架工程如何开发调试如何处理命令行参数如何实现用户交互如何拷贝文件夹或文件如何动态生成文件…

Android工具栏ToolBar

主流APP除了底部有一排标签栏外&#xff0c;通常顶部还有一排导航栏。在Android5.0之前&#xff0c;这个顶部导航栏以ActionBar控件的形式出现&#xff0c;但AcionBar存在不灵活、难以扩展等毛病&#xff0c;所以Android5.0之后推出了ToolBar工具栏控件&#xff0c;意在取代Aci…

Python 获取cpu、内存利用率

获取cpu、内存利用率 # -*- coding: latin1 -*- import psutil cpuPercent 0 psutil.cpu_percent() while True:vm psutil.virtual_memory()memoryPercent vm.percentcpuPercent psutil.cpu_percent(1) *10print("cpuPercent:"str(cpuPercent)" %")prin…
最新文章