|
from langchain.chains.question_answering import load_qa_chain |
|
from langchain.prompts import PromptTemplate |
|
import json |
|
from prompts import * |
|
import streamlit as st |
|
import os |
|
|
|
|
|
def util(context, numPairs, inputPrompt,model): |
|
stuff_chain = load_qa_chain(model, chain_type="stuff", prompt=inputPrompt) |
|
stuff_answer = stuff_chain( |
|
{"input_documents": context, "numPairs": numPairs}, return_only_outputs=True |
|
) |
|
output_text = stuff_answer['output_text'] |
|
output_json = json.loads(output_text) |
|
return output_json |
|
|
|
|
|
|
|
def getLongQAPairs(context, numPairs,model): |
|
prompt_template = getLongQAPrompt() |
|
prompt = PromptTemplate( |
|
template=prompt_template, input_variables=["context", "numPairs"] |
|
) |
|
return util(context, numPairs, prompt,model) |
|
|
|
|
|
|
|
def getShortQAPairs(context, numPairs,model): |
|
prompt_template = getShortQAPrompt() |
|
prompt = PromptTemplate( |
|
template=prompt_template, input_variables=["context", "numPairs"] |
|
) |
|
return util(context, numPairs, prompt,model) |
|
|
|
|
|
|
|
def getMcqQAPairs(context, numPairs,model): |
|
prompt_template = getMcqQAPrompt() |
|
prompt = PromptTemplate( |
|
template=prompt_template, input_variables=["context", "numPairs"] |
|
) |
|
return util(context, numPairs, prompt,model) |
|
|
|
|
|
def downloadFile(response,FileName): |
|
with open(FileName, "w") as outfile: |
|
json.dump(response, outfile, indent=4) |
|
with open(FileName, "rb") as file: |
|
st.download_button( |
|
label="Download File", |
|
data=file, |
|
file_name=FileName, |
|
mime="text/json", |
|
type='primary' |
|
) |
|
os.remove(FileName) |
|
os.remove('temp.pdf') |
|
|
|
|
|
|