Updated bitsandbytes config
Browse files- handler.py +3 -3
handler.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
from typing import Dict, Any
|
2 |
|
3 |
import torch
|
4 |
-
from transformers import Blip2Processor, Blip2Config, Blip2ForConditionalGeneration
|
5 |
from accelerate import init_empty_weights, infer_auto_device_map
|
6 |
|
7 |
from PIL import Image
|
@@ -19,11 +19,11 @@ class EndpointHandler():
|
|
19 |
model = Blip2ForConditionalGeneration(config)
|
20 |
device_map = infer_auto_device_map(model, no_split_module_classes=["T5Block"])
|
21 |
device_map['language_model.lm_head'] = device_map["language_model.encoder.embed_tokens"]
|
22 |
-
|
23 |
self.model = Blip2ForConditionalGeneration.from_pretrained(
|
24 |
"Salesforce/blip2-flan-t5-xxl", device_map=device_map,
|
25 |
torch_dtype=torch.float16,
|
26 |
-
load_in_8bit=True,
|
27 |
)
|
28 |
|
29 |
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
|
|
1 |
from typing import Dict, Any
|
2 |
|
3 |
import torch
|
4 |
+
from transformers import Blip2Processor, Blip2Config, Blip2ForConditionalGeneration, BitsAndBytesConfig
|
5 |
from accelerate import init_empty_weights, infer_auto_device_map
|
6 |
|
7 |
from PIL import Image
|
|
|
19 |
model = Blip2ForConditionalGeneration(config)
|
20 |
device_map = infer_auto_device_map(model, no_split_module_classes=["T5Block"])
|
21 |
device_map['language_model.lm_head'] = device_map["language_model.encoder.embed_tokens"]
|
22 |
+
|
23 |
self.model = Blip2ForConditionalGeneration.from_pretrained(
|
24 |
"Salesforce/blip2-flan-t5-xxl", device_map=device_map,
|
25 |
torch_dtype=torch.float16,
|
26 |
+
quantization_config=BitsAndBytesConfig(load_in_8bit=True, llm_int8_enable_fp32_cpu_offload=True)
|
27 |
)
|
28 |
|
29 |
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
|