csaybar commited on
Commit
307a330
1 Parent(s): caa772d

Update benchmark.py

Browse files
Files changed (1) hide show
  1. benchmark.py +18 -8
benchmark.py CHANGED
@@ -3,15 +3,16 @@ import pathlib
3
  import opensr_test
4
  import matplotlib.pyplot as plt
5
 
6
- from typing import Callable
7
 
8
 
9
  def create_geotiff(
10
  model: Callable,
11
  fn: Callable,
12
- datasets: list,
13
  output_path: str,
14
- force: bool = False
 
15
  ) -> None:
16
  """Create all the GeoTIFFs for a specific dataset snippet
17
 
@@ -26,13 +27,18 @@ def create_geotiff(
26
  force (bool, optional): If True, the dataset is redownloaded. Defaults
27
  to False.
28
  """
 
 
 
 
29
  for snippet in datasets:
30
  create_geotiff_batch(
31
  model=model,
32
  fn=fn,
33
  snippet=snippet,
34
  output_path=output_path,
35
- force=force
 
36
  )
37
 
38
  return None
@@ -42,7 +48,8 @@ def create_geotiff_batch(
42
  fn: Callable,
43
  snippet: str,
44
  output_path: str,
45
- force: bool = False
 
46
  ) -> pathlib.Path:
47
  """Create all the GeoTIFFs for a specific dataset snippet
48
 
@@ -72,7 +79,7 @@ def create_geotiff_batch(
72
  output_path_dataset_png.mkdir(parents=True, exist_ok=True)
73
 
74
  # Load the dataset
75
- dataset = opensr_test.load(snippet, force=False)
76
  lr_dataset, hr_dataset, metadata = dataset["L2A"], dataset["HRharm"], dataset["metadata"]
77
  for index in range(len(lr_dataset)):
78
  print(f"Processing {index}/{len(lr_dataset)}")
@@ -81,7 +88,8 @@ def create_geotiff_batch(
81
  results = fn(
82
  model=model,
83
  lr=lr_dataset[index],
84
- hr=hr_dataset[index]
 
85
  )
86
 
87
  # Get the image name
@@ -165,4 +173,6 @@ def plot(
165
  pathlib.Path: The output path where the plots and tables are
166
  saved.
167
  """
168
- pass
 
 
 
3
  import opensr_test
4
  import matplotlib.pyplot as plt
5
 
6
+ from typing import Callable, Union
7
 
8
 
9
  def create_geotiff(
10
  model: Callable,
11
  fn: Callable,
12
+ datasets: Union[str, list],
13
  output_path: str,
14
+ force: bool = False,
15
+ **kwargs
16
  ) -> None:
17
  """Create all the GeoTIFFs for a specific dataset snippet
18
 
 
27
  force (bool, optional): If True, the dataset is redownloaded. Defaults
28
  to False.
29
  """
30
+
31
+ if datasets == "all":
32
+ datasets = opensr_test.datasets
33
+
34
  for snippet in datasets:
35
  create_geotiff_batch(
36
  model=model,
37
  fn=fn,
38
  snippet=snippet,
39
  output_path=output_path,
40
+ force=force,
41
+ **kwargs
42
  )
43
 
44
  return None
 
48
  fn: Callable,
49
  snippet: str,
50
  output_path: str,
51
+ force: bool = False,
52
+ **kwargs
53
  ) -> pathlib.Path:
54
  """Create all the GeoTIFFs for a specific dataset snippet
55
 
 
79
  output_path_dataset_png.mkdir(parents=True, exist_ok=True)
80
 
81
  # Load the dataset
82
+ dataset = opensr_test.load(snippet, force=force)
83
  lr_dataset, hr_dataset, metadata = dataset["L2A"], dataset["HRharm"], dataset["metadata"]
84
  for index in range(len(lr_dataset)):
85
  print(f"Processing {index}/{len(lr_dataset)}")
 
88
  results = fn(
89
  model=model,
90
  lr=lr_dataset[index],
91
+ hr=hr_dataset[index],
92
+ **kwargs
93
  )
94
 
95
  # Get the image name
 
173
  pathlib.Path: The output path where the plots and tables are
174
  saved.
175
  """
176
+ pass
177
+
178
+