shunk031 commited on
Commit
1b569b1
·
1 Parent(s): 69b70f4

deploy: fc0c10e734116107123b6dce81a6df2cbbf84dfe

Browse files
Files changed (2) hide show
  1. layout-overlay.py +126 -0
  2. requirements.txt +89 -0
layout-overlay.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Union
2
+
3
+ import datasets as ds
4
+ import evaluate
5
+ import numpy as np
6
+ import numpy.typing as npt
7
+
8
+ _DESCRIPTION = r"""\
9
+ Computes the average IoU of all pairs of elements except for underlay.
10
+ """
11
+
12
+ _KWARGS_DESCRIPTION = """\
13
+ FIXME
14
+ """
15
+
16
+ _CITATION = """\
17
+ @inproceedings{hsu2023posterlayout,
18
+ title={Posterlayout: A new benchmark and approach for content-aware visual-textual presentation layout},
19
+ author={Hsu, Hsiao Yuan and He, Xiangteng and Peng, Yuxin and Kong, Hao and Zhang, Qing},
20
+ booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
21
+ pages={6018--6026},
22
+ year={2023}
23
+ }
24
+ """
25
+
26
+
27
+ class LayoutOverlay(evaluate.Metric):
28
+ def __init__(
29
+ self,
30
+ canvas_width: int,
31
+ canvas_height: int,
32
+ **kwargs,
33
+ ) -> None:
34
+ super().__init__(**kwargs)
35
+ self.canvas_width = canvas_width
36
+ self.canvas_height = canvas_height
37
+
38
+ def _info(self) -> evaluate.EvaluationModuleInfo:
39
+ return evaluate.MetricInfo(
40
+ description=_DESCRIPTION,
41
+ citation=_CITATION,
42
+ inputs_description=_KWARGS_DESCRIPTION,
43
+ features=ds.Features(
44
+ {
45
+ "predictions": ds.Sequence(ds.Sequence(ds.Value("float64"))),
46
+ "gold_labels": ds.Sequence(ds.Sequence(ds.Value("int64"))),
47
+ }
48
+ ),
49
+ codebase_urls=[
50
+ "https://github.com/PKU-ICST-MIPL/PosterLayout-CVPR2023/blob/main/eval.py#L205-L222",
51
+ ],
52
+ )
53
+
54
+ def get_rid_of_invalid(
55
+ self, predictions: npt.NDArray[np.float64], gold_labels: npt.NDArray[np.int64]
56
+ ) -> npt.NDArray[np.int64]:
57
+ assert len(predictions) == len(gold_labels)
58
+
59
+ w = self.canvas_width / 100
60
+ h = self.canvas_height / 100
61
+
62
+ for i, prediction in enumerate(predictions):
63
+ for j, b in enumerate(prediction):
64
+ xl, yl, xr, yr = b
65
+ xl = max(0, xl)
66
+ yl = max(0, yl)
67
+ xr = min(self.canvas_width, xr)
68
+ yr = min(self.canvas_height, yr)
69
+ if abs((xr - xl) * (yr - yl)) < w * h * 10:
70
+ if gold_labels[i, j]:
71
+ gold_labels[i, j] = 0
72
+ return gold_labels
73
+
74
+ def metrics_iou(
75
+ self, bb1: npt.NDArray[np.float64], bb2: npt.NDArray[np.float64]
76
+ ) -> float:
77
+ # shape: bb1 = (4,), bb2 = (4,)
78
+ xl_1, yl_1, xr_1, yr_1 = bb1
79
+ xl_2, yl_2, xr_2, yr_2 = bb2
80
+
81
+ w_1 = xr_1 - xl_1
82
+ w_2 = xr_2 - xl_2
83
+ h_1 = yr_1 - yl_1
84
+ h_2 = yr_2 - yl_2
85
+
86
+ w_inter = min(xr_1, xr_2) - max(xl_1, xl_2)
87
+ h_inter = min(yr_1, yr_2) - max(yl_1, yl_2)
88
+
89
+ a_1 = w_1 * h_1
90
+ a_2 = w_2 * h_2
91
+ a_inter = w_inter * h_inter
92
+ if w_inter <= 0 or h_inter <= 0:
93
+ a_inter = 0
94
+
95
+ return a_inter / (a_1 + a_2 - a_inter)
96
+
97
+ def _compute(
98
+ self,
99
+ *,
100
+ predictions: Union[npt.NDArray[np.float64], List[List[float]]],
101
+ gold_labels: Union[npt.NDArray[np.int64], List[int]],
102
+ ) -> float:
103
+ predictions = np.array(predictions)
104
+ gold_labels = np.array(gold_labels)
105
+
106
+ predictions[:, :, ::2] *= self.canvas_width
107
+ predictions[:, :, 1::2] *= self.canvas_height
108
+
109
+ gold_labels = self.get_rid_of_invalid(
110
+ predictions=predictions, gold_labels=gold_labels
111
+ )
112
+
113
+ score = 0.0
114
+
115
+ for gold_label, prediction in zip(gold_labels, predictions):
116
+ ove = 0.0
117
+ mask = (gold_label > 0).reshape(-1) & (gold_label != 3).reshape(-1)
118
+ mask_box = prediction[mask]
119
+ n = len(mask_box)
120
+ for i in range(n):
121
+ bb1 = mask_box[i]
122
+ for j in range(i + 1, n):
123
+ bb2 = mask_box[j]
124
+ ove += self.metrics_iou(bb1, bb2)
125
+ score += ove / n
126
+ return score / len(gold_labels)
requirements.txt ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiofiles==23.2.1 ; python_version >= "3.9" and python_version < "4.0"
2
+ aiohttp==3.9.3 ; python_version >= "3.9" and python_version < "4.0"
3
+ aiosignal==1.3.1 ; python_version >= "3.9" and python_version < "4.0"
4
+ altair==5.2.0 ; python_version >= "3.9" and python_version < "4.0"
5
+ annotated-types==0.6.0 ; python_version >= "3.9" and python_version < "4.0"
6
+ anyio==4.2.0 ; python_version >= "3.9" and python_version < "4.0"
7
+ arrow==1.3.0 ; python_version >= "3.9" and python_version < "4.0"
8
+ async-timeout==4.0.3 ; python_version >= "3.9" and python_version < "3.11"
9
+ attrs==23.2.0 ; python_version >= "3.9" and python_version < "4.0"
10
+ binaryornot==0.4.4 ; python_version >= "3.9" and python_version < "4.0"
11
+ certifi==2024.2.2 ; python_version >= "3.9" and python_version < "4.0"
12
+ chardet==5.2.0 ; python_version >= "3.9" and python_version < "4.0"
13
+ charset-normalizer==3.3.2 ; python_version >= "3.9" and python_version < "4.0"
14
+ click==8.1.7 ; python_version >= "3.9" and python_version < "4.0"
15
+ colorama==0.4.6 ; python_version >= "3.9" and python_version < "4.0"
16
+ contourpy==1.2.0 ; python_version >= "3.9" and python_version < "4.0"
17
+ cookiecutter==2.5.0 ; python_version >= "3.9" and python_version < "4.0"
18
+ cycler==0.12.1 ; python_version >= "3.9" and python_version < "4.0"
19
+ datasets==2.17.0 ; python_version >= "3.9" and python_version < "4.0"
20
+ dill==0.3.8 ; python_version >= "3.9" and python_version < "4.0"
21
+ evaluate[template]==0.4.1 ; python_version >= "3.9" and python_version < "4.0"
22
+ exceptiongroup==1.2.0 ; python_version >= "3.9" and python_version < "3.11"
23
+ fastapi==0.109.2 ; python_version >= "3.9" and python_version < "4.0"
24
+ ffmpy==0.3.1 ; python_version >= "3.9" and python_version < "4.0"
25
+ filelock==3.13.1 ; python_version >= "3.9" and python_version < "4.0"
26
+ fonttools==4.48.1 ; python_version >= "3.9" and python_version < "4.0"
27
+ frozenlist==1.4.1 ; python_version >= "3.9" and python_version < "4.0"
28
+ fsspec==2023.10.0 ; python_version >= "3.9" and python_version < "4.0"
29
+ fsspec[http]==2023.10.0 ; python_version >= "3.9" and python_version < "4.0"
30
+ gradio-client==0.10.0 ; python_version >= "3.9" and python_version < "4.0"
31
+ gradio==4.18.0 ; python_version >= "3.9" and python_version < "4.0"
32
+ h11==0.14.0 ; python_version >= "3.9" and python_version < "4.0"
33
+ httpcore==1.0.2 ; python_version >= "3.9" and python_version < "4.0"
34
+ httpx==0.26.0 ; python_version >= "3.9" and python_version < "4.0"
35
+ huggingface-hub==0.20.3 ; python_version >= "3.9" and python_version < "4.0"
36
+ idna==3.6 ; python_version >= "3.9" and python_version < "4.0"
37
+ importlib-resources==6.1.1 ; python_version >= "3.9" and python_version < "4.0"
38
+ jinja2==3.1.3 ; python_version >= "3.9" and python_version < "4.0"
39
+ jsonschema-specifications==2023.12.1 ; python_version >= "3.9" and python_version < "4.0"
40
+ jsonschema==4.21.1 ; python_version >= "3.9" and python_version < "4.0"
41
+ kiwisolver==1.4.5 ; python_version >= "3.9" and python_version < "4.0"
42
+ markdown-it-py==3.0.0 ; python_version >= "3.9" and python_version < "4.0"
43
+ markupsafe==2.1.5 ; python_version >= "3.9" and python_version < "4.0"
44
+ matplotlib==3.8.2 ; python_version >= "3.9" and python_version < "4.0"
45
+ mdurl==0.1.2 ; python_version >= "3.9" and python_version < "4.0"
46
+ multidict==6.0.5 ; python_version >= "3.9" and python_version < "4.0"
47
+ multiprocess==0.70.16 ; python_version >= "3.9" and python_version < "4.0"
48
+ numpy==1.26.4 ; python_version >= "3.9" and python_version < "4.0"
49
+ orjson==3.9.13 ; python_version >= "3.9" and python_version < "4.0"
50
+ packaging==23.2 ; python_version >= "3.9" and python_version < "4.0"
51
+ pandas==2.2.0 ; python_version >= "3.9" and python_version < "4.0"
52
+ pillow==10.2.0 ; python_version >= "3.9" and python_version < "4.0"
53
+ pyarrow-hotfix==0.6 ; python_version >= "3.9" and python_version < "4.0"
54
+ pyarrow==15.0.0 ; python_version >= "3.9" and python_version < "4.0"
55
+ pydantic-core==2.16.2 ; python_version >= "3.9" and python_version < "4.0"
56
+ pydantic==2.6.1 ; python_version >= "3.9" and python_version < "4.0"
57
+ pydub==0.25.1 ; python_version >= "3.9" and python_version < "4.0"
58
+ pygments==2.17.2 ; python_version >= "3.9" and python_version < "4.0"
59
+ pyparsing==3.1.1 ; python_version >= "3.9" and python_version < "4.0"
60
+ python-dateutil==2.8.2 ; python_version >= "3.9" and python_version < "4.0"
61
+ python-multipart==0.0.9 ; python_version >= "3.9" and python_version < "4.0"
62
+ python-slugify==8.0.4 ; python_version >= "3.9" and python_version < "4.0"
63
+ pytz==2024.1 ; python_version >= "3.9" and python_version < "4.0"
64
+ pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "4.0"
65
+ referencing==0.33.0 ; python_version >= "3.9" and python_version < "4.0"
66
+ requests==2.31.0 ; python_version >= "3.9" and python_version < "4.0"
67
+ responses==0.18.0 ; python_version >= "3.9" and python_version < "4.0"
68
+ rich==13.7.0 ; python_version >= "3.9" and python_version < "4.0"
69
+ rpds-py==0.17.1 ; python_version >= "3.9" and python_version < "4.0"
70
+ ruff==0.2.1 ; python_version >= "3.9" and python_version < "4.0"
71
+ semantic-version==2.10.0 ; python_version >= "3.9" and python_version < "4.0"
72
+ shellingham==1.5.4 ; python_version >= "3.9" and python_version < "4.0"
73
+ six==1.16.0 ; python_version >= "3.9" and python_version < "4.0"
74
+ sniffio==1.3.0 ; python_version >= "3.9" and python_version < "4.0"
75
+ starlette==0.36.3 ; python_version >= "3.9" and python_version < "4.0"
76
+ text-unidecode==1.3 ; python_version >= "3.9" and python_version < "4.0"
77
+ tomlkit==0.12.0 ; python_version >= "3.9" and python_version < "4.0"
78
+ toolz==0.12.1 ; python_version >= "3.9" and python_version < "4.0"
79
+ tqdm==4.66.2 ; python_version >= "3.9" and python_version < "4.0"
80
+ typer[all]==0.9.0 ; python_version >= "3.9" and python_version < "4.0"
81
+ types-python-dateutil==2.8.19.20240106 ; python_version >= "3.9" and python_version < "4.0"
82
+ typing-extensions==4.9.0 ; python_version >= "3.9" and python_version < "4.0"
83
+ tzdata==2024.1 ; python_version >= "3.9" and python_version < "4.0"
84
+ urllib3==2.2.0 ; python_version >= "3.9" and python_version < "4.0"
85
+ uvicorn==0.27.1 ; python_version >= "3.9" and python_version < "4.0"
86
+ websockets==11.0.3 ; python_version >= "3.9" and python_version < "4.0"
87
+ xxhash==3.4.1 ; python_version >= "3.9" and python_version < "4.0"
88
+ yarl==1.9.4 ; python_version >= "3.9" and python_version < "4.0"
89
+ zipp==3.17.0 ; python_version >= "3.9" and python_version < "3.10"