vignesh-trustt commited on
Commit
71fe2f3
1 Parent(s): 94eabde

Delete handler.py

Browse files
Files changed (1) hide show
  1. handler.py +0 -70
handler.py DELETED
@@ -1,70 +0,0 @@
1
- from typing import Dict, Any, List
2
- import logging
3
-
4
- from transformers import AutoModelForCausalLM, AutoTokenizer
5
- from peft import PeftConfig, PeftModel
6
- import torch.cuda
7
-
8
- dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0]==8 else torch.float16
9
-
10
- # LOGGER = logging.getLogger(__name__)
11
- # logging.basicConfig(level=logging.INFO)
12
- # device = "cuda" if torch.cuda.is_available() else "cpu"
13
-
14
-
15
- class EndpointHandler():
16
- def __init__(self, path=""):
17
- tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True)
18
- model = AutoModelForCausalLM.from_pretrained(
19
- path,
20
- return_dict=True,
21
- device_map="auto",
22
- load_in_8bit=True,
23
- torch_dtype=dtype,
24
- trust_remote_code=True,
25
- )
26
-
27
- generation_config = model.generation_config
28
- generation_config.max_new_tokens=512
29
- generation_config.temperation = 0
30
- generation_config.num_return_sequences=1
31
- generation_config.pad_token_id = tokenizer.eos_token_id
32
- generation_config.eos_token_id = tokenizer.eos_token_id
33
- self.generation_config = generation_config
34
-
35
- self.pipeline = transformers.pipeline(
36
- "text-generation",model=model,tokenizer=tokenizer
37
- )
38
-
39
-
40
- # config = PeftConfig.from_pretrained(path)
41
- # model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path, load_in_8bit=True, device_map='auto')
42
- # self.tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)
43
- # # Load the Lora model
44
- # self.model = PeftModel.from_pretrained(model, path)
45
-
46
- def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
47
- # """
48
- # Args:
49
- # data (Dict): The payload with the text prompt and generation parameters.
50
- # """
51
- # LOGGER.info(f"Received data: {data}")
52
- # Get inputs
53
- prompt = data.pop("inputs", None)
54
- # parameters = data.pop("parameters", None)
55
- # if prompt is None:
56
- # raise ValueError("Missing prompt.")
57
- # # Preprocess
58
- # input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(device)
59
- # # Forward
60
- # LOGGER.info(f"Start generation.")
61
- # if parameters is not None:
62
- # output = self.model.generate(input_ids=input_ids, **parameters)
63
- # else:
64
- # output = self.model.generate(input_ids=input_ids)
65
- # # Postprocess
66
- # prediction = self.tokenizer.decode(output[0])
67
- # LOGGER.info(f"Generated text: {prediction}")
68
- # return {"generated_text": prediction}
69
- result = self.pipeline(prompt,generation_config=self.generation_config)
70
- return result