Kedar Dabhadkar commited on
Commit
9d30a24
1 Parent(s): a86a460

Adapt for Fast Dash

Browse files
Dockerfile CHANGED
@@ -56,8 +56,7 @@ RUN conda install -c conda-forge gradio -y
56
 
57
  WORKDIR /home/user
58
 
59
- RUN --mount=type=secret,id=git_token,mode=0444,required=true \
60
- git clone --branch mmseg-only https://$(cat /run/secrets/git_token)@github.com/NASA-IMPACT/hls-foundation-os.git
61
 
62
 
63
  WORKDIR hls-foundation-os
@@ -66,7 +65,7 @@ RUN pip3 install -e .
66
 
67
  RUN mim install mmcv-full==1.6.2 -f https://download.openmmlab.com/mmcv/dist/11.5/1.11.0/index.html
68
 
69
- RUN pip3 install rasterio scikit-image
70
  # Set the working directory to the user's home directory
71
  WORKDIR $HOME/app
72
 
@@ -75,4 +74,4 @@ ENV LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/code/miniconda/lib"
75
  # Copy the current directory contents into the container at $HOME/app setting the owner to the user
76
  COPY --chown=user . $HOME/app
77
 
78
- CMD ["python3", "app.py"]
 
56
 
57
  WORKDIR /home/user
58
 
59
+ RUN git clone --branch mmseg-only https://github.com/NASA-IMPACT/hls-foundation-os.git
 
60
 
61
 
62
  WORKDIR hls-foundation-os
 
65
 
66
  RUN mim install mmcv-full==1.6.2 -f https://download.openmmlab.com/mmcv/dist/11.5/1.11.0/index.html
67
 
68
+ RUN pip3 install rasterio scikit-image fast-dash gunicorn
69
  # Set the working directory to the user's home directory
70
  WORKDIR $HOME/app
71
 
 
74
  # Copy the current directory contents into the container at $HOME/app setting the owner to the user
75
  COPY --chown=user . $HOME/app
76
 
77
+ CMD exec gunicorn app:server --bind :7860
app.py CHANGED
@@ -197,36 +197,38 @@ custom_test_pipeline=process_test_pipeline(model.cfg.data.test.pipeline, None)
197
 
198
  func = partial(inference_on_file, model=model, custom_test_pipeline=custom_test_pipeline)
199
 
200
- with gr.Blocks() as demo:
201
-
202
- gr.Markdown(value='# Prithvi sen1floods11')
203
- gr.Markdown(value='''Prithvi is a first-of-its-kind temporal Vision transformer pretrained by the IBM and NASA team on continental US Harmonised Landsat Sentinel 2 (HLS) data. This demo showcases how the model was finetuned to detect water at a higher resolution than it was trained on (i.e. 10m versus 30m) using Sentinel 2 imagery from on the [sen1floods11 dataset](https://github.com/cloudtostreet/Sen1Floods11). More detailes can be found [here](https://huggingface.co/ibm-nasa-geospatial/Prithvi-100M-sen1floods11).\n
204
- The user needs to provide a Sentinel 2 image with all the 12 bands (in the usual Sentinel 2) order in reflectance units multiplied by 10,000 (e.g. to save on space), with the code that is going to pull up Blue, Green, Red, Narrow NIR, SWIR, SWIR 2.
205
- ''')
206
- with gr.Row():
207
- with gr.Column():
208
- inp = gr.File()
209
- btn = gr.Button("Submit")
210
-
211
- with gr.Row():
212
- gr.Markdown(value='### Input RGB')
213
- gr.Markdown(value='### Model prediction (Black: Land; White: Water)')
214
 
215
- with gr.Row():
216
- out1=gr.Image(image_mode='RGB')
217
- out2 = gr.Image(image_mode='L')
218
-
219
- btn.click(fn=func, inputs=inp, outputs=[out1, out2])
220
-
221
- with gr.Row():
222
- gr.Examples(examples=["India_900498_S2Hand.tif",
223
- "Spain_7370579_S2Hand.tif",
224
- "USA_430764_S2Hand.tif"],
225
- inputs=inp,
226
- outputs=[out1, out2],
227
- preprocess=preprocess_example,
228
- fn=func,
229
- cache_examples=True,
230
- )
231
-
232
- demo.launch()
 
 
 
 
 
197
 
198
  func = partial(inference_on_file, model=model, custom_test_pipeline=custom_test_pipeline)
199
 
