LangChain之关于RetrievalQA input_variables 的定义与使用

最近在使用LangChain来做一个LLMs和KBs结合的小Demo玩玩,也就是RAG(Retrieval Augmented Generation)。
这部分的内容其实在LangChain的官网已经给出了流程图。在这里插入图片描述
我这里就直接偷懒了,准备对Webui的项目进行复刻练习,那么接下来就是照着葫芦画瓢就行。
那么我卡在了Retrieve这一步。先放有疑惑地方的代码:

if web_content:
            prompt_template = f"""基于以下已知信息,简洁和专业的来回答用户的问题。
                                如果无法从中得到答案,请说 "根据已知信息无法回答该问题" 或 "没有提供足够的相关信息",不允许在答案中添加编造成分,答案请使用中文。
                                已知网络检索内容:{web_content}""" + """
                                已知内容:
                                {context}
                                问题:
                                {question}"""
        else:
            prompt_template = """基于以下已知信息,请简洁并专业地回答用户的问题。
                如果无法从中得到答案,请说 "根据已知信息无法回答该问题" 或 "没有提供足够的相关信息"。不允许在答案中添加编造成分。另外,答案请使用中文。

                已知内容:
                {context}

                问题:
                {question}"""
        prompt = PromptTemplate(template=prompt_template,
                                input_variables=["context", "question"])
        ......

        knowledge_chain = RetrievalQA.from_llm(
            llm=self.llm,
            retriever=vector_store.as_retriever(
                search_kwargs={"k": self.top_k}),
            prompt=prompt)
        knowledge_chain.combine_documents_chain.document_prompt = PromptTemplate(
            input_variables=["page_content"], template="{page_content}")

        knowledge_chain.return_source_documents = True

        result = knowledge_chain({"query": query})
        return result

我对prompt_templateknowledge_chain.combine_documents_chain.document_prompt result = knowledge_chain({"query": query})这三个地方的input_key不明白为啥一定要这样设置。虽然我也看了LangChain的API文档。但是我并未得到详细的答案,那么只能一行行看源码是到底怎么设置的了。

注意:由于LangChain是一层层封装的,那么result = knowledge_chain({"query": query})可以认为是最外层,那么我们先看最外层。

result = knowledge_chain({“query”: query})

其实这部分是直接与用户的输入问题做对接的,我们只需要定位到RetrievalQA这个类就可以了,下面是RetrievalQA这个类的实现:

class RetrievalQA(BaseRetrievalQA):
    """Chain for question-answering against an index.
    Example:
        .. code-block:: python
            from langchain.llms import OpenAI
            from langchain.chains import RetrievalQA
            from langchain.vectorstores import FAISS
            from langchain.schema.vectorstore import VectorStoreRetriever
            retriever = VectorStoreRetriever(vectorstore=FAISS(...))
            retrievalQA = RetrievalQA.from_llm(llm=OpenAI(), retriever=retriever)

    """

    retriever: BaseRetriever = Field(exclude=True)

    def _get_docs(
        self,
        question: str,
        *,
        run_manager: CallbackManagerForChainRun,
    ) -> List[Document]:
        """Get docs."""
        return self.retriever.get_relevant_documents(
            question, callbacks=run_manager.get_child()
        )

    async def _aget_docs(
        self,
        question: str,
        *,
        run_manager: AsyncCallbackManagerForChainRun,
    ) -> List[Document]:
        """Get docs."""
        return await self.retriever.aget_relevant_documents(
            question, callbacks=run_manager.get_child()
        )

    @property
    def _chain_type(self) -> str:
        """Return the chain type."""
        return "retrieval_qa"

可以看到其继承了BaseRetrievalQA这个父类,同时对_get_docs这个抽象方法进行了实现。

这里要扩展的说一下,_get_docs这个方法就是利用向量相似性,在vector Base中选择与embedding之后的query最近似的Document结果。然后作为RetrievalQA的上下文。具体只需要看BaseRetrievalQA这个方法的_call和就可以了。
接下来我们只需要看BaseRetrievalQA这个类的属性就可以了。

