Spaces:
Sleeping
Sleeping
import json | |
import torch | |
from huggingnft.lightweight_gan.train import timestamped_filename | |
from streamlit_option_menu import option_menu | |
from huggingface_hub import hf_hub_download, file_download | |
from PIL import Image | |
from huggingface_hub.hf_api import HfApi | |
import streamlit as st | |
from huggingnft.lightweight_gan.lightweight_gan import Generator, LightweightGAN, evaluate_in_chunks, Trainer | |
from accelerate import Accelerator | |
from huggan.pytorch.cyclegan.modeling_cyclegan import GeneratorResNet | |
from torchvision import transforms as T | |
from torchvision.transforms import Compose, Resize, ToTensor, Normalize, RandomCrop, RandomHorizontalFlip | |
from torchvision.utils import make_grid | |
hfapi = HfApi() | |
model_names = [model.modelId[model.modelId.index("/") + 1:] for model in hfapi.list_models(author="huggingnft")] | |
# streamlit-option-menu | |
# st.set_page_config(page_title="Sharone's Streamlit App Gallery", page_icon="", layout="wide") | |
# sysmenu = ''' | |
# <style> | |
# #MainMenu {visibility:hidden;} | |
# footer {visibility:hidden;} | |
# ''' | |
# st.markdown(sysmenu,unsafe_allow_html=True) | |
# # Add a logo (optional) in the sidebar | |
# logo = Image.open(r'C:\Users\13525\Desktop\Insights_Bees_logo.png') | |
# profile = Image.open(r'C:\Users\13525\Desktop\medium_profile.png') | |
ABOUT_TEXT = "🤗 Hugging NFT - Generate NFT by OpenSea collection name." | |
CONTACT_TEXT = "Here is some contact info" | |
GENERATE_IMAGE_TEXT = "Text about generation" | |
INTERPOLATION_TEXT = "Text about Interpolation" | |
COLLECTION2COLLECTION_TEXT = "Text about Collection2Collection" | |
STOPWORDS = ["-old"] | |
COLLECTION2COLLECTION_KEYS = ["__2__"] | |
def load_lightweight_model(model_name): | |
file_path = file_download.hf_hub_download( | |
repo_id=model_name, | |
filename="config.json" | |
) | |
config = json.loads(open(file_path).read()) | |
organization_name, name = model_name.split("/") | |
model = Trainer(**config, organization_name=organization_name, name=name) | |
model.load(use_cpu=True) | |
model.accelerator = Accelerator() | |
return model | |
def clean_models(model_names, stopwords): | |
cleaned_model_names = [] | |
for model_name in model_names: | |
clear = True | |
for stopword in stopwords: | |
if stopword in model_name: | |
clear = False | |
break | |
if clear: | |
cleaned_model_names.append(model_name) | |
return cleaned_model_names | |
def get_concat_h(im1, im2): | |
dst = Image.new('RGB', (im1.width + im2.width, im1.height)) | |
dst.paste(im1, (0, 0)) | |
dst.paste(im2, (im1.width, 0)) | |
return dst | |
model_names = clean_models(model_names, STOPWORDS) | |
with st.sidebar: | |
choose = option_menu("Hugging NFT", | |
["About", "Generate image", "Interpolation", "Collection2Collection", "Contact"], | |
icons=['house', 'camera fill', 'bi bi-youtube', 'book', 'person lines fill'], | |
menu_icon="app-indicator", default_index=0, | |
styles={ | |
# "container": {"padding": "5!important", "background-color": "#fafafa", }, | |
"container": {"border-radius": ".0rem"}, | |
# "icon": {"color": "orange", "font-size": "25px"}, | |
# "nav-link": {"font-size": "16px", "text-align": "left", "margin": "0px", | |
# "--hover-color": "#eee"}, | |
# "nav-link-selected": {"background-color": "#02ab21"}, | |
} | |
) | |
st.sidebar.markdown( | |
""" | |
<style> | |
.aligncenter { | |
text-align: center; | |
} | |
</style> | |
<p style='text-align: center'> | |
<a href="https://github.com/AlekseyKorshuk/huggingnft" target="_blank">Project Repository</a> | |
</p> | |
<p class="aligncenter"> | |
<a href="https://github.com/AlekseyKorshuk/huggingnft" target="_blank"> | |
<img src="https://img.shields.io/github/stars/AlekseyKorshuk/huggingnft?style=social"/> | |
</a> | |
</p> | |
<p class="aligncenter"> | |
<a href="https://twitter.com/alekseykorshuk" target="_blank"> | |
<img src="https://img.shields.io/twitter/follow/alekseykorshuk?style=social"/> | |
</a> | |
</p> | |
""", | |
unsafe_allow_html=True, | |
) | |
if choose == "About": | |
st.title(choose) | |
st.markdown(ABOUT_TEXT) | |
if choose == "Contact": | |
st.title(choose) | |
st.markdown(CONTACT_TEXT) | |
if choose == "Generate image": | |
st.title(choose) | |
st.markdown(GENERATE_IMAGE_TEXT) | |
model_name = st.selectbox( | |
'Choose model:', | |
clean_models(model_names, COLLECTION2COLLECTION_KEYS) | |
) | |
generation_type = st.selectbox( | |
'Select generation type:', | |
["default", "ema"] | |
) | |
nrows = st.number_input("Number of rows:", | |
min_value=1, | |
max_value=10, | |
step=1, | |
value=8, | |
) | |
generate_image_button = st.button("Generate") | |
if generate_image_button: | |
with st.spinner(text=f"Downloading selected model..."): | |
model = load_lightweight_model(f"huggingnft/{model_name}") | |
with st.spinner(text=f"Generating..."): | |
st.image( | |
model.generate_app( | |
num=timestamped_filename(), | |
nrow=nrows, | |
checkpoint=-1, | |
types=generation_type | |
)[0] | |
) | |
if choose == "Interpolation": | |
st.title(choose) | |
st.markdown(INTERPOLATION_TEXT) | |
model_name = st.selectbox( | |
'Choose model:', | |
clean_models(model_names, COLLECTION2COLLECTION_KEYS) | |
) | |
nrows = st.number_input("Number of rows:", | |
min_value=1, | |
max_value=10, | |
step=1, | |
value=1, | |
) | |
num_steps = st.number_input("Number of steps:", | |
min_value=1, | |
max_value=1000, | |
step=1, | |
value=100, | |
) | |
generate_image_button = st.button("Generate") | |
if generate_image_button: | |
with st.spinner(text=f"Downloading selected model..."): | |
model = load_lightweight_model(f"huggingnft/{model_name}") | |
my_bar = st.progress(0) | |
result = model.generate_interpolation( | |
num=timestamped_filename(), | |
num_image_tiles=nrows, | |
num_steps=num_steps, | |
save_frames=False, | |
progress_bar=my_bar | |
) | |
my_bar.empty() | |
with st.spinner(text=f"Uploading result..."): | |
st.image(result) | |
if choose == "Collection2Collection": | |
st.title(choose) | |
st.markdown(COLLECTION2COLLECTION_TEXT) | |
model_name = st.selectbox( | |
'Choose model:', | |
set(model_names) - set(clean_models(model_names, COLLECTION2COLLECTION_KEYS)) | |
) | |
nrows = st.number_input("Number of images to generate:", | |
min_value=1, | |
max_value=10, | |
step=1, | |
value=1, | |
) | |
generate_image_button = st.button("Generate") | |
if generate_image_button: | |
n_channels = 3 | |
image_size = 256 | |
input_shape = (image_size, image_size) | |
transform = Compose([ | |
T.ToPILImage(), | |
T.Resize(input_shape), | |
ToTensor(), | |
Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), | |
]) | |
# generator = modeling_dcgan.Generator.from_pretrained("huggingnft/cryptopunks") | |
with st.spinner(text=f"Downloading selected model..."): | |
translator = GeneratorResNet.from_pretrained(f'huggingnft/{model_name}', | |
input_shape=(n_channels, image_size, image_size), | |
num_residual_blocks=9) | |
z = torch.randn(nrows, 100, 1, 1) | |
with st.spinner(text=f"Downloading selected model..."): | |
model = load_lightweight_model(f"huggingnft/{model_name.split('__2__')[0]}") | |
with st.spinner(text=f"Generating input images..."): | |
punks = model.generate_app( | |
num=timestamped_filename(), | |
nrow=4, | |
checkpoint=-1, | |
types="default" | |
)[1] | |
pipe_transform = T.Resize((256, 256)) | |
input = pipe_transform(punks) | |
with st.spinner(text=f"Generating output images..."): | |
output = translator(input) | |
out_img = make_grid(output, | |
nrow=4, normalize=True) | |
# out_img = make_grid(punks, | |
# nrow=8, normalize=True) | |
out_transform = Compose([ | |
T.ToPILImage() | |
]) | |
results = [] | |
for out_punk, out_ape in zip(input, output): | |
results.append( | |
get_concat_h(out_transform(make_grid(out_punk, nrow=1, normalize=True)), out_transform(make_grid(out_ape, nrow=1, normalize=True))) | |
) | |
for result in results: | |
st.image(result) | |