quarterturn commited on
Commit
7cbe3e4
·
verified ·
1 Parent(s): 25739ea

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -4
app.py CHANGED
@@ -27,7 +27,7 @@ def unzip_images(zip_file):
27
 
28
  return image_paths, image_data, session_dir
29
 
30
- @spaces.GPU(duration=120) # Keep increased timeout
31
  def generate_caption(image_path, prompt):
32
  try:
33
  # Load processor and model in FP16
@@ -35,7 +35,7 @@ def generate_caption(image_path, prompt):
35
  model = AutoModelForCausalLM.from_pretrained(
36
  model_id,
37
  trust_remote_code=True,
38
- torch_dtype=torch.float16, # Cast model to FP16
39
  device_map='auto'
40
  )
41
 
@@ -48,8 +48,11 @@ def generate_caption(image_path, prompt):
48
  text=prompt,
49
  )
50
 
51
- # Move and cast inputs to FP16 on GPU
52
- inputs = {k: v.to('cuda', dtype=torch.float16).unsqueeze(0) for k, v in inputs.items()}
 
 
 
53
 
54
  with torch.autocast(device_type="cuda", dtype=torch.float16, enabled=True):
55
  output = model.generate_from_batch(
 
27
 
28
  return image_paths, image_data, session_dir
29
 
30
+ @spaces.GPU(duration=120)
31
  def generate_caption(image_path, prompt):
32
  try:
33
  # Load processor and model in FP16
 
35
  model = AutoModelForCausalLM.from_pretrained(
36
  model_id,
37
  trust_remote_code=True,
38
+ torch_dtype=torch.float16,
39
  device_map='auto'
40
  )
41
 
 
48
  text=prompt,
49
  )
50
 
51
+ # Move inputs to GPU, keeping input_ids as torch.long, others as FP16
52
+ inputs = {
53
+ k: v.to('cuda', dtype=torch.float16 if k != 'input_ids' else torch.long).unsqueeze(0)
54
+ for k, v in inputs.items()
55
+ }
56
 
57
  with torch.autocast(device_type="cuda", dtype=torch.float16, enabled=True):
58
  output = model.generate_from_batch(