File size: 1,912 Bytes
74cb225
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import sys
import os
abs_path = os.getcwd()
sys.path.append(abs_path) # Adds higher directory to python modules path.

from transformers import AutoModelForCausalLM, AutoTokenizer
from outlines import models, generate

from pydantic import BaseModel

schema = """
{
    "title": "Modules",
    "type": "object",
    "properties": {
        "background": {"type": "boolean"},
        "command": {"type": "boolean"},
        "suggesstion": {"type": "boolean"},
        "goal": {"type": "boolean"},
        "examples": {"type": "boolean"},
        "constraints": {"type": "boolean"},
        "workflow": {"type": "boolean"},
        "output_format": {"type": "boolean"},
        "skills": {"type": "boolean"},
        "style": {"type": "boolean"},
        "initialization": {"type": "boolean"}
    },
    "required": ["background", "command", "suggesstion", "goal", "examples", "constraints", "workflow", "output_format", "skills", "style", "initialization"]
}
"""

class Generator:
    def __init__(self, model_path, device):
        self.llm = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code = True).to(device)
        self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code = True)
        self.llm = self.llm.eval()
        self.model = models.Transformers(self.llm, self.tokenizer)
        pass
    def generate_response(self, messages):
        g = generate.text(self.model)
        prompt = self.tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True
        )
        response = g(prompt)
        return response
    def json_response(self, messages):
        g = generate.json(self.model, schema)
        prompt = self.tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True
        )
        response = g(prompt)
        return response