Use model config directly (#1)
Browse files- Use model config directly (7eb2d5255f6968846ee0dae38472cb2405841bcb)
Co-authored-by: Joshua <Xenova@users.noreply.huggingface.co>
README.md
CHANGED
@@ -28,24 +28,22 @@ import requests
|
|
28 |
import onnxruntime as ort
|
29 |
from PIL import Image
|
30 |
from io import BytesIO
|
31 |
-
from transformers import
|
32 |
|
33 |
# Command line arguments
|
34 |
model_path = sys.argv[1]
|
35 |
onnx_path = sys.argv[2]
|
36 |
|
37 |
-
# Initialize model and tokenizer
|
38 |
-
|
39 |
-
model_path, torch_dtype=torch.float32, device_map='mps'
|
40 |
-
)
|
41 |
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
42 |
|
43 |
# Model configuration
|
44 |
max_length = 1024
|
45 |
-
num_attention_heads =
|
46 |
-
num_key_value_heads =
|
47 |
-
head_dim =
|
48 |
-
num_layers =
|
49 |
|
50 |
# Setup ONNX sessions
|
51 |
session_options = ort.SessionOptions()
|
|
|
28 |
import onnxruntime as ort
|
29 |
from PIL import Image
|
30 |
from io import BytesIO
|
31 |
+
from transformers import Qwen2VLConfig, AutoTokenizer
|
32 |
|
33 |
# Command line arguments
|
34 |
model_path = sys.argv[1]
|
35 |
onnx_path = sys.argv[2]
|
36 |
|
37 |
+
# Initialize model config and tokenizer
|
38 |
+
model_config = Qwen2VLConfig.from_pretrained(model_path)
|
|
|
|
|
39 |
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
40 |
|
41 |
# Model configuration
|
42 |
max_length = 1024
|
43 |
+
num_attention_heads = model_config.num_attention_heads
|
44 |
+
num_key_value_heads = model_config.num_key_value_heads
|
45 |
+
head_dim = model_config.hidden_size // num_attention_heads
|
46 |
+
num_layers = model_config.num_hidden_layers
|
47 |
|
48 |
# Setup ONNX sessions
|
49 |
session_options = ort.SessionOptions()
|