|
|
|
from ragatouille import RAGPretrainedModel
|
|
import subprocess
|
|
import json
|
|
import spaces
|
|
import firebase_admin
|
|
from firebase_admin import credentials, firestore
|
|
import logging
|
|
from pathlib import Path
|
|
from time import perf_counter
|
|
from datetime import datetime
|
|
import gradio as gr
|
|
from jinja2 import Environment, FileSystemLoader
|
|
import numpy as np
|
|
from sentence_transformers import CrossEncoder
|
|
from huggingface_hub import InferenceClient
|
|
from os import getenv
|
|
|
|
from backend.query_llm import generate_hf, generate_openai
|
|
from backend.semantic_search import table, retriever
|
|
from huggingface_hub import InferenceClient
|
|
|
|
|
|
VECTOR_COLUMN_NAME = "vector"
|
|
TEXT_COLUMN_NAME = "text"
|
|
HF_TOKEN = getenv("HUGGING_FACE_HUB_TOKEN")
|
|
proj_dir = Path(__file__).parent
|
|
|
|
logging.basicConfig(level=logging.INFO)
|
|
logger = logging.getLogger(__name__)
|
|
client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1",token=HF_TOKEN)
|
|
|
|
env = Environment(loader=FileSystemLoader(proj_dir / 'templates'))
|
|
|
|
|
|
template = env.get_template('template.j2')
|
|
template_html = env.get_template('template_html.j2')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
examples = ['Tabulate the difference between veins and arteries','What are defects in Human eye?',
|
|
'Frame 5 short questions and 5 MCQ on Chapter 2 ','Suggest creative and engaging ideas to teach students on Chapter on Metals and Non Metals '
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def add_text(history, text):
|
|
history = [] if history is None else history
|
|
history = history + [(text, None)]
|
|
return history, gr.Textbox(value="", interactive=False)
|
|
|
|
|
|
def bot(history, cross_encoder):
|
|
top_rerank = 25
|
|
top_k_rank = 20
|
|
query = history[-1][0]
|
|
|
|
if not query:
|
|
gr.Warning("Please submit a non-empty string as a prompt")
|
|
raise ValueError("Empty string was submitted")
|
|
|
|
logger.warning('Retrieving documents...')
|
|
|
|
|
|
if cross_encoder=='(HIGH ACCURATE) ColBERT':
|
|
gr.Warning('Retrieving using ColBERT.. First time query will take a minute for model to load..pls wait')
|
|
RAG= RAGPretrainedModel.from_pretrained("colbert-ir/colbertv2.0")
|
|
RAG_db=RAG.from_index('.ragatouille/colbert/indexes/cbseclass10index')
|
|
documents_full=RAG_db.search(query,k=top_k_rank)
|
|
|
|
documents=[item['content'] for item in documents_full]
|
|
|
|
prompt = template.render(documents=documents, query=query)
|
|
prompt_html = template_html.render(documents=documents, query=query)
|
|
|
|
generate_fn = generate_hf
|
|
|
|
history[-1][1] = ""
|
|
for character in generate_fn(prompt, history[:-1]):
|
|
history[-1][1] = character
|
|
yield history, prompt_html
|
|
print('Final history is ',history)
|
|
|
|
else:
|
|
|
|
document_start = perf_counter()
|
|
|
|
query_vec = retriever.encode(query)
|
|
logger.warning(f'Finished query vec')
|
|
doc1 = table.search(query_vec, vector_column_name=VECTOR_COLUMN_NAME).limit(top_k_rank)
|
|
|
|
|
|
|
|
logger.warning(f'Finished search')
|
|
documents = table.search(query_vec, vector_column_name=VECTOR_COLUMN_NAME).limit(top_rerank).to_list()
|
|
documents = [doc[TEXT_COLUMN_NAME] for doc in documents]
|
|
logger.warning(f'start cross encoder {len(documents)}')
|
|
|
|
query_doc_pair = [[query, doc] for doc in documents]
|
|
if cross_encoder=='(FAST) MiniLM-L6v2' :
|
|
cross_encoder1 = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
|
|
elif cross_encoder=='(ACCURATE) BGE reranker':
|
|
cross_encoder1 = CrossEncoder('BAAI/bge-reranker-base')
|
|
|
|
cross_scores = cross_encoder1.predict(query_doc_pair)
|
|
sim_scores_argsort = list(reversed(np.argsort(cross_scores)))
|
|
logger.warning(f'Finished cross encoder {len(documents)}')
|
|
|
|
documents = [documents[idx] for idx in sim_scores_argsort[:top_k_rank]]
|
|
logger.warning(f'num documents {len(documents)}')
|
|
|
|
document_time = perf_counter() - document_start
|
|
logger.warning(f'Finished Retrieving documents in {round(document_time, 2)} seconds...')
|
|
|
|
|
|
prompt = template.render(documents=documents, query=query)
|
|
prompt_html = template_html.render(documents=documents, query=query)
|
|
|
|
generate_fn = generate_hf
|
|
|
|
history[-1][1] = ""
|
|
for character in generate_fn(prompt, history[:-1]):
|
|
history[-1][1] = character
|
|
yield history, prompt_html
|
|
print('Final history is ',history)
|
|
|
|
|
|
def system_instructions(question_difficulty, topic,documents_str):
|
|
return f"""<s> [INST] Your are a great teacher and your task is to create 10 questions with 4 choices with a {question_difficulty} difficulty about topic request " {topic} " only from the below given documents, {documents_str} then create an answers. Index in JSON format, the questions as "Q#":"" to "Q#":"", the four choices as "Q#:C1":"" to "Q#:C4":"", and the answers as "A#":"Q#:C#" to "A#":"Q#:C#". [/INST]"""
|
|
|
|
|
|
|
|
with gr.Blocks(theme='NoCrypt/miku') as CHATBOT:
|
|
with gr.Row():
|
|
with gr.Column(scale=10):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
gr.HTML(value="""<div style="color: #FF4500;"><h1>CHEERFULL CBSE-</h1> <h1><span style="color: #008000">AI Assisted Fun Learning</span></h1>
|
|
</div>""", elem_id='heading')
|
|
|
|
gr.HTML(value=f"""
|
|
<p style="font-family: sans-serif; font-size: 16px;">
|
|
A free Artificial Intelligence Chatbot assistant trained on CBSE Class 10 Science Notes to engage and help students and teachers of Puducherry.
|
|
</p>
|
|
""", elem_id='Sub-heading')
|
|
|
|
gr.HTML(value=f"""<p style="font-family: Arial, sans-serif; font-size: 14px;">Developed by K M Ramyasri , TGT,GHS.SUTHUKENY . Suggestions may be sent to <a href="mailto:ramyadevi1607@yahoo.com" style="color: #00008B; font-style: italic;">ramyadevi1607@yahoo.com</a>.</p>""", elem_id='Sub-heading1 ')
|
|
|
|
with gr.Column(scale=3):
|
|
gr.Image(value='logo.png',height=200,width=200)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
chatbot = gr.Chatbot(
|
|
[],
|
|
elem_id="chatbot",
|
|
avatar_images=('https://aui.atlassian.com/aui/8.8/docs/images/avatar-person.svg',
|
|
'https://huggingface.co/datasets/huggingface/brand-assets/resolve/main/hf-logo.svg'),
|
|
bubble_full_width=False,
|
|
show_copy_button=True,
|
|
show_share_button=True,
|
|
)
|
|
|
|
with gr.Row():
|
|
txt = gr.Textbox(
|
|
scale=3,
|
|
show_label=False,
|
|
placeholder="Enter text and press enter",
|
|
container=False,
|
|
)
|
|
txt_btn = gr.Button(value="Submit text", scale=1)
|
|
|
|
cross_encoder = gr.Radio(choices=['(FAST) MiniLM-L6v2','(ACCURATE) BGE reranker','(HIGH ACCURATE) ColBERT'], value='(ACCURATE) BGE reranker',label="Embeddings", info="Only First query to Colbert may take litte time)")
|
|
|
|
prompt_html = gr.HTML()
|
|
|
|
txt_msg = txt_btn.click(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(
|
|
bot, [chatbot, cross_encoder], [chatbot, prompt_html])
|
|
|
|
|
|
txt_msg.then(lambda: gr.Textbox(interactive=True), None, [txt], queue=False)
|
|
|
|
|
|
txt_msg = txt.submit(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(
|
|
bot, [chatbot, cross_encoder], [chatbot, prompt_html])
|
|
|
|
|
|
txt_msg.then(lambda: gr.Textbox(interactive=True), None, [txt], queue=False)
|
|
|
|
|
|
gr.Examples(examples, txt)
|
|
|
|
|
|
RAG_db=gr.State()
|
|
|
|
with gr.Blocks(title="Quiz Maker", theme=gr.themes.Default(primary_hue="green", secondary_hue="green"), css="style.css") as QUIZBOT:
|
|
def load_model():
|
|
RAG= RAGPretrainedModel.from_pretrained("colbert-ir/colbertv2.0")
|
|
RAG_db.value=RAG.from_index('.ragatouille/colbert/indexes/cbseclass10index')
|
|
return 'Ready to Go!!'
|
|
with gr.Column(scale=4):
|
|
gr.HTML("""
|
|
<center>
|
|
<h1><span style="color: purple;">AI NANBAN</span> - CBSE Class Quiz Maker</h1>
|
|
<h2>AI-powered Learning Game</h2>
|
|
<i>β οΈ Students create quiz from any topic /CBSE Chapter ! β οΈ</i>
|
|
</center>
|
|
""")
|
|
|
|
with gr.Column(scale=2):
|
|
load_btn = gr.Button("Click to Load!π")
|
|
load_text=gr.Textbox()
|
|
load_btn.click(load_model,[],load_text)
|
|
|
|
|
|
topic = gr.Textbox(label="Enter the Topic for Quiz", placeholder="Write any topic from CBSE notes")
|
|
|
|
with gr.Row():
|
|
radio = gr.Radio(
|
|
["easy", "average", "hard"], label="How difficult should the quiz be?"
|
|
)
|
|
|
|
|
|
generate_quiz_btn = gr.Button("Generate Quiz!π")
|
|
quiz_msg=gr.Textbox()
|
|
|
|
question_radios = [gr.Radio(visible=False), gr.Radio(visible=False), gr.Radio(
|
|
visible=False), gr.Radio(visible=False), gr.Radio(visible=False), gr.Radio(visible=False), gr.Radio(visible=False), gr.Radio(
|
|
visible=False), gr.Radio(visible=False), gr.Radio(visible=False)]
|
|
|
|
print(question_radios)
|
|
|
|
@spaces.GPU
|
|
@generate_quiz_btn.click(inputs=[radio, topic], outputs=[quiz_msg]+question_radios, api_name="generate_quiz")
|
|
def generate_quiz(question_difficulty, topic):
|
|
top_k_rank=10
|
|
RAG_db_=RAG_db.value
|
|
documents_full=RAG_db_.search(topic,k=top_k_rank)
|
|
|
|
|
|
|
|
generate_kwargs = dict(
|
|
temperature=0.2,
|
|
max_new_tokens=4000,
|
|
top_p=0.95,
|
|
repetition_penalty=1.0,
|
|
do_sample=True,
|
|
seed=42,
|
|
)
|
|
question_radio_list = []
|
|
count=0
|
|
while count<=3:
|
|
try:
|
|
documents=[item['content'] for item in documents_full]
|
|
document_summaries = [f"[DOCUMENT {i+1}]: {summary}{count}" for i, summary in enumerate(documents)]
|
|
documents_str='\n'.join(document_summaries)
|
|
formatted_prompt = system_instructions(
|
|
question_difficulty, topic,documents_str)
|
|
print(formatted_prompt)
|
|
pre_prompt = [
|
|
{"role": "system", "content": formatted_prompt}
|
|
]
|
|
response = client.text_generation(
|
|
formatted_prompt, **generate_kwargs, stream=False, details=False, return_full_text=False,
|
|
)
|
|
output_json = json.loads(f"{response}")
|
|
|
|
|
|
print(response)
|
|
print('output json', output_json)
|
|
|
|
global quiz_data
|
|
|
|
quiz_data = output_json
|
|
|
|
|
|
|
|
for question_num in range(1, 11):
|
|
question_key = f"Q{question_num}"
|
|
answer_key = f"A{question_num}"
|
|
|
|
question = quiz_data.get(question_key)
|
|
answer = quiz_data.get(quiz_data.get(answer_key))
|
|
|
|
if not question or not answer:
|
|
continue
|
|
|
|
choice_keys = [f"{question_key}:C{i}" for i in range(1, 5)]
|
|
choice_list = []
|
|
for choice_key in choice_keys:
|
|
choice = quiz_data.get(choice_key, "Choice not found")
|
|
choice_list.append(f"{choice}")
|
|
|
|
radio = gr.Radio(choices=choice_list, label=question,
|
|
visible=True, interactive=True)
|
|
|
|
question_radio_list.append(radio)
|
|
if len(question_radio_list)==10:
|
|
break
|
|
else:
|
|
print('10 questions not generated . So trying again!')
|
|
count+=1
|
|
continue
|
|
except Exception as e:
|
|
count+=1
|
|
print(f"Exception occurred: {e}")
|
|
if count==3:
|
|
print('Retry exhausted')
|
|
gr.Warning('Sorry. Pls try with another topic !')
|
|
else:
|
|
print(f"Trying again..{count} time...please wait")
|
|
continue
|
|
|
|
print('Question radio list ' , question_radio_list)
|
|
|
|
return ['Quiz Generated!']+ question_radio_list
|
|
|
|
check_button = gr.Button("Check Score")
|
|
|
|
score_textbox = gr.Markdown()
|
|
|
|
@check_button.click(inputs=question_radios, outputs=score_textbox)
|
|
def compare_answers(*user_answers):
|
|
user_anwser_list = []
|
|
user_anwser_list = user_answers
|
|
|
|
answers_list = []
|
|
|
|
for question_num in range(1, 20):
|
|
answer_key = f"A{question_num}"
|
|
answer = quiz_data.get(quiz_data.get(answer_key))
|
|
if not answer:
|
|
break
|
|
answers_list.append(answer)
|
|
|
|
score = 0
|
|
|
|
for item in user_anwser_list:
|
|
if item in answers_list:
|
|
score += 1
|
|
if score>5:
|
|
message = f"### Good ! You got {score} over 10!"
|
|
elif score>7:
|
|
message = f"### Excellent ! You got {score} over 10!"
|
|
else:
|
|
message = f"### You got {score} over 10! Dont worry . You can prepare well and try better next time !"
|
|
|
|
return message
|
|
|
|
|
|
|
|
demo = gr.TabbedInterface([CHATBOT,QUIZBOT], ["AI ChatBot", "AI Nanban-Quizbot"])
|
|
|
|
demo.queue()
|
|
demo.launch(debug=True)
|
|
|