Spaces:
Runtime error
Runtime error
# based on https://huggingface.co/spaces/NimaBoscarino/climategan/blob/main/app.py # noqa: E501 | |
# thank you @NimaBoscarino | |
import os | |
from textwrap import dedent | |
from urllib import parse | |
import googlemaps | |
import gradio as gr | |
import numpy as np | |
from gradio.components import ( | |
HTML, | |
Button, | |
Column, | |
Dropdown, | |
Image, | |
Markdown, | |
Radio, | |
Row, | |
Textbox, | |
) | |
from skimage import io | |
from climategan_wrapper import ClimateGAN | |
HTMLS = [ | |
dedent( | |
""" | |
<p> | |
Climate change does not impact everyone equally. | |
This Space shows the effects of the climate emergency, | |
"one address at a time". | |
</p> | |
<p> | |
Visit the original experience at | |
<a href="https://thisclimatedoesnotexist.com/"> | |
ThisClimateDoesNotExist.com | |
</a> | |
</p> | |
<br> | |
<p> | |
Enter an address or upload a Street View image, and ClimateGAN | |
will generate images showing how the location could be impacted | |
by flooding, wildfires, or smog if it happened there. | |
</p> | |
<br> | |
<p> | |
This is <strong>not</strong> an exercise in climate prediction, | |
rather an exercise of empathy, to put yourself in other's shoes, | |
as if Climate Change came crushing on your doorstep. | |
</p> | |
""" | |
), | |
dedent( | |
""" | |
<br><br><br><br> | |
<p style='text-align: center'> | |
Visit | |
<a href='https://thisclimatedoesnotexist.com/'> | |
ThisClimateDoesNotExist.com | |
</a> | |
for more information | |
| | |
Original | |
<a href='https://github.com/cc-ai/climategan'> | |
ClimateGAN GitHub Repo | |
</a> | |
</p> | |
<br> | |
<p> | |
After you have selected an image and started the inference you | |
will see all the outputs of ClimateGAN, including intermediate | |
outputs such as the flood mask, the segmentation map and the | |
depth maps used to produce the 3 events. | |
</p> | |
<br> | |
<p> | |
This Space makes use of recent Stable Diffusion in-painting | |
pipelines to replace ClimateGAN's original Painter. If you | |
select 'Both' painters, you will see a comparison | |
</p> | |
<br> | |
<br> | |
<p> | |
Read the original | |
<a | |
href='https://openreview.net/forum?id=EZNOb_uNpJk' | |
target='_blank'> | |
ICLR 2021 ClimateGAN paper | |
</a> | |
</p> | |
""" | |
), | |
] | |
CSS = dedent( | |
""" | |
a { | |
color: #0088ff; | |
text-decoration: underline; | |
} | |
strong { | |
color: #c34318; | |
} | |
""" | |
) | |
def toggle(radio): | |
if "address" in radio.lower(): | |
return [ | |
gr.update(visible=True), | |
gr.update(visible=False), | |
gr.update(visible=True), | |
] | |
else: | |
return [ | |
gr.update(visible=False), | |
gr.update(visible=True), | |
gr.update(visible=True), | |
] | |
def predict(cg: ClimateGAN, api_key): | |
def _predict(*args): | |
image = place = painter = radio = None | |
if api_key: | |
radio, image, place, painter = args | |
else: | |
image, painter = args | |
if api_key and place and "address" in radio.lower(): | |
geocode_result = gmaps.geocode(place) | |
address = geocode_result[0]["formatted_address"] | |
static_map_url = f"https://maps.googleapis.com/maps/api/streetview?size=640x640&location={parse.quote(address)}&source=outdoor&key={api_key}" | |
img_np = io.imread(static_map_url) | |
else: | |
img_np = image | |
painters = { | |
"ClimateGAN Painter": "climategan", | |
"Stable Diffusion Painter": "stable_diffusion", | |
"Both": "both", | |
} | |
output_dict = cg.infer_single(img_np, painters[painter], as_pil_image=True) | |
input_image = output_dict["input"] | |
masked_input = output_dict["masked_input"] | |
wildfire = output_dict["wildfire"] | |
smog = output_dict["smog"] | |
depth = np.repeat(output_dict["depth"], 3, axis=-1) | |
segmentation = output_dict["segmentation"] | |
climategan_flood = output_dict.get( | |
"climategan_flood", | |
np.ones(input_image.shape, dtype=np.uint8) * 255, | |
) | |
stable_flood = output_dict.get( | |
"stable_flood", | |
np.ones(input_image.shape, dtype=np.uint8) * 255, | |
) | |
stable_copy_flood = output_dict.get( | |
"stable_copy_flood", | |
np.ones(input_image.shape, dtype=np.uint8) * 255, | |
) | |
concat = output_dict.get( | |
"concat", | |
np.ones(input_image.shape, dtype=np.uint8) * 255, | |
) | |
return ( | |
input_image, | |
masked_input, | |
segmentation, | |
depth, | |
climategan_flood, | |
stable_flood, | |
stable_copy_flood, | |
concat, | |
wildfire, | |
smog, | |
) | |
return _predict | |
if __name__ == "__main__": | |
api_key = os.environ.get("GMAPS_API_KEY") | |
gmaps = None | |
if api_key is not None: | |
gmaps = googlemaps.Client(key=api_key) | |
cg = ClimateGAN( | |
model_path="config/model/masker", | |
dev_mode=os.environ.get("CG_DEV_MODE", "").lower() == "true", | |
) | |
cg._setup_stable_diffusion() | |
radio = address = None | |
pred_ins = [] | |
pred_outs = [] | |
with gr.Blocks(css=CSS) as app: | |
with Row(): | |
with Column(): | |
Markdown("# ClimateGAN: Visualize Climate Change") | |
HTML(HTMLS[0]) | |
with Column(): | |
HTML(HTMLS[1]) | |
with Row(): | |
Markdown("## Inputs") | |
with Row(): | |
with Column(): | |
if api_key: | |
radio = Radio(["From Address", "From Image"], label="Input Type") | |
pred_ins += [radio] | |
im_inp = Image(label="Input Image", visible=not api_key) | |
pred_ins += [im_inp] | |
if api_key: | |
address = Textbox(label="Address or place name", visible=False) | |
pred_ins += [address] | |
with Column(): | |
pred_ins += [ | |
Dropdown( | |
choices=[ | |
"ClimateGAN Painter", | |
"Stable Diffusion Painter", | |
"Both", | |
], | |
label="Choose Flood Painter", | |
value="Both", | |
) | |
] | |
btn = Button( | |
"See for yourself!", | |
label="Run", | |
variant="primary", | |
visible=not api_key, | |
) | |
with Row(): | |
Markdown("## Outputs") | |
with Row(): | |
pred_outs += [Image(type="numpy", label="Original image")] | |
pred_outs += [Image(type="numpy", label="Masked input image")] | |
pred_outs += [Image(type="numpy", label="Segmentation map")] | |
pred_outs += [Image(type="numpy", label="Depth map")] | |
with Row(): | |
pred_outs += [Image(type="numpy", label="ClimateGAN-Flooded image")] | |
pred_outs += [Image(type="numpy", label="Stable Diffusion-Flooded image")] | |
pred_outs += [ | |
Image( | |
type="numpy", | |
label="Stable Diffusion-Flooded image (restricted to masked area)", | |
) | |
] | |
with Row(): | |
pred_outs += [Image(type="numpy", label="Comparison of flood images")] | |
with Row(): | |
pred_outs += [Image(type="numpy", label="Wildfire")] | |
pred_outs += [Image(type="numpy", label="Smog")] | |
Image(type="numpy", label="Empty on purpose", interactive=False) | |
btn.click(predict(cg, api_key), inputs=pred_ins, outputs=pred_outs) | |
if api_key: | |
radio.change(toggle, inputs=[radio], outputs=[address, im_inp, btn]) | |
app.launch() | |