alfiannajih commited on
Commit
57d40a4
1 Parent(s): 434fca2

Update g_retriever_pipeline.py

Browse files
Files changed (1) hide show
  1. g_retriever_pipeline.py +3 -0
g_retriever_pipeline.py CHANGED
@@ -11,6 +11,7 @@ class GRetrieverPipeline(Pipeline):
11
  self.max_txt_len = self.model.config.max_txt_len
12
  self.bos_length = len(self.model.config.bos_id)
13
  self.input_length = 0
 
14
 
15
  def _sanitize_parameters(self, **kwargs):
16
  preprocess_kwargs = {}
@@ -27,6 +28,7 @@ class GRetrieverPipeline(Pipeline):
27
 
28
  def preprocess(self, inputs, textualized_graph, graph, generate_kwargs=None):
29
  textualized_graph_ids = self.tokenizer(textualized_graph, add_special_tokens=False)["input_ids"][:self.max_txt_len]
 
30
  question_ids = self.tokenizer(inputs, add_special_tokens=False)["input_ids"]
31
  eos_user_ids = self.tokenizer(self.eos_user, add_special_tokens=False)["input_ids"]
32
 
@@ -34,6 +36,7 @@ class GRetrieverPipeline(Pipeline):
34
  [-1]*(self.bos_length + 1)
35
  + textualized_graph_ids
36
  + question_ids
 
37
  + eos_user_ids
38
  ])
39
  model_inputs = {
 
11
  self.max_txt_len = self.model.config.max_txt_len
12
  self.bos_length = len(self.model.config.bos_id)
13
  self.input_length = 0
14
+ self.prompt = "Generate a detailed review of a resume in relation to the current job market, presented as a textual graph. The review should be divided into three sections: strengths, weaknesses, and improvements."
15
 
16
  def _sanitize_parameters(self, **kwargs):
17
  preprocess_kwargs = {}
 
28
 
29
  def preprocess(self, inputs, textualized_graph, graph, generate_kwargs=None):
30
  textualized_graph_ids = self.tokenizer(textualized_graph, add_special_tokens=False)["input_ids"][:self.max_txt_len]
31
+ prompt_ids = self.tokenizer(self.prompt, add_special_tokens=False)["input_ids"]
32
  question_ids = self.tokenizer(inputs, add_special_tokens=False)["input_ids"]
33
  eos_user_ids = self.tokenizer(self.eos_user, add_special_tokens=False)["input_ids"]
34
 
 
36
  [-1]*(self.bos_length + 1)
37
  + textualized_graph_ids
38
  + question_ids
39
+ + prompt_ids
40
  + eos_user_ids
41
  ])
42
  model_inputs = {