Vladislav Sokolovskii commited on
Commit
4c400a3
1 Parent(s): cb07a8a

Remove custom handler

Browse files
Files changed (1) hide show
  1. handler.py +0 -71
handler.py DELETED
@@ -1,71 +0,0 @@
1
- import os
2
- from typing import Dict, List, Any
3
- from unsloth import FastLanguageModel
4
- from unsloth.chat_templates import get_chat_template
5
- import torch
6
- from huggingface_hub import login
7
- import os
8
-
9
- class EndpointHandler:
10
- def __init__(self, path=""):
11
- # access_token = os.environ["HUGGINGFACE_TOKEN"]
12
- # login(token=access_token)
13
- # Load the model and tokenizer
14
- self.model, self.tokenizer = FastLanguageModel.from_pretrained(
15
- model_name = path, # Use the current directory path
16
- max_seq_length = 2048,
17
- dtype = None,
18
- load_in_4bit = True,
19
- )
20
- FastLanguageModel.for_inference(self.model)
21
-
22
- # Set up the chat template
23
- self.tokenizer = get_chat_template(
24
- self.tokenizer,
25
- chat_template="llama-3",
26
- mapping={"role": "from", "content": "value", "user": "human", "assistant": "gpt"}
27
- )
28
-
29
- def __call__(self, data: Dict[str, Any]) -> List[str]:
30
- inputs = data.pop("inputs", data)
31
- parameters = data.pop("parameters", {})
32
-
33
- # Extract parameters or use defaults
34
- max_tokens = parameters.get("max_new_tokens", 512)
35
- temperature = parameters.get("temperature", 0.2)
36
- top_p = parameters.get("top_p", 0.5)
37
- system_message = parameters.get("system_message", "")
38
-
39
- # Prepare messages
40
- messages = [{"from": "human", "value": system_message}]
41
- if isinstance(inputs, str):
42
- messages.append({"from": "human", "value": inputs})
43
- elif isinstance(inputs, list):
44
- for msg in inputs:
45
- role = "human" if msg["role"] == "user" else "gpt"
46
- messages.append({"from": role, "value": msg["content"]})
47
-
48
- # Tokenize input
49
- tokenized_input = self.tokenizer.apply_chat_template(
50
- messages,
51
- tokenize=True,
52
- add_generation_prompt=True,
53
- return_tensors="pt"
54
- ).to("cuda")
55
-
56
- # Generate output
57
- with torch.no_grad():
58
- output = self.model.generate(
59
- input_ids=tokenized_input,
60
- max_new_tokens=max_tokens,
61
- temperature=temperature,
62
- top_p=top_p,
63
- use_cache=True
64
- )
65
-
66
- # Decode and process the output
67
- full_response = self.tokenizer.decode(output[0], skip_special_tokens=True)
68
- response_lines = [line.strip() for line in full_response.split('\n') if line.strip()]
69
- last_response = response_lines[-1] if response_lines else ""
70
-
71
- return [last_response]