class BaseRetrievalQA(Chain):
    """Base class for question-answering chains."""

    combine_documents_chain: BaseCombineDocumentsChain
    """Chain to use to combine the documents."""
    input_key: str = "query"  #: :meta private:
    output_key: str = "result"  #: :meta private:
    return_source_documents: bool = False
    """Return the source documents or not."""
    ……
    def _call(
        self,
        inputs: Dict[str, Any],
        run_manager: Optional[CallbackManagerForChainRun] = None,
    ) -> Dict[str, Any]:
        """Run get_relevant_text and llm on input query.

        If chain has 'return_source_documents' as 'True', returns
        the retrieved documents as well under the key 'source_documents'.

        Example:
        .. code-block:: python

        res = indexqa({'query': 'This is my query'})
        answer, docs = res['result'], res['source_documents']
        """
        _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
        question = inputs[self.input_key]
        accepts_run_manager = (
            "run_manager" in inspect.signature(self._get_docs).parameters
        )
        if accepts_run_manager:
            docs = self._get_docs(question, run_manager=_run_manager)
        else:
            docs = self._get_docs(question)  # type: ignore[call-arg]
        answer = self.combine_documents_chain.run(
            input_documents=docs, question=question, callbacks=_run_manager.get_child()
        )

        if self.return_source_documents:
            return {self.output_key: answer, "source_documents": docs}
        else:
            return {self.output_key: answer}

可以看到其有input_key这个属性,默认值是"query"。到这里我们就可以看到result = knowledge_chain({"query": query})是调用的BaseRetrievalQA_call,这里的question = inputs[self.input_key]就是其体现。

knowledge_chain.combine_documents_chain.document_prompt

这个地方一开始我很奇怪,为什么会重新定义呢?
我们可以先定位到,combine_documents_chain这个参数的位置,其是StuffDocumentsChain的方法。

@classmethod
def from_llm(
    cls,
    llm: BaseLanguageModel,
    prompt: Optional[PromptTemplate] = None,
    callbacks: Callbacks = None,
    **kwargs: Any,
) -> BaseRetrievalQA:
    """Initialize from LLM."""
    _prompt = prompt or PROMPT_SELECTOR.get_prompt(llm)
    llm_chain = LLMChain(llm=llm, prompt=_prompt, callbacks=callbacks)
    document_prompt = PromptTemplate(
        input_variables=["page_content"], template="Context:\n{page_content}"
    )
    combine_documents_chain = StuffDocumentsChain(
        llm_chain=llm_chain,
        document_variable_name="context",
        document_prompt=document_prompt,
        callbacks=callbacks,
    )

    return cls(
        combine_documents_chain=combine_documents_chain,
        callbacks=callbacks,
        **kwargs,
    )

可以看到原始的document_prompt中PromptTemplate的template是“Context:\n{page_content}”。因为这个项目是针对中文的,所以需要将英文的Context去掉。

扩展

  1. 这里PromptTemplate(input_variables=[“page_content”], template=“Context:\n{page_content}”)的input_variablestemplate为什么要这样定义呢?其实是根据Document这个数据对象来定义使用的,我们可以看到其数据格式为:Document(page_content=‘……’, metadata={‘source’: ‘……’, ‘row’: ……})
    那么input_variables的输入就是Document的page_content。
  2. StuffDocumentsChain中有一个参数是document_variable_name。那么这个类是这样定义的This chain takes a list of documents and first combines them into a single string. It does this by formatting each document into a string with the document_prompt and then joining them together with document_separator. It then adds that new string to the inputs with the variable name set by document_variable_name. Those inputs are then passed to the llm_chain. 这个document_variable_name简单来说就是在document_prompt中的占位符,用于在Chain中的使用。
    因此我们上文prompt_template变量中的“已知内容: {context}”,用的就是context这个变量。因此在prompt_template中换成其他的占位符都不能正常使用这个Chain。

prompt_template

在上面的拓展中其实已经对prompt_template做了部分的讲解,那么这个字符串还剩下“问题:{question}”这个地方没有说通
还是回归源码:

