Adding cpu offloading (if necessary?)
Browse files- handler.py +1 -4
handler.py
CHANGED
@@ -19,14 +19,11 @@ class EndpointHandler():
|
|
19 |
model = Blip2ForConditionalGeneration(config)
|
20 |
device_map = infer_auto_device_map(model, no_split_module_classes=["T5Block"])
|
21 |
device_map['language_model.lm_head'] = device_map["language_model.encoder.embed_tokens"]
|
22 |
-
|
23 |
-
print(device_map)
|
24 |
-
exit()
|
25 |
|
26 |
self.model = Blip2ForConditionalGeneration.from_pretrained(
|
27 |
"Salesforce/blip2-flan-t5-xxl", device_map=device_map,
|
28 |
# torch_dtype=torch.float16
|
29 |
-
load_in_8bit=True,
|
30 |
)
|
31 |
|
32 |
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
|
|
19 |
model = Blip2ForConditionalGeneration(config)
|
20 |
device_map = infer_auto_device_map(model, no_split_module_classes=["T5Block"])
|
21 |
device_map['language_model.lm_head'] = device_map["language_model.encoder.embed_tokens"]
|
|
|
|
|
|
|
22 |
|
23 |
self.model = Blip2ForConditionalGeneration.from_pretrained(
|
24 |
"Salesforce/blip2-flan-t5-xxl", device_map=device_map,
|
25 |
# torch_dtype=torch.float16
|
26 |
+
load_in_8bit=True, load_in_8bit_fp32_cpu_offload=True
|
27 |
)
|
28 |
|
29 |
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
|