File size: 3,823 Bytes
7e02cc7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from embeddings import Embeddings
from chain import Chain
from llm import LLM
from retriever import Retriever
from fastapi import FastAPI, HTTPException
from fastapi.responses import HTMLResponse
from functools import lru_cache
from tools import *
import re

emb = Embeddings("hf", "all-MiniLM-L6-v2")
llm = LLM('gemini').get_llm()
ch = Chain(llm,None)
ret = Retriever('pinecone', 'pinecone', emb.embedding, 'ensemble', 5)

is_arabic = False

@lru_cache()
def investment_banker(query):
    global is_arabic
    context, context_list = ret.get_context(query)
    if not is_arabic:
        prompt_template = f"""
            You are an investment banker and financial advisor. 
            Answer the question as detailed as possible from the provided context and make sure to provide all the details.
            Answer only from the context. If the answer is not in provided context, say "Answer not in context".\n\n
            Context:\n {context}\n\n
            Question: \n{query}\n

            Answer:
        """
    else:
        prompt_template = f"""
            You are an investment banker and financial advisor. 
            Answer the question as detailed as possible from the provided context and make sure to provide all the details.
            Answer only from the context. If the answer is not in provided context, say "Answer not in context".
            Return the answer in Arabic only.\n\n
            Context:\n {context}\n\n
            Question: \n{query}\n

            Answer:
        """
    response = ch.run_conversational_chain(prompt_template)
    is_arabic = False
    return response

def check_arabic(s):
    arabic_pattern = re.compile(r'[\u0600-\u06FF]')
    if arabic_pattern.search(s):
        return True
    else:
        return False

history = ""

@lru_cache()
def refine_query(query, conversation):
    prompt=f"""Given the following user query and historical user conversation with banker.
    If the current user query is in arabic, convert it to english and then proceed.
    If conversation history is empty return the current query as it is.
    If the query is a continuation of previous conversation then only rephrase the users current query to form a meaningful and clear question.
    Otherwise return the user query as it is.
    Previously user and banker had the following conversation: \n{conversation}\n\n User's Current Query: {query}. 
    What will be the refined query? Only provide the query without any extra details or explanations."""
    ans = llm.invoke(prompt).content
    return ans


def get_answer(query):
    global history
    global is_arabic

    is_arabic = check_arabic(query)
    ref_query = refine_query(query, history)
    ans = investment_banker(ref_query)
    history += "Human: "+ ref_query + "\n"
    history += "Banker: "+ ans + "\n"

    return ans
if __name__ == "__main__":
    response = get_answer()
    print(response)
# app = FastAPI()

# class Query(BaseModel):
#     question: str

# @app.post("/chat/")
# async def chat(query: Query):
#     global history
#     global is_arabic
    
#     try:
        
#         is_arabic = check_arabic(query.question)
#         ref_query = refine_query(query.question, history)
        

#         print(query.question, ref_query)
#         print(is_arabic)
#         ans = investment_banker(ref_query)
#         history += "Human: "+ ref_query + "\n"
#         history += "Banker: "+ ans + "\n"
#         return {"question": query.question, "answer": ans}
#     except Exception as e:
#         raise HTTPException(status_code=500, detail=str(e))
    

# @app.get("/", response_class=HTMLResponse)
# async def read_index():
#     with open('index.html', 'r') as f:
#         return HTMLResponse(content=f.read())