KevinQu7 commited on
Commit
40be8c2
1 Parent(s): 09c3706
Files changed (1) hide show
  1. app.py +13 -16
app.py CHANGED
@@ -1,4 +1,4 @@
1
- # Copyright 2024 Anton Obukhov, ETH Zurich. All rights reserved.
2
  #
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
  # you may not use this file except in compliance with the License.
@@ -57,8 +57,8 @@ loaded_pipelines = {} # Cache to store loaded pipelines
57
  def process_with_loaded_pipeline(image_path, denoise_steps, ensemble_size, processing_res, model_type):
58
 
59
  # Load and cache the pipeline based on the model type.
60
- if model_type not in loaded_pipelines:
61
- auth_token = os.environ.get("KEV_TOKEN")
62
  if model_type == "appearance":
63
  loaded_pipelines[model_type] = MarigoldIIDAppearancePipeline.from_pretrained(
64
  "prs-eth/marigold-iid-appearance-v1-1", token=auth_token
@@ -71,6 +71,12 @@ def process_with_loaded_pipeline(image_path, denoise_steps, ensemble_size, proce
71
  # Move the pipeline to GPU if available
72
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
73
  loaded_pipelines[model_type] = loaded_pipelines[model_type].to(device)
 
 
 
 
 
 
74
 
75
  pipe = loaded_pipelines[model_type]
76
 
@@ -511,6 +517,9 @@ def run_demo_server(hf_writer=None):
511
  None,
512
  None,
513
  None,
 
 
 
514
  default_image_ensemble_size,
515
  default_image_denoise_steps,
516
  default_image_processing_res,
@@ -522,6 +531,7 @@ def run_demo_server(hf_writer=None):
522
  image_output_slider2,
523
  image_output_slider3,
524
  image_output_files,
 
525
  image_ensemble_size,
526
  image_denoise_steps,
527
  image_processing_res,
@@ -602,25 +612,12 @@ def run_demo_server(hf_writer=None):
602
 
603
 
604
  def main():
605
- CHECKPOINT = "prs-eth/marigold-iid-appearance-v1-1"
606
  CROWD_DATA = "crowddata-marigold-iid-appearance-v1-1-space-v1-1"
607
 
608
  os.system("pip freeze")
609
 
610
  if "HF_TOKEN_LOGIN" in os.environ:
611
  login(token=os.environ["HF_TOKEN_LOGIN"])
612
-
613
- auth_token = os.environ.get("KEV_TOKEN")
614
- pipe = MarigoldIIDAppearancePipeline.from_pretrained(CHECKPOINT,token=auth_token)
615
- try:
616
- import xformers
617
-
618
- pipe.enable_xformers_memory_efficient_attention()
619
- except:
620
- pass # run without xformers
621
-
622
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
623
- pipe = pipe.to(device)
624
 
625
  hf_writer = None
626
  if "HF_TOKEN_LOGIN_WRITE_CROWD" in os.environ:
 
1
+ # Copyright 2024 Anton Obukhov and Kevin Qu, ETH Zurich. All rights reserved.
2
  #
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
  # you may not use this file except in compliance with the License.
 
57
  def process_with_loaded_pipeline(image_path, denoise_steps, ensemble_size, processing_res, model_type):
58
 
59
  # Load and cache the pipeline based on the model type.
60
+ if model_type not in loaded_pipelines.keys():
61
+ auth_token = os.environ.get("KEV_DEV")
62
  if model_type == "appearance":
63
  loaded_pipelines[model_type] = MarigoldIIDAppearancePipeline.from_pretrained(
64
  "prs-eth/marigold-iid-appearance-v1-1", token=auth_token
 
71
  # Move the pipeline to GPU if available
72
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
73
  loaded_pipelines[model_type] = loaded_pipelines[model_type].to(device)
74
+ try:
75
+ import xformers
76
+
77
+ loaded_pipelines[model_type].enable_xformers_memory_efficient_attention()
78
+ except:
79
+ pass # run without xformers
80
 
81
  pipe = loaded_pipelines[model_type]
82
 
 
517
  None,
518
  None,
519
  None,
520
+ None,
521
+ None,
522
+ default_model_type,
523
  default_image_ensemble_size,
524
  default_image_denoise_steps,
525
  default_image_processing_res,
 
531
  image_output_slider2,
532
  image_output_slider3,
533
  image_output_files,
534
+ model_type,
535
  image_ensemble_size,
536
  image_denoise_steps,
537
  image_processing_res,
 
612
 
613
 
614
  def main():
 
615
  CROWD_DATA = "crowddata-marigold-iid-appearance-v1-1-space-v1-1"
616
 
617
  os.system("pip freeze")
618
 
619
  if "HF_TOKEN_LOGIN" in os.environ:
620
  login(token=os.environ["HF_TOKEN_LOGIN"])
 
 
 
 
 
 
 
 
 
 
 
 
621
 
622
  hf_writer = None
623
  if "HF_TOKEN_LOGIN_WRITE_CROWD" in os.environ: