Spaces:
Runtime error
Runtime error
# coding=utf-8 | |
# Copyright 2023 The GlotLID Authors. | |
# Lint as: python3 | |
# This space is built based on AMR-KELEG/ALDi space. | |
# GlotLID Space | |
import string | |
import constants | |
import pandas as pd | |
import streamlit as st | |
from huggingface_hub import hf_hub_download | |
from GlotScript import get_script_predictor | |
import matplotlib | |
from matplotlib import pyplot as plt | |
import fasttext | |
import altair as alt | |
from altair import X, Y, Scale | |
import base64 | |
import json | |
import os | |
import re | |
import transformers | |
from transformers import pipeline | |
def load_sp(): | |
sp = get_script_predictor() | |
return sp | |
sp = load_sp() | |
def get_script(text): | |
"""Get the writing systems of given text. | |
Args: | |
text: The text to be preprocessed. | |
Returns: | |
The main script and list of all scripts. | |
""" | |
res = sp(text) | |
main_script = res[0] if res[0] else 'Zyyy' | |
all_scripts_dict = res[2]['details'] | |
if all_scripts_dict: | |
all_scripts = list(all_scripts_dict.keys()) | |
else: | |
all_scripts = 'Zyyy' | |
for ws in all_scripts: | |
if ws in ['Kana', 'Hrkt', 'Hani', 'Hira']: | |
all_scripts.append('Jpan') | |
all_scripts = list(set(all_scripts)) | |
return main_script, all_scripts | |
def preprocess_text(text): | |
"""Apply preprocessing to the given text. | |
Args: | |
text: Thetext to be preprocessed. | |
Returns: | |
The preprocessed text. | |
""" | |
# remove \n | |
text = text.replace('\n', ' ') | |
# get rid of characters that are ubiquitous | |
replace_by = " " | |
replacement_map = { | |
ord(c): replace_by | |
for c in ':•#{|}' + string.digits | |
} | |
text = text.translate(replacement_map) | |
# make multiple space one space | |
text = re.sub(r'\s+', ' ', text) | |
# strip the text | |
text = text.strip() | |
return text | |
def language_names(json_path): | |
with open(json_path, 'r') as json_file: | |
data = json.load(json_file) | |
return data | |
label2name = language_names("assets/language_names.json") | |
def get_name(label): | |
"""Get the name of language from label""" | |
iso_3 = label.split('_')[0] | |
name = label2name[iso_3] | |
return name | |
def render_svg(svg): | |
"""Renders the given svg string.""" | |
b64 = base64.b64encode(svg.encode("utf-8")).decode("utf-8") | |
html = rf'<p align="center"> <img src="data:image/svg+xml;base64,{b64}", width="40%"/></p>' | |
c = st.container() | |
c.write(html, unsafe_allow_html=True) | |
def render_metadata(): | |
"""Renders the metadata.""" | |
html = r"""<p align="center"> | |
<a href="https://huggingface.co/dsfsi/za-lid"><img alt="HuggingFace Model" src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Model-8A2BE2"></a> | |
<a href="https://github.com/dsfsi/za-lid"><img alt="GitHub" src="https://img.shields.io/badge/%F0%9F%93%A6%20GitHub-orange"></a> | |
<a href="https://github.com/dsfsi/za-lid/blob/master/LICENSE.md"><img alt="GitHub license" src="https://img.shields.io/badge/Github%20Licence-blue"></a> | |
<a href="https://docs.google.com/forms/d/e/1FAIpQLSf7S36dyAUPx2egmXbFpnTBuzoRulhL5Elu-N1eoMhaO7v10w/viewform" target="_blank"><img alt="Feedback Form" src="https://img.shields.io/badge/Feedback-Form-brightgreen"></a> | |
<a href="https://arxiv.org/abs/2410.08728" target="_blank"><img alt="arxiv" src="https://img.shields.io/badge/arxiv-2410.08728-blue"></a></p>""" | |
c = st.container() | |
c.write(html, unsafe_allow_html=True) | |
def citation(): | |
"""Renders the metadata.""" | |
_CITATION = """ | |
@inproceedings{ | |
kargaran2023glotlid, | |
title={GlotLID: Language Identification for Low-Resource Languages}, | |
author={Kargaran, Amir Hossein and Imani, Ayyoob and Yvon, Fran{\c{c}}ois and Sch{\"u}tze, Hinrich}, | |
booktitle={The 2023 Conference on Empirical Methods in Natural Language Processing}, | |
year={2023}, | |
url={https://openreview.net/forum?id=dl4e3EBz5j} | |
}""" | |
st.code(_CITATION, language="python", line_numbers=False) | |
def convert_df(df): | |
# IMPORTANT: Cache the conversion to prevent computation on every rerun | |
return df.to_csv(index=None).encode("utf-8") | |
def load_model(model_name, file_name): | |
model_path = hf_hub_download(repo_id=model_name, filename=file_name) | |
model = fasttext.load_model(model_path) | |
return model | |
def load_model_pipeline(model_name, file_name): | |
model = pipeline("text-classification", model=model_name) | |
return model | |
# model_1 = load_model(constants.MODEL_NAME, "model_v1.bin") | |
# model_2 = load_model(constants.MODEL_NAME, "model_v2.bin") | |
# model_3 = load_model(constants.MODEL_NAME, "model_v3.bin") | |
# openlid = load_model('laurievb/OpenLID', "model.bin") | |
# nllb = load_model('facebook/fasttext-language-identification', "model.bin") | |
# MODELS | |
model_xlmr_large = load_model_pipeline('dsfsi/za-xlmrlarge-lid', "model.bin") | |
model_serengeti = load_model_pipeline('dsfsi/za-serengeti-lid', "model.bin") | |
model_afriberta = load_model_pipeline('dsfsi/za-afriberta-lid', "model.bin") | |
model_afroxlmr_base = load_model_pipeline('dsfsi/za-afro-xlmr-base-lid', "model.bin") | |
model_afrolm = load_model_pipeline('dsfsi/za-afrolm-lid', "model.bin") | |
za_lid = load_model_pipeline('dsfsi/za-lid-bert', "model.bin") | |
openlid = load_model('laurievb/OpenLID', "model.bin") | |
glotlid_3 = load_model(constants.MODEL_NAME, "model_v3.bin") | |
# @st.cache_resource | |
def plot(label, prob): | |
ORANGE_COLOR = "#FF8000" | |
BLACK_COLOR = "#31333F" | |
fig, ax = plt.subplots(figsize=(8, 1)) | |
fig.patch.set_facecolor("none") | |
ax.set_facecolor("none") | |
ax.spines["left"].set_color(BLACK_COLOR) | |
ax.spines["bottom"].set_color(BLACK_COLOR) | |
ax.tick_params(axis="x", colors=BLACK_COLOR) | |
ax.spines[["right", "top"]].set_visible(False) | |
ax.barh(y=[0], width=[prob], color=ORANGE_COLOR) | |
ax.set_xlim(0, 1) | |
ax.set_ylim(-1, 1) | |
ax.set_title(f"Label: {label}, Language: {get_name(label)}", color=BLACK_COLOR) | |
ax.get_yaxis().set_visible(False) | |
ax.set_xlabel("Confidence", color=BLACK_COLOR) | |
st.pyplot(fig) | |
# @st.cache_resource | |
def plot_multiples(models, labels, probs): | |
ORANGE_COLOR = "#FF8000" | |
BLACK_COLOR = "#31333F" | |
fig, ax = plt.subplots(figsize=(12, len(models))) | |
fig.patch.set_facecolor("none") | |
ax.set_facecolor("none") | |
ax.spines["left"].set_color(BLACK_COLOR) | |
ax.spines["bottom"].set_color(BLACK_COLOR) | |
ax.tick_params(axis="x", colors=BLACK_COLOR) | |
ax.spines[["right", "top"]].set_visible(False) | |
# Plot bars for each model, label, and probability | |
y_positions = range(len(models)) # Y positions for each model | |
ax.barh(y=y_positions, width=probs, color=ORANGE_COLOR) | |
# Add labels next to each bar | |
for i, (prob, label) in enumerate(zip(probs, labels)): | |
ax.text(prob + 0.01, i, f"{label} ({prob:.2f})", va='center', color=BLACK_COLOR) | |
# Set y-ticks and labels | |
ax.set_yticks(y_positions) | |
ax.set_yticklabels(models, color=BLACK_COLOR) | |
ax.set_xlim(0, 1) | |
ax.set_xlabel("Confidence", color=BLACK_COLOR) | |
ax.set_title("Model Predictions", color=BLACK_COLOR) | |
st.pyplot(fig) | |
def compute(sentences, version = 'v3'): | |
"""Computes the language probablities and labels for the given sentences. | |
Args: | |
sentences: A list of sentences. | |
Returns: | |
A list of language probablities and labels for the given sentences. | |
""" | |
progress_text = "Computing Language..." | |
if version == 'xlmrlarge': | |
model_choice = model_xlmr_large | |
elif version == 'serengeti': | |
model_choice = model_serengeti | |
elif version == 'afriberta': | |
model_choice = model_afriberta | |
elif version == 'afroxlmrbase': | |
model_choice = model_afroxlmr_base | |
elif version=='afrolm': | |
model_choice = model_afrolm | |
elif version == 'BERT': | |
model_choice = za_lid | |
elif version == 'openlid-201': | |
model_choice = openlid | |
elif version == 'GlotLID v3': | |
model_choice = glotlid_3 | |
else: | |
model_choice = [(model_xlmr_large, "xlmrlarge"),(model_serengeti,"serengeti"), (model_afriberta,"afriberta"), (model_afroxlmr_base,"afroxlmrbase"), (model_afrolm,"afrolm"), (za_lid,"BERT"), (openlid,"openlid-201"), (glotlid_3,"GlotLID v3")] | |
my_bar = st.progress(0, text=progress_text) | |
probs = [] | |
labels = [] | |
sentences = [preprocess_text(sent) for sent in sentences] | |
for index, sent in enumerate(sentences): | |
if type(model_choice) == list: | |
all_models_pred = [] | |
for model_version in model_choice: | |
m_version = model_version[1] | |
model = model_version[0] | |
if m_version not in ["openlid-201", "GlotLID v3"]: | |
output = model.predict(sent) | |
output_label = output[index]['label'] | |
output_prob = output[index]['score'] | |
output_label_language = output[index]['label'] | |
labels = labels + [output_label] | |
probs = probs + [output_prob] | |
my_bar.progress( | |
min((index) / len(sentences), 1), | |
text=progress_text, | |
) | |
else: | |
output = model.predict(sent) | |
output_label = output[0][0].split('__')[-1].replace('_Hans', '_Hani').replace('_Hant', '_Hani') | |
output_prob = max(min(output[1][0], 1), 0) | |
output_label_language = output_label.split('_')[0] | |
# script control | |
if version in ['GlotLID v3', 'openlid-201', 'nllb-218'] and output_label_language!= 'zxx': | |
main_script, all_scripts = get_script(sent) | |
output_label_script = output_label.split('_')[1] | |
if output_label_script not in all_scripts: | |
output_label_script = main_script | |
output_label = f"und_{output_label_script}" | |
output_prob = 0 | |
labels = labels + [output_label] | |
probs = probs + [output_prob] | |
my_bar.progress( | |
min((index) / len(sentences), 1), | |
text=progress_text, | |
) | |
else: | |
if version not in ["openlid-201", "GlotLID v3"]: | |
output = model_choice.predict(sent) | |
output_label = output[index]['label'] | |
output_prob = output[index]['score'] | |
output_label_language = output[index]['label'] | |
labels = labels + [output_label] | |
probs = probs + [output_prob] | |
my_bar.progress( | |
min((index) / len(sentences), 1), | |
text=progress_text, | |
) | |
else: | |
output = model_choice.predict(sent) | |
output_label = output[0][0].split('__')[-1].replace('_Hans', '_Hani').replace('_Hant', '_Hani') | |
output_prob = max(min(output[1][0], 1), 0) | |
output_label_language = output_label.split('_')[0] | |
# script control | |
if version in ['GlotLID v3', 'openlid-201', 'nllb-218'] and output_label_language!= 'zxx': | |
main_script, all_scripts = get_script(sent) | |
output_label_script = output_label.split('_')[1] | |
if output_label_script not in all_scripts: | |
output_label_script = main_script | |
output_label = f"und_{output_label_script}" | |
output_prob = 0 | |
labels = labels + [output_label] | |
probs = probs + [output_prob] | |
my_bar.progress( | |
min((index) / len(sentences), 1), | |
text=progress_text, | |
) | |
my_bar.empty() | |
return probs, labels | |
# st.markdown("[![Duplicate Space](https://img.shields.io/badge/-Duplicate%20Space-blue?labelColor=white&style=flat&logo=&logoWidth=14)](https://huggingface.co/spaces/cis-lmu/glotlid-space?duplicate=true)") | |
# render_svg(open("assets/glotlid_logo.svg").read()) | |
render_metadata() | |
img1, img2, img3 = st.columns(3) | |
with img2: | |
with st.container(): | |
st.image("logo_transparent_small.png") | |
st.markdown("**DSFSI** Language Identification (LID) Inference Endpoint Created with **HuggingFace Spaces**.") | |
with st.expander("More information about the space"): | |
st.write(''' | |
Authors: Thapelo Sindane, Vukosi Marivate | |
''') | |
tab1, tab2 = st.tabs(["Input a Sentence", "Upload a File"]) | |
with tab1: | |
# choice = st.radio( | |
# "Set granularity level", | |
# ["default", "merge", "individual"], | |
# captions=["enable both macrolanguage and its varieties (default)", "merge macrolanguage and its varieties into one label", "remove macrolanguages - only shows individual langauges"], | |
# ) | |
version = st.radio( | |
"Choose model", | |
["xlmrlarge", "serengeti", "afriberta", "afroxlmrbase", "afrolm", "BERT", "openlid-201", "GlotLID v3", "All-Models"], | |
captions=["za-XLMR-Large", "za-Serengeti", "za-AfriBERTa", "za-Afro-XLMR-BASE", "za-AfroLM", "za-BERT", "OpenLID", "GlotLID v3",'All-Models'], | |
index = 4, | |
key = 'version_tab1', | |
horizontal = True | |
) | |
sent = st.text_input( | |
"Sentence:", placeholder="Enter a sentence.", on_change=None | |
) | |
# TODO: Check if this is needed! | |
clicked = st.button("Submit") | |
if sent: | |
probs, labels = compute([sent], version=version) | |
prob = probs[0] | |
label = labels[0] | |
# Check if the file exists | |
if not os.path.exists('logs.txt'): | |
with open('logs.txt', 'w') as file: | |
pass | |
print(f"{sent}, {label}: {prob}") | |
with open("logs.txt", "a") as f: | |
f.write(f"{sent}, {label}: {prob}\n") | |
# plot | |
if version == "All-Models": | |
plot_multiples(["xlmrlarge", "serengeti", "afriberta", "afroxlmrbase", "afrolm", "BERT", "OpenLID", "GlotLID v3"], labels, probs) | |
else: | |
plot(label, prob) | |
with tab2: | |
version = st.radio( | |
"Choose model", | |
["xlmrlarge", "serengeti", "afriberta", "afroxlmrbase", "afrolm", "BERT","openlid-201", "GlotLID v3", "All-Models"], | |
captions=["za-XLMR-Large", "za-Serengeti", "za-AfriBERTa", "za-Afro-XLMR-BASE", "za-AfroLM", "za-BERT", "OpenLID", "GlotLID v3", "All-Models"], | |
index = 4, | |
key = 'version_tab2', | |
horizontal = True | |
) | |
file = st.file_uploader("Upload a file", type=["txt"]) | |
if file is not None: | |
df = pd.read_csv(file, sep="¦\t¦", header=None, engine='python') | |
df.columns = ["Sentence"] | |
df.reset_index(drop=True, inplace=True) | |
# TODO: Run the model | |
df['Prob'], df["Label"] = compute(df["Sentence"].tolist(), version= version) | |
df['Language'] = df["Label"].apply(get_name) | |
# A horizontal rule | |
st.markdown("""---""") | |
chart = ( | |
alt.Chart(df.reset_index()) | |
.mark_area(color="darkorange", opacity=0.5) | |
.encode( | |
x=X(field="index", title="Sentence Index"), | |
y=Y("Prob", scale=Scale(domain=[0, 1])), | |
) | |
) | |
st.altair_chart(chart.interactive(), use_container_width=True) | |
col1, col2 = st.columns([4, 1]) | |
with col1: | |
# Display the output | |
st.table( | |
df, | |
) | |
with col2: | |
# Add a download button | |
csv = convert_df(df) | |
st.download_button( | |
label=":file_folder: Download predictions as CSV", | |
data=csv, | |
file_name="GlotLID.csv", | |
mime="text/csv", | |
) | |
# citation() |