|
from langchain_openai import ChatOpenAI
|
|
from langchain_core.messages import (
|
|
HumanMessage,
|
|
SystemMessage
|
|
)
|
|
|
|
from rake_nltk import Rake
|
|
import nltk
|
|
nltk.download('stopwords')
|
|
nltk.download('punkt')
|
|
"""
|
|
This function takes in user query and returns keywords
|
|
Input:
|
|
user_query: str
|
|
keyword_type: str (openai, rake, or na)
|
|
If the keyword type is na, then user query is returned.
|
|
Output: keywords: str
|
|
"""
|
|
def get_keywords(user_query: str, keyword_type: str) -> str:
|
|
if keyword_type == "openai":
|
|
return get_keywords_openai(user_query)
|
|
if keyword_type == "rake":
|
|
return get_keywords_rake(user_query)
|
|
else:
|
|
return user_query
|
|
|
|
|
|
"""
|
|
This function takes user query and returns keywords using rake_nltk
|
|
rake_nltk actually returns keyphrases, not keywords. Since using keyphrases did not show improvement, we are using keywords
|
|
to match the output type of the other keyword functions.
|
|
Input:
|
|
user_query: str
|
|
Output: keywords: str
|
|
"""
|
|
def get_keywords_rake(user_query: str) -> str:
|
|
r = Rake()
|
|
r.extract_keywords_from_text(user_query)
|
|
keyphrases = r.get_ranked_phrases()
|
|
|
|
|
|
out = ""
|
|
for phrase in keyphrases:
|
|
out += phrase + " "
|
|
return out
|
|
|
|
|
|
"""
|
|
This function takes user query and returns keywords using openai
|
|
Input:
|
|
user_query: str
|
|
Output: keywords: str
|
|
"""
|
|
def get_keywords_openai(user_query: str) -> str:
|
|
llm = ChatOpenAI(temperature=0.0)
|
|
command = "return the keywords of the following query. response should be words separated by commas. "
|
|
message = [
|
|
SystemMessage(content=command),
|
|
HumanMessage(content=user_query)
|
|
]
|
|
response = llm(message)
|
|
res = response.content.replace(",", "")
|
|
return res
|
|
|