from threading import Thread from closedai import ClosedAIPipeline from closedai.server import app, data # noqa from transformers import LlamaForCausalLM, LlamaTokenizer, TextIteratorStreamer, pipeline class LlamaPipeline(ClosedAIPipeline): def __init__(self): tokenizer = LlamaTokenizer.from_pretrained("decapoda-research/llama-7b-hf") model = LlamaForCausalLM.from_pretrained("decapoda-research/llama-7b-hf") self.streamer = TextIteratorStreamer(tokenizer, skip_prompt=True) self.pipe = pipeline( 'text-generation', model=model, tokenizer=tokenizer, streamer=self.streamer, device="auto" ) def generate_completion(self, text, **generate_kwargs): thread = Thread(target=self.pipe.__call__, kwargs=dict(text_inputs=text, **generate_kwargs)) thread.start() for new_text in self.streamer: yield new_text pipeline = LlamaPipeline() data["pipeline"] = pipeline