Create natural_language_query.py
Browse files- natural_language_query.py +105 -0
natural_language_query.py
ADDED
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
import nltk
|
3 |
+
from nltk.tokenize import word_tokenize
|
4 |
+
from nltk.corpus import stopwords
|
5 |
+
from nltk.stem import WordNetLemmatizer
|
6 |
+
import re
|
7 |
+
from sklearn.feature_extraction.text import TfidfVectorizer
|
8 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
9 |
+
|
10 |
+
class NLQueryEngine:
|
11 |
+
def __init__(self):
|
12 |
+
nltk.download('punkt')
|
13 |
+
nltk.download('stopwords')
|
14 |
+
nltk.download('wordnet')
|
15 |
+
self.stop_words = set(stopwords.words('english'))
|
16 |
+
self.lemmatizer = WordNetLemmatizer()
|
17 |
+
self.vectorizer = TfidfVectorizer()
|
18 |
+
|
19 |
+
def process_query(self, query, data):
|
20 |
+
tokens = self.preprocess_query(query)
|
21 |
+
intent = self.identify_intent(tokens)
|
22 |
+
result = self.execute_query(intent, tokens, data)
|
23 |
+
return result
|
24 |
+
|
25 |
+
def preprocess_query(self, query):
|
26 |
+
# Tokenize, remove stop words, and lemmatize
|
27 |
+
tokens = word_tokenize(query.lower())
|
28 |
+
tokens = [self.lemmatizer.lemmatize(token) for token in tokens if token not in self.stop_words]
|
29 |
+
return tokens
|
30 |
+
|
31 |
+
def identify_intent(self, tokens):
|
32 |
+
intent_keywords = {
|
33 |
+
'statistical': ['average', 'mean', 'median', 'mode', 'max', 'maximum', 'min', 'minimum', 'sum', 'count'],
|
34 |
+
'comparison': ['compare', 'difference', 'versus', 'vs'],
|
35 |
+
'trend': ['trend', 'over time', 'increase', 'decrease'],
|
36 |
+
'distribution': ['distribution', 'spread', 'range'],
|
37 |
+
'correlation': ['correlation', 'relationship', 'between']
|
38 |
+
}
|
39 |
+
|
40 |
+
for intent, keywords in intent_keywords.items():
|
41 |
+
if any(keyword in tokens for keyword in keywords):
|
42 |
+
return intent
|
43 |
+
return 'general'
|
44 |
+
|
45 |
+
def find_column(self, tokens, columns):
|
46 |
+
# Use TF-IDF and cosine similarity to find the best matching column
|
47 |
+
column_vectors = self.vectorizer.fit_transform(columns)
|
48 |
+
query_vector = self.vectorizer.transform([' '.join(tokens)])
|
49 |
+
similarities = cosine_similarity(query_vector, column_vectors).flatten()
|
50 |
+
best_match_index = similarities.argmax()
|
51 |
+
return columns[best_match_index] if similarities[best_match_index] > 0.1 else None
|
52 |
+
|
53 |
+
def execute_query(self, intent, tokens, data):
|
54 |
+
column = self.find_column(tokens, data.columns)
|
55 |
+
if not column:
|
56 |
+
return "I couldn't identify a relevant column in your query. Can you please specify the column name?"
|
57 |
+
|
58 |
+
if intent == 'statistical':
|
59 |
+
return self.statistical_query(tokens, data, column)
|
60 |
+
elif intent == 'comparison':
|
61 |
+
return self.comparison_query(tokens, data, column)
|
62 |
+
elif intent == 'trend':
|
63 |
+
return self.trend_query(data, column)
|
64 |
+
elif intent == 'distribution':
|
65 |
+
return self.distribution_query(data, column)
|
66 |
+
elif intent == 'correlation':
|
67 |
+
return self.correlation_query(data, column, tokens)
|
68 |
+
else:
|
69 |
+
return f"Here's a summary of {column}:\n{data[column].describe()}"
|
70 |
+
|
71 |
+
def statistical_query(self, tokens, data, column):
|
72 |
+
if 'average' in tokens or 'mean' in tokens:
|
73 |
+
return f"The average of {column} is {data[column].mean():.2f}"
|
74 |
+
elif 'median' in tokens:
|
75 |
+
return f"The median of {column} is {data[column].median():.2f}"
|
76 |
+
elif 'mode' in tokens:
|
77 |
+
return f"The mode of {column} is {data[column].mode().values[0]}"
|
78 |
+
elif 'maximum' in tokens or 'max' in tokens:
|
79 |
+
return f"The maximum of {column} is {data[column].max():.2f}"
|
80 |
+
elif 'minimum' in tokens or 'min' in tokens:
|
81 |
+
return f"The minimum of {column} is {data[column].min():.2f}"
|
82 |
+
elif 'sum' in tokens:
|
83 |
+
return f"The sum of {column} is {data[column].sum():.2f}"
|
84 |
+
elif 'count' in tokens:
|
85 |
+
return f"The count of {column} is {data[column].count()}"
|
86 |
+
|
87 |
+
def comparison_query(self, tokens, data, column):
|
88 |
+
# Implement comparison logic here
|
89 |
+
return f"Comparison analysis for {column} is not implemented yet."
|
90 |
+
|
91 |
+
def trend_query(self, data, column):
|
92 |
+
# Implement trend analysis logic here
|
93 |
+
return f"Trend analysis for {column} is not implemented yet."
|
94 |
+
|
95 |
+
def distribution_query(self, data, column):
|
96 |
+
# Implement distribution analysis logic here
|
97 |
+
return f"Distribution analysis for {column} is not implemented yet."
|
98 |
+
|
99 |
+
def correlation_query(self, data, column1, tokens):
|
100 |
+
column2 = self.find_column([token for token in tokens if token != column1], data.columns)
|
101 |
+
if column2:
|
102 |
+
correlation = data[column1].corr(data[column2])
|
103 |
+
return f"The correlation between {column1} and {column2} is {correlation:.2f}"
|
104 |
+
else:
|
105 |
+
return f"I couldn't identify a second column to correlate with {column1}. Can you please specify two column names?"
|