bmay commited on
Commit
9aa9e84
1 Parent(s): 093a08f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -3
app.py CHANGED
@@ -14,8 +14,8 @@ def load_description(fp):
14
 
15
 
16
  @spaces.GPU(duration=90)
17
- def run_theia(image, pred_iou_thresh, stability_score_thresh):
18
- theia_model = AutoModel.from_pretrained("theaiinstitute/theia-tiny-patch16-224-cddsv", trust_remote_code=True)
19
  theia_model = theia_model.to('cuda')
20
  target_model_names = [
21
  "google/vit-huge-patch14-224-in21k",
@@ -72,6 +72,7 @@ with gr.Blocks() as demo:
72
  input_image = gr.Image(label="Input Image", type="pil")
73
 
74
  with gr.Accordion("Advanced Settings", open=False):
 
75
  pred_iou_thresh = gr.Slider(0.05, 0.95, step=0.05, value=0.5, label="SAM Pred IoU Thresh")
76
  stability_score_thresh = gr.Slider(0.05, 0.95, step=0.05, value=0.7, label="SAM Stability Score Thresh")
77
 
@@ -84,7 +85,7 @@ with gr.Blocks() as demo:
84
 
85
  submit_button.click(
86
  run_theia,
87
- inputs=[input_image, pred_iou_thresh, stability_score_thresh],
88
  outputs=[dinov2_output, sam_output, depth_anything_output]
89
  )
90
 
 
14
 
15
 
16
  @spaces.GPU(duration=90)
17
+ def run_theia(model_size, image, pred_iou_thresh, stability_score_thresh):
18
+ theia_model = AutoModel.from_pretrained(f"theaiinstitute/theia-{model_size}-patch16-224-cddsv", trust_remote_code=True)
19
  theia_model = theia_model.to('cuda')
20
  target_model_names = [
21
  "google/vit-huge-patch14-224-in21k",
 
72
  input_image = gr.Image(label="Input Image", type="pil")
73
 
74
  with gr.Accordion("Advanced Settings", open=False):
75
+ model_size = gr.Radio(["tiny", "small", "base"], value="tiny", label="Theia Model Size")
76
  pred_iou_thresh = gr.Slider(0.05, 0.95, step=0.05, value=0.5, label="SAM Pred IoU Thresh")
77
  stability_score_thresh = gr.Slider(0.05, 0.95, step=0.05, value=0.7, label="SAM Stability Score Thresh")
78
 
 
85
 
86
  submit_button.click(
87
  run_theia,
88
+ inputs=[model_size, input_image, pred_iou_thresh, stability_score_thresh],
89
  outputs=[dinov2_output, sam_output, depth_anything_output]
90
  )
91