File size: 2,350 Bytes
f53a084
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
from PIL import Image 
import os,csv
import pandas as pd
import numpy as np
import gradio as gr

prompts=pd.read_csv('promptsadjectives.csv')
masc = prompts['Masc-adj'][:10].tolist()
fem = prompts['Fem-adj'][:10].tolist()
adjectives = sorted(masc+fem)
adjectives.insert(0, '')
occupations = prompts['Occupation-Noun'][:150].tolist()


def get_averages(adj, profession):
    if adj != "":
        prompt = (adj + ' ' + profession).replace(' ','_')
    else:
        prompt = profession.replace(' ','_')
    #TODO: fix upper/lowercase error
    sd14_average = 'facer_faces/SDv14/'+prompt+'.png'
    if os.path.isfile(sd14_average) == False:
        sd14_average = 'facer_faces/blank.png'
    sdv2_average = 'facer_faces/SDv2/'+prompt+'.png'
    if os.path.isfile(sdv2_average) == False:
        sdv2_average = 'facer_faces/blank.png'
    dalle_average = 'facer_faces/dalle2/'+prompt.lower()+'.png'
    if os.path.isfile(dalle_average) == False:
        dalle_average = 'facer_faces/blank.png'    
    
    return((Image.open(sd14_average), "Stable Diffusion v 1.4"), (Image.open(sdv2_average), "Stable Diffusion v 2"), (Image.open(dalle_average), "Dall-E 2"))


with gr.Blocks() as demo:

    gr.Markdown("# Text-to-Image Diffusion Model Average Faces")
    gr.Markdown("### We ran 150 professions through 3 diffusion models to examine what they generate.")
    gr.Markdown("#### Choose one of the professions and adjectives and see the average face generated by each model.")
    with gr.Row():

        with gr.Column():
            adj = gr.Dropdown(sorted(adjectives, key=str.casefold), value = '',label = "Choose an adjective", interactive= True)
            prof = gr.Dropdown(sorted(occupations, key=str.casefold), value = '', label = "Choose a profession", interactive= True)
            btn = gr.Button("Get average faces!")
        with gr.Column():
            gallery = gr.Gallery(
            label="Average images", show_label=False, elem_id="gallery"
        ).style(grid=[0,3], height="auto")
            gr.Markdown("The three models are: Stable Diffusion v.1.4, Stable Diffusion v.2, and Dall-E 2.")
            gr.Markdown("If you see a black square above, we weren't able to compute an average face for this profession, sorry!")
    
    
    btn.click(fn=get_averages, inputs=[adj,prof], outputs=gallery)

demo.launch(share=True)