ChatGLM3 langchain_demo 代码解析

ChatGLM3 langchain_demo 代码解析

  • 0. 背景
  • 1. 项目代码结构
  • 2. 代码解析
    • 2-1. utils.py
    • 2-2. ChatGLM3.py
    • 2-3. Tool/Calculator.py
    • 2-4. Tool/Weather.py
    • 2-5. main.py

0. 背景

学习 ChatGLM3 的项目内容,过程中使用 AI 代码工具,对代码进行解释,帮助自己快速理解代码。这篇文章记录 ChatGLM3 langchain_demo 的代码解析内容。

1. 项目代码结构

在这里插入图片描述

2. 代码解析

2-1. utils.py

import os
import yaml


def tool_config_from_file(tool_name, directory="Tool/"):
    """search tool yaml and return json format"""
    for filename in os.listdir(directory):
        if filename.endswith('.yaml') and tool_name in filename:
            file_path = os.path.join(directory, filename)
            with open(file_path, encoding='utf-8') as f:
                return yaml.safe_load(f)
    return None

这段代码定义了一个函数 tool_config_from_file,用于从文件中加载工具的配置信息。

该函数接受两个参数:tool_name 表示要加载的工具名称,directory 表示存储工具配置文件的目录,默认为 “Tool/”。

在函数中,首先使用 os.listdir 函数获取指定目录下的所有文件名。然后,通过遍历文件名列表,找到以 “.yaml” 结尾且包含指定工具名称的文件。如果找到了匹配的文件,就构造文件的完整路径,并使用 open 函数打开文件。接着,使用 yaml.safe_load 函数加载文件内容,并将其转换为 JSON 格式的数据返回。

如果遍历完所有文件后仍未找到匹配的工具配置文件,则返回 None。

总体而言,这段代码定义了一个函数 tool_config_from_file,用于根据工具名称和文件目录获取工具的配置信息,并将其转换为 JSON 格式的数据返回。

2-2. ChatGLM3.py

import json
from langchain.llms.base import LLM
from transformers import AutoTokenizer, AutoModel, AutoConfig
from typing import List, Optional
from utils import tool_config_from_file

这段代码导入了一些模块和函数,并定义了一些类型注解。

首先,导入了 json 模块,用于处理 JSON 数据。 然后,导入了 LLM 类和 AutoTokenizer、AutoModel、AutoConfig 类,这些来自 langchain.llms.base 和 transformers 模块,用于构建和配置语言模型。 接下来,导入了 List 和 Optional 类型,用于类型注解。 最后,导入了 tool_config_from_file 函数,该函数来自 utils 模块,用于加载工具的配置信息。

总体而言,这段代码导入了所需的模块、类和函数,以及定义了一些类型注解。

class ChatGLM3(LLM):
    max_token: int = 8192
    do_sample: bool = False
    temperature: float = 0.8
    top_p = 0.8
    tokenizer: object = None
    model: object = None
    history: List = []
    tool_names: List = []
    has_search: bool = False

    def __init__(self):
        super().__init__()

    @property
    def _llm_type(self) -> str:
        return "ChatGLM3"

    def load_model(self, model_name_or_path=None):
        model_config = AutoConfig.from_pretrained(
            model_name_or_path,
            trust_remote_code=True
        )
        self.tokenizer = AutoTokenizer.from_pretrained(
            model_name_or_path,
            trust_remote_code=True
        )
        self.model = AutoModel.from_pretrained(
            model_name_or_path, config=model_config, trust_remote_code=True
        ).half().cuda()

    def _tool_history(self, prompt: str):
        ans = []
        tool_prompts = prompt.split(
            "You have access to the following tools:\n\n")[1].split("\n\nUse a json blob")[0].split("\n")

        tool_names = [tool.split(":")[0] for tool in tool_prompts]
        self.tool_names = tool_names
        tools_json = []
        for i, tool in enumerate(tool_names):
            tool_config = tool_config_from_file(tool)
            if tool_config:
                tools_json.append(tool_config)
            else:
                ValueError(
                    f"Tool {tool} config not found! It's description is {tool_prompts[i]}"
                )

        ans.append({
            "role": "system",
            "content": "Answer the following questions as best as you can. You have access to the following tools:",
            "tools": tools_json
        })
        query = f"""{prompt.split("Human: ")[-1].strip()}"""
        return ans, query

    def _extract_observation(self, prompt: str):
        return_json = prompt.split("Observation: ")[-1].split("\nThought:")[0]
        self.history.append({
            "role": "observation",
            "content": return_json
        })
        return

    def _extract_tool(self):
        if len(self.history[-1]["metadata"]) > 0:
            metadata = self.history[-1]["metadata"]
            content = self.history[-1]["content"]
            if "tool_call" in content:
                for tool in self.tool_names:
                    if tool in metadata:
                        input_para = content.split("='")[-1].split("'")[0]
                        action_json = {
                            "action": tool,
                            "action_input": input_para
                        }
                        self.has_search = True
                        return f"""
Action: 

{json.dumps(action_json, ensure_ascii=False)}

        final_answer_json = {
            "action": "Final Answer",
            "action_input": self.history[-1]["content"]
        }
        self.has_search = False
        return f"""
