Spaces:
Runtime error
Runtime error
Aastha
commited on
Commit
β’
002cef1
1
Parent(s):
aa73c8d
add configurable device support
Browse files
app.py
CHANGED
@@ -19,6 +19,7 @@ from super_gradients.training import models
|
|
19 |
|
20 |
class Kosmos2:
|
21 |
def __init__(self):
|
|
|
22 |
self.colors = [
|
23 |
(0, 255, 0),
|
24 |
(0, 0, 255),
|
@@ -43,7 +44,7 @@ class Kosmos2:
|
|
43 |
}
|
44 |
|
45 |
self.ckpt = "ydshieh/kosmos-2-patch14-224"
|
46 |
-
self.model = AutoModelForVision2Seq.from_pretrained(self.ckpt, trust_remote_code=True).to(
|
47 |
self.processor = AutoProcessor.from_pretrained(self.ckpt, trust_remote_code=True)
|
48 |
|
49 |
def is_overlapping(self, rect1, rect2):
|
@@ -191,11 +192,11 @@ class Kosmos2:
|
|
191 |
inputs = self.processor(text=text_input, images=image_input, return_tensors="pt")
|
192 |
|
193 |
generated_ids = self.model.generate(
|
194 |
-
pixel_values=inputs["pixel_values"].to(
|
195 |
-
input_ids=inputs["input_ids"][:, :-1].to(
|
196 |
-
attention_mask=inputs["attention_mask"][:, :-1].to(
|
197 |
img_features=None,
|
198 |
-
img_attn_mask=inputs["img_attn_mask"][:, :-1].to(
|
199 |
use_cache=True,
|
200 |
max_new_tokens=128,
|
201 |
)
|
|
|
19 |
|
20 |
class Kosmos2:
|
21 |
def __init__(self):
|
22 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
23 |
self.colors = [
|
24 |
(0, 255, 0),
|
25 |
(0, 0, 255),
|
|
|
44 |
}
|
45 |
|
46 |
self.ckpt = "ydshieh/kosmos-2-patch14-224"
|
47 |
+
self.model = AutoModelForVision2Seq.from_pretrained(self.ckpt, trust_remote_code=True).to(self.device)
|
48 |
self.processor = AutoProcessor.from_pretrained(self.ckpt, trust_remote_code=True)
|
49 |
|
50 |
def is_overlapping(self, rect1, rect2):
|
|
|
192 |
inputs = self.processor(text=text_input, images=image_input, return_tensors="pt")
|
193 |
|
194 |
generated_ids = self.model.generate(
|
195 |
+
pixel_values=inputs["pixel_values"].to(self.device),
|
196 |
+
input_ids=inputs["input_ids"][:, :-1].to(self.device),
|
197 |
+
attention_mask=inputs["attention_mask"][:, :-1].to(self.device),
|
198 |
img_features=None,
|
199 |
+
img_attn_mask=inputs["img_attn_mask"][:, :-1].to(self.device),
|
200 |
use_cache=True,
|
201 |
max_new_tokens=128,
|
202 |
)
|