Spaces:
Runtime error
Runtime error
import math | |
import numpy as np | |
import pandas as pd | |
import tensorflow as tf | |
import tensorflow_addons as tfa | |
from tensorflow import keras | |
from tensorflow.keras import layers | |
import gradio as gr | |
from huggingface_hub import from_pretrained_keras | |
model = from_pretrained_keras('keras-io/tab_transformer', custom_objects={'optimizer': tfa.optimizers.AdamW}) | |
def get_dataset_from_pandas(data): | |
for col in data.columns: | |
if data[col].dtype == 'float64': | |
data[col] = data[col].astype('float32') | |
elif col == 'age': | |
data[col] = data[col].astype('float32') | |
ds = tf.data.Dataset.from_tensors(dict(data.drop(columns = [i for i in ['income_bracket','fnlwgt'] if i in data.columns]))) | |
return ds | |
def infer(age, workclass, education, education_num, marital_status, occupation, relationship, race, gender, capital_gain, capital_loss, hours_per_week, native_country): | |
data = pd.DataFrame({ | |
'age': age, | |
'workclass': workclass, | |
'education': education, | |
'education_num': education_num, | |
'marital_status': marital_status, | |
'occupation': occupation, | |
'relationship':relationship, | |
'race': race, | |
'gender': gender, | |
'capital_gain': capital_gain, | |
'capital_loss': capital_loss, | |
'hours_per_week':hours_per_week, | |
'native_country': native_country, | |
}, index=[0]) | |
validation_dataset = get_dataset_from_pandas(data) | |
# validation_dataset = get_dataset_from_csv(test_data_file, 1) | |
pred = model.predict(validation_dataset) | |
return f"{round(pred.flatten()[0]*100, 2)}%" | |
# get the inputs | |
inputs = [ | |
gr.Slider(minimum=16, maximum=120, step=1, label='age', value=30), | |
gr.Radio(choices=[' Private', ' Local-gov', ' ?', ' Self-emp-not-inc',' Federal-gov', ' State-gov', ' Self-emp-inc', ' Without-pay', ' Never-worked'], | |
label='workclass', type='value',value=' Private'), | |
gr.Radio(choices=[' 11th', ' HS-grad', ' Assoc-acdm', ' Some-college', ' 10th', ' Prof-school', ' 7th-8th', ' Bachelors', ' Masters', ' Doctorate', | |
' 5th-6th', ' Assoc-voc', ' 9th', ' 12th', ' 1st-4th', ' Preschool'], | |
type='value', label='education', value=' Bachelors'), | |
gr.Slider(minimum=1, maximum=16, step=1, label='education_num', value=10), | |
gr.Radio(choices=['', ' Married-civ-spouse', ' Widowed', ' Divorced', ' Separated', ' Married-spouse-absent', ' Married-AF-spouse'], | |
type='value', label='marital_status', value=' Married-civ-spouse'), | |
gr.Radio(choices=[' Machine-op-inspct', ' Farming-fishing', ' Protective-serv', ' ?', ' Other-service', ' Prof-specialty', ' Craft-repair', | |
' Adm-clerical', ' Exec-managerial', ' Tech-support', ' Sales', ' Priv-house-serv', ' Transport-moving', ' Handlers-cleaners', ' Armed-Forces'], | |
type='value', label='occupation', value=' Tech-support'), | |
gr.Radio(choices=[' Own-child', ' Husband', ' Not-in-family', ' Unmarried', ' Wife', ' Other-relative'], | |
type='value', label='relationship', value=' Wife'), | |
gr.Radio(choices=[' Black', ' White', ' Asian-Pac-Islander', ' Other', ' Amer-Indian-Eskimo'], | |
type='value', label='race', value=' Other'), | |
gr.Radio(choices=[' Male', ' Female'], type='value', label='gender', value=' Female'), | |
gr.Slider(minimum=0, maximum=500000, step=1, label='capital_gain', value=80000), | |
gr.Slider(minimum=0, maximum=50000, step=1, label='capital_loss', value=1000), | |
gr.Slider(minimum=1, maximum=168, step=1, label='hours_per_week', value=40), | |
gr.Radio(choices=[' United-States', ' ?', ' Peru', ' Guatemala', ' Mexico', ' Dominican-Republic', ' Ireland', ' Germany', ' Philippines', ' Thailand', ' Haiti', | |
' El-Salvador', ' Puerto-Rico', ' Vietnam', ' South', ' Columbia', ' Japan', ' India', ' Cambodia', ' Poland', ' Laos', ' England', ' Cuba', ' Taiwan', | |
' Italy', ' Canada', ' Portugal', ' China', ' Nicaragua', ' Honduras', ' Iran', ' Scotland', ' Jamaica', ' Ecuador', ' Yugoslavia', ' Hungary', | |
' Hong', ' Greece', ' Trinadad&Tobago', ' Outlying-US(Guam-USVI-etc)', ' France'], | |
type='value', label='native_country', value=' Vietnam'), | |
] | |
# the app outputs two segmented images | |
output = gr.Textbox(label='Probability of income larger than 50,000 USD per year:') | |
# it's good practice to pass examples, description and a title to guide users | |
title = 'Tab Transformer for Structured data' | |
description = 'Using Transformer to predict whether the income will be larger than 50,000 USD given the input features.' | |
article = "Author: <a href=\"https://huggingface.co/geninhu\">Nhu Hoang</a>. Based on this <a href=\"https://keras.io/examples/structured_data/tabtransformer/\">keras example</a> by <a href=\"https://www.linkedin.com/in/khalid-salama-24403144\">Khalid Salama.</a> HuggingFace Model <a href=\"https://huggingface.co/keras-io/tab_transformer\">here</a> " | |
examples = [[39.0, ' State-gov', ' Assoc-voc', 11.0, ' Divorced', ' Tech-support', ' Not-in-family', ' White', ' Female', 50000.0, 0.0, 40.0, ' Puerto-Rico'], | |
[65.0, ' Self-emp-inc', ' 12th', 8.0, ' Married-civ-spouse', ' Handlers-cleaners', ' Husband', ' Black', ' Male', 41000.0, 0.0, 55.0, ' United-States'], | |
[42.0, ' Private',' Masters', 14.0, ' Married-civ-spouse', ' Prof-specialty', ' Husband', ' Asian-Pac-Islander', ' Male', 35000.0, 0.0, 40.0, ' Taiwan',], | |
[25.0, ' Local-gov',' Bachelors', 13.0, ' Never-married', ' Craft-repair', ' Unmarried', ' White', ' Male', 75000.0, 0.0, 51.0, ' England'], | |
[57.0, ' Private', ' Masters', 14.0, ' Never-married', ' Prof-specialty', ' Not-in-family', ' Asian-Pac-Islander', ' Male', 150000.0, 0.0, 45.0, ' Iran']] | |
gr.Interface(infer, inputs, output, examples= examples, allow_flagging='never', | |
title=title, description=description, article=article, live=False).launch(enable_queue=True, debug=True, inbrowser=True) | |