FINGU-AI commited on
Commit
22da08d
1 Parent(s): 75dfc1d

Upload inference.py

Browse files
Files changed (1) hide show
  1. inference.py +94 -0
inference.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import torch
4
+ from vllm import LLM, SamplingParams
5
+ from vllm.utils import random_uuid
6
+ from typing import List, Dict
7
+
8
+ # Function to format chat messages using Qwen's chat template
9
+ def format_chat(messages: List[Dict[str, str]]) -> str:
10
+ """
11
+ Format chat messages using Qwen's chat template
12
+ """
13
+ formatted_text = ""
14
+ for message in messages:
15
+ role = message["role"]
16
+ content = message["content"]
17
+
18
+ if role == "system":
19
+ formatted_text += f"<|im_start|>system\n{content}<|im_end|>\n"
20
+ elif role == "user":
21
+ formatted_text += f"<|im_start|>user\n{content}<|im_end|>\n"
22
+ elif role == "assistant":
23
+ formatted_text += f"<|im_start|>assistant\n{content}<|im_end|>\n"
24
+
25
+ # Add the final assistant prompt
26
+ formatted_text += "<|im_start|>assistant\n"
27
+
28
+ return formatted_text
29
+
30
+ # Model loading function for SageMaker
31
+ def model_fn(model_dir):
32
+ # Load the quantized model from the model directory
33
+ model = LLM(
34
+ model=model_dir,
35
+ trust_remote_code=True,
36
+ gpu_memory_utilization=0.9 # Optimal GPU usage
37
+ )
38
+ return model
39
+
40
+ # Custom predict function for SageMaker
41
+ def predict_fn(input_data, model):
42
+ try:
43
+ data = json.loads(input_data)
44
+
45
+ # Format the prompt using Qwen's chat template
46
+ messages = data.get("messages", [])
47
+ formatted_prompt = format_chat(messages)
48
+
49
+ # Build sampling parameters (without do_sample to match OpenAI API)
50
+ sampling_params = SamplingParams(
51
+ temperature=data.get("temperature", 0.7),
52
+ top_p=data.get("top_p", 0.9),
53
+ max_new_tokens=data.get("max_new_tokens", 512),
54
+ top_k=data.get("top_k", -1), # Support for top-k sampling
55
+ repetition_penalty=data.get("repetition_penalty", 1.0),
56
+ length_penalty=data.get("length_penalty", 1.0),
57
+ stop_token_ids=data.get("stop_token_ids", None),
58
+ skip_special_tokens=data.get("skip_special_tokens", True)
59
+ )
60
+
61
+ # Generate output
62
+ outputs = model.generate(formatted_prompt, sampling_params)
63
+ generated_text = outputs[0].outputs[0].text
64
+
65
+ # Build response
66
+ response = {
67
+ "id": f"chatcmpl-{random_uuid()}",
68
+ "object": "chat.completion",
69
+ "created": int(torch.cuda.current_timestamp()),
70
+ "model": "qwen-72b",
71
+ "choices": [{
72
+ "index": 0,
73
+ "message": {
74
+ "role": "assistant",
75
+ "content": generated_text
76
+ },
77
+ "finish_reason": "stop"
78
+ }],
79
+ "usage": {
80
+ "prompt_tokens": len(formatted_prompt),
81
+ "completion_tokens": len(generated_text),
82
+ "total_tokens": len(formatted_prompt) + len(generated_text)
83
+ }
84
+ }
85
+ return response
86
+ except Exception as e:
87
+ return {"error": str(e), "details": repr(e)}
88
+
89
+ # Define input and output formats for SageMaker
90
+ def input_fn(serialized_input_data, content_type):
91
+ return serialized_input_data
92
+
93
+ def output_fn(prediction_output, accept):
94
+ return json.dumps(prediction_output)