File size: 3,831 Bytes
d54b6e0 df9d248 d54b6e0 df9d248 a6861fa df9d248 3fc41b7 df9d248 d54b6e0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 |
import os
import json
import subprocess
import sys
import torch
from typing import List, Dict
# Ensure vllm is installed and specify version to match CUDA compatibility
try:
import vllm
except ImportError:
# Install vllm with CUDA 11.8 support
vllm_version = "v0.6.1.post1"
pip_cmd = [
sys.executable,
"-m", "pip", "install",
f"https://github.com/vllm-project/vllm/releases/download/{vllm_version}/vllm-{vllm_version}+cu118-cp310-cp310-manylinux1_x86_64.whl",
"--extra-index-url", "https://download.pytorch.org/whl/cu118"
]
subprocess.check_call(pip_cmd)
# Import the necessary modules after installation
from vllm import LLM, SamplingParams
from vllm.utils import random_uuid
# Function to format chat messages using Qwen's chat template
def format_chat(messages: List[Dict[str, str]]) -> str:
"""
Format chat messages using Qwen's chat template
"""
formatted_text = ""
for message in messages:
role = message["role"]
content = message["content"]
if role == "system":
formatted_text += f"<|im_start|>system\n{content}<|im_end|>\n"
elif role == "user":
formatted_text += f"<|im_start|>user\n{content}<|im_end|>\n"
elif role == "assistant":
formatted_text += f"<|im_start|>assistant\n{content}<|im_end|>\n"
# Add the final assistant prompt
formatted_text += "<|im_start|>assistant\n"
return formatted_text
# Model loading function for SageMaker
def model_fn(model_dir):
# Load the quantized model from the model directory
model = LLM(
model=model_dir,
trust_remote_code=True,
gpu_memory_utilization=0.9 # Optimal GPU usage
)
return model
# Custom predict function for SageMaker
def predict_fn(input_data, model):
try:
data = json.loads(input_data)
# Format the prompt using Qwen's chat template
messages = data.get("messages", [])
formatted_prompt = format_chat(messages)
# Build sampling parameters (without do_sample to match OpenAI API)
sampling_params = SamplingParams(
temperature=data.get("temperature", 0.7),
top_p=data.get("top_p", 0.9),
max_new_tokens=data.get("max_new_tokens", 512),
top_k=data.get("top_k", -1), # Support for top-k sampling
repetition_penalty=data.get("repetition_penalty", 1.0),
length_penalty=data.get("length_penalty", 1.0),
stop_token_ids=data.get("stop_token_ids", None),
skip_special_tokens=data.get("skip_special_tokens", True)
)
# Generate output
outputs = model.generate(formatted_prompt, sampling_params)
generated_text = outputs[0].outputs[0].text
# Build response
response = {
"id": f"chatcmpl-{random_uuid()}",
"object": "chat.completion",
"created": int(torch.cuda.current_timestamp()),
"model": "qwen-72b",
"choices": [{
"index": 0,
"message": {
"role": "assistant",
"content": generated_text
},
"finish_reason": "stop"
}],
"usage": {
"prompt_tokens": len(formatted_prompt),
"completion_tokens": len(generated_text),
"total_tokens": len(formatted_prompt) + len(generated_text)
}
}
return response
except Exception as e:
return {"error": str(e), "details": repr(e)}
# Define input and output formats for SageMaker
def input_fn(serialized_input_data, content_type):
return serialized_input_data
def output_fn(prediction_output, accept):
return json.dumps(prediction_output)
|