File size: 1,834 Bytes
2695ee3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1899be2
 
2695ee3
 
 
1ab056b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch

# Load the pre-trained model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("pparasurama/raceBERT-ethnicity")
model = AutoModelForSequenceClassification.from_pretrained("pparasurama/raceBERT-ethnicity")

# Mapping of model output IDs to ethnicity labels
id2label = {
    0: "GreaterEuropean,British",
    1: "GreaterEuropean,WestEuropean,French",
    2: "GreaterEuropean,WestEuropean,Italian",
    3: "GreaterEuropean,WestEuropean,Hispanic",
    4: "GreaterEuropean,Jewish",
    5: "GreaterEuropean,EastEuropean",
    6: "Asian,IndianSubContinent",
    7: "Asian,GreaterEastAsian,Japanese",
    8: "GreaterAfrican,Muslim",
    9: "Asian,GreaterEastAsian,EastAsian",
    10: "GreaterEuropean,WestEuropean,Nordic",
    11: "GreaterEuropean,WestEuropean,Germanic",
    12: "GreaterAfrican,Africans"
}

# Function to make predictions based on the input name
def predict_ethnicity(name):
    inputs = tokenizer(name, return_tensors="pt")
    with torch.no_grad():
        outputs = model(**inputs)
    logits = outputs.logits
    probabilities = torch.softmax(logits, dim=1)[0]
    
    # Get top 5 predictions
    top_preds = torch.topk(probabilities, 5)
    
    # Prepare the output as a sorted human-friendly list
    result = "\n".join([f"{id2label[idx.item()]}: {prob.item() * 100:.2f}%" for idx, prob in zip(top_preds.indices, top_preds.values)])
    
    return result

# Gradio Interface
interface = gr.Interface(
    fn=predict_ethnicity,
    inputs=gr.Textbox(lines=1, placeholder="Enter a name"),
    outputs="text",
    title="TOPS Infosolutions Ethnicity Predictor - Kaleida",
    description="Enter a person's name and get the predicted ethnicity breakdown.",
)

# Launch the Gradio app
interface.launch()