girt-space / app.py
nafisehNik's picture
add req.
1f845b3
raw
history blame
4.84 kB
# coding=utf-8
# Copyright 2023 The GIRT Authors.
# Lint as: python3
# This space is built based on AMR-KELEG/ALDi and cis-lmu/GlotLID space.
# GIRT Space
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import streamlit as st
import pandas as pd
import base64
st.markdown(
"""
<style>
[data-testid="stSidebar"][aria-expanded="true"]{
min-width: 450px;
max-width: 450px;
}
""",
unsafe_allow_html=True)
with st.sidebar:
st.title(" πŸ”§ Settings")
with st.expander("πŸ— Issue Template Inputs", True):
in_name = st.text_input("Name Metadata: ", placeholder="e.g., Bug Report or Feqture Request or Question", on_change=None)
in_about = st.text_input("About Metadata: ", placeholder="e.g., File a bug report", on_change=None)
in_title = st.text_input("Title Metadata: ", placeholder="e.g., [Bug]: ", on_change=None)
in_labels = st.text_input("Labels Metadata: ", placeholder="e.g., feature, enhancement", on_change=None)
in_assignees = st.text_input("Assignees Metadata: ", placeholder="e.g., USER_1, USER_2", on_change=None)
# if no headlines is selected, force the headlines to be empty as well.
option = st.selectbox(
'How would you like to be Your Heders?',
('**Emphasis**', '# Header', 'No headlines'))
in_headlines = st.text_area("Headlines: ", placeholder="Enter each headline in one line.", on_change=None, height=200)
df = pd.DataFrame(
[{"headline": "Welcome"},{"command": "Concise Description"}, {"command": "Additional Info"},])
in_headlines = st.experimental_data_editor(df, num_rows="dynamic")
in_summary = st.text_area("Summary: ", placeholder="This Github Issue Template is ...", on_change=None, height=200)
with st.expander("πŸŽ› Model Configs", False):
max_length = st.slider("max_length", 30, 512, 300)
min_length = st.slider("min_length", 0, 300, 30)
top_p = st.slider("top_p", 0.0, 1.0, 0.92)
top_k = st.slider("top_k", 0, 100, 0)
@st.cache_data
def render_svg(svg):
"""Renders the given svg string."""
b64 = base64.b64encode(svg.encode("utf-8")).decode("utf-8")
html = rf'<p align="center"> <img src="data:image/svg+xml;base64,{b64}", width="40%"/> </p>'
c = st.container()
c.write(html, unsafe_allow_html=True)
@st.cache_resource
def load_model(model_name):
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
return model
@st.cache_resource
def load_tokenizer(model_name):
tokenizer = AutoTokenizer.from_pretrained(model_name)
return tokenizer
with st.spinner(text="Please wait while the model is loading...."):
model = load_model('nafisehNik/girt-t5-base')
tokenizer = load_tokenizer('nafisehNik/girt-t5-base')
def compute(sample, top_p, top_k, do_sample, max_length, min_length):
inputs = tokenizer(sample, return_tensors="pt").to('cpu')
outputs = model.generate(
**inputs,
min_length= min_length,
max_length=max_length,
do_sample=do_sample,
top_p=top_p,
top_k=top_k).to('cpu')
generated_texts = tokenizer.batch_decode(outputs, skip_special_tokens=False)
generated_text = generated_texts[0]
replace_dict = {
'\n ': '\n',
'</s>': '',
'<pad> ': '',
'<pad>': '',
'<unk>': ''
}
postprocess_text = generated_text
for key, value in replace_dict.items():
postprocess_text = postprocess_text.replace(key, value)
return postprocess_text
st.markdown("[![Duplicate Space](https://img.shields.io/badge/-Duplicate%20Space-blue?labelColor=white&style=flat&logo=data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAABAAAAAQCAYAAAAf8/9hAAAAAXNSR0IArs4c6QAAAP5JREFUOE+lk7FqAkEURY+ltunEgFXS2sZGIbXfEPdLlnxJyDdYB62sbbUKpLbVNhyYFzbrrA74YJlh9r079973psed0cvUD4A+4HoCjsA85X0Dfn/RBLBgBDxnQPfAEJgBY+A9gALA4tcbamSzS4xq4FOQAJgCDwV2CPKV8tZAJcAjMMkUe1vX+U+SMhfAJEHasQIWmXNN3abzDwHUrgcRGmYcgKe0bxrblHEB4E/pndMazNpSZGcsZdBlYJcEL9Afo75molJyM2FxmPgmgPqlWNLGfwZGG6UiyEvLzHYDmoPkDDiNm9JR9uboiONcBXrpY1qmgs21x1QwyZcpvxt9NS09PlsPAAAAAElFTkSuQmCC&logoWidth=14)](https://huggingface.co/spaces/nafisehNik/girt-space?duplicate=true)")
render_svg(open("assets/logo.svg").read())
tab1, tab2 = st.tabs(["Design GitHub Issue Template", "Manual Prompt"])
with tab1:
col1, col2, col3 = st.columns([6, 1, 7])
with col1:
pass
with tab2:
prompt = st.text_area("Prompt: ", placeholder="Enter your prompt.", on_change=None, height=200)
# TODO: Check if this is needed!
clicked = st.button("Submit")
with st.spinner("Please Wait..."):
if prompt:
res = compute(prompt, top_p=0.92, top_k=0, do_sample=True, max_length=300, min_length=0)
st.code(res, language="python")