Spaces:
Sleeping
Sleeping
import gradio as gr | |
from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForCausalLM | |
from transformers import MBartForConditionalGeneration, MBart50TokenizerFast | |
from aksharamukha import transliterate | |
import torch | |
from dotenv import load_dotenv | |
import os | |
import requests | |
access_token = os.getenv('token') | |
# Set up device | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
chat_language = 'sin_Sinh' | |
trans_model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-600M") | |
eng_trans_tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M") | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
translator = pipeline('translation', model=trans_model, tokenizer=eng_trans_tokenizer, src_lang="eng_Latn", tgt_lang=chat_language, max_length = 400, device=device) | |
# Initialize translation pipelines | |
pipe = pipeline("translation", model="thilina/mt5-sinhalese-english") | |
sin_trans_model = AutoModelForSeq2SeqLM.from_pretrained("thilina/mt5-sinhalese-english") | |
si_trans_tokenizer = AutoTokenizer.from_pretrained("thilina/mt5-sinhalese-english") | |
singlish_pipe = pipeline("text2text-generation", model="Dhahlan2000/Simple_Translation-model-for-GPT-v8") | |
# Translation functions | |
def translate_Singlish_to_sinhala(text): | |
translated_text = singlish_pipe(f"translate Singlish to Sinhala: {text}", clean_up_tokenization_spaces=False)[0]['generated_text'] | |
return translated_text.replace('\u200d', '') | |
def translate_english_to_sinhala(text): | |
# Split the text into sentences or paragraphs | |
parts = text.split("\n") # Split by new lines for paragraphs, adjust as needed | |
translated_parts = [] | |
for part in parts: | |
translated_part = translator(part, clean_up_tokenization_spaces=False)[0]['translation_text'] | |
translated_parts.append(translated_part) | |
# Join the translated parts back together | |
translated_text = "\n".join(translated_parts) | |
return translated_text.replace("ප් රභූවරුන්", "").replace('\u200d', '') | |
def translate_sinhala_to_english(text): | |
# Split the text into sentences or paragraphs | |
parts = text.split("\n") # Split by new lines for paragraphs, adjust as needed | |
translated_parts = [] | |
for part in parts: | |
# Tokenize each part | |
inputs = si_trans_tokenizer(part.strip(), return_tensors="pt", padding=True, truncation=True, max_length=512) | |
# Generate translation | |
outputs = sin_trans_model.generate(**inputs) | |
# Decode translated text while preserving formatting | |
translated_part = si_trans_tokenizer.decode(outputs[0], skip_special_tokens=True, clean_up_tokenization_spaces=False) | |
translated_parts.append(translated_part) | |
# Join the translated parts back together | |
translated_text = "\n".join(translated_parts) | |
return translated_text | |
def transliterate_from_sinhala(text): | |
# Define the source and target scripts | |
source_script = 'Sinhala' | |
target_script = 'Velthuis' | |
# Perform transliteration | |
latin_text = transliterate.process(source_script, target_script, text) | |
# Convert to a list to allow modification | |
latin_text_list = list(latin_text) | |
# Replace periods with the following character | |
i = 0 | |
for i in range(len(latin_text_list) - 1): | |
if latin_text_list[i] == '.': | |
latin_text_list[i] = '' | |
if latin_text_list[i] == '*': | |
latin_text_list[i] = '' | |
if latin_text_list[i] == '\"': | |
latin_text_list[i] = '' | |
# Convert back to a string | |
latin_text = ''.join(latin_text_list) | |
return latin_text.lower() | |
def transliterate_to_sinhala(text): | |
# Define the source and target scripts | |
source_script = 'Velthuis' | |
target_script = 'Sinhala' | |
# Perform transliteration | |
latin_text = transliterate.process(source_script, target_script, text) | |
return latin_text | |
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b-it", token = access_token) | |
model = AutoModelForCausalLM.from_pretrained( | |
"google/gemma-2b-it", | |
torch_dtype=torch.bfloat16, | |
token = access_token | |
) | |
def conversation_predict(input_text): | |
input_ids = tokenizer(input_text, return_tensors="pt") | |
outputs = model.generate(**input_ids) | |
return tokenizer.decode(outputs[0]) | |
def ai_predicted(user_input): | |
user_input = translate_Singlish_to_sinhala(user_input) | |
print("You(Singlish): ", user_input,"\n") | |
user_input = transliterate_to_sinhala(user_input) | |
print("You(Sinhala): ", user_input,"\n") | |
user_input = translate_sinhala_to_english(user_input) | |
print("You(English): ", user_input,"\n") | |
# Get AI response | |
ai_response = conversation_predict(user_input) | |
# Split the AI response into separate lines | |
# ai_response_lines = ai_response.split("</s>") | |
print("AI(English): ", ai_response,"\n") | |
response = translate_english_to_sinhala(ai_response) | |
print("AI(Sinhala): ", response,"\n") | |
response = transliterate_from_sinhala(response) | |
print(response) | |
return response | |
# Gradio Interface | |
def respond( | |
message, | |
history: list[tuple[str, str]], | |
system_message, | |
max_tokens, | |
temperature, | |
top_p, | |
): | |
messages = [{"role": "system", "content": system_message}] | |
for val in history: | |
if val[0]: | |
messages.append({"role": "user", "content": val[0]}) | |
if val[1]: | |
messages.append({"role": "assistant", "content": val[1]}) | |
messages.append({"role": "user", "content": message}) | |
response = ai_predicted(message) | |
yield response | |
demo = gr.ChatInterface( | |
respond, | |
additional_inputs=[ | |
gr.Textbox(value="You are a friendly Chatbot.", label="System message"), | |
gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"), | |
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"), | |
gr.Slider( | |
minimum=0.1, | |
maximum=1.0, | |
value=0.95, | |
step=0.05, | |
label="Top-p (nucleus sampling)", | |
), | |
], | |
) | |
if __name__ == "__main__": | |
demo.launch(share=True) |