CaioXapelaum commited on
Commit
56994dc
1 Parent(s): 3324b84

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -13
app.py CHANGED
@@ -1,21 +1,27 @@
1
  import streamlit as st
2
- from diffusers import DiffusionPipeline
3
- import torch
4
  from PIL import Image
5
- from io import BytesIO
 
6
 
7
- # Load the diffusion pipeline
8
- @st.cache_resource
9
- def load_pipeline():
10
- return DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0").to("cuda" if torch.cuda.is_available() else "cpu")
11
 
12
- pipeline = load_pipeline()
 
 
 
 
 
 
 
 
13
 
14
  def generate_image(prompt):
15
- # Generate image
16
- with torch.no_grad():
17
- result = pipeline(prompt).images[0]
18
- return result
19
 
20
  def main():
21
  st.title("Stable Diffusion Image Generator")
@@ -27,7 +33,8 @@ def main():
27
  if prompt:
28
  # Generate and display the image
29
  image = generate_image(prompt)
30
- st.image(image, caption="Generated Image")
 
31
  else:
32
  st.warning("Please enter a prompt.")
33
 
 
1
  import streamlit as st
2
+ import requests
 
3
  from PIL import Image
4
+ import io
5
+ import os
6
 
7
+ # Set up your Hugging Face API key
8
+ token = os.getenv("HF_TOKEN")
 
 
9
 
10
+ API_URL = "https://api-inference.huggingface.co/models/stabilityai/stable-diffusion-xl-base-1.0"
11
+ headers = {"Authorization": f"Bearer {token}"}
12
+
13
+ def query(payload):
14
+ response = requests.post(API_URL, headers=headers, json=payload)
15
+ if response.status_code != 200:
16
+ st.error(f"Error: {response.status_code} - {response.text}")
17
+ return None
18
+ return response.content
19
 
20
  def generate_image(prompt):
21
+ image_bytes = query({"inputs": prompt})
22
+ if image_bytes:
23
+ return Image.open(io.BytesIO(image_bytes))
24
+ return None
25
 
26
  def main():
27
  st.title("Stable Diffusion Image Generator")
 
33
  if prompt:
34
  # Generate and display the image
35
  image = generate_image(prompt)
36
+ if image:
37
+ st.image(image, caption="Generated Image")
38
  else:
39
  st.warning("Please enter a prompt.")
40