|
import pandas as pd |
|
import nltk |
|
from nltk.tokenize import word_tokenize |
|
from nltk.corpus import stopwords |
|
from nltk.stem import WordNetLemmatizer |
|
import re |
|
from sklearn.feature_extraction.text import TfidfVectorizer |
|
from sklearn.metrics.pairwise import cosine_similarity |
|
|
|
class NLQueryEngine: |
|
def __init__(self): |
|
nltk.download('punkt') |
|
nltk.download('stopwords') |
|
nltk.download('wordnet') |
|
self.stop_words = set(stopwords.words('english')) |
|
self.lemmatizer = WordNetLemmatizer() |
|
self.vectorizer = TfidfVectorizer() |
|
|
|
def process_query(self, query, data): |
|
tokens = self.preprocess_query(query) |
|
intent = self.identify_intent(tokens) |
|
result = self.execute_query(intent, tokens, data) |
|
return result |
|
|
|
def preprocess_query(self, query): |
|
|
|
tokens = word_tokenize(query.lower()) |
|
tokens = [self.lemmatizer.lemmatize(token) for token in tokens if token not in self.stop_words] |
|
return tokens |
|
|
|
def identify_intent(self, tokens): |
|
intent_keywords = { |
|
'statistical': ['average', 'mean', 'median', 'mode', 'max', 'maximum', 'min', 'minimum', 'sum', 'count'], |
|
'comparison': ['compare', 'difference', 'versus', 'vs'], |
|
'trend': ['trend', 'over time', 'increase', 'decrease'], |
|
'distribution': ['distribution', 'spread', 'range'], |
|
'correlation': ['correlation', 'relationship', 'between'] |
|
} |
|
|
|
for intent, keywords in intent_keywords.items(): |
|
if any(keyword in tokens for keyword in keywords): |
|
return intent |
|
return 'general' |
|
|
|
def find_column(self, tokens, columns): |
|
|
|
column_vectors = self.vectorizer.fit_transform(columns) |
|
query_vector = self.vectorizer.transform([' '.join(tokens)]) |
|
similarities = cosine_similarity(query_vector, column_vectors).flatten() |
|
best_match_index = similarities.argmax() |
|
return columns[best_match_index] if similarities[best_match_index] > 0.1 else None |
|
|
|
def execute_query(self, intent, tokens, data): |
|
column = self.find_column(tokens, data.columns) |
|
if not column: |
|
return "I couldn't identify a relevant column in your query. Can you please specify the column name?" |
|
|
|
if intent == 'statistical': |
|
return self.statistical_query(tokens, data, column) |
|
elif intent == 'comparison': |
|
return self.comparison_query(tokens, data, column) |
|
elif intent == 'trend': |
|
return self.trend_query(data, column) |
|
elif intent == 'distribution': |
|
return self.distribution_query(data, column) |
|
elif intent == 'correlation': |
|
return self.correlation_query(data, column, tokens) |
|
else: |
|
return f"Here's a summary of {column}:\n{data[column].describe()}" |
|
|
|
def statistical_query(self, tokens, data, column): |
|
if 'average' in tokens or 'mean' in tokens: |
|
return f"The average of {column} is {data[column].mean():.2f}" |
|
elif 'median' in tokens: |
|
return f"The median of {column} is {data[column].median():.2f}" |
|
elif 'mode' in tokens: |
|
return f"The mode of {column} is {data[column].mode().values[0]}" |
|
elif 'maximum' in tokens or 'max' in tokens: |
|
return f"The maximum of {column} is {data[column].max():.2f}" |
|
elif 'minimum' in tokens or 'min' in tokens: |
|
return f"The minimum of {column} is {data[column].min():.2f}" |
|
elif 'sum' in tokens: |
|
return f"The sum of {column} is {data[column].sum():.2f}" |
|
elif 'count' in tokens: |
|
return f"The count of {column} is {data[column].count()}" |
|
|
|
def comparison_query(self, tokens, data, column): |
|
|
|
return f"Comparison analysis for {column} is not implemented yet." |
|
|
|
def trend_query(self, data, column): |
|
|
|
return f"Trend analysis for {column} is not implemented yet." |
|
|
|
def distribution_query(self, data, column): |
|
|
|
return f"Distribution analysis for {column} is not implemented yet." |
|
|
|
def correlation_query(self, data, column1, tokens): |
|
column2 = self.find_column([token for token in tokens if token != column1], data.columns) |
|
if column2: |
|
correlation = data[column1].corr(data[column2]) |
|
return f"The correlation between {column1} and {column2} is {correlation:.2f}" |
|
else: |
|
return f"I couldn't identify a second column to correlate with {column1}. Can you please specify two column names?" |