tejavardhan commited on
Commit
a6c15c9
·
verified ·
1 Parent(s): 060d79d

Upload main.py

Browse files
Files changed (1) hide show
  1. main.py +113 -0
main.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Dependencies
2
+ """pip install torch pillow requests diffusers imageio gradio==3.4 httpx==0.23.2 transformers accelerate"""
3
+
4
+ import gradio as gr
5
+ import imageio
6
+ import torch
7
+ import numpy as np
8
+ from PIL import Image
9
+
10
+ from diffusers import StableDiffusionInpaintPipeline
11
+
12
+ def perform_inpainting(prompt):
13
+ # save_images()
14
+
15
+
16
+ # Ensure CPU inference
17
+ img_path = "Original Image.png"
18
+ mask_path= "Mask Image.png"
19
+ device = "cuda"
20
+ model_name="runwayml/stable-diffusion-v1-5"
21
+ torch_dtype = torch.float16
22
+ # Create the inpainting pipeline
23
+ pipeline = create_inpaint_pipeline(model_name)
24
+ pipeline = pipeline.to(device) # Explicitly move model to CPU
25
+
26
+ # Load and pre-process images
27
+ try:
28
+ init_image = Image.open(img_path).convert("RGB").resize((512, 512))
29
+ mask_image = Image.open(mask_path).convert("RGB").resize((512, 512))
30
+ except FileNotFoundError:
31
+ print(f"Error: Image files '{img_path}' or '{mask_path}' not found.")
32
+ return None
33
+
34
+ print("Processing the image...")
35
+
36
+ # Perform inpainting
37
+ try:
38
+ image = pipeline(prompt=prompt, image=init_image, mask_image=mask_image).images[0]
39
+ image.save("Inpainted_img.png")
40
+ return image
41
+ except Exception as e:
42
+ print(f"Error during inpainting: {e}")
43
+ return None
44
+
45
+ def create_inpaint_pipeline(model_name):
46
+ pipeline = StableDiffusionInpaintPipeline.from_pretrained(
47
+ model_name,
48
+ torch_dtype=torch.float16,
49
+ )
50
+ return pipeline
51
+
52
+ # if __name__ == "__main__":
53
+ # generated_image = perform_inpainting()
54
+
55
+ # if generated_image is not None:
56
+ # generated_image.show()
57
+ # # Optionally save the generated image
58
+ # generated_image.save("inpainted_image.png")
59
+
60
+
61
+ def Mask(img):
62
+ """
63
+ Function to process the input image and generate a mask.
64
+
65
+ Args:
66
+ img (dict): Dictionary containing the base image and the mask image.
67
+
68
+ Returns:
69
+ tuple: A tuple containing the base image and the mask image.
70
+ """
71
+ try:
72
+ # Save the mask image to a file
73
+ imageio.imwrite("Original Image.png",img["image"])
74
+ imageio.imwrite("Mask Image.png", img["mask"])
75
+
76
+ return img["image"], img["mask"]
77
+ except KeyError as e:
78
+ # Handle case where expected keys are not in the input dictionary
79
+ return f"Key error: {e}", None
80
+ except Exception as e:
81
+ # Handle any other unexpected errors
82
+ return f"An error occurred: {e}", None
83
+
84
+
85
+ def main():
86
+ # Create the Gradio interface
87
+ with gr.Blocks() as demo:
88
+ with gr.Row():
89
+ img = gr.Image(tool="sketch", label="Paint Image", show_label=True)
90
+ img1 = gr.Image(label="Original Image")
91
+ img2 = gr.Image(label="Mask Image", show_label=True)
92
+
93
+ btn = gr.Button()
94
+ # Set the button click action
95
+ btn.click(Mask, inputs=img, outputs=[img1, img2])
96
+
97
+ # with gr.Blocks():
98
+ with gr.Row():
99
+ prompt = gr.Textbox(label="Enter the prompt")
100
+ button = gr.Button("Click")
101
+ output_image = gr.Image(label="Generated Image")
102
+
103
+
104
+
105
+
106
+ button.click(perform_inpainting, inputs=prompt,outputs=output_image)
107
+
108
+ # Launch the Gradio interface
109
+ demo.launch()
110
+
111
+
112
+ if __name__=='__main__':
113
+ main()