File size: 1,006 Bytes
7e2a2ba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
from threading import Thread

from closedai import ClosedAIPipeline
from closedai.server import app, data  # noqa

from transformers import LlamaForCausalLM, LlamaTokenizer, TextIteratorStreamer, pipeline

class LlamaPipeline(ClosedAIPipeline):
    def __init__(self):
        tokenizer = LlamaTokenizer.from_pretrained("decapoda-research/llama-7b-hf")
        model = LlamaForCausalLM.from_pretrained("decapoda-research/llama-7b-hf")
        self.streamer = TextIteratorStreamer(tokenizer, skip_prompt=True)
        self.pipe = pipeline(
            'text-generation',
            model=model,
            tokenizer=tokenizer,
            streamer=self.streamer,
            device="auto"
        )

    def generate_completion(self, text, **generate_kwargs):
        thread = Thread(target=self.pipe.__call__, kwargs=dict(text_inputs=text, **generate_kwargs))
        thread.start()
        for new_text in self.streamer:
            yield new_text

pipeline = LlamaPipeline()
data["pipeline"] = pipeline