Action: 

{json.dumps(final_answer_json, ensure_ascii=False)}


    def _call(self, prompt: str, history: List = [], stop: Optional[List[str]] = ["<|user|>"]):
        print("======")
        print(prompt)
        print("======")
        if not self.has_search:
            self.history, query = self._tool_history(prompt)
        else:
            self._extract_observation(prompt)
            query = ""
        # print("======")
        # print(history)
        # print("======")
        _, self.history = self.model.chat(
            self.tokenizer,
            query,
            history=self.history,
            do_sample=self.do_sample,
            max_length=self.max_token,
            temperature=self.temperature,
        )
        response = self._extract_tool()
        history.append((prompt, response))
        return response

这段代码定义了一个名为 ChatGLM3 的类,该类继承自 LLM 类。

以下是这个类的成员变量和方法的详细解析:

  • max_token: int = 8192:最大令牌数,默认为 8192。

  • do_sample: bool = False:是否进行采样,默认为 False。

  • temperature: float = 0.8:采样温度,默认为 0.8。

  • top_p = 0.8:top-p 采样的概率阈值,默认为 0.8。

  • tokenizer: object = None:tokenizer 对象,默认为 None。

  • model: object = None:模型对象,默认为 None。

  • history: List = []:对话历史记录列表,默认为空列表。

  • tool_names: List = []:工具名称列表,默认为空列表。

  • has_search: bool = False:是否进行工具搜索的标志位,默认为 False。

  • init(self):类的构造函数,调用父类 LLM 的构造函数。

  • _llm_type(self) -> str:类属性,返回字符串 “ChatGLM3”。

  • load_model(self, model_name_or_path=None):加载模型的方法。根据模型的名称或路径,使用 AutoConfig、AutoTokenizer 和 AutoModel 类从预训练模型中加载配置、tokenizer 和模型,并将模型转化为半精度浮点数并放置在 CUDA 设备上。

  • _tool_history(self, prompt: str):提取工具历史记录的方法。根据提示字符串,从中提取出工具名称和工具配置信息,并将其存储到 tool_names 和 tools_json 中,最后将结果作为字典添加到 ans 列表中,并返回 ans 列表和查询字符串。

  • _extract_observation(self, prompt: str):提取观察信息的方法。从提示字符串中提取观察信息,将其添加到历史记录列表 history 中。

  • _extract_tool(self):提取工具信息的方法。根据最后一条历史记录的元数据和内容,判断是否存在工具调用。如果存在,遍历工具名称列表,如果某个工具名称出现在元数据中,提取出输入参数,并构造一个动作 JSON 对象。同时,将 has_search 设置为 True。如果不存在工具调用,构造一个最终回答的动作 JSON 对象,并将 has_search 设置为 False。最后,返回包含动作 JSON 的字符串。

  • _call(self, prompt: str, history: List = [], stop: Optional[List[str]] = [“<|user|>”]):执行对话的方法。根据是否进行工具搜索的标志位,调用 _tool_history 方法或 _extract_observation 方法来获取查询字符串和更新历史记录。然后,使用模型的 chat 方法进行对话生成,并将结果传递给 _extract_tool 方法,提取工具信息。最后,将提示和响应添加到历史记录列表中,并返回响应。

总体而言,这段代码定义了一个名为ChatGLM3 的类,该类继承自 LLM 类。它包含了一些成员变量和方法,用于加载模型、提取工具历史记录、提取观察信息和执行对话。

在 load_model 方法中,模型的配置、tokenizer 和模型本身被加载,并存储在 tokenizer 和 model 成员变量中。

_tool_history 方法用于从提示字符串中提取工具历史记录。它首先将提示字符串按特定的分隔符切分,然后从切分结果中提取工具名称和相应的工具配置信息。这些信息被存储在 tool_names 和 tools_json 成员变量中,并作为字典添加到 ans 列表中。最后,返回 ans 列表和查询字符串。

_extract_observation 方法用于从提示字符串中提取观察信息。它将观察信息存储在 history 成员变量中。

_extract_tool 方法用于提取工具信息。它检查最后一条历史记录的元数据和内容,判断是否存在工具调用。如果存在工具调用,它将提取出工具名称和输入参数,并构造一个包含动作和输入参数的 JSON 对象。如果不存在工具调用,它将构造一个包含最终回答动作的 JSON 对象。最后,它将返回包含动作 JSON 的字符串。

_call 方法用于执行对话。它根据 has_search 标志位,选择调用 _tool_history 方法或 _extract_observation 方法,获取查询字符串和更新历史记录。然后,它使用模型的 chat 方法生成响应,并将结果传递给 _extract_tool 方法,提取工具信息。最后,它将提示和响应添加到历史记录列表中,并返回响应。

