|
from typing import Any, List, Mapping, Optional |
|
|
|
from langchain.llms.base import LLM |
|
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline |
|
|
|
model_name = "bigscience/bloom-560m" |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
model = AutoModelForCausalLM.from_pretrained(model_name, config='T5Config') |
|
|
|
pl = pipeline( |
|
model=model, |
|
tokenizer=tokenizer, |
|
task="text-generation", |
|
|
|
|
|
do_sample=True, |
|
top_p=0.95, |
|
top_k=50, |
|
temperature=0.7 |
|
) |
|
|
|
class CustomLLM(LLM): |
|
pipeline = pl |
|
|
|
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: |
|
prompt_length = len(prompt) |
|
response = self.pipeline(prompt, max_new_tokens=525)[0]["generated_text"] |
|
|
|
|
|
return response[prompt_length:] |
|
|
|
@property |
|
def _identifying_params(self) -> Mapping[str, Any]: |
|
return {"name_of_model": self.model_name} |
|
|
|
@property |
|
def _llm_type(self) -> str: |
|
return "custom" |