from htbuilder import HtmlElement, div, ul, li, br, hr, a, p, img, styles, classes, fonts from htbuilder.units import percent, px from htbuilder.funcs import rgba, rgb import streamlit as st import os import sys import argparse import clip import numpy as np from PIL import Image from dalle.models import Dalle from dalle.utils.utils import set_seed, clip_score def link(link, text, **style): return a(_href=link, _target="_blank", style=styles(**style))(text) def layout(*args): style = """ """ style_div = styles( position="fixed", left=0, bottom=0, margin=px(0, 0, 0, 0), width=percent(100), color="black", text_align="center", height="auto", opacity=1 ) style_hr = styles( display="block", margin=px(8, 8, "auto", "auto"), border_style="inset", border_width=px(2) ) body = p() foot = div( style=style_div )( hr( style=style_hr ), body ) st.markdown(style, unsafe_allow_html=True) for arg in args: if isinstance(arg, str): body(arg) elif isinstance(arg, HtmlElement): body(arg) st.markdown(str(foot), unsafe_allow_html=True) def footer(): myargs = [ "Created by ", link("https://jonathanmalott.com", "Jonathan Malott"), br(), link("https://bridgingbarriers.utexas.edu/good-systems", "Good Systems"), " Grand Challenge", ", The University of Texas at Austin.", " Advised by Dr. Junfeng Jiao.", br(), br(), ] layout(*myargs) #footer() def generate(prompt,crazy,k): device = 'cpu' print("-2-") model = Dalle.from_pretrained('.cache/minDALL-E/1.3B') # This will automatically download the pretrained model. print("-3-") model.to(device=device) num_candidates = 1 images = [] set_seed(np.random.randint(0,10000)) # Sampling images = model.sampling(prompt=prompt, top_k=2048, top_p=None, softmax_temperature=crazy, num_candidates=num_candidates, device=device).cpu().numpy() images = np.transpose(images, (0, 2, 3, 1)) # CLIP Re-ranking model_clip, preprocess_clip = clip.load("ViT-B/32", device=device) model_clip.to(device=device) rank = clip_score(prompt=prompt, images=images, model_clip=model_clip, preprocess_clip=preprocess_clip, device=device) result = images[rank] item = {} item['prompt'] = prompt item['crazy'] = crazy item['k'] = k item['image'] = Image.fromarray((result*255).astype(np.uint8)) st.session_state.results.append(item) def drawGrid(): master = {} order = 0 #print(st.session_state.results) for r in st.session_state.results[::-1]: _txt = r['prompt']+" "+str(r['crazy'])+" "+str(r['k']) if(_txt not in master): master[_txt] = [r] order += 1 else: master[_txt].append(r) for m in master: #with placeholder.container(): txt = master[m][0]['prompt']+" (temperature:"+ str(master[m][0]['crazy']) + ", top k:" + str(master[m][0]['k']) + ")" st.subheader(txt) col1, col2, col3 = st.columns(3) for ix, item in enumerate(master[m]): if ix % 3 == 0: with col1: st.image(item["image"]) if ix % 3 == 1: with col2: st.image(item["image"]) if ix % 3 == 2: with col3: st.image(item["image"])