SemF1 / tests.py
nbansal's picture
Minor
536d81c
raw
history blame
23.9 kB
import statistics
import unittest
import numpy as np
import torch
from numpy.testing import assert_almost_equal
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
from .encoder_models import SBertEncoder, get_encoder
from .semf1 import SemF1, _compute_cosine_similarity, _validate_input_format
from .utils import get_gpu, slice_embeddings, is_nested_list_of_type, flatten_list, compute_f1, Scores
class TestUtils(unittest.TestCase):
def test_get_gpu(self):
gpu_count = torch.cuda.device_count()
gpu_available = torch.cuda.is_available()
# Test single boolean input
self.assertEqual(get_gpu(True), 0 if gpu_available else "cpu")
self.assertEqual(get_gpu(False), "cpu")
# Test single string input
self.assertEqual(get_gpu("cpu"), "cpu")
self.assertEqual(get_gpu("gpu"), 0 if gpu_available else "cpu")
self.assertEqual(get_gpu("cuda"), 0 if gpu_available else "cpu")
# Test single integer input
self.assertEqual(get_gpu(0), 0 if gpu_available else "cpu")
self.assertEqual(get_gpu(1), 1 if gpu_available else "cpu")
# Test list input with unique elements
self.assertEqual(get_gpu([True, "cpu", 0]), [0, "cpu"] if gpu_available else ["cpu", "cpu", "cpu"])
# Test list input with duplicate elements
self.assertEqual(get_gpu([0, 0, "gpu"]), 0 if gpu_available else ["cpu", "cpu", "cpu"])
# Test list input with duplicate elements of different types
self.assertEqual(get_gpu([True, 0, "gpu"]), 0 if gpu_available else ["cpu", "cpu", "cpu"])
# Test list input but only one element
self.assertEqual(get_gpu([True]), 0 if gpu_available else "cpu")
# Test list input with all integers
self.assertEqual(get_gpu(list(range(gpu_count))),
list(range(gpu_count)) if gpu_available else gpu_count * ["cpu"])
with self.assertRaises(ValueError):
get_gpu("invalid")
with self.assertRaises(ValueError):
get_gpu(torch.cuda.device_count())
def test_slice_embeddings(self):
embeddings = np.random.rand(10, 5)
num_sentences = [3, 2, 5]
expected_output = [embeddings[:3], embeddings[3:5], embeddings[5:]]
self.assertTrue(
all(np.array_equal(a, b) for a, b in zip(slice_embeddings(embeddings, num_sentences),
expected_output))
)
num_sentences_nested = [[2, 1], [3, 4]]
expected_output_nested = [[embeddings[:2], embeddings[2:3]], [embeddings[3:6], embeddings[6:]]]
self.assertTrue(
slice_embeddings(embeddings, num_sentences_nested), expected_output_nested
)
with self.assertRaises(TypeError):
slice_embeddings(embeddings, "invalid")
def test_is_nested_list_of_type(self):
# Test case: Depth 0, single element matching element_type
self.assertEqual(is_nested_list_of_type("test", str, 0), (True, ""))
# Test case: Depth 0, single element not matching element_type
is_valid, err_msg = is_nested_list_of_type("test", int, 0)
self.assertEqual(is_valid, False)
# Test case: Depth 1, list of elements matching element_type
self.assertEqual(is_nested_list_of_type(["apple", "banana"], str, 1), (True, ""))
# Test case: Depth 1, list of elements not matching element_type
is_valid, err_msg = is_nested_list_of_type([1, 2, 3], str, 1)
self.assertEqual(is_valid, False)
# Test case: Depth 0 (Wrong), list of elements matching element_type
is_valid, err_msg = is_nested_list_of_type([1, 2, 3], str, 0)
self.assertEqual(is_valid, False)
# Depth 2
self.assertEqual(is_nested_list_of_type([[1, 2], [3, 4]], int, 2), (True, ""))
self.assertEqual(is_nested_list_of_type([['1', '2'], ['3', '4']], str, 2), (True, ""))
is_valid, err_msg = is_nested_list_of_type([[1, 2], ["a", "b"]], int, 2)
self.assertEqual(is_valid, False)
# Depth 3
is_valid, err_msg = is_nested_list_of_type([[[1], [2]], [[3], [4]]], list, 3)
self.assertEqual(is_valid, False)
self.assertEqual(is_nested_list_of_type([[[1], [2]], [[3], [4]]], int, 3), (True, ""))
# Test case: Depth is negative, expecting ValueError
with self.assertRaises(ValueError):
is_nested_list_of_type([1, 2], int, -1)
def test_flatten_list(self):
self.assertEqual(flatten_list([1, [2, 3], [[4], 5]]), [1, 2, 3, 4, 5])
self.assertEqual(flatten_list([]), [])
self.assertEqual(flatten_list([1, 2, 3]), [1, 2, 3])
self.assertEqual(flatten_list([[[[1]]]]), [1])
def test_compute_f1(self):
self.assertAlmostEqual(compute_f1(0.5, 0.5), 0.5)
self.assertAlmostEqual(compute_f1(1, 0), 0.0)
self.assertAlmostEqual(compute_f1(0, 1), 0.0)
self.assertAlmostEqual(compute_f1(1, 1), 1.0)
def test_scores(self):
scores = Scores(precision=0.8, recall=[0.7, 0.9])
self.assertAlmostEqual(scores.f1, compute_f1(0.8, statistics.fmean([0.7, 0.9])))
class TestSBertEncoder(unittest.TestCase):
def setUp(self, device=None):
if device is None:
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
else:
self.device = device
self.model_name = "stsb-roberta-large"
self.batch_size = 8
self.verbose = False
self.encoder = SBertEncoder(self.model_name, self.device, self.batch_size, self.verbose)
def test_initialization(self):
self.assertIsInstance(self.encoder.model, SentenceTransformer)
self.assertEqual(self.encoder.device, self.device)
self.assertEqual(self.encoder.batch_size, self.batch_size)
self.assertEqual(self.encoder.verbose, self.verbose)
def test_encode_single_device(self):
sentences = ["This is a test sentence.", "Here is another sentence."]
embeddings = self.encoder.encode(sentences)
self.assertIsInstance(embeddings, np.ndarray)
self.assertEqual(embeddings.shape[0], len(sentences))
self.assertEqual(embeddings.shape[1], self.encoder.model.get_sentence_embedding_dimension())
def test_encode_multi_device(self):
if torch.cuda.device_count() < 2:
self.skipTest("Multi-GPU test requires at least 2 GPUs.")
else:
devices = ["cuda:0", "cuda:1"]
self.setUp(devices)
sentences = ["This is a test sentence.", "Here is another sentence.", "This is a test sentence."]
embeddings = self.encoder.encode(sentences)
self.assertIsInstance(embeddings, np.ndarray)
self.assertEqual(embeddings.shape[0], 3)
self.assertEqual(embeddings.shape[1], self.encoder.model.get_sentence_embedding_dimension())
class TestGetEncoder(unittest.TestCase):
def setUp(self):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.batch_size = 8
self.verbose = False
def _base_test(self, model_name):
encoder = get_encoder(model_name, self.device, self.batch_size, self.verbose)
# Assert
self.assertIsInstance(encoder, SBertEncoder)
self.assertEqual(encoder.device, self.device)
self.assertEqual(encoder.batch_size, self.batch_size)
self.assertEqual(encoder.verbose, self.verbose)
def test_get_sbert_encoder(self):
model_name = "stsb-roberta-large"
self._base_test(model_name)
def test_sbert_model(self):
model_name = "all-mpnet-base-v2"
self._base_test(model_name)
def test_huggingface_model(self):
"""Test Huggingface models which work with SBert library"""
model_name = "roberta-base"
self._base_test(model_name)
def test_get_encoder_environment_error(self): # This parameter is used when using patch decorator
model_name = "abc" # Wrong model_name
with self.assertRaises(EnvironmentError):
get_encoder(model_name, self.device, self.batch_size, self.verbose)
def test_get_encoder_other_exception(self):
model_name = "apple/OpenELM-270M" # This model is not supported by SentenceTransformer lib
with self.assertRaises(RuntimeError):
get_encoder(model_name, self.device, self.batch_size, self.verbose)
class TestSemF1(unittest.TestCase):
def setUp(self):
self.semf1_metric = SemF1() # semf1_metric
# Example cases, #Samples = 1
self.untokenized_single_reference_predictions = [
"This is a prediction sentence 1. This is a prediction sentence 2."]
self.untokenized_single_reference_references = [
"This is a reference sentence 1. This is a reference sentence 2."]
self.tokenized_single_reference_predictions = [
["This is a prediction sentence 1.", "This is a prediction sentence 2."],
]
self.tokenized_single_reference_references = [
["This is a reference sentence 1.", "This is a reference sentence 2."],
]
self.untokenized_multi_reference_predictions = [
"Prediction sentence 1. Prediction sentence 2."
]
self.untokenized_multi_reference_references = [
["Reference sentence 1. Reference sentence 2.", "Alternative reference 1. Alternative reference 2."],
]
self.tokenized_multi_reference_predictions = [
["Prediction sentence 1.", "Prediction sentence 2."],
]
self.tokenized_multi_reference_references = [
[
["Reference sentence 1.", "Reference sentence 2."],
["Alternative reference 1.", "Alternative reference 2."]
],
]
def test_untokenized_single_reference(self):
scores = self.semf1_metric.compute(
predictions=self.untokenized_single_reference_predictions,
references=self.untokenized_single_reference_references,
tokenize_sentences=True,
multi_references=False,
gpu=False,
batch_size=32,
verbose=False
)
self.assertIsInstance(scores, list)
self.assertEqual(len(scores), len(self.untokenized_single_reference_predictions))
def test_tokenized_single_reference(self):
scores = self.semf1_metric.compute(
predictions=self.tokenized_single_reference_predictions,
references=self.tokenized_single_reference_references,
tokenize_sentences=False,
multi_references=False,
gpu=False,
batch_size=32,
verbose=False
)
self.assertIsInstance(scores, list)
self.assertEqual(len(scores), len(self.tokenized_single_reference_predictions))
for score in scores:
self.assertIsInstance(score, Scores)
self.assertTrue(0.0 <= score.precision <= 1.0)
self.assertTrue(all(0.0 <= recall <= 1.0 for recall in score.recall))
def test_untokenized_multi_reference(self):
scores = self.semf1_metric.compute(
predictions=self.untokenized_multi_reference_predictions,
references=self.untokenized_multi_reference_references,
tokenize_sentences=True,
multi_references=True,
gpu=False,
batch_size=32,
verbose=False
)
self.assertIsInstance(scores, list)
self.assertEqual(len(scores), len(self.untokenized_multi_reference_predictions))
def test_tokenized_multi_reference(self):
scores = self.semf1_metric.compute(
predictions=self.tokenized_multi_reference_predictions,
references=self.tokenized_multi_reference_references,
tokenize_sentences=False,
multi_references=True,
gpu=False,
batch_size=32,
verbose=False
)
self.assertIsInstance(scores, list)
self.assertEqual(len(scores), len(self.tokenized_multi_reference_predictions))
for score in scores:
self.assertIsInstance(score, Scores)
self.assertTrue(0.0 <= score.precision <= 1.0)
self.assertTrue(all(0.0 <= recall <= 1.0 for recall in score.recall))
def test_same_predictions_and_references(self):
scores = self.semf1_metric.compute(
predictions=self.tokenized_single_reference_predictions,
references=self.tokenized_single_reference_predictions,
tokenize_sentences=False,
multi_references=False,
gpu=False,
batch_size=32,
verbose=False
)
self.assertIsInstance(scores, list)
self.assertEqual(len(scores), len(self.tokenized_single_reference_predictions))
for score in scores:
self.assertIsInstance(score, Scores)
self.assertAlmostEqual(score.precision, 1.0, places=6)
assert_almost_equal(score.recall, 1, decimal=5, err_msg="Not all values are almost equal to 1")
def test_exact_output_scores(self):
predictions = [
["I go to School.", "You are stupid."],
["I love adventure sports."],
]
references = [
["I go to playground.", "You are genius.", "You need to be admired."],
["I love adventure sports."],
]
scores = self.semf1_metric.compute(
predictions=predictions,
references=references,
tokenize_sentences=False,
multi_references=False,
gpu=False,
batch_size=32,
verbose=False,
model_type="use",
)
self.assertIsInstance(scores, list)
self.assertEqual(len(scores), len(predictions))
score = scores[0]
self.assertIsInstance(score, Scores)
self.assertAlmostEqual(score.precision, 0.73, places=2)
self.assertAlmostEqual(score.recall[0], 0.63, places=2)
def test_none_input(self):
def _call_metric(preds, refs, tok, mul_ref):
with self.assertRaises(Exception) as ctx:
_ = self.semf1_metric.compute(
predictions=preds,
references=refs,
tokenize_sentences=tok,
multi_references=mul_ref,
gpu=False,
batch_size=32,
verbose=False,
model_type="use",
)
print(f"Raised Exception with message: {ctx.exception}")
return ""
# # Case 1: tokenize_sentences = True, multi_references = True
tokenize_sentences = True
multi_references = True
predictions = [
"I go to School. You are stupid.",
"I go to School. You are stupid.",
]
references = [
["I am", "I am"],
[None, "I am"],
]
print(f"Case I\n{_call_metric(predictions, references, tokenize_sentences, multi_references)}\n")
# Case 2: tokenize_sentences = False, multi_references = True
tokenize_sentences = False
multi_references = True
predictions = [
["I go to School.", "You are stupid."],
["I go to School.", "You are stupid."],
]
references = [
[["I am", "I am"], [None, "I am"]],
[[None, "I am"]],
]
print(f"Case II\n{_call_metric(predictions, references, tokenize_sentences, multi_references)}\n")
# Case 3: tokenize_sentences = True, multi_references = False
tokenize_sentences = True
multi_references = False
predictions = [
None,
"I go to School. You are stupid.",
]
references = [
"I am. I am.",
"I am. I am.",
]
print(f"Case III\n{_call_metric(predictions, references, tokenize_sentences, multi_references)}\n")
# Case 4: tokenize_sentences = False, multi_references = False
# This is taken care by the library itself
tokenize_sentences = False
multi_references = False
predictions = [
["I go to School.", None],
["I go to School.", "You are stupid."],
]
references = [
["I am.", "I am."],
["I am.", "I am."],
]
print(f"Case IV\n{_call_metric(predictions, references, tokenize_sentences, multi_references)}\n")
def test_empty_input(self):
predictions = ["", ""]
references = ["I go to School. You are stupid.", "I am"]
scores = self.semf1_metric.compute(
predictions=predictions,
references=references,
)
print(scores)
# # Test with Gibberish Cases
# predictions = ["lth cgezawrxretxdr", "dsfgsdfhsdfh"]
# references = ["dzfgzeWfnAfse", "dtjsrtzerZJSEWr"]
# scores = self.semf1_metric.compute(
# predictions=predictions,
# references=references,
# )
# print(scores)
class TestCosineSimilarity(unittest.TestCase):
def setUp(self):
# Sample embeddings for testing
self.pred_embeds = np.array([
[1, 0, 0],
[0, 1, 0],
[0, 0, 1]
])
self.ref_embeds = np.array([
[1, 0, 0],
[0, 1, 0],
[0, 0, 1]
])
self.pred_embeds_random = np.random.rand(3, 3)
self.ref_embeds_random = np.random.rand(3, 3)
def test_cosine_similarity_perfect_match(self):
precision, recall = _compute_cosine_similarity(self.pred_embeds, self.ref_embeds)
# Expected values are 1.0 for both precision and recall since embeddings are identical
self.assertAlmostEqual(precision, 1.0, places=5)
self.assertAlmostEqual(recall, 1.0, places=5)
def _test_cosine_similarity_base(self, pred_embeds, ref_embeds):
precision, recall = _compute_cosine_similarity(pred_embeds, ref_embeds)
# Calculate expected precision and recall using sklearn's cosine similarity function
cosine_scores = cosine_similarity(pred_embeds, ref_embeds)
expected_precision = np.mean(np.max(cosine_scores, axis=-1)).item()
expected_recall = np.mean(np.max(cosine_scores, axis=0)).item()
self.assertAlmostEqual(precision, expected_precision, places=5)
self.assertAlmostEqual(recall, expected_recall, places=5)
def test_cosine_similarity_random(self):
self._test_cosine_similarity_base(self.pred_embeds_random, self.ref_embeds_random)
def test_cosine_similarity_different_shapes(self):
pred_embeds_diff = np.random.rand(5, 3)
ref_embeds_diff = np.random.rand(3, 3)
self._test_cosine_similarity_base(pred_embeds_diff, ref_embeds_diff)
class TestValidateInputFormat(unittest.TestCase):
def setUp(self):
# Sample predictions and references for different scenarios where number of samples = 1
# Note: Naming Convention: # When tokenize_sentences = True (i.e. input is untokenized) and vice-versa
# When tokenize_sentences = True (untokenized input) and multi_references = False
self.untokenized_single_reference_predictions = [
"This is a prediction sentence 1. This is a prediction sentence 2."
]
self.untokenized_single_reference_references = [
"This is a reference sentence 1. This is a reference sentence 2."
]
# When tokenize_sentences = False (tokenized input) and multi_references = False
self.tokenized_single_reference_predictions = [
["This is a prediction sentence 1.", "This is a prediction sentence 2."]
]
self.tokenized_single_reference_references = [
["This is a reference sentence 1.", "This is a reference sentence 2."]
]
# When tokenize_sentences = True (untokenized input) and multi_references = True
self.untokenized_multi_reference_predictions = [
"This is a prediction sentence 1. This is a prediction sentence 2."
]
self.untokenized_multi_reference_references = [
[
"This is a reference sentence 1. This is a reference sentence 2.",
"Another reference sentence."
]
]
# When tokenize_sentences = False (tokenized input) and multi_references = True
self.tokenized_multi_reference_predictions = [
["This is a prediction sentence 1.", "This is a prediction sentence 2."]
]
self.tokenized_multi_reference_references = [
[
["This is a reference sentence 1.", "This is a reference sentence 2."],
["Another reference sentence."]
]
]
def test_tokenized_sentences_true_multi_references_true(self):
# Invalid format should raise an error
with self.assertRaises(ValueError):
_validate_input_format(
True,
True,
self.tokenized_single_reference_predictions,
self.tokenized_single_reference_references,
)
# Valid format should pass without error
_validate_input_format(
True,
True,
self.untokenized_multi_reference_predictions,
self.untokenized_multi_reference_references,
)
def test_tokenized_sentences_false_multi_references_true(self):
# Invalid format should raise an error
with self.assertRaises(ValueError):
_validate_input_format(
False,
True,
self.untokenized_single_reference_predictions,
self.untokenized_multi_reference_references,
)
# Valid format should pass without error
_validate_input_format(
False,
True,
self.tokenized_multi_reference_predictions,
self.tokenized_multi_reference_references,
)
def test_tokenized_sentences_true_multi_references_false(self):
# Invalid format should raise an error
with self.assertRaises(ValueError):
_validate_input_format(
True,
False,
self.tokenized_single_reference_predictions,
self.tokenized_single_reference_references,
)
# Valid format should pass without error
_validate_input_format(
True,
False,
self.untokenized_single_reference_predictions,
self.untokenized_single_reference_references,
)
def test_tokenized_sentences_false_multi_references_false(self):
# Invalid format should raise an error
with self.assertRaises(ValueError):
_validate_input_format(
False,
False,
self.untokenized_single_reference_predictions,
self.untokenized_single_reference_references,
)
# Valid format should pass without error
_validate_input_format(
False,
False,
self.tokenized_single_reference_predictions,
self.tokenized_single_reference_references,
)
def test_mismatched_lengths(self):
# Length mismatch should raise an error
with self.assertRaises(ValueError):
_validate_input_format(
True,
True,
self.untokenized_single_reference_predictions,
[self.untokenized_single_reference_predictions[0], self.untokenized_single_reference_predictions[0]],
)
if __name__ == '__main__':
unittest.main(verbosity=2)