Mrahsanahmad commited on
Commit
89531b2
1 Parent(s): f0cc5ec

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -17
app.py CHANGED
@@ -1,15 +1,8 @@
1
  import gradio as gr
2
  import numpy as np
3
  import torch
4
- import jax
5
- import jax.numpy as jnp
6
- from flax.jax_utils import replicate
7
- from flax.training.common_utils import shard
8
  from PIL import Image
9
  from segment_anything import SamPredictor, sam_model_registry, SamAutomaticMaskGenerator
10
- from diffusers import (
11
- FlaxStableDiffusionControlNetPipeline,
12
- FlaxControlNetModel,
13
  )
14
  from transformers import pipeline
15
 
@@ -49,11 +42,6 @@ with gr.Blocks() as demo:
49
  """
50
  ## Work in Progress
51
  ### About
52
- We have trained a JAX ControlNet model for semantic segmentation on Wildlife Animal Images.
53
-
54
- For the training data creation we used the [Wildlife Animals Images](https://www.kaggle.com/datasets/anshulmehtakaggl/wildlife-animals-images) dataset.
55
- We created segmentation masks with the help of [Grounded SAM](https://github.com/IDEA-Research/Grounded-Segment-Anything) where we used the animals names
56
- as input prompts for detection and more accurate segmentation.
57
 
58
  ### How To Use
59
 
@@ -64,10 +52,6 @@ with gr.Blocks() as demo:
64
  mask_img = gr.Image(label="Mask", interactive=False)
65
  output_img = gr.Image(label="Output", interactive=False)
66
 
67
- with gr.Row():
68
- prompt_text = gr.Textbox(lines=1, label="Prompt")
69
- negative_prompt_text = gr.Textbox(lines=1, label="Negative Prompt")
70
-
71
  with gr.Row():
72
  submit = gr.Button("Submit")
73
  clear = gr.Button("Clear")
@@ -84,7 +68,8 @@ with gr.Blocks() as demo:
84
  pil_img = Image.fromarray(np_img, 'RGB')
85
  mask_images.append(pil_img)
86
 
87
- return np.stack(mask_images)
 
88
 
89
  # def infer(
90
  # image, prompts, negative_prompts, num_inference_steps=50, seed=4, num_samples=4
 
1
  import gradio as gr
2
  import numpy as np
3
  import torch
 
 
 
 
4
  from PIL import Image
5
  from segment_anything import SamPredictor, sam_model_registry, SamAutomaticMaskGenerator
 
 
 
6
  )
7
  from transformers import pipeline
8
 
 
42
  """
43
  ## Work in Progress
44
  ### About
 
 
 
 
 
45
 
46
  ### How To Use
47
 
 
52
  mask_img = gr.Image(label="Mask", interactive=False)
53
  output_img = gr.Image(label="Output", interactive=False)
54
 
 
 
 
 
55
  with gr.Row():
56
  submit = gr.Button("Submit")
57
  clear = gr.Button("Clear")
 
68
  pil_img = Image.fromarray(np_img, 'RGB')
69
  mask_images.append(pil_img)
70
 
71
+ #return np.stack(mask_images)
72
+ return image
73
 
74
  # def infer(
75
  # image, prompts, negative_prompts, num_inference_steps=50, seed=4, num_samples=4