BlooMeteo / app.py
PaulNdrei's picture
Fix
1999a4e
raw
history blame
No virus
7 kB
import os
from dotenv import load_dotenv
import gradio as gr
from gradio.components import Textbox, Button, Slider, Image
from AinaTheme import AinaGradioTheme
from meteocat_app import generate
import csv
load_dotenv()
MAX_INPUT_CHARACTERS= int(os.environ.get("MAX_INPUT_CHARACTERS", default=200))
SHOW_MODEL_PARAMETERS_IN_UI = os.environ.get("SHOW_MODEL_PARAMETERS_IN_UI", default=True)
MAX_NEW_TOKENS = int(os.environ.get("MAX_NEW_TOKENS", default=200))
with open('./locations.txt', 'r', encoding='utf-8') as file:
locations = file.read().splitlines()
def csv_to_dict(file_path):
result = []
with open(file_path, 'r') as file:
reader = csv.DictReader(file)
for row in reader:
result.append(dict(row))
return result
def submit_input(input_, repetition_penalty, temperature):
outputs = generate(input_, repetition_penalty, temperature)
if input_.strip() == "":
gr.Warning('No és possible processar un input buit')
return None, None, None
if outputs is None:
gr.Warning("""
És possible que no hagi trobat el lloc o la data.
Només puc respondre a preguntes sobre el temps a alguna localitat en concret.
""")
return None, None, None
data_as_dict = csv_to_dict("./code2simbol.csv")
codes = outputs["context"]
code = []
url_images = []
if codes['interval'] in ['matí', 'matí?', 'mati', 'mati?']:
code.append(codes['codis'][0])
elif codes['interval'] in ['tarda', 'tarda?']:
code.append(codes['codis'][1])
elif codes['interval'] in ['nit', 'nit?']:
code.append(codes['codis'][2])
print(code)
for object in data_as_dict:
if object['codi'] in code:
print("FOUND",object['codi'])
if codes['interval'] in ['matí', 'matí?', 'mati', 'mati?', 'tarda', 'tarda?']:
url_images.append(object['simbol_url_dia'])
elif codes['interval'] in ['nit', 'nit?']:
url_images.append(object['simbol_url_noche'])
print(url_images)
if codes['interval'] in ["tot el dia"]:
return (
outputs["model_answer"],
gr.Image(
value=None,
show_share_button=False,
show_download_button=False
),
outputs["ccma_response"]
)
else:
return (
outputs["model_answer"],
gr.Image(
value=url_images[0],
show_share_button=False,
show_download_button=False
),
outputs["ccma_response"]
)
# print(outputs)
# print(outputs["model_answer"], outputs["context"], outputs["ccma_response"])
# return outputs["model_answer"], outputs["context"], outputs["ccma_response"]
def change_interactive(text):
if len(text.strip()) > MAX_INPUT_CHARACTERS:
return gr.update(interactive = True), gr.update(interactive = False)
return gr.update(interactive = True), gr.update(interactive = True)
def clean():
return (
None,
None,
None,
None,
gr.Slider(value=1.0),
gr.Slider(value=1.0),
)
with gr.Blocks(**AinaGradioTheme().get_kwargs()) as demo:
with gr.Row():
with gr.Column():
placeholder_max_token = Textbox(
visible=False,
interactive=False,
value= MAX_INPUT_CHARACTERS
)
input_ = Textbox(
lines=11,
label="Input",
placeholder="e.g. Prompt example."
)
gr.Dropdown(
label="Available locations",
choices=locations
)
with gr.Row(variant="panel", equal_height=True):
gr.HTML("""<span id="countertext" style="display: flex; justify-content: start; color:#ef4444; font-weight: bold;"></span>""")
gr.HTML(f"""<span id="counter" style="display: flex; justify-content: end;"> <span id="inputlenght">0</span>&nbsp;/&nbsp;{MAX_INPUT_CHARACTERS}</span>""")
with gr.Row():
clear_btn = Button(
"Clear",
interactive=False
)
submit_btn = Button(
"Submit",
variant="primary",
interactive=False
)
with gr.Accordion("Model parameters", open=True, visible=SHOW_MODEL_PARAMETERS_IN_UI):
repetition_penalty = Slider(
minimum=0.1,
maximum=2.0,
step=0.1,
value=0.85,
label="Repetition penalty"
)
temperature = Slider(
minimum=0.0,
maximum=2.0,
value=0.85,
label="Temperature"
)
with gr.Column():
output_answer = Textbox(
lines=9,
label="Model text",
interactive=False,
show_copy_button=True
)
# output_context = Textbox(
# lines=9,
# label="Model context",
# interactive=False,
# show_copy_button=True
# )
with gr.Row():
output_image = Image(
show_label=False,
show_share_button=False,
show_download_button=False,
value=None,
width=20,
height=50
)
output_CCMA = Textbox(
lines=9,
label="CCMA text",
interactive=False,
show_copy_button=True
)
input_.change(fn=change_interactive, inputs=[input_], outputs=[clear_btn, submit_btn])
input_.change(
fn=None, inputs=[input_],
js=f"""(i) => document.getElementById('countertext').textContent = i.length > {MAX_INPUT_CHARACTERS} && 'Max length {MAX_INPUT_CHARACTERS} characters. ' || '' """)
input_.change(
fn=None,
inputs=[input_, placeholder_max_token],
js="""(i, m) => {
document.getElementById('inputlenght').textContent = i.length + ' '
document.getElementById('inputlenght').style.color = (i.length > m) ? "#ef4444" : "";
}""")
clear_btn.click(
fn=clean,
inputs=[],
outputs=[input_, output_answer, output_image, output_CCMA, repetition_penalty, temperature],
queue=False
)
submit_btn.click(
fn=submit_input,
inputs=[input_, repetition_penalty, temperature],
outputs=[output_answer, output_image, output_CCMA]
)
demo.launch(show_api=True)