总体而言,这段代码定义了一个用于对话生成的类 ChatGLM3,它继承自 LLM 类,并提供了加载模型、提取工具历史记录、提取观察信息和执行对话的功能。

    def _tool_history(self, prompt: str):
        ans = []
        tool_prompts = prompt.split(
            "You have access to the following tools:\n\n")[1].split("\n\nUse a json blob")[0].split("\n")

        tool_names = [tool.split(":")[0] for tool in tool_prompts]
        self.tool_names = tool_names
        tools_json = []
        for i, tool in enumerate(tool_names):
            tool_config = tool_config_from_file(tool)
            if tool_config:
                tools_json.append(tool_config)
            else:
                ValueError(
                    f"Tool {tool} config not found! It's description is {tool_prompts[i]}"
                )

        ans.append({
            "role": "system",
            "content": "Answer the following questions as best as you can. You have access to the following tools:",
            "tools": tools_json
        })
        query = f"""{prompt.split("Human: ")[-1].strip()}"""
        return ans, query

这段代码定义了一个名为 _tool_history 的方法,它接受一个参数 prompt,该参数是一个字符串。

以下是这段代码的详细解析:

  • ans = []:创建一个空列表 ans,用于存储返回结果。

  • tool_prompts = prompt.split(“You have access to the following tools:\n\n”)[1].split(“\n\nUse a json blob”)[0].split(“\n”):从 prompt 字符串中提取工具提示信息。它使用特定的分隔符将 prompt 字符串进行切分,提取包含工具提示信息的部分,并将其存储在 tool_prompts 列表中。

  • tool_names = [tool.split(“:”)[0] for tool in tool_prompts]:从 tool_prompts 列表中提取工具名称。它使用冒号进行切分,并将每个工具的名称存储在 tool_names 列表中。

  • self.tool_names = tool_names:将 tool_names 列表赋值给类的成员变量 tool_names。

  • tools_json = []:创建一个空列表 tools_json,用于存储工具的配置信息。

  • for i, tool in enumerate(tool_names)::对 tool_names 列表进行遍历,循环变量 tool 表示当前遍历的工具名称,循环变量 i 表示当前遍历的索引。

    • tool_config = tool_config_from_file(tool):调用 tool_config_from_file 函数,根据工具名称获取工具的配置信息,并将结果存储在 tool_config 变量中。

    • if tool_config::检查 tool_config 是否存在。

      • tools_json.append(tool_config):如果 tool_config 存在,则将其添加到 tools_json 列表中。
    • else::否则,如果 tool_config 不存在。

      • ValueError(…):抛出一个 ValueError 异常,提示工具配置未找到,并包含工具的描述信息。
  • ans.append({…}):将一个字典对象添加到 ans 列表中。字典包含以下键值对:

    • “role”: “system”:角色为 “system”,表示系统的角色。
    • “content”: “Answer the following questions as best as you can. You have access to the following tools:”:内容为提示信息。
    • “tools”: tools_json:工具列表为 tools_json。
  • query = f"“”{prompt.split(“Human: “)[-1].strip()}””":从 prompt 字符串中提取查询字符串。它首先根据特定的分隔符将 prompt 字符串进行切分,然后从切分结果中选择最后一个元素,并去除首尾空格。

  • return ans, query:返回 ans 列表和 query 字符串作为结果。

总体而言,这段代码定义了一个方法 _tool_history,它从 prompt 字符串中提取工具提示信息,并构造一个包含提示信息、工具列表和查询字符串的字典对象。

    def _extract_observation(self, prompt: str):
        return_json = prompt.split("Observation: ")[-1].split("\nThought:")[0]
        self.history.append({
            "role": "observation",
            "content": return_json
        })
        return

这段代码定义了一个名为 _extract_observation 的方法,它接受一个参数 prompt,该参数是一个字符串。

以下是这段代码的详细解析:

  • return_json = prompt.split(“Observation: “)[-1].split(”\nThought:”)[0]:从 prompt 字符串中提取观察信息。它首先根据特定的分隔符将 prompt 字符串进行切分,然后选择切分结果中的最后一个元素,并再次根据特定的分隔符将其进行切分,最后选择切分结果中的第一个元素。

  • self.history.append({…}):将一个字典对象添加到类的成员变量 history 列表中。字典包含以下键值对:

    • “role”: “observation”:角色为 “observation”,表示观察的角色。
    • “content”: return_json:内容为观察信息。
  • return:没有指定返回值,默认返回 None。

