Spaces:
Runtime error
Runtime error
# based on https://huggingface.co/spaces/NimaBoscarino/climategan/blob/main/app.py # noqa: E501 | |
# thank you @NimaBoscarino | |
import os | |
import gradio as gr | |
import googlemaps | |
from skimage import io | |
from urllib import parse | |
import numpy as np | |
from climategan_wrapper import ClimateGAN | |
def predict(cg: ClimateGAN, api_key): | |
def _predict(*args): | |
image = place = painter = None | |
if len(args) == 2: | |
image = args[0] | |
painter = args[1] | |
else: | |
assert len(args) == 3, "Unknown number of inputs {}".format(len(args)) | |
image, place, painter = args | |
if api_key and place: | |
geocode_result = gmaps.geocode(place) | |
address = geocode_result[0]["formatted_address"] | |
static_map_url = f"https://maps.googleapis.com/maps/api/streetview?size=640x640&location={parse.quote(address)}&source=outdoor&key={api_key}" | |
img_np = io.imread(static_map_url) | |
else: | |
img_np = image | |
painters = { | |
"ClimateGAN Painter": "both", | |
"Stable Diffusion Painter": "stable_diffusion", | |
"Both": "climategan", | |
} | |
output_dict = cg.infer_single(img_np, painters[painter]) | |
input_image = output_dict["input"] | |
masked_input = output_dict["masked_input"] | |
wildfire = output_dict["wildfire"] | |
smog = output_dict["smog"] | |
climategan_flood = output_dict.get( | |
"climategan_flood", | |
np.ones(input_image.shape) * 255, | |
) | |
stable_flood = output_dict.get( | |
"stable_flood", | |
np.ones(input_image.shape) * 255, | |
) | |
stable_copy_flood = output_dict.get( | |
"stable_copy_flood", | |
np.ones(input_image.shape) * 255, | |
) | |
concat = output_dict.get( | |
"concat", | |
np.ones(input_image.shape) * 255, | |
) | |
return ( | |
input_image, | |
masked_input, | |
climategan_flood, | |
stable_flood, | |
stable_copy_flood, | |
concat, | |
wildfire, | |
smog, | |
) | |
return _predict | |
if __name__ == "__main__": | |
api_key = os.environ.get("GMAPS_API_KEY") | |
gmaps = None | |
if api_key is not None: | |
gmaps = googlemaps.Client(key=api_key) | |
cg = ClimateGAN( | |
model_path="config/model/masker", | |
dev_mode=os.environ.get("CG_DEV_MODE", "false").lower() == "true", | |
) | |
cg._setup_stable_diffusion() | |
with gr.Blocks() as blocks: | |
with gr.Row(): | |
with gr.Column(): | |
gr.Markdown("# ClimateGAN: Visualize Climate Change") | |
gr.HTML( | |
'Climate change does not impact everyone equally. This Space shows the effects of the climate emergency, "one address at a time". Visit the original experience at <a href="https://thisclimatedoesnotexist.com/">ThisClimateDoesNotExist.com</a>.<br>Enter an address or place name, and ClimateGAN will generate images showing how the location could be impacted by flooding, wildfires, or smog.' # noqa: E501 | |
) | |
with gr.Column(): | |
gr.HTML( | |
"<p style='text-align: center'>This project is an unofficial clone of <a href='https://thisclimatedoesnotexist.com/'>ThisClimateDoesNotExist</a> | <a href='https://github.com/cc-ai/climategan'>ClimateGAN GitHub Repo</a></p>" # noqa: E501 | |
) | |
with gr.Row(): | |
gr.Markdown("## Inputs") | |
with gr.Row(): | |
with gr.Column(): | |
inputs = [gr.inputs.Image(label="Input Image")] | |
with gr.Column(): | |
if api_key: | |
inputs += [gr.inputs.Textbox(label="Address or place name")] | |
inputs += [ | |
gr.inputs.Dropdown( | |
choices=[ | |
"ClimateGAN Painter", | |
"Stable Diffusion Painter", | |
"Both", | |
], | |
label="Choose Flood Painter", | |
default="Both", | |
) | |
] | |
btn = gr.Button("See for yourself!", label="Run") | |
with gr.Row(): | |
gr.Markdown("## Outputs") | |
with gr.Row(): | |
outputs = [] | |
outputs.append( | |
gr.outputs.Image(type="numpy", label="Original image"), | |
) | |
outputs.append( | |
gr.outputs.Image(type="numpy", label="Masked input image"), | |
) | |
with gr.Row(): | |
outputs.append( | |
gr.outputs.Image(type="numpy", label="ClimateGAN-Flooded image"), | |
) | |
outputs.append( | |
gr.outputs.Image(type="numpy", label="Stable Diffusion-Flooded image"), | |
) | |
outputs.append( | |
gr.outputs.Image( | |
type="numpy", | |
label="Stable Diffusion-Flooded image (restricted to masked area)", | |
) | |
), | |
with gr.Row(): | |
outputs.append( | |
gr.outputs.Image(type="numpy", label="Comparison of previous images"), | |
) | |
with gr.Row(): | |
outputs.append( | |
gr.outputs.Image(type="numpy", label="Wildfire"), | |
) | |
outputs.append( | |
gr.outputs.Image(type="numpy", label="Smog"), | |
) | |
btn.click(predict(cg, api_key), inputs=inputs, outputs=outputs) | |
blocks.launch() | |