Aastha commited on
Commit
002cef1
β€’
1 Parent(s): aa73c8d

add configurable device support

Browse files
Files changed (1) hide show
  1. app.py +6 -5
app.py CHANGED
@@ -19,6 +19,7 @@ from super_gradients.training import models
19
 
20
  class Kosmos2:
21
  def __init__(self):
 
22
  self.colors = [
23
  (0, 255, 0),
24
  (0, 0, 255),
@@ -43,7 +44,7 @@ class Kosmos2:
43
  }
44
 
45
  self.ckpt = "ydshieh/kosmos-2-patch14-224"
46
- self.model = AutoModelForVision2Seq.from_pretrained(self.ckpt, trust_remote_code=True).to("cuda")
47
  self.processor = AutoProcessor.from_pretrained(self.ckpt, trust_remote_code=True)
48
 
49
  def is_overlapping(self, rect1, rect2):
@@ -191,11 +192,11 @@ class Kosmos2:
191
  inputs = self.processor(text=text_input, images=image_input, return_tensors="pt")
192
 
193
  generated_ids = self.model.generate(
194
- pixel_values=inputs["pixel_values"].to("cuda"),
195
- input_ids=inputs["input_ids"][:, :-1].to("cuda"),
196
- attention_mask=inputs["attention_mask"][:, :-1].to("cuda"),
197
  img_features=None,
198
- img_attn_mask=inputs["img_attn_mask"][:, :-1].to("cuda"),
199
  use_cache=True,
200
  max_new_tokens=128,
201
  )
 
19
 
20
  class Kosmos2:
21
  def __init__(self):
22
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
23
  self.colors = [
24
  (0, 255, 0),
25
  (0, 0, 255),
 
44
  }
45
 
46
  self.ckpt = "ydshieh/kosmos-2-patch14-224"
47
+ self.model = AutoModelForVision2Seq.from_pretrained(self.ckpt, trust_remote_code=True).to(self.device)
48
  self.processor = AutoProcessor.from_pretrained(self.ckpt, trust_remote_code=True)
49
 
50
  def is_overlapping(self, rect1, rect2):
 
192
  inputs = self.processor(text=text_input, images=image_input, return_tensors="pt")
193
 
194
  generated_ids = self.model.generate(
195
+ pixel_values=inputs["pixel_values"].to(self.device),
196
+ input_ids=inputs["input_ids"][:, :-1].to(self.device),
197
+ attention_mask=inputs["attention_mask"][:, :-1].to(self.device),
198
  img_features=None,
199
+ img_attn_mask=inputs["img_attn_mask"][:, :-1].to(self.device),
200
  use_cache=True,
201
  max_new_tokens=128,
202
  )