总体而言,这段代码定义了一个方法 _extract_observation,它从 prompt 字符串中提取观察信息,并将观察信息添加到类的历史记录列表 history 中。

    def _extract_tool(self):
        if len(self.history[-1]["metadata"]) > 0:
            metadata = self.history[-1]["metadata"]
            content = self.history[-1]["content"]
            if "tool_call" in content:
                for tool in self.tool_names:
                    if tool in metadata:
                        input_para = content.split("='")[-1].split("'")[0]
                        action_json = {
                            "action": tool,
                            "action_input": input_para
                        }
                        self.has_search = True
                        return f"""
Action: 

{json.dumps(action_json, ensure_ascii=False)}

        final_answer_json = {
            "action": "Final Answer",
            "action_input": self.history[-1]["content"]
        }
        self.has_search = False
        return f"""
Action: 

{json.dumps(final_answer_json, ensure_ascii=False)}


    def _call(self, prompt: str, history: List = [], stop: Optional[List[str]] = ["<|user|>"]):
        print("======")
        print(prompt)
        print("======")
        if not self.has_search:
            self.history, query = self._tool_history(prompt)
        else:
            self._extract_observation(prompt)
            query = ""
        # print("======")
        # print(history)
        # print("======")
        _, self.history = self.model.chat(
            self.tokenizer,
            query,
            history=self.history,
            do_sample=self.do_sample,
            max_length=self.max_token,
            temperature=self.temperature,
        )
        response = self._extract_tool()
        history.append((prompt, response))
        return response

这段代码包含两个方法 _extract_tool 和 _call。

以下是这段代码的详细解析:

_extract_tool 方法:

  • if len(self.history[-1][“metadata”]) > 0::检查历史记录的最后一项是否包含元数据。如果包含元数据,则执行以下操作。

    • metadata = self.history[-1][“metadata”]:将元数据存储在变量 metadata 中。

    • content = self.history[-1][“content”]:将历史记录的最后一项的内容存储在变量 content 中。

    • if “tool_call” in content::检查内容中是否包含字符串 “tool_call”。如果包含,则执行以下操作。

      • for tool in self.tool_names::对工具名称列表进行遍历,循环变量 tool 表示当前遍历的工具名称。

        • if tool in metadata::检查工具名称是否存在于元数据中。如果存在,则执行以下操作。

          • input_para = content.split(“='”)[-1].split(“'”)[0]:从内容中提取输入参数。它首先根据特定的分隔符将内容进行切分,然后选择切分结果中的倒数第二个元素,并再次根据特定的分隔符将其进行切分,最后选择切分结果中的第一个元素。

          • action_json = {…}:构造一个包含动作和输入参数的字典对象。字典包含以下键值对:

            • “action”: tool:动作为工具名称。
            • “action_input”: input_para:动作输入参数为提取的输入参数。
          • self.has_search = True:将成员变量 has_search 设置为 True,表示存在搜索。

        • return f"…":返回一个包含动作 JSON 的字符串。字符串使用 Markdown 格式进行格式化,显示动作 JSON。

  • final_answer_json = {…}:构造一个包含最终回答动作的字典对象。字典包含以下键值对:

    • “action”: “Final Answer”:动作为 “Final Answer”。
    • “action_input”: self.history[-1][“content”]:动作输入参数为历史记录的最后一项的内容。
  • self.has_search = False:将成员变量 has_search 设置为 False,表示不存在搜索。

  • return f"…":返回一个包含动作 JSON 的字符串。字符串使用 Markdown 格式进行格式化,显示动作 JSON。

_call 方法:

  • print(“======”):打印分隔线。

  • print(prompt):打印传入的参数 prompt。

  • print(“======”):打印分隔线。

  • if not self.has_search::检查成员变量 has_search 是否为 False。如果为 False,表示不存在搜索,执行以下操作。

    • self.history, query = self._tool_history(prompt):调用 _tool_history 方法,获取历史记录和查询字符串。将返回的历史记录赋值给类的成员变量 history,将返回的查询字符串赋值给变量 query。
  • else::如果存在搜索

    • self._extract_observation(prompt):调用 _extract_observation 方法,从 prompt 中提取观察信息,并将其添加到类的历史记录列表 history 中。

    • query = “”:将查询字符串设置为空字符串。

  • _, self.history = self.model.chat(…):调用 model.chat 方法进行对话。它使用模型、分词器和其他参数来生成对话的响应。返回的结果包含生成的响应和更新后的历史记录。使用 _ 忽略生成的响应,将更新后的历史记录赋值给类的成员变量 history。

  • response = self._extract_tool():调用 _extract_tool 方法,从更新后的历史记录中提取工具动作并生成响应。

  • history.append((prompt, response)):将元组 (prompt, response) 添加到 history 列表中,用于记录对话历史。

  • return response:返回生成的响应作为结果。

总体而言,这段代码定义了两个方法 _extract_tool 和 _call。_extract_tool 方法用于从历史记录中提取工具动作,生成相应的响应。_call 方法用于处理对话流程,包括提取观察信息、生成对话响应和记录对话历史。

2-3. Tool/Calculator.py

