File size: 1,869 Bytes
12cca3e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
from langchain_openai import ChatOpenAI
from langchain_core.messages import (
    HumanMessage,
    SystemMessage
)

from rake_nltk import Rake
import nltk
nltk.download('stopwords')
nltk.download('punkt')
"""

This function takes in user query and returns keywords

Input:

    user_query: str

    keyword_type: str (openai, rake, or na)

        If the keyword type is na, then user query is returned.

Output: keywords: str

"""
def get_keywords(user_query: str, keyword_type: str) -> str:
    if keyword_type == "openai":
        return get_keywords_openai(user_query)
    if keyword_type == "rake":
        return get_keywords_rake(user_query)
    else:
        return user_query


"""

This function takes user query and returns keywords using rake_nltk

rake_nltk actually returns keyphrases, not keywords. Since using keyphrases did not show improvement, we are using keywords

to match the output type of the other keyword functions.

Input:

    user_query: str

Output: keywords: str

"""
def get_keywords_rake(user_query: str) -> str:
    r = Rake()
    r.extract_keywords_from_text(user_query)
    keyphrases = r.get_ranked_phrases()

    # If we want to get keyphrases, return keyphrases but should do keywords
    out = ""
    for phrase in keyphrases:
        out += phrase + " "
    return out


"""

This function takes user query and returns keywords using openai

Input:

    user_query: str

Output: keywords: str

"""
def get_keywords_openai(user_query: str) -> str:
    llm = ChatOpenAI(temperature=0.0)
    command = "return the keywords of the following query. response should be words separated by commas. "
    message = [
        SystemMessage(content=command),
        HumanMessage(content=user_query)
    ]
    response = llm(message)
    res = response.content.replace(",", "")
    return res