系列文章为基于langchain开发应用时的备忘笔记,供遇到相似问题时翻阅,内容可能不够详细,若阅读过程中遇到困难,欢迎评论&私信交流
使用langchain开发项目过程中避免不了需要对接生态还不支持的大模型,langchian为我们做了很好的抽象,只需要扩展器LLM类即可实现自定义llm类
LLM基类
需继承基类from langchain_core.language_models import LLM
必须实现的_call
方法和_identifying_params
属性方法
_call
方法
使用输入的prompt调用llm,有invoke
方法调用
_identifying_params
方法
标识LLM的参数,返回一个字典数据,用于缓存和记录,建议包含model_name
一种实现
一般包含llm调用必要的llm参数,包含model、max_token、temperature等模型调用参数和api_key等认证参数
class CustomLLM(LLM):model: str = "some_model"max_tokens = 1024temperature = 0.3api_key: Optional[SecretStr] = None"""The API key to use for authentication."""base_url: str = DEFAULT_SERVICE_URL_BASEdef _call(self,prompt: str,stop: Optional[List[str]] = None,run_manager: Optional[CallbackManagerForLLMRun] = None,**kwargs: Any,) -> str:request = self._invocation_paramsrequest["messages"] = [{"role": "user", "content": prompt}]request.update(kwargs)text = self.completion(request)if stop is not None:# This is required since the stop tokens# are not enforced by the model parameterstext = enforce_stop_tokens(text, stop)return text@propertydef _identifying_params(self) -> Mapping[str, Any]:"""Get the identifying parameters."""return {"model": "some_model"}
推荐的实现方法
- api_key等敏感字段,使用langchain的SecretStr类型,避免敏感字段泄露
from langchain_core.pydantic_v1 import SecretStr
- 实现validate_enviroment方法,校验调用参数
@root_validator()def validate_environment(cls, values: Dict) -> Dict:"""Validate that api key and python package exists in environment."""values["api_key"] = convert_to_secret_str(get_from_dict_or_env(values, "api_key", "API_KEY"))values["base_url"] = get_from_dict_or_env(values, "base_url", "BASE_URL", DEFAULT_SERVICE_URL_BASE)return values
@root_validator()注解的方法,会在成员变量初始化后执行,修改values字段值同修改成员变量的值