File size: 1,728 Bytes
3b96cb1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
from opencd.registry import TRANSFORMS


@TRANSFORMS.register_module()
class MultiImgLoadImageFromFile(MMCV_LoadImageFromFile):
	"""Load an image pair from files.

	Required Keys:

	- img_path

	Modified Keys:

	- img
	- img_shape
	- ori_shape

	"""

	def __init__(self, **kwargs) -> None:
		super().__init__(**kwargs)

	def transform(self, results: dict) -> Optional[dict]:
		"""Functions to load image.

		Args:
			results (dict): Result dict from
				:class:`mmengine.dataset.BaseDataset`.

		Returns:
			dict: The dict contains loaded image and meta information.
		"""

		filenames = results['img_path']
		imgs = []
		try:
			for filename in filenames:
				if self.file_client_args is not None:
					file_client = fileio.FileClient.infer_client(
						self.file_client_args, filename)
					img_bytes = file_client.get(filename)
				else:
					img_bytes = fileio.get(
						filename, backend_args=self.backend_args)
				img = mmcv.imfrombytes(
					img_bytes, flag=self.color_type, backend=self.imdecode_backend)
				if self.to_float32:
					img = img.astype(np.float32)
				imgs.append(img)
		except Exception as e:
			if self.ignore_empty:
				return None
			else:
				raise e

		results['img'] = imgs
		results['img_shape'] = imgs[0].shape[:2]
		results['ori_shape'] = imgs[0].shape[:2]
		return results

@TRANSFORMS.register_module()
class LoadMultiImageFromNDArray(MultiImgLoadImageFromFile):

    def transform(self, results: dict) -> dict:

        img = results['img']
        if self.to_float32:
            img = img.astype(np.float32)

        results['img_path'] = None
        results['img'] = img
        results['img_shape'] = img.shape[:2]
        results['ori_shape'] = img.shape[:2]
        return results