from src.cross_validate import CrossValidation | |
from src.data_reader import load_train_test_df | |
from src.solutions.constant_predictor import ConstantPredictorSolution | |
from src.utils import pandas_set_print_options | |
pandas_set_print_options() | |
def test_cross_validation(): | |
train_df, _ = load_train_test_df(is_testing=True) | |
x_columns = ['text_id', 'full_text'] | |
X, y = train_df[x_columns], train_df.drop(columns=['full_text']) | |
n_splits = 3 | |
cv = CrossValidation(saving_dir='/tmp/sdfjsld', n_splits=n_splits) | |
predictor = ConstantPredictorSolution() | |
cv_scores = cv.fit(predictor, X, y) | |
assert cv_scores.shape == (n_splits + 1, 6) | |
prediction_df = cv.predict(train_df[x_columns]) | |
assert prediction_df.shape == (len(train_df), 7) | |