return cls(
        combine_documents_chain=combine_documents_chain,
        callbacks=callbacks,
        **kwargs,
    )

我们可以在from_llm函数中看到其返回值是到了_call,那么剩下的我们来看这个函数:


......
uestion = inputs[self.input_key]
accepts_run_manager = (
    "run_manager" in inspect.signature(self._get_docs).parameters
)
if accepts_run_manager:
    docs = self._get_docs(question, run_manager=_run_manager)
else:
    docs = self._get_docs(question)  # type: ignore[call-arg]
answer = self.combine_documents_chain.run(
    input_documents=docs, question=question, callbacks=_run_manager.get_child()
)
......

这里是在run这个函数中传入了一个字典值,这个字典值有三个参数。

注意:

  1. 这三个参数就是kwargs,也就是_validate_inputs的参数input;
  2. 此时已经是在Chain这个基本类了)
def run(
        self,
        *args: Any,
        callbacks: Callbacks = None,
        tags: Optional[List[str]] = None,
        metadata: Optional[Dict[str, Any]] = None,
        **kwargs: Any,
    ) -> Any:
       """Convenience method for executing chain.

        The main difference between this method and `Chain.__call__` is that this
        method expects inputs to be passed directly in as positional arguments or
        keyword arguments, whereas `Chain.__call__` expects a single input dictionary
        with all the inputs"""

接下来调用__call__:

def __call__(
        self,
        inputs: Union[Dict[str, Any], Any],
        return_only_outputs: bool = False,
        callbacks: Callbacks = None,
        *,
        tags: Optional[List[str]] = None,
        metadata: Optional[Dict[str, Any]] = None,
        run_name: Optional[str] = None,
        include_run_info: bool = False,
    ) -> Dict[str, Any]:
        """Execute the chain.

        Args:
            inputs: Dictionary of inputs, or single input if chain expects
                only one param. Should contain all inputs specified in
                `Chain.input_keys` except for inputs that will be set by the chain's
                memory.
            return_only_outputs: Whether to return only outputs in the
                response. If True, only new keys generated by this chain will be
                returned. If False, both input keys and new keys generated by this
                chain will be returned. Defaults to False.
            callbacks: Callbacks to use for this chain run. These will be called in
                addition to callbacks passed to the chain during construction, but only
                these runtime callbacks will propagate to calls to other objects.
            tags: List of string tags to pass to all callbacks. These will be passed in
                addition to tags passed to the chain during construction, but only
                these runtime tags will propagate to calls to other objects.
            metadata: Optional metadata associated with the chain. Defaults to None
            include_run_info: Whether to include run info in the response. Defaults
                to False.

        Returns:
            A dict of named outputs. Should contain all outputs specified in
                `Chain.output_keys`.
        """
        inputs = self.prep_inputs(inputs)
        ......

这里的prep_inputs会调用_validate_inputs函数

def _validate_inputs(self,inputs: Dict[str, Any]) -> None:
    """Check that all inputs are present."""
    missing_keys = set(self.input_keys).difference(inputs)
    if missing_keys:
        raise ValueError(f"Missing some input keys: {missing_keys}")

这里的input_keys通过调试,看到的就是有多个输入,分别是"input_documents"和"question"
这里的"input_documents"是来自于BaseCombineDocumentsChain

class BaseCombineDocumentsChain(Chain, ABC):
    """Base interface for chains combining documents.

    Subclasses of this chain deal with combining documents in a variety of
    ways. This base class exists to add some uniformity in the interface these types
    of chains should expose. Namely, they expect an input key related to the documents
    to use (default `input_documents`), and then also expose a method to calculate
    the length of a prompt from documents (useful for outside callers to use to
    determine whether it's safe to pass a list of documents into this chain or whether
    that will longer than the context length).
    """

    input_key: str = "input_documents"  #: :meta private:
    output_key: str = "output_text"  #: :meta private:

