手把手教你在 LangChain 中创建自定义大语言模型类

在使用 LangChain 开发大语言模型应用时,我们常常会遇到这样的需求:需要集成自研的 LLM 模型,或者使用框架尚未原生支持的第三方模型。这时候,自定义 LLM 类就成了必经之路。今天,我们就来深入探讨如何通过简单的几步,在 LangChain 中构建属于自己的 LLM 包装器,让自定义模型无缝融入现有生态。

一、为什么需要自定义 LLM 类

当我们面临以下场景时,自定义 LLM 类就显得尤为重要:

  • 使用自研模型:团队训练了专属业务场景的模型,需要接入 LangChain 框架
  • 集成第三方 API:框架尚未支持的模型服务(如国内大模型 API)
  • 定制化功能:需要在模型调用前后添加自定义处理逻辑
  • 性能优化:针对特定硬件或部署环境进行底层优化

通过 LangChain 的标准 LLM 接口封装模型后,我们能获得这些额外福利:

  • 自动支持 LangChain 的 Runnable 接口
  • 开箱即用的异步调用能力
  • 原生支持流式输出和事件回调
  • 无缝集成 LangSmith 等生态工具

二、自定义 LLM 的核心实现要点

1. 必须实现的核心方法

自定义 LLM 类需要实现两个基础方法,这是与 LangChain 生态交互的基础:

python

from langchain_core.language_models.llms import LLM

class CustomLLM(LLM):
    # 核心推理方法,接受提示词并返回模型输出
    def _call(
        self,
        prompt: str,
        stop: Optional[List[str]] = None,
        run_manager: Optional[CallbackManagerForLLMRun] = None,
        **kwargs: Any,
    ) -> str:
        # 实现模型调用逻辑
        pass
    
    # 模型类型标识,仅用于日志记录
    @property
    def _llm_type(self) -> str:
        return "custom"  # 返回自定义类型标识

2. 可选增强功能实现

为了让自定义 LLM 更强大,还可以实现这些可选方法:

python

from typing import Dict, Iterator
from langchain_core.outputs import GenerationChunk
from langchain_core.callbacks.manager import CallbackManagerForLLMRun

class CustomLLM(LLM):
    # 模型标识参数,用于打印和识别模型
    @property
    def _identifying_params(self) -> Dict[str, Any]:
        return {"model_name": "MyCustomModel"}
    
    # 异步调用实现
    async def _acall(
        self,
        prompt: str,
        stop: Optional[List[str]] = None,
        run_manager: Optional[CallbackManagerForLLMRun] = None,
        **kwargs: Any,
    ) -> str:
        # 异步模型调用逻辑
        pass
    
    # 流式输出实现
    def _stream(
        self,
        prompt: str,
        stop: Optional[List[str]] = None,
        run_manager: Optional[CallbackManagerForLLMRun] = None,
        **kwargs: Any,
    ) -> Iterator[GenerationChunk]:
        # 逐个token生成的逻辑
        pass
    
    # 异步流式输出实现
    async def _astream(
        self,
        prompt: str,
        stop: Optional[List[str]] = None,
        run_manager: Optional[CallbackManagerForLLMRun] = None,
        **kwargs: Any,
    ) -> Iterator[GenerationChunk]:
        # 异步流式生成逻辑
        pass

三、实战案例:实现一个简单的自定义 LLM

1. 完整实现代码

下面我们实现一个极简的 CustomLLM,它的功能是返回输入字符串的前 n 个字符:

python

from typing import Any, Dict, Iterator, List, Optional
from langchain_core.callbacks.manager import CallbackManagerForLLMRun
from langchain_core.language_models.llms import LLM
from langchain_core.outputs import GenerationChunk

