|
import matplotlib.pyplot as plt |
|
import seaborn as sns |
|
import pandas as pd |
|
import numpy as np |
|
|
|
class VisualizationSelector: |
|
def select_visualizations(self, data): |
|
visualizations = [] |
|
|
|
|
|
numeric_columns = data.select_dtypes(include=[np.number]).columns |
|
for column in numeric_columns: |
|
fig, ax = plt.subplots() |
|
sns.histplot(data[column], kde=True, ax=ax) |
|
ax.set_title(f'Distribution of {column}') |
|
visualizations.append(fig) |
|
|
|
|
|
if len(numeric_columns) > 1: |
|
fig, ax = plt.subplots(figsize=(10, 8)) |
|
sns.heatmap(data[numeric_columns].corr(), annot=True, cmap='coolwarm', ax=ax) |
|
ax.set_title('Correlation Heatmap') |
|
visualizations.append(fig) |
|
|
|
|
|
if len(numeric_columns) > 1: |
|
fig = sns.pairplot(data[numeric_columns]) |
|
fig.fig.suptitle('Scatter Plot Matrix', y=1.02) |
|
visualizations.append(fig) |
|
|
|
|
|
categorical_columns = data.select_dtypes(include=['object']).columns |
|
for cat_col in categorical_columns: |
|
for num_col in numeric_columns: |
|
fig, ax = plt.subplots() |
|
sns.boxplot(x=cat_col, y=num_col, data=data, ax=ax) |
|
ax.set_title(f'{cat_col} vs {num_col}') |
|
ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha='right') |
|
visualizations.append(fig) |
|
|
|
return visualizations |