File size: 2,592 Bytes
a1cb05b
 
 
 
 
 
 
 
 
 
 
 
 
57d40a4
a1cb05b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57d40a4
a1cb05b
 
 
 
 
 
 
57d40a4
a1cb05b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
434fca2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
from transformers import Pipeline, AutoTokenizer
from torch_geometric.data import Batch
import torch

class GRetrieverPipeline(Pipeline):
    def __init__(self, **kwargs):
        Pipeline.__init__(self, **kwargs)
        
        self.tokenizer = AutoTokenizer.from_pretrained(self.model.config._name_or_path)
        self.eos_user = "<|eot_id|><|start_header_id|>assistant<|end_header_id|>"
        self.max_txt_len = self.model.config.max_txt_len
        self.bos_length = len(self.model.config.bos_id)
        self.input_length = 0
        self.prompt = "Generate a detailed review of a resume in relation to the current job market, presented as a textual graph. The review should be divided into three sections: strengths, weaknesses, and improvements."
        
    def _sanitize_parameters(self, **kwargs):
        preprocess_kwargs = {}
        if "textualized_graph" in kwargs:
            preprocess_kwargs["textualized_graph"] = kwargs["textualized_graph"]

        if "graph" in kwargs:
            preprocess_kwargs["graph"] = kwargs["graph"]

        if "generate_kwargs" in kwargs:
            preprocess_kwargs["generate_kwargs"] = kwargs["generate_kwargs"]
            
        return preprocess_kwargs, {}, {}

    def preprocess(self, inputs, textualized_graph, graph, generate_kwargs=None):
        textualized_graph_ids = self.tokenizer(textualized_graph, add_special_tokens=False)["input_ids"][:self.max_txt_len]
        prompt_ids = self.tokenizer(self.prompt, add_special_tokens=False)["input_ids"]
        question_ids = self.tokenizer(inputs, add_special_tokens=False)["input_ids"]
        eos_user_ids = self.tokenizer(self.eos_user, add_special_tokens=False)["input_ids"]
        
        input_ids = torch.tensor([
            [-1]*(self.bos_length + 1)
            + textualized_graph_ids
            + question_ids
            + prompt_ids
            + eos_user_ids
        ])
        model_inputs = {
            "input_ids": input_ids,
            "attention_mask": torch.ones_like(input_ids)
        }
        model_inputs.update({
            "graph": Batch.from_data_list([graph])
        })

        if generate_kwargs != None:
            model_inputs.update(generate_kwargs)
        
        self.input_length = input_ids.shape[1]
        
        return model_inputs

    def _forward(self, model_inputs):
        model_outputs = self.model.generate(**model_inputs)
            
        return model_outputs

    def postprocess(self, model_outputs):
        return self.tokenizer.decode(model_outputs[0, self.input_length:], skip_special_tokens=True)