200
+
201
+ if __name__ == "__main__":
202
+ with gr.Blocks() as demo:
203
+
204
+ gr.Markdown(value='# Prithvi sen1floods11')
205
+ gr.Markdown(value='''Prithvi is a first-of-its-kind temporal Vision transformer pretrained by the IBM and NASA team on continental US Harmonised Landsat Sentinel 2 (HLS) data. This demo showcases how the model was finetuned to detect water at a higher resolution than it was trained on (i.e. 10m versus 30m) using Sentinel 2 imagery from on the [sen1floods11 dataset](https://github.com/cloudtostreet/Sen1Floods11). More detailes can be found [here](https://huggingface.co/ibm-nasa-geospatial/Prithvi-100M-sen1floods11).\n
206
+ The user needs to provide a Sentinel 2 image with all the 12 bands (in the usual Sentinel 2) order in reflectance units multiplied by 10,000 (e.g. to save on space), with the code that is going to pull up Blue, Green, Red, Narrow NIR, SWIR, SWIR 2.
207
+ ''')
208
+ with gr.Row():
209
+ with gr.Column():
210
+ inp = gr.File()
211
+ btn = gr.Button("Submit")
 
 
212
 
213
+ with gr.Row():
214
+ gr.Markdown(value='### Input RGB')
215
+ gr.Markdown(value='### Model prediction (Black: Land; White: Water)')
216
+
217
+ with gr.Row():
218
+ out1=gr.Image(image_mode='RGB')
219
+ out2 = gr.Image(image_mode='L')
220
+
221
+ btn.click(fn=func, inputs=inp, outputs=[out1, out2])
222
+
223
+ with gr.Row():
224
+ gr.Examples(examples=["India_900498_S2Hand.tif",
225
+ "Spain_7370579_S2Hand.tif",
226
+ "USA_430764_S2Hand.tif"],
227
+ inputs=inp,
228
+ outputs=[out1, out2],
229
+ preprocess=preprocess_example,
230
+ fn=func,
231
+ cache_examples=True,
232
+ )
233
+
234
+ demo.launch(server_port=5001)
examples/India_900498_S2Hand.png ADDED
examples/India_900498_S2Hand_result.png ADDED
examples/Spain_7370579_S2Hand.png ADDED
examples/Spain_7370579_S2Hand_result.png ADDED
examples/USA_430764_S2Hand.png ADDED
examples/USA_430764_S2Hand_result.png ADDED
fast_dash_app.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fast_dash import FastDash, UploadImage, Image, Upload
2
+ import PIL
3
+ import numpy as np
4
+ from app import func
5
+ import tempfile
6
+ import base64
7
+
8
+ def generate_image(input_tiff_image: Upload) -> (Image, Image):
9
+
10
+ with tempfile.NamedTemporaryFile(delete=False) as f:
11
+ contents = input_tiff_image.encode("utf8").split(b";base64,")[1]
12
+ f.write(base64.decodebytes(contents))
13
+
14
+ input_image, output_image = func(f)
15
+
16
+ # Convert numpy arrays to PIL images
17
+ input_image = PIL.Image.fromarray(np.uint8(input_image)).convert('RGB')
18
+ output_image = PIL.Image.fromarray(np.uint8(output_image)).convert('RGB')
19
+
20
+ return input_image, output_image
21
+
22
+ app = FastDash(generate_image, port=8000)
23
+
24
+ if __name__ == "__main__":
25
+ app.run()
fastapp.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from fast_dash import FastDash, UploadImage, Image, Upload
3
+ import PIL
4
+ import numpy as np
5
+
6
+ from app import func
7
+ import tempfile
8
+ import base64
9
+
10
+ examples = ["India_900498_S2Hand", "Spain_7370579_S2Hand", "USA_430764_S2Hand"]
11
+
12
+ def detect_water(select_an_example: str = examples, input_tiff_image: Upload = None) -> (Image, Image):
13
+ """
14
+ NASA and IBM recently uploaded their foundation model on Hugging Face, Pritivi, at https://huggingface.co/ibm-nasa-geospatial.
15
+ This demo, built with Fast Dash, showcases a version of Prithvi that they finetuned to detect water from satellite images.
16
+ Select an example or upload your own TIFF image.
17
+ If uploading your own image, read the format details at https://huggingface.co/ibm-nasa-geospatial/Prithvi-100M-sen1floods11.
18
+ """
19
+
20
+ # If example is selected
21
+ if input_tiff_image is None:
22
+ input_satellite_image = PIL.Image.open(os.path.join("examples", f"{select_an_example}.png"))
23
+ water_prediction_mask = PIL.Image.open(os.path.join("examples", f"{select_an_example}_result.png"))
24
+
25
+ # If file is uploaded, run inference
26
+ else:
27
+ with tempfile.NamedTemporaryFile(delete=False) as f:
28
+ contents = input_tiff_image.encode("utf8").split(b";base64,")[1]
29
+ f.write(base64.decodebytes(contents))
30
+
31
+ input_satellite_image, water_prediction_mask = func(f)
32
+
33
+ input_satellite_image = PIL.Image.fromarray(np.uint8(input_satellite_image)).convert('RGB')
34
+ water_prediction_mask = PIL.Image.fromarray(np.uint8(water_prediction_mask)).convert('RGB')
35
+
36
+ return input_satellite_image, water_prediction_mask
37
+
38
+ app = FastDash(detect_water, theme="cosmo", port=7860)
39
+ server = app.server
40
+
41
+ if __name__ == "__main__":
42
+ app.run()