那为什么有两个呢,“question”来自于哪里?
StuffDocumentsChain继承BaseCombineDocumentsChain,其input_key是这样定义的:

  @property
  def input_keys(self) -> List[str]:
      extra_keys = [
          k for k in self.llm_chain.input_keys if k != self.document_variable_name
      ]
      return super().input_keys + extra_keys

原来是重写了input_keys函数,其是对llm_chain的input_keys进行遍历。

那么llm_chain的input_keys是用其prompt的input_variables。(这里的input_variables是PromptTemplate中的[“context”, “question”])

	@property
	def input_keys(self) -> List[str]:
	   """Will be whatever keys the prompt expects.
	   :meta private:
	   """
	   return self.prompt.input_variables

至此,我们StuffDocumentsChain的input_keys有两个变量了。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/120759.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

HTML的初步学习

HTML HTML 描述网页的骨架, 标签化的语言. HTML 的执行是浏览器的工作,浏览器会解析 html 的内容,根据里面的代码,往页面上放东西,浏览器的工作归根结底,还是以汇编的形式在CPU上执行. 浏览器对于html语法格式的检查没有很严格,即使你写的代码有一些不合规范之处,浏览器也会尽可…

GIS开发入门,TopoJSON格式是什么?TopoJSON格式与GeoJSON格式有什么不同?

TopoJSON介绍 TopoJSON是一种几何拓扑结构的地理数据格式,它使用拓扑结构来表示地理对象,可以更有效地压缩和转移数据,从而加快数据加载速度。 TopoJSON格式构成 TopoJSON文件由三部分组成,transform、objects和arcs组成。transform描述了变换参数; objects描述地理实体…

NCV7721D2R2G一款完全保护的双半桥驱动器 专为汽车工业运动控制解决方案

NCV7721D2R2G是一款完全保护的双半桥驱动器,专为汽车和工业运动控制应用而设计。两个半桥驱动器具有独立控制。这允许高侧、低侧和H桥控制。H桥控制提供正向、反向、制动和高阻抗状态。驱动器通过逻辑电平输入进行控制。 特性: 1.睡眠模式下的超低静态电…

生成无损压缩png和有损压缩png的做法

作者:朱金灿 来源:clever101的专栏 为什么大多数人学不会人工智能编程?>>> png是一种常用的图像格式。png一般为无损压缩,但是可以是有损压缩的。 下图都是100x100的png图像,一个是无损压缩,一个是有损压缩。 看着效果基本一样,但是它们的大小相差很大,无损…

WPF布局与控件分类

Refer:WPF从假入门到真的入门 - 知乎 (zhihu.com) Refer:WPF从假入门到真的入门 - 知乎 (zhihu.com) https://www.zhihu.com/column/c_1397867519101755392 https://blog.csdn.net/qq_44034384/article/details/106154954 https://www.cnblogs.com/mq0…

k8s之service五种负载均衡byte的区别

1,什么是Service? 1.1 Service的概念​ 在k8s中,service 是一个固定接入层,客户端可以通过访问 service 的 ip 和端口访问到 service 关联的后端pod,这个 service 工作依赖于在 kubernetes 集群之上部署的一个附件&a…

【黑马程序员】SpringCloud——Eureka

文章目录 前言一、提供者与消费者1. 服务调用关系 二、远程调用的问题三、eureka 原理分析1. eureka 的作用 四、Eureka 案例1. 搭建 eureka 服务1. 服务注册1.1 注册 user-service1.2 启动 user-service3. order-service 完成服务注册 3. 服务发现1. 在 order-service 完成服务…

把枯燥的PDF文档转换为翻页电子书,一键上传搞定

PDF是我们工作生活中比较常用的文档格式之一,由于PDF文件可以离线观看,所以通常都是静态的,有时候密密麻麻的文字看得很是头晕眼花,这使得阅读体验变得单调乏味。 为了解决这个问题 , 我们推荐使用FLBOOK &#xff0c…

JS逆向爬虫---响应结果加密⑤【token参数加密与DES解密】

