dsgt-snakeclef / evaluate /test_evaluate.py
Anthony Miyaguchi
Remove lightning dependency from submission
a0583df
raw
history blame contribute delete
No virus
2.79 kB
import numpy as np
import pandas as pd
import PIL
import pytest
import torch
from pytorch_lightning import Trainer
from .data import ImageDataset
from .data_lightning import InferenceDataModel
from .model_lightning import LinearClassifier as LightningLinearClassifier
from .submission import make_submission
class TestingInferenceDataModel(InferenceDataModel):
def train_dataloader(self):
for batch in self.predict_dataloader():
# add a label to the batch with classes from 0 to 9
batch["label"] = torch.randint(0, 10, (batch["features"].shape[0],))
yield batch
@pytest.fixture
def images_root(tmp_path):
images_root = tmp_path / "images"
images_root.mkdir()
for i in range(10):
img = PIL.Image.fromarray(
np.random.randint(0, 255, (100, 100, 3), dtype=np.uint8)
)
img.save(images_root / f"{i}.jpg")
return images_root
@pytest.fixture
def metadata(tmp_path, images_root):
res = []
for i, img in enumerate(images_root.glob("*.jpg")):
res.append({"filename": img.name, "observation_id": i})
df = pd.DataFrame(res)
df.to_csv(tmp_path / "metadata.csv", index=False)
return tmp_path / "metadata.csv"
@pytest.fixture
def model_checkpoint(tmp_path, metadata, images_root):
model_checkpoint = tmp_path / "model.ckpt"
model = LightningLinearClassifier(768, 10)
trainer = Trainer(max_epochs=1, fast_dev_run=True)
dm = TestingInferenceDataModel(metadata, images_root)
trainer.fit(model, dm)
trainer.save_checkpoint(model_checkpoint)
return model_checkpoint
def test_image_dataset(images_root, metadata):
dataset = ImageDataset(metadata, images_root)
assert len(dataset) == 10
for i in range(10):
assert dataset[i]["features"].shape == torch.Size([3, 100, 100])
def test_inference_datamodel(images_root, metadata):
batch_size = 5
model = InferenceDataModel(metadata, images_root, batch_size=batch_size)
model.setup()
assert len(model.dataloader) == 2
for batch in model.predict_dataloader():
assert set(batch.keys()) == {"features", "observation_id"}
assert batch["features"].shape == torch.Size([batch_size, 768])
def test_model_checkpoint(model_checkpoint):
model = LightningLinearClassifier.load_from_checkpoint(model_checkpoint)
assert model
def test_make_submission(model_checkpoint, metadata, images_root, tmp_path):
output_csv_path = tmp_path / "submission.csv"
make_submission(metadata, model_checkpoint, output_csv_path, images_root)
submission_df = pd.read_csv(output_csv_path)
assert len(submission_df) == 10
assert set(submission_df.columns) == {"observation_id", "class_id"}
assert submission_df["class_id"].isin(range(10)).all()