|
"""Array cache test.""" |
|
from pathlib import Path |
|
|
|
import numpy as np |
|
import pytest |
|
|
|
from manifest.caches.array_cache import ArrayCache |
|
|
|
|
|
def test_init(tmpdir: Path) -> None: |
|
"""Test cache initialization.""" |
|
cache = ArrayCache(Path(tmpdir)) |
|
assert (tmpdir / "hash2arrloc.sqlite").exists() |
|
assert cache.cur_file_idx == 0 |
|
assert cache.cur_offset == 0 |
|
|
|
|
|
def test_put_get(tmpdir: Path) -> None: |
|
"""Test putting and getting.""" |
|
cache = ArrayCache(tmpdir) |
|
cache.max_memmap_size = 5 |
|
arr = np.random.rand(10, 10) |
|
|
|
with pytest.raises(ValueError) as exc_info: |
|
cache.put("key", arr) |
|
assert str(exc_info.value) == ("Array is too large to be cached. Max is 5") |
|
|
|
cache.max_memmap_size = 120 |
|
cache.put("key", arr) |
|
assert np.allclose(cache.get("key"), arr) |
|
assert cache.get("key").dtype == arr.dtype |
|
assert cache.cur_file_idx == 0 |
|
assert cache.cur_offset == 100 |
|
assert cache.hash2arrloc["key"] == { |
|
"file_idx": 0, |
|
"offset": 0, |
|
"flatten_size": 100, |
|
"shape": (10, 10), |
|
"dtype": np.dtype("float64"), |
|
} |
|
|
|
arr2 = np.random.randint(0, 3, size=(10, 10)) |
|
cache.put("key2", arr2) |
|
assert np.allclose(cache.get("key2"), arr2) |
|
assert cache.get("key2").dtype == arr2.dtype |
|
assert cache.cur_file_idx == 1 |
|
assert cache.cur_offset == 100 |
|
assert cache.hash2arrloc["key2"] == { |
|
"file_idx": 1, |
|
"offset": 0, |
|
"flatten_size": 100, |
|
"shape": (10, 10), |
|
"dtype": np.dtype("int64"), |
|
} |
|
|
|
cache = ArrayCache(tmpdir) |
|
assert cache.hash2arrloc["key"] == { |
|
"file_idx": 0, |
|
"offset": 0, |
|
"flatten_size": 100, |
|
"shape": (10, 10), |
|
"dtype": np.dtype("float64"), |
|
} |
|
assert cache.hash2arrloc["key2"] == { |
|
"file_idx": 1, |
|
"offset": 0, |
|
"flatten_size": 100, |
|
"shape": (10, 10), |
|
"dtype": np.dtype("int64"), |
|
} |
|
assert np.allclose(cache.get("key"), arr) |
|
assert np.allclose(cache.get("key2"), arr2) |
|
|
|
|
|
def test_contains_key(tmpdir: Path) -> None: |
|
"""Test contains key.""" |
|
cache = ArrayCache(tmpdir) |
|
assert not cache.contains_key("key") |
|
arr = np.random.rand(10, 10) |
|
cache.put("key", arr) |
|
assert cache.contains_key("key") |
|
|