vignesh-trustt commited on
Commit
94eabde
1 Parent(s): 1662c58

Upload handler.py

Browse files
Files changed (1) hide show
  1. handler.py +70 -0
handler.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Any, List
2
+ import logging
3
+
4
+ from transformers import AutoModelForCausalLM, AutoTokenizer
5
+ from peft import PeftConfig, PeftModel
6
+ import torch.cuda
7
+
8
+ dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0]==8 else torch.float16
9
+
10
+ # LOGGER = logging.getLogger(__name__)
11
+ # logging.basicConfig(level=logging.INFO)
12
+ # device = "cuda" if torch.cuda.is_available() else "cpu"
13
+
14
+
15
+ class EndpointHandler():
16
+ def __init__(self, path=""):
17
+ tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True)
18
+ model = AutoModelForCausalLM.from_pretrained(
19
+ path,
20
+ return_dict=True,
21
+ device_map="auto",
22
+ load_in_8bit=True,
23
+ torch_dtype=dtype,
24
+ trust_remote_code=True,
25
+ )
26
+
27
+ generation_config = model.generation_config
28
+ generation_config.max_new_tokens=512
29
+ generation_config.temperation = 0
30
+ generation_config.num_return_sequences=1
31
+ generation_config.pad_token_id = tokenizer.eos_token_id
32
+ generation_config.eos_token_id = tokenizer.eos_token_id
33
+ self.generation_config = generation_config
34
+
35
+ self.pipeline = transformers.pipeline(
36
+ "text-generation",model=model,tokenizer=tokenizer
37
+ )
38
+
39
+
40
+ # config = PeftConfig.from_pretrained(path)
41
+ # model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path, load_in_8bit=True, device_map='auto')
42
+ # self.tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)
43
+ # # Load the Lora model
44
+ # self.model = PeftModel.from_pretrained(model, path)
45
+
46
+ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
47
+ # """
48
+ # Args:
49
+ # data (Dict): The payload with the text prompt and generation parameters.
50
+ # """
51
+ # LOGGER.info(f"Received data: {data}")
52
+ # Get inputs
53
+ prompt = data.pop("inputs", None)
54
+ # parameters = data.pop("parameters", None)
55
+ # if prompt is None:
56
+ # raise ValueError("Missing prompt.")
57
+ # # Preprocess
58
+ # input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(device)
59
+ # # Forward
60
+ # LOGGER.info(f"Start generation.")
61
+ # if parameters is not None:
62
+ # output = self.model.generate(input_ids=input_ids, **parameters)
63
+ # else:
64
+ # output = self.model.generate(input_ids=input_ids)
65
+ # # Postprocess
66
+ # prediction = self.tokenizer.decode(output[0])
67
+ # LOGGER.info(f"Generated text: {prediction}")
68
+ # return {"generated_text": prediction}
69
+ result = self.pipeline(prompt,generation_config=self.generation_config)
70
+ return result