class CustomLLM(LLM):
    """自定义LLM示例,返回输入的前n个字符"""
    
    n: int
    """返回输入字符串的前n个字符"""
    
    def _call(
        self,
        prompt: str,
        stop: Optional[List[str]] = None,
        run_manager: Optional[CallbackManagerForLLMRun] = None,
        **kwargs: Any,
    ) -> str:
        """核心调用方法,实现模型推理逻辑"""
        if stop is not None:
            raise ValueError("此模型不支持stop参数")
        return prompt[: self.n]
    
    def _stream(
        self,
        prompt: str,
        stop: Optional[List[str]] = None,
        run_manager: Optional[CallbackManagerForLLMRun] = None,
        **kwargs: Any,
    ) -> Iterator[GenerationChunk]:
        """流式输出实现,逐个字符生成"""
        for char in prompt[: self.n]:
            chunk = GenerationChunk(text=char)
            if run_manager:
                run_manager.on_llm_new_token(char, chunk=chunk)
            yield chunk
    
    @property
    def _identifying_params(self) -> Dict[str, Any]:
        """返回模型标识参数,用于打印和监控"""
        return {"model_name": "CustomChatModel"}
    
    @property
    def _llm_type(self) -> str:
        """返回模型类型,用于日志记录"""
        return "custom"

2. 功能测试与验证

我们来测试这个自定义 LLM 的各项功能:

python

# 初始化模型,设置返回前5个字符
llm = CustomLLM(n=5)

# 同步调用测试
print(llm.invoke("This is a test prompt"))  # 输出: This 

# 异步调用测试
import asyncio
async_result = asyncio.run(llm.ainvoke("hello world"))
print(async_result)  # 输出: hello

# 批量调用测试
batch_result = llm.batch(["woof woof", "meow meow"])
print(batch_result)  # 输出: ['woof ', 'meow ']

# 流式输出测试
async def test_stream():
    async for token in llm.astream("hello there"):
        print(token, end="|", flush=True)
# 输出: h|e|l|l|o|

# 与其他LangChain组件集成测试
from langchain_core.prompts import ChatPromptTemplate

prompt = ChatPromptTemplate.from_messages(
    [("system", "你是一个助手"), ("human", "{input}")]
)

chain = prompt | llm
async def test_chain():
    async for event in chain.astream_events({"input": "你好世界"}, version="v1"):
        print(event.get("data", {}).get("chunk", ""), end="", flush=True)
# 输出: 你好世界

测试结果显示,我们的自定义 LLM 完美支持 LangChain 的各种核心功能,包括同步 / 异步调用、批量处理、流式输出以及与提示模板的无缝集成。

四、进阶实践:集成真实模型的关键步骤

1. 接入第三方 API 的实现框架

以下是集成外部 API 的典型实现结构:

python

import requests
from langchain_core.language_models.llms import LLM

class ThirdPartyLLM(LLM):
    api_key: str
    api_url: str
    
    def _call(
        self,
        prompt: str,
        stop: Optional[List[str]] = None,
        **kwargs: Any,
    ) -> str:
        # 构建API请求
        headers = {"Authorization": f"Bearer {self.api_key}"}
        payload = {"prompt": prompt, "stop": stop, **kwargs}
        
        # 发送请求
        response = requests.post(self.api_url, json=payload)
        response.raise_for_status()
        
        # 解析响应
        return response.json()["choices"][0]["text"]
    
    @property
    def _llm_type(self) -> str:
        return "third_party"

2. 关键优化点

在实际集成中,需要注意这些优化方向:

  • 参数映射:将 LangChain 的参数(如 temperature)映射到第三方 API 的对应参数
  • 错误处理:完善 API 调用的错误捕获和重试机制
  • 令牌计数:实现准确的令牌计数逻辑,避免超出 API 限制
  • 缓存策略:添加响应缓存,减少重复调用开销
  • 流式处理:如果 API 支持,实现完整的流式输出管道

五、总结

通过自定义 LLM 类,我们打破了框架的限制,实现了模型接入的完全自由度。从简单的字符截取示例,到复杂的第三方 API 集成,LangChain 的自定义接口为我们提供了强大的扩展能力。

在实际项目中,自定义 LLM 类特别适合以下场景:

  • 企业自研模型的工程化落地
  • 多模型融合的混合架构
  • 特定领域的模型优化
  • 低成本的模型推理方案

如果本文对你有帮助,别忘了点赞收藏,关注我,一起探索更高效的开发方式~

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

佑瞻

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值