在使用 LangChain 开发大语言模型应用时,我们常常会遇到这样的需求:需要集成自研的 LLM 模型,或者使用框架尚未原生支持的第三方模型。这时候,自定义 LLM 类就成了必经之路。今天,我们就来深入探讨如何通过简单的几步,在 LangChain 中构建属于自己的 LLM 包装器,让自定义模型无缝融入现有生态。
一、为什么需要自定义 LLM 类
当我们面临以下场景时,自定义 LLM 类就显得尤为重要:
- 使用自研模型:团队训练了专属业务场景的模型,需要接入 LangChain 框架
- 集成第三方 API:框架尚未支持的模型服务(如国内大模型 API)
- 定制化功能:需要在模型调用前后添加自定义处理逻辑
- 性能优化:针对特定硬件或部署环境进行底层优化
通过 LangChain 的标准 LLM 接口封装模型后,我们能获得这些额外福利:
- 自动支持 LangChain 的 Runnable 接口
- 开箱即用的异步调用能力
- 原生支持流式输出和事件回调
- 无缝集成 LangSmith 等生态工具
二、自定义 LLM 的核心实现要点
1. 必须实现的核心方法
自定义 LLM 类需要实现两个基础方法,这是与 LangChain 生态交互的基础:
python
from langchain_core.language_models.llms import LLM
class CustomLLM(LLM):
# 核心推理方法,接受提示词并返回模型输出
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
# 实现模型调用逻辑
pass
# 模型类型标识,仅用于日志记录
@property
def _llm_type(self) -> str:
return "custom" # 返回自定义类型标识
2. 可选增强功能实现
为了让自定义 LLM 更强大,还可以实现这些可选方法:
python
from typing import Dict, Iterator
from langchain_core.outputs import GenerationChunk
from langchain_core.callbacks.manager import CallbackManagerForLLMRun
class CustomLLM(LLM):
# 模型标识参数,用于打印和识别模型
@property
def _identifying_params(self) -> Dict[str, Any]:
return {"model_name": "MyCustomModel"}
# 异步调用实现
async def _acall(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
# 异步模型调用逻辑
pass
# 流式输出实现
def _stream(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[GenerationChunk]:
# 逐个token生成的逻辑
pass
# 异步流式输出实现
async def _astream(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[GenerationChunk]:
# 异步流式生成逻辑
pass
三、实战案例:实现一个简单的自定义 LLM
1. 完整实现代码
下面我们实现一个极简的 CustomLLM,它的功能是返回输入字符串的前 n 个字符:
python
from typing import Any, Dict, Iterator, List, Optional
from langchain_core.callbacks.manager import CallbackManagerForLLMRun
from langchain_core.language_models.llms import LLM
from langchain_core.outputs import GenerationChunk
class CustomLLM(LLM):
"""自定义LLM示例,返回输入的前n个字符"""
n: int
"""返回输入字符串的前n个字符"""
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""核心调用方法,实现模型推理逻辑"""
if stop is not None:
raise ValueError("此模型不支持stop参数")
return prompt[: self.n]
def _stream(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[GenerationChunk]:
"""流式输出实现,逐个字符生成"""
for char in prompt[: self.n]:
chunk = GenerationChunk(text=char)
if run_manager:
run_manager.on_llm_new_token(char, chunk=chunk)
yield chunk
@property
def _identifying_params(self) -> Dict[str, Any]:
"""返回模型标识参数,用于打印和监控"""
return {"model_name": "CustomChatModel"}
@property
def _llm_type(self) -> str:
"""返回模型类型,用于日志记录"""
return "custom"
2. 功能测试与验证
我们来测试这个自定义 LLM 的各项功能:
python
# 初始化模型,设置返回前5个字符
llm = CustomLLM(n=5)
# 同步调用测试
print(llm.invoke("This is a test prompt")) # 输出: This
# 异步调用测试
import asyncio
async_result = asyncio.run(llm.ainvoke("hello world"))
print(async_result) # 输出: hello
# 批量调用测试
batch_result = llm.batch(["woof woof", "meow meow"])
print(batch_result) # 输出: ['woof ', 'meow ']
# 流式输出测试
async def test_stream():
async for token in llm.astream("hello there"):
print(token, end="|", flush=True)
# 输出: h|e|l|l|o|
# 与其他LangChain组件集成测试
from langchain_core.prompts import ChatPromptTemplate
prompt = ChatPromptTemplate.from_messages(
[("system", "你是一个助手"), ("human", "{input}")]
)
chain = prompt | llm
async def test_chain():
async for event in chain.astream_events({"input": "你好世界"}, version="v1"):
print(event.get("data", {}).get("chunk", ""), end="", flush=True)
# 输出: 你好世界
测试结果显示,我们的自定义 LLM 完美支持 LangChain 的各种核心功能,包括同步 / 异步调用、批量处理、流式输出以及与提示模板的无缝集成。
四、进阶实践:集成真实模型的关键步骤
1. 接入第三方 API 的实现框架
以下是集成外部 API 的典型实现结构:
python
import requests
from langchain_core.language_models.llms import LLM
class ThirdPartyLLM(LLM):
api_key: str
api_url: str
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> str:
# 构建API请求
headers = {"Authorization": f"Bearer {self.api_key}"}
payload = {"prompt": prompt, "stop": stop, **kwargs}
# 发送请求
response = requests.post(self.api_url, json=payload)
response.raise_for_status()
# 解析响应
return response.json()["choices"][0]["text"]
@property
def _llm_type(self) -> str:
return "third_party"
2. 关键优化点
在实际集成中,需要注意这些优化方向:
- 参数映射:将 LangChain 的参数(如 temperature)映射到第三方 API 的对应参数
- 错误处理:完善 API 调用的错误捕获和重试机制
- 令牌计数:实现准确的令牌计数逻辑,避免超出 API 限制
- 缓存策略:添加响应缓存,减少重复调用开销
- 流式处理:如果 API 支持,实现完整的流式输出管道
五、总结
通过自定义 LLM 类,我们打破了框架的限制,实现了模型接入的完全自由度。从简单的字符截取示例,到复杂的第三方 API 集成,LangChain 的自定义接口为我们提供了强大的扩展能力。
在实际项目中,自定义 LLM 类特别适合以下场景:
- 企业自研模型的工程化落地
- 多模型融合的混合架构
- 特定领域的模型优化
- 低成本的模型推理方案
如果本文对你有帮助,别忘了点赞收藏,关注我,一起探索更高效的开发方式~