123LETSPLAY commited on
Commit
461ffd0
·
verified ·
1 Parent(s): 1983f1b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -8
app.py CHANGED
@@ -1,15 +1,31 @@
1
  import gradio as gr
 
2
  from transformers import DalleBartTokenizer, DalleBartForConditionalGeneration
 
 
3
 
4
  # Load the tokenizer and model
5
  tokenizer = DalleBartTokenizer.from_pretrained("dalle-mini/dalle-mini")
6
  model = DalleBartForConditionalGeneration.from_pretrained("dalle-mini/dalle-mini")
7
 
 
 
 
 
 
8
  def generate_image(prompt):
9
- inputs = tokenizer(prompt, return_tensors="pt")
10
- outputs = model.generate(**inputs)
11
- image = outputs[0] # This will depend on your model's output format
12
- return image
 
 
 
 
 
 
 
 
13
 
14
  # Define Gradio interface
15
  iface = gr.Interface(fn=generate_image,
@@ -20,7 +36,3 @@ iface = gr.Interface(fn=generate_image,
20
 
21
  if __name__ == "__main__":
22
  iface.launch()
23
-
24
-
25
-
26
-
 
1
  import gradio as gr
2
+ import torch
3
  from transformers import DalleBartTokenizer, DalleBartForConditionalGeneration
4
+ from PIL import Image
5
+ import numpy as np
6
 
7
  # Load the tokenizer and model
8
  tokenizer = DalleBartTokenizer.from_pretrained("dalle-mini/dalle-mini")
9
  model = DalleBartForConditionalGeneration.from_pretrained("dalle-mini/dalle-mini")
10
 
11
+ # Ensure the model is in evaluation mode and use GPU if available
12
+ model.eval()
13
+ device = "cuda" if torch.cuda.is_available() else "cpu"
14
+ model.to(device)
15
+
16
  def generate_image(prompt):
17
+ try:
18
+ inputs = tokenizer(prompt, return_tensors="pt").to(device)
19
+ outputs = model.generate(**inputs)
20
+
21
+ # Convert the output tensor to an image
22
+ # Adjust this line according to the model output format
23
+ image = outputs[0] # Assuming this is an image tensor
24
+ image = Image.fromarray(np.array(image)) # Convert tensor to numpy array then to image
25
+
26
+ return image
27
+ except Exception as e:
28
+ return str(e)
29
 
30
  # Define Gradio interface
31
  iface = gr.Interface(fn=generate_image,
 
36
 
37
  if __name__ == "__main__":
38
  iface.launch()