https://spa7.scrape.center/ 文本数据 数据内嵌在js内,普通合理请求即可获取 图片 位于固定接口 类似https://spa7.scrape.center/img/durant.png 固定url名称 Token 参数确定 base64Name > base64编码后的中文名称 nodejs 代码 //导入crypto-js模块 var CryptoJS…

【AntDesign】Docker部署

docker部署是主流的部署方式,极大的方便了开发部署环境,保持了环境的统一,也是实现自动化部署的前提。 1 项目的目录结构 dist: 使用build打包命令,生成的打包目录 npm run build : 打包项目命令 docker: 存放docker容器需要修改…

Elasticsearch:ES|QL 中的数据丰富

在之前的文章 “Elasticsearch:ES|QL 查询语言简介”,我有介绍 ES|QL 的 ENRICH 处理命令。ES|QL ENRICH 处理命令在查询时将来自一个或多个源索引的数据与 Elasticsearch 丰富索引中找到的字段值组合相结合。这个有点类似于关系数据库查询中所使用的 jo…

【Proteus仿真】【Arduino单片机】OLED液晶显示

文章目录 一、功能简介二、软件设计三、实验现象联系作者 一、功能简介 本项目使用Proteus8仿真Arduino单片机控制器,使用IIC OLED液晶等。 主要功能: 系统运行后,OLED液晶显示各种图形、字符、图像。 二、软件设计 /* 作者:嗨小…

6个机器学习可解释性框架

1、SHAP SHapley Additive explanation (SHAP)是一种解释任何机器学习模型输出的博弈论方法。它利用博弈论中的经典Shapley值及其相关扩展将最优信贷分配与局部解释联系起来. 举例:基于随机森林模型的心脏病患者预测分类 数据集中每个特征对模型预测的贡献由Shap…

工业CT 三维重建 及分割

目录 工业CT介绍 工业CT主要应用于以下领域: CT三维重建软件: 效果: 工业CT介绍 工业CT设备是基于线阵探测器的断层扫描技术,是一种常用的无损检测技术,用于获取物体内部的准确三维结构信息。它通过X射线的投射和接…

计算机网络学习笔记(五):运输层(待更新)

5.1 概述 5.1.1 TCP协议的应用场景 TCP为应用层协议提供可靠传输,发送端按顺序发送,接收端按顺序接收,其间发送丢包、乱序,TCP负责重传和排序。下面是TCP的应用场景。 多次交互:客户端程序和服务端程序需要多次交互才…

SQL必知会(二)-SQL查询篇(2)-排序检索数据

第3课、排序检索数据 排序数据 OEDER BY:排序 进行排序 1)按单个列排序 需求: 以 prod_name 字段按照字母顺序排序 SELECT prod_name FROM Products ORDER BY prod_name; -- 以 prod_name 列按照字母顺序排序输出结果: 2&…

高等数学教材重难点题型总结(一)函数与极限

强化阶段的另一个专题,本专题主要总结高数课本上的经典例题与课后题,尤其一部分加*标的题目,对于冲击高分的同学来说,必须熟练掌握。 (蓝色代表难点,红色代表重点,紫色代表重难点) …

软文发布如何选择对应的媒体

企业做软文推广第一步,就是选择合适的媒体进行投放,然而许多企业不知道如何选择合适的媒体导致推广工作十分被动,无法取得效果,今天媒介盒子就来和大家分享,企业应该如何选择对应的媒体。 一、 媒体类型 根据软文类型…

【Python】python获取本机IP的两种方式

1.使用专用网络 通过进入网站:http://myip.ipip.net获取本机ip地址 代码实现: import requests res requests.get(http://myip.ipip.net, timeout5).text print(res) 也可以在终端cmd中用如下代码实现; curl http://myip.ipip.net 2.使用自带的socke…

光学仿真 | 仿真推动以人类视觉感知为本的汽车显示设计

如果产品设计无法使终端用户产生共鸣,就不会存在卓越的工程设计。您可以设计一种结构坚固的方向盘,但如果它被放在错误的位置,就无法实现其用于转向的主要目的。 同样,在围绕人类视觉进行设计时,显示器其实无需具备尽…
最新文章