File size: 2,533 Bytes
a3afef0
 
06a9ba3
a3afef0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
06a9ba3
 
 
a3afef0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
from typing import Any, Dict, List

import torch, re

import transformers

from transformers import AutoModelForCausalLM, AutoTokenizer

dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] == 8 else torch.float16

class EndpointHandler:
    def __init__(self, path=""):
        tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code = True)
        model = AutoModelForCausalLM.from_pretrained(
            path,
            return_dict = True,
            device_map = "auto",
            load_in_8bit = True,
            torch_dtype = dtype,
            trust_remote_code = True,
        )
        
        gen_config = model.generation_config
        gen_config.max_new_tokens = 100
        gen_config.temperature = 0
        gen_config.num_return_sequences = 1
        gen_config.pad_token_id = tokenizer.eos_token_id
        gen_config.eos_token_id = tokenizer.eos_token_id
        
        self.generation_config = gen_config
        
        self.pipeline = transformers.pipeline(
            'text-generation', model=model, tokenizer=tokenizer
        )
       
     
      
    def __call__(self, data: Dict[dict, Any]) -> Dict[str, Any]:
        inputs = data.pop("inputs", data)
        
        question = data.pop("question", None)
        
        context = data.pop("context", None)
        
        temp = data.pop("temp", None)
        
        max_tokens = data.pop("max_tokens", None)
        
        bos_token = "<s>"

        original_system_message = "Below is an instruction that describes a task. Write a response that appropriately completes the request."

        system_message = "Use the provided context followed by a question to answer it."

        full_prompt = f"""<s>### Instruction:
        {system_message}

        ### Context:
        {context}


        ### Question:

        {question}


        ### Answer: 
        """

        full_prompt = " ".join(full_prompt.split())
        
        self.generation_config.max_new_tokens = max_tokens
        self.generation_config.temperature = temp
        
        result = self.pipeline(full_prompt, generation_config = self.generation_config)[0]['generated_text']
               
        match = re.search(r'### Answer:(.*?)###', result, re.DOTALL)
        
        if match:
            result =  match.group(1).strip()
            
        pattern = r"### Answer:(.*)"

        match = re.search(pattern, result)
        
        if match:
            result = match.group(1).strip()      
        
        return result