import abc
import math
from typing import Any

from langchain.tools import BaseTool


class Calculator(BaseTool, abc.ABC):
    name = "Calculator"
    description = "Useful for when you need to answer questions about math"

    def __init__(self):
        super().__init__()

    async def _arun(self, *args: Any, **kwargs: Any) -> Any:
        # 用例中没有用到 arun 不予具体实现
        pass


    def _run(self, para: str) -> str:
        para = para.replace("^", "**")
        if "sqrt" in para:
            para = para.replace("sqrt", "math.sqrt")
        elif "log" in para:
            para = para.replace("log", "math.log")
        return eval(para)


if __name__ == "__main__":
    calculator_tool = Calculator()
    result = calculator_tool.run("sqrt(2) + 3")
    print(result)

这段代码定义了一个名为 Calculator 的类,它继承自 BaseTool 类和 abc.ABC 抽象基类。Calculator 类是一个计算器工具,用于执行数学计算操作。

以下是这段代码的详细解析:

  • import abc:导入 abc 模块,用于定义抽象基类。

  • import math:导入 math 模块,用于执行数学计算操作。

  • from typing import Any:从 typing 模块导入 Any 类型,用于函数参数和返回值的类型注解。

  • from langchain.tools import BaseTool:从 langchain.tools 模块导入 BaseTool 类,用作 Calculator 类的父类。

  • class Calculator(BaseTool, abc.ABC)::定义了一个名为 Calculator 的类,它继承自 BaseTool 类和 abc.ABC 抽象基类。Calculator 类表示一个计算器工具。

    • name = “Calculator”:类属性 name 被设置为字符串 “Calculator”,表示工具的名称。

    • description = “Useful for when you need to answer questions about math”:类属性 description 被设置为字符串 “Useful for when you need to answer questions about math”,表示工具的描述信息。

    • def init(self)::构造函数,用于初始化 Calculator 类的实例。它调用父类的构造函数 super().init() 来完成初始化。

    • async def _arun(self, *args: Any, **kwargs: Any) -> Any::定义了一个异步方法 _arun,它接受任意数量的位置参数和关键字参数,并返回任意类型的值。在这段代码中,_arun 方法没有具体的实现,因为在示例中没有使用到它。

    • def _run(self, para: str) -> str::定义了一个方法 _run,它接受一个字符串参数 para,并返回一个字符串。在这个方法中,它首先对参数 para 进行替换操作,将字符串中的 “^” 替换为 “**”。然后,它检查参数 para 中是否包含 “sqrt”,如果是,则将 “sqrt” 替换为 “math.sqrt”;如果参数 para 中包含 “log”,则将 “log” 替换为 “math.log”。最后,使用 eval 函数来执行参数 para 的计算操作,并返回计算结果。

请注意,这段代码中的 BaseTool 类没有给出具体的定义,因此我无法提供关于它的更多详细信息。如果您能提供 BaseTool 类的定义或相关代码,我将能够给出更准确的解释。

2-4. Tool/Weather.py

import os
from typing import Any

import requests
from langchain.tools import BaseTool


class Weather(BaseTool):
    name = "weather"
    description = "Use for searching weather at a specific location"

    def __init__(self):
        super().__init__()

    async def _arun(self, *args: Any, **kwargs: Any) -> Any:
        # 用例中没有用到 arun 不予具体实现
        pass

    def get_weather(self, location):
        api_key = os.environ["SENIVERSE_KEY"]
        url = f"https://api.seniverse.com/v3/weather/now.json?key={api_key}&location={location}&language=zh-Hans&unit=c"
        response = requests.get(url)
        if response.status_code == 200:
            data = response.json()
            weather = {
                "temperature": data["results"][0]["now"]["temperature"],
                "description": data["results"][0]["now"]["text"],
            }
            return weather
        else:
            raise Exception(
                f"Failed to retrieve weather: {response.status_code}")

    def _run(self, para: str) -> str:
        return self.get_weather(para)


if __name__ == "__main__":
    weather_tool = Weather()
    weather_info = weather_tool.run("成都")
    print(weather_info)

这段代码定义了一个名为 Weather 的类,它继承自 BaseTool 类。Weather 类是一个天气工具,用于查询特定位置的天气信息。

