Nick088 commited on
Commit
ea450aa
1 Parent(s): dae60c0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -7
app.py CHANGED
@@ -8,6 +8,14 @@ import spaces
8
 
9
  HF_TOKEN = os.getenv("HF_TOKEN")
10
 
 
 
 
 
 
 
 
 
11
  model_path = snapshot_download(
12
  repo_id="stabilityai/stable-diffusion-3-medium",
13
  revision="refs/pr/26",
@@ -17,24 +25,19 @@ model_path = snapshot_download(
17
  token=HF_TOKEN,
18
  )
19
 
20
- if torch.cuda.is_available():
21
- device = "cuda"
22
- print("Using GPU")
23
- else:
24
- device = "cpu"
25
- print("Using CPU")
26
 
27
  # Initialize the pipeline and download the model
28
  pipe = StableDiffusion3Pipeline.from_pretrained(model_path, torch_dtype=torch.float16)
29
  pipe.to(device)
30
 
 
31
  tokenizer = T5Tokenizer.from_pretrained("roborovski/superprompt-v1")
32
  model = T5ForConditionalGeneration.from_pretrained("roborovski/superprompt-v1", device_map="auto", torch_dtype="auto")
33
  model.to(device)
34
 
35
  # Define the image generation function
36
  @spaces.GPU(duration=60)
37
- def generate_image(prompt, negative_prompt, num_inference_steps, height, width, guidance_scale, seed, num_images_per_prompt):
38
  if seed == 0:
39
  seed = random.randint(1, 2**32-1)
40
 
 
8
 
9
  HF_TOKEN = os.getenv("HF_TOKEN")
10
 
11
+ if torch.cuda.is_available():
12
+ device = "cuda"
13
+ print("Using GPU")
14
+ else:
15
+ device = "cpu"
16
+ print("Using CPU")
17
+
18
+ # download sd3 medium weights
19
  model_path = snapshot_download(
20
  repo_id="stabilityai/stable-diffusion-3-medium",
21
  revision="refs/pr/26",
 
25
  token=HF_TOKEN,
26
  )
27
 
 
 
 
 
 
 
28
 
29
  # Initialize the pipeline and download the model
30
  pipe = StableDiffusion3Pipeline.from_pretrained(model_path, torch_dtype=torch.float16)
31
  pipe.to(device)
32
 
33
+ # superprompt-v1
34
  tokenizer = T5Tokenizer.from_pretrained("roborovski/superprompt-v1")
35
  model = T5ForConditionalGeneration.from_pretrained("roborovski/superprompt-v1", device_map="auto", torch_dtype="auto")
36
  model.to(device)
37
 
38
  # Define the image generation function
39
  @spaces.GPU(duration=60)
40
+ def generate_image(prompt, enhance_prompt, negative_prompt, num_inference_steps, height, width, guidance_scale, seed, num_images_per_prompt):
41
  if seed == 0:
42
  seed = random.randint(1, 2**32-1)
43