Ozgur98 commited on
Commit
16db4fc
1 Parent(s): 1cf0759

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +10 -18
handler.py CHANGED
@@ -9,30 +9,22 @@ LOGGER = logging.getLogger(__name__)
9
 
10
  class EndpointHandler():
11
  def __init__(self, path=""):
12
- self.model = AutoModelForCausalLM.from_pretrained("bigscience/bloom-3b", load_in_8bit=True, device_map='auto')
13
- self.tokenizer = AutoTokenizer.from_pretrained("bigscience/bloom-3b")
14
  # Load the Lora model
15
 
16
- def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
17
  """
18
  Args:
19
  data (Dict): The payload with the text prompt and generation parameters.
20
  """
21
- LOGGER.info(f"Received data: {data}")
22
- # Get inputs
23
- prompt = data.pop("inputs", None)
24
- parameters = data.pop("parameters", None)
25
- if prompt is None:
26
- raise ValueError("Missing prompt.")
27
- # Preprocess
28
- input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(device)
29
  # Forward
30
  LOGGER.info(f"Start generation.")
31
- if parameters is not None:
32
- output = self.model.generate(input_ids=input_ids, **parameters)
33
- else:
34
- output = self.model.generate(input_ids=input_ids)
35
  # Postprocess
36
- prediction = self.tokenizer.decode(output[0])
37
- LOGGER.info(f"Generated text: {prediction}")
38
- return {"generated_text": prediction}
 
9
 
10
  class EndpointHandler():
11
  def __init__(self, path=""):
12
+ self.model = AutoModelForCausalLM.from_pretrained("Ozgur98/pushed_model_mosaic_small", load_in_8bit=True, device_map='auto')
13
+ self.tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
14
  # Load the Lora model
15
 
16
+ def __call__(self, data):
17
  """
18
  Args:
19
  data (Dict): The payload with the text prompt and generation parameters.
20
  """
21
+ print("CALLED")
22
+ LOGGER.info(data)
 
 
 
 
 
 
23
  # Forward
24
  LOGGER.info(f"Start generation.")
25
+ tokenized_example = tokenizer(data, return_tensors='pt')
26
+ outputs = self.model.generate(tokenized_example['input_ids'].to('cuda:0'), max_new_tokens=100, do_sample=True, top_k=10, top_p = 0.95)
 
 
27
  # Postprocess
28
+ answer = tokenizer.batch_decode(outputs, skip_special_tokens=True)
29
+ prompt = answer[0].rstrip()
30
+ return prompt