File size: 4,793 Bytes
ee21b96
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import contextlib
import unittest
import tempfile
from io import StringIO

import numpy as np

from tests.utils import create_dummy_data, preprocess_lm_data, train_language_model

try:
    from pyarrow import plasma
    from fairseq.data.plasma_utils import PlasmaView, PlasmaStore

    PYARROW_AVAILABLE = True
except ImportError:
    PYARROW_AVAILABLE = False

dummy_path = "dummy"


@unittest.skipUnless(PYARROW_AVAILABLE, "")
class TestPlasmaView(unittest.TestCase):
    def setUp(self) -> None:
        self.tmp_file = tempfile.NamedTemporaryFile()  # noqa: P201
        self.path = self.tmp_file.name
        self.server = PlasmaStore.start(path=self.path, nbytes=10000)
        self.client = plasma.connect(self.path, num_retries=10)

    def tearDown(self) -> None:
        self.client.disconnect()
        self.tmp_file.close()
        self.server.kill()

    def test_two_servers_do_not_share_object_id_space(self):
        data_server_1 = np.array([0, 1])
        data_server_2 = np.array([2, 3])
        server_2_path = self.path
        with tempfile.NamedTemporaryFile() as server_1_path:
            server = PlasmaStore.start(path=server_1_path.name, nbytes=10000)
            arr1 = PlasmaView(
                data_server_1, dummy_path, 1, plasma_path=server_1_path.name
            )
            assert len(arr1.client.list()) == 1
            assert (arr1.array == data_server_1).all()
            arr2 = PlasmaView(data_server_2, dummy_path, 1, plasma_path=server_2_path)
            assert (arr2.array == data_server_2).all()
            assert (arr1.array == data_server_1).all()
            server.kill()

    def test_hash_collision(self):
        data_server_1 = np.array([0, 1])
        data_server_2 = np.array([2, 3])
        arr1 = PlasmaView(data_server_1, dummy_path, 1, plasma_path=self.path)
        assert len(arr1.client.list()) == 1
        arr2 = PlasmaView(data_server_2, dummy_path, 1, plasma_path=self.path)
        assert len(arr1.client.list()) == 1
        assert len(arr2.client.list()) == 1
        assert (arr2.array == data_server_1).all()
        # New hash key based on tuples
        arr3 = PlasmaView(
            data_server_2, dummy_path, (1, 12312312312, None), plasma_path=self.path
        )
        assert (
            len(arr2.client.list()) == 2
        ), "No new object was created by using a novel hash key"
        assert (
            arr3.object_id in arr2.client.list()
        ), "No new object was created by using a novel hash key"
        assert (
            arr3.object_id in arr3.client.list()
        ), "No new object was created by using a novel hash key"
        del arr3, arr2, arr1

    @staticmethod
    def _assert_view_equal(pv1, pv2):
        np.testing.assert_array_equal(pv1.array, pv2.array)

    def test_putting_same_array_twice(self):
        data = np.array([4, 4, 4])
        arr1 = PlasmaView(data, dummy_path, 1, plasma_path=self.path)
        assert len(self.client.list()) == 1
        arr1b = PlasmaView(
            data, dummy_path, 1, plasma_path=self.path
        )  # should not change contents of store
        arr1c = PlasmaView(
            None, dummy_path, 1, plasma_path=self.path
        )  # should not change contents of store

        assert len(self.client.list()) == 1
        self._assert_view_equal(arr1, arr1b)
        self._assert_view_equal(arr1, arr1c)
        PlasmaView(
            data, dummy_path, 2, plasma_path=self.path
        )  # new object id, adds new entry
        assert len(self.client.list()) == 2

        new_client = plasma.connect(self.path)
        assert len(new_client.list()) == 2  # new client can access same objects
        assert isinstance(arr1.object_id, plasma.ObjectID)
        del arr1b
        del arr1c

    def test_plasma_store_full_raises(self):
        with tempfile.NamedTemporaryFile() as new_path:
            server = PlasmaStore.start(path=new_path.name, nbytes=10000)
            with self.assertRaises(plasma.PlasmaStoreFull):
                # 2000 floats is more than 2000 bytes
                PlasmaView(
                    np.random.rand(10000, 1), dummy_path, 1, plasma_path=new_path.name
                )
            server.kill()

    def test_object_id_overflow(self):
        PlasmaView.get_object_id("", 2 ** 21)

    def test_training_lm_plasma(self):
        with contextlib.redirect_stdout(StringIO()):
            with tempfile.TemporaryDirectory("test_transformer_lm") as data_dir:
                create_dummy_data(data_dir)
                preprocess_lm_data(data_dir)
                train_language_model(
                    data_dir,
                    "transformer_lm",
                    ["--use-plasma-view", "--plasma-path", self.path],
                    run_validation=True,
                )