chatbot-stsv / app.py
DuongTrongChi's picture
setup-project
673210b
import os
from threading import Thread
from typing import Iterator
import gradio as gr
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
MAX_MAX_NEW_TOKENS = 2048
DEFAULT_MAX_NEW_TOKENS = 1024
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
# DESCRIPTION = ""
# if not torch.cuda.is_available():
# DESCRIPTION = "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
if torch.cuda.is_available():
device = torch.device("cuda")
print('There are %d GPU(s) available.' % torch.cuda.device_count())
print('We will use the GPU:', torch.cuda.get_device_name(0))
else:
print('No GPU available, using the CPU instead.')
device = torch.device("cpu")
tokenizer = AutoTokenizer.from_pretrained("Back-up/T5-pretrain")
model = AutoModelForSeq2SeqLM.from_pretrained("Back-up/T5-large-QA")
model.to(device)
@spaces.GPU
def generate(
message: str,
chat_history: list[tuple[str, str]],
system_prompt: str,
max_new_tokens: int = 1024,
temperature: float = 0.6,
top_p: float = 0.9,
top_k: int = 50,
repetition_penalty: float = 1.2,
) -> Iterator[str]:
tokenized_text = tokenizer.encode(message, return_tensors="pt").to(model.device)
model.eval()
summary_ids = model.generate(
tokenized_text,
max_length=1024,
min_length=8,
num_beams=5,
repetition_penalty=2.5,
length_penalty=1.0,
early_stopping=True
)
output = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
yield output
chat_interface = gr.ChatInterface(
fn=generate,
additional_inputs=[
gr.Textbox(label="System prompt", lines=6),
gr.Slider(
label="Max new tokens",
minimum=1,
maximum=MAX_MAX_NEW_TOKENS,
step=1,
value=DEFAULT_MAX_NEW_TOKENS,
),
gr.Slider(
label="Temperature",
minimum=0.1,
maximum=4.0,
step=0.1,
value=0.6,
),
gr.Slider(
label="Top-p (nucleus sampling)",
minimum=0.05,
maximum=1.0,
step=0.05,
value=0.9,
),
gr.Slider(
label="Top-k",
minimum=1,
maximum=1000,
step=1,
value=50,
),
gr.Slider(
label="Repetition penalty",
minimum=1.0,
maximum=2.0,
step=0.05,
value=1.2,
),
],
stop_btn=None,
examples=[
["Trường đại học Nông Lâm thành phố Hồ Chí Minh nằm ở đâu?"],
["Mục tiêu chiến lược của trường đại học Nông Lâm thành phố Hồ Chí Minh là gì?"],
["Sinh viên được khen thưởng cá nhân và tập thể khi nào?"],
["Điều kiện cơ bản để được hỗ trợ vay tiền sinh viên là gì?"],
["Trường Đại học Nông Lâm đã trải qua bao nhiêu năm hoạt động tính đến năm 2023?"],
["Những hành vi nào của sinh viên bị coi là vi phạm quy định của Nhà trường?"],
["Địa chỉ của Phân hiệu Trường Đại học Nông Lâm tại Ninh Thuận?"],
["Làm thế nào khi sinh viên không hài lòng với việc giải quyết thắc mắc của Trưởng Bộ môn?"],
["Làm thế để yêu cầu phúc khảo bài thi?"],
["Nghĩa vụ của sinh viên là gì?"],
["Viết cho tôi một chương trình tính số nguyên tố bằng python."]
],
)
with gr.Blocks(css="style.css") as demo:
chat_interface.render()
if __name__ == "__main__":
demo.queue(max_size=20).launch()