openfree commited on
Commit
89bf5d2
1 Parent(s): 54a62f5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -13
app.py CHANGED
@@ -34,13 +34,14 @@ prompt_values = df.values.flatten()
34
  with open('loras.json', 'r') as f:
35
  loras = json.load(f)
36
 
37
- # Initialize the base model
38
- dtype = torch.bfloat16
39
- device = "cuda" if torch.cuda.is_available() else "cpu"
40
 
41
  # 공통 FLUX 모델 로드
42
  base_model = "black-forest-labs/FLUX.1-dev"
43
- pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=dtype, device_map="auto")
 
 
 
44
 
45
  # LoRA를 위한 설정
46
  taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype).to(device)
@@ -55,12 +56,8 @@ pipe_i2i = AutoPipelineForImage2Image.from_pretrained(
55
  tokenizer=pipe.tokenizer,
56
  text_encoder_2=pipe.text_encoder_2,
57
  tokenizer_2=pipe.tokenizer_2,
58
- torch_dtype=dtype
59
- )
60
-
61
- # Upscale을 위한 ControlNet 설정
62
- controlnet = FluxControlNetModel.from_pretrained(
63
- "jasperai/Flux.1-dev-Controlnet-Upscaler", torch_dtype=torch.bfloat16
64
  ).to(device)
65
 
66
  # Upscale 파이프라인 설정 (기존 pipe 재사용)
@@ -72,9 +69,9 @@ pipe_upscale = FluxControlNetPipeline(
72
  scheduler=pipe.scheduler,
73
  safety_checker=pipe.safety_checker,
74
  feature_extractor=pipe.feature_extractor,
75
- controlnet=controlnet
76
- )
77
-
78
 
79
 
80
 
 
34
  with open('loras.json', 'r') as f:
35
  loras = json.load(f)
36
 
37
+
 
 
38
 
39
  # 공통 FLUX 모델 로드
40
  base_model = "black-forest-labs/FLUX.1-dev"
41
+ pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=dtype, device_map="balanced")
42
+
43
+ device = "cuda" if torch.cuda.is_available() else "cpu"
44
+ pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=dtype, device_map="balanced").to(device)
45
 
46
  # LoRA를 위한 설정
47
  taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype).to(device)
 
56
  tokenizer=pipe.tokenizer,
57
  text_encoder_2=pipe.text_encoder_2,
58
  tokenizer_2=pipe.tokenizer_2,
59
+ torch_dtype=dtype,
60
+ device_map="balanced"
 
 
 
 
61
  ).to(device)
62
 
63
  # Upscale 파이프라인 설정 (기존 pipe 재사용)
 
69
  scheduler=pipe.scheduler,
70
  safety_checker=pipe.safety_checker,
71
  feature_extractor=pipe.feature_extractor,
72
+ controlnet=controlnet,
73
+ device_map="balanced"
74
+ ).to(device)
75
 
76
 
77