以下是代码的详细解析:

  • import os:导入 os 模块,用于访问操作系统的功能,例如环境变量。

  • from typing import Any:从 typing 模块导入 Any 类型,用于函数参数和返回值的类型注解。

  • import requests:导入 requests 模块,用于发送 HTTP 请求。

  • from langchain.tools import BaseTool:从 langchain.tools 模块导入 BaseTool 类,用作 Weather 类的父类。

  • class Weather(BaseTool)::定义了一个名为 Weather 的类,它继承自 BaseTool 类。Weather 类表示一个天气工具。

    • name = “weather”:类属性 name 被设置为字符串 “weather”,表示工具的名称。

    • description = “Use for searching weather at a specific location”:类属性 description 被设置为字符串 “Use for searching weather at a specific location”,表示工具的描述信息。

    • def init(self)::构造函数,用于初始化 Weather 类的实例。它调用父类的构造函数 super().init() 来完成初始化。

    • async def _arun(self, *args: Any, **kwargs: Any) -> Any::定义了一个异步方法 _arun,它接受任意数量的位置参数和关键字参数,并返回任意类型的值。在这段代码中,_arun 方法没有具体的实现,因为在示例中没有使用到它。

    • def get_weather(self, location)::定义了一个方法 get_weather,它接受一个参数 location,表示要查询的位置。在这个方法中,它使用 os.environ 字典获取名为 “SENIVERSE_KEY” 的环境变量作为 API 密钥。然后,它构建一个 URL,使用 requests.get 方法发送 GET 请求到该 URL,并获取响应结果。如果响应的状态码为 200(表示请求成功),则解析响应的 JSON 数据,并返回包含温度和天气描述的字典。如果响应的状态码不为 200,则抛出异常。

    • def _run(self, para: str) -> str::定义了一个方法 _run,它接受一个字符串参数 para,并返回一个字符串。在这个方法中,它调用 get_weather 方法,传入参数 para(表示要查询的位置),并返回查询到的天气信息。

2-5. main.py

from typing import List
from ChatGLM3 import ChatGLM3

from langchain.agents import load_tools
from Tool.Weather import Weather
from Tool.Calculator import Calculator

from langchain.agents import initialize_agent
from langchain.agents import AgentType


def run_tool(tools, llm, prompt_chain: List[str]):
    loaded_tolls = []
    for tool in tools:
        if isinstance(tool, str):
            loaded_tolls.append(load_tools([tool], llm=llm)[0])
        else:
            loaded_tolls.append(tool)
    agent = initialize_agent(
        loaded_tolls, llm,
        agent=AgentType.STRUCTURED_CHAT_ZERO_SHOT_REACT_DESCRIPTION,
        verbose=True,
        handle_parsing_errors=True
    )
    for prompt in prompt_chain:
        agent.run(prompt)


if __name__ == "__main__":
    model_path = "THUDM/chatglm3-6b"
    llm = ChatGLM3()
    llm.load_model(model_name_or_path=model_path)

    # arxiv: 单个工具调用示例 1
    run_tool(["arxiv"], llm, [
        "帮我查询GLM-130B相关工作"
    ])

    # weather: 单个工具调用示例 2
    # run_tool([Weather()], llm, [
    #     "今天北京天气怎么样?",
    #     "What's the weather like in Shanghai today",
    # ])

    # calculator: 单个工具调用示例 3
    run_tool([Calculator()], llm, [
        "12345679乘以54等于多少?",
        "3.14的3.14次方等于多少?",
        "根号2加上根号三等于多少?",
    ]),

    # arxiv + weather + calculator: 多个工具结合调用
    # run_tool([Calculator(), "arxiv", Weather()], llm, [
    #     "帮我检索GLM-130B相关论文",
    #     "今天北京天气怎么样?",
    #     "根号3减去根号二再加上4等于多少?",
    # ])

这段代码是一个示例程序,演示了如何使用 langchain 库来调用不同的工具进行自然语言处理。

以下是代码的详细解析:

  • from typing import List:从 typing 模块导入 List 类型,用于函数参数的类型注解。

  • from ChatGLM3 import ChatGLM3:从 ChatGLM3 模块导入 ChatGLM3 类,用于创建语言模型。

  • from langchain.agents import load_tools:从 langchain.agents 模块导入 load_tools 函数,用于加载工具。

  • from Tool.Weather import Weather:从 Tool.Weather 模块导入 Weather 类,用于天气查询工具。

  • from Tool.Calculator import Calculator:从 Tool.Calculator 模块导入 Calculator 类,用于计算器工具。

  • from langchain.agents import initialize_agent:从 langchain.agents 模块导入 initialize_agent 函数,用于初始化代理。

  • from langchain.agents import AgentType:从 langchain.agents 模块导入 AgentType 枚举,用于指定代理类型。

  • def run_tool(tools, llm, prompt_chain: List[str])::定义了一个函数 run_tool,它接受三个参数:tools(要加载的工具列表),llm(语言模型实例),prompt_chain(要运行的提示列表)。

    • loaded_tolls = []:创建一个空列表 loaded_tolls,用于存储加载后的工具。
    • for tool in tools::对于 tools 列表中的每个工具:
      • if isinstance(tool, str)::如果工具是字符串类型,表示需要加载的是一个工具模块:
        • loaded_tolls.append(load_tools([tool], llm=llm)[0]):加载指定的工具模块,并将返回的工具实例添加到 loaded_tolls 列表中。
    • else::如果工具不是字符串类型,表示已经是一个工具实例:
      • loaded_tolls.append(tool):直接将工具实例添加到 loaded_tolls 列表中。
  • agent = initialize_agent(loaded_tolls, llm, agent=AgentType.STRUCTURED_CHAT_ZERO_SHOT_REACT_DESCRIPTION, verbose=True, handle_parsing_errors=True):使用加载后的工具和语言模型实例初始化代理。AgentType.STRUCTURED_CHAT_ZERO_SHOT_REACT_DESCRIPTION 表示采用结构化对话的零次推理反应描述的代理类型。

    • for prompt in prompt_chain::对于 prompt_chain 列表中的每个提示:
    • agent.run(prompt):使用代理运行提示。
  • if name == “main”::如果当前模块被直接执行(而不是被导入为模块):

    • model_path = “THUDM/chatglm3-6b”:设置语言模型的路径。
    • llm = ChatGLM3():创建 ChatGLM3 类的实例,用于创建语言模型。
  • llm.load_model(model_name_or_path=model_path):加载指定路径的语言模型。

  • run_tool([“arxiv”], llm, [“帮我查询GLM-130B相关工作”]):调用 run_tool 函数,使用 arxiv 工具和语言模型实例 llm,并传入一个提示列表 [“帮我查询GLM-130B相关工作”],以执行相关工作的查询。

  • run_tool([Calculator()], llm, [“12345679乘以54等于多少?”, “3.14的3.14次方等于多少?”, “根号2加上根号三等于多少?”]):调用 run_tool 函数,使用 Calculator 工具和语言模型实例 llm,并传入一个提示列表,以执行计算器工具的计算。

注释掉的代码块是其他工具的调用示例,包括天气查询工具和多个工具的结合调用。

这段代码演示了如何使用 langchain 库,加载不同的工具和语言模型,然后通过代理来运行自然语言提示,以执行不同的任务,例如查询相关工作、查询天气和进行计算等。

完结!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/133834.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

SpringMvc 常见面试题

1、SpringMvc概述 1.1、什么是Spring MVC &#xff1f;简单介绍下你对springMVC的理解? Spring MVC是一个基于Java的实现了MVC设计模式的请求驱动类型的轻量级Web框架&#xff0c;通过把Model&#xff0c;View&#xff0c;Controller分离&#xff0c;将web层进行职责解耦&am…

C++算法:矩阵中的最长递增路径

涉及知识点 拓扑排序 题目 给定一个 m x n 整数矩阵 matrix &#xff0c;找出其中 最长递增路径 的长度。 对于每个单元格&#xff0c;你可以往上&#xff0c;下&#xff0c;左&#xff0c;右四个方向移动。 你 不能 在 对角线 方向上移动或移动到 边界外&#xff08;即不允…

学习美团推荐系统质量模型建设

目录 一、背景引入 &#xff08;一&#xff09;基本背景说明 &#xff08;二&#xff09;从推荐系统“数据飞轮”看质量建设必要性 二、质量的定位和考量思考 &#xff08;一&#xff09;对推荐系统质量的思考迭代 &#xff08;二&#xff09;可用性计算的关注点 &#…

FreeRTOS源码阅读笔记3--queue.c

消息队列可以应用于发送不定长消息的场合&#xff0c;包括任务与任务间的消息交换&#xff0c;队列是 FreeRTOS 主要的任务间通讯方式&#xff0c;可以在任务与任务间、中断和任务间传送信息&#xff0c;发送到 队列的消息是通过拷贝方式实现的&#xff0c;这意味着队列存储…

【原创】java+swing+mysql爱心捐赠管理系统设计与实现

摘要&#xff1a; 爱心捐赠管理系统旨在管理和优化捐赠过程&#xff0c;提高效率&#xff0c;增强透明度&#xff0c;并鼓励更多的个人和企业参与公益捐赠&#xff0c;用户可以捐款或者捐物。本系统采用javaswing界面可视化技术&#xff0c;数据库使用mysql。 功能分析&#…

Python高级语法----深入理解Python迭代器与生成器

文章目录 1. 迭代器协议代码示例:2. 生成器基础代码示例:3. 使用yield的高级技巧代码示例:4. 生成器表达式代码示例:迭代器和生成器是Python中实现迭代的两种主要方式,它们都允许用户创建可以遍历数据集的对象。在Python中,迭代器协议是指对象需要遵守__iter__()和__next…

【大数据】NiFi 中的处理器(一):GenerateTableFetch

NiFi 中的处理器&#xff08;一&#xff09;&#xff1a;GenerateTableFetch 1.简介2.应用场景3.示例3.1 案例一&#xff1a;无输入流文件&#xff0c;来源表含增量字段3.2 案例二&#xff1a;无输入流文件&#xff0c;不含增量字段3.3 案例三&#xff1a;无输入流文件&#xf…

通用文件在线预览软件kkFileView

什么是 kkFileView &#xff1f; kkFileView 为文件文档在线预览解决方案&#xff0c;基本支持主流办公文档的在线预览&#xff0c;如 doc&#xff0c;docx&#xff0c;xls&#xff0c;xlsx&#xff0c;ppt&#xff0c;pptx&#xff0c;pdf&#xff0c;txt&#xff0c;zip&…

如何配置《动手学强化学习》的环境

如何配置《动手学强化学习》的环境 网站&#xff1a;https://hrl.boyuai.com/chapter/intro github仓库&#xff1a;https://github.com/boyu-ai/Hands-on-RL/tree/main 可以看到该教程要求使用gym0.18.3版本的gym库&#xff0c;本教程可以用于解决绝大多数需要使用Pendulum-…

科力雷达Lidar使用指南

科力2D Lidar使用指南 作者&#xff1a; Herman Ye Galbot Auromix 版本&#xff1a; V1.0 测试环境&#xff1a; Ubuntu20.04(x86) PC 以及 Ubuntu20.04(Arm) Nvidia Orin 更新日期&#xff1a; 2023/11/11 注1&#xff1a; 本文内容中的硬件由 Galbot 提供支持。 注2&#x…

力扣100题——子串

560.和为k的子数组 这道题目不是滑动窗口的类型&#xff0c;因为长度并不是固定的。&#xff08;好的&#xff0c;我在说废话&#xff09; 注意题目要求是子数组&#xff0c;且是连贯的。那这里的话&#xff0c;解法有很多&#xff0c;最简单的就是暴力解法&#xff0c;但在这…

无缝集成GORM与Go Web框架

探索GORM与流行的Go Web框架之间的和谐集成&#xff0c;以实现高效的数据管理 高效的数据管理是每个成功的Web应用程序的基础。GORM&#xff0c;多才多艺的Go对象关系映射库&#xff0c;与流行的Go Web框架非常搭配&#xff0c;提供了无缝集成&#xff0c;简化了数据交互。本指…

Git可视化界面的操作,SSH协议的以及IDEA集成Git

目录 一. Git可视化界面的操作 二. gitee的ssh key 2.1 SSH协议 2.2 ssh key 三. IDEA集成Git 3.1 分享项目 3.2 下载项目 一. Git可视化界面的操作 上一篇博客只用到了git的命令窗口&#xff0c;现在就来看看可视化窗口要怎么操作。 点击Git GUI Here GUI界面 在g…

【Git】git常用命令大全

&#x1f389;&#x1f389;欢迎来到我的CSDN主页&#xff01;&#x1f389;&#x1f389; &#x1f3c5;我是Java方文山&#xff0c;一个在CSDN分享笔记的博主。&#x1f4da;&#x1f4da; &#x1f31f;推荐给大家我的专栏《Git》。&#x1f3af;&#x1f3af; &#x1f449…

afsim 下载链接

afsim是一个通用的建模框架&#xff0c;能够构建典型的虚拟威胁环境和相关模型。能够以可视化形式分析软件仿真结果&#xff0c;显示平台、路由、传感器区域等内容&#xff0c;能够基于事件生成图表&#xff0c;进行结果统计&#xff0c;能够按类型进行统计分析。 苦于网上没有…

【Git】Git分支与应用分支

一&#xff0c;Git分支 1.1 理解Git分支 在 Git 中&#xff0c;分支是指一个独立的代码线&#xff0c;并且可以在这个分支上添加、修改和删除文件&#xff0c;同时作为另一个独立的代码线存在。一个仓库可以有多个分支&#xff0c;不同的分支可以独立开发不同的功能&#xff0…

maven教程

1. Maven概述 1.1 Maven的功能 1、Maven 作为依赖管理工具 随着我们使用越来越多的框架&#xff0c;或者框架封装程度越来越高&#xff0c;项目中使用的jar包也越来越多。项目中&#xff0c;一个模块里面用到上百个jar包是非常正常的。jar包所属技术的官网通常是英文界面&am…

极智芯 | 存算一体 弯道超车的希望

欢迎关注我的公众号 [极智视界]&#xff0c;获取我的更多经验分享 大家好&#xff0c;我是极智视界&#xff0c;本文分享一下 存算一体 弯道超车的希望。 邀您加入我的知识星球「极智视界」&#xff0c;星球内有超多好玩的项目实战源码和资源下载&#xff0c;链接&#xff1a;…

【C++笔记】优先级队列priority_queue的模拟实现

【C笔记】优先级队列priority_queue的模拟实现 一、优先级队列的介绍与使用方式1.1、优先级队列介绍1.2、优先级队列的常见使用 二、优先级队列的模拟实现1.0、仿函数的介绍1.1、构造函数1.2、优先级队列的插入push1.3、优先级队列的删除(删除堆顶元素)1.4、获取堆顶元素1.5、判…

MATLAB仿真通信系统的眼图

eyediagram eyediagram(complex(used_i,used_q),1100)
最新文章