Update inference.py
Browse files- inference.py +9 -2
inference.py
CHANGED
@@ -5,12 +5,19 @@ import sys
|
|
5 |
import torch
|
6 |
from typing import List, Dict
|
7 |
|
8 |
-
# Ensure vllm is installed
|
9 |
try:
|
10 |
import vllm
|
11 |
except ImportError:
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
|
|
|
14 |
# Import the necessary modules after installation
|
15 |
from vllm import LLM, SamplingParams
|
16 |
from vllm.utils import random_uuid
|
|
|
5 |
import torch
|
6 |
from typing import List, Dict
|
7 |
|
8 |
+
# Ensure vllm is installed and specify version to match CUDA compatibility
|
9 |
try:
|
10 |
import vllm
|
11 |
except ImportError:
|
12 |
+
# Check CUDA version and install the correct vllm version
|
13 |
+
cuda_version = torch.version.cuda
|
14 |
+
if cuda_version == "11.8":
|
15 |
+
vllm_version = "v0.6.1.post1"
|
16 |
+
pip_cmd = f"pip install https://github.com/vllm-project/vllm/releases/download/{vllm_version}/vllm-{vllm_version}+cu118-cp310-cp310-manylinux1_x86_64.whl --extra-index-url https://download.pytorch.org/whl/cu118"
|
17 |
+
else:
|
18 |
+
raise RuntimeError(f"Unsupported CUDA version: {cuda_version}")
|
19 |
|
20 |
+
subprocess.check_call([sys.executable, "-m", "pip", "install", pip_cmd])
|
21 |
# Import the necessary modules after installation
|
22 |
from vllm import LLM, SamplingParams
|
23 |
from vllm.utils import random_uuid
|