yehezkielgunawan's picture
fix: :bug: fix build atributes
2395379
raw
history blame
No virus
2.54 kB
import gradio as gr
import torch
import joblib
import numpy as np
import pandas as pd
from transformers import AutoTokenizer, AutoModel
# Load IndoBERT tokenizer
tokenizer = AutoTokenizer.from_pretrained("indolem/indobert-base-uncased")
# Load IndoBERT model
model = AutoModel.from_pretrained("indolem/indobert-base-uncased")
# Mapping dictionaries for labels
priority_score_mapping = {1: "low", 2: "medium", 3: "high"}
problem_domain_mapping = {0: "operational", 1: "tech"}
# Load the trained Random Forest models
best_classifier1 = joblib.load('best_classifier1.pkl')
best_classifier2 = joblib.load('best_classifier2.pkl')
# Function to perform predictions
def predict(text):
# Convert the sentences into input features
encoded_inputs = tokenizer(text, padding=True, truncation=True, return_tensors="pt", max_length=128)
# Perform word embedding using IndoBERT model
with torch.no_grad():
outputs = model(**encoded_inputs)
embeddings = outputs.last_hidden_state
# Convert the embeddings to numpy array
embeddings = embeddings.numpy()
embeddings_custom_flat = embeddings.reshape(embeddings.shape[0], -1)
# Ensure mean_pooled_embeddings has exactly 768 features
num_features_expected = 768
if embeddings_custom_flat.shape[1] < num_features_expected:
# If the number of features is less than 768, pad the embeddings
pad_width = num_features_expected - embeddings_custom_flat.shape[1]
embeddings_custom_flat = np.pad(embeddings_custom_flat, ((0, 0), (0, pad_width)), mode='constant')
elif embeddings_custom_flat.shape[1] > num_features_expected:
# If the number of features is more than 768, truncate the embeddings
embeddings_custom_flat = embeddings_custom_flat[:, :num_features_expected]
# Predict the priority_score for the custom input
custom_priority_score = best_classifier1.predict(embeddings_custom_flat)
# Predict the problem_domain for the custom input
custom_problem_domain = best_classifier2.predict(embeddings_custom_flat)
# Map numerical labels to human-readable labels
mapped_priority_score = priority_score_mapping.get(custom_priority_score[0], "unknown")
mapped_problem_domain = problem_domain_mapping.get(custom_problem_domain[0], "unknown")
return f"Predicted Priority Score: {mapped_priority_score}, Predicted Problem Domain: {mapped_problem_domain}"
# Create a Gradio interface
gr.Interface(fn=predict, inputs="text", outputs="text", title="Simple Risk Classifier Demo").launch(debug=True)