Spaces:
Build error
Build error
init
Browse files- .gitignore +2 -0
- CrossEncoder/cross_encoder.py +122 -0
- CrossEncoder/cross_encoder_env.yml +53 -0
- DiT_Extractor/base_utils.py +378 -0
- DiT_Extractor/dit_object_detection/README.md +120 -0
- DiT_Extractor/dit_object_detection/ditod/__init__.py +11 -0
- DiT_Extractor/dit_object_detection/ditod/backbone.py +156 -0
- DiT_Extractor/dit_object_detection/ditod/beit.py +671 -0
- DiT_Extractor/dit_object_detection/ditod/config.py +32 -0
- DiT_Extractor/dit_object_detection/ditod/deit.py +476 -0
- DiT_Extractor/dit_object_detection/publaynet_configs/Base-RCNN-FPN.yaml +69 -0
- DiT_Extractor/dit_object_detection/publaynet_configs/cascade/cascade_dit_base.yaml +20 -0
- DiT_Extractor/dit_object_detection/publaynet_configs/cascade/cascade_dit_large.yaml +28 -0
- DiT_Extractor/dit_object_detection/publaynet_configs/maskrcnn/maskrcnn_dit_base.yaml +15 -0
- DiT_Extractor/dit_object_detection/publaynet_configs/maskrcnn/maskrcnn_dit_large.yaml +22 -0
- DiT_Extractor/dit_runner.py +158 -0
- DiT_Extractor/sentence_extractor.py +136 -0
- LICENSE +207 -0
- NOTICE +21 -0
- README.md +14 -4
- UnifiedQA/demo_QA.py +180 -0
- app.py +120 -0
- env_setup.sh +32 -0
- examples/1810.04805.pdf +0 -0
- examples/1909.00694.pdf +0 -0
- examples/2105.03011.pdf +0 -0
- ms-marco-electra-base/CEBinaryClassificationEvaluator_MS-Marco_results.csv +43 -0
- ms-marco-electra-base/README.md +64 -0
- ms-marco-electra-base/config.json +31 -0
- ms-marco-electra-base/pytorch_model.bin +3 -0
- ms-marco-electra-base/special_tokens_map.json +1 -0
- ms-marco-electra-base/tokenizer_config.json +1 -0
- ms-marco-electra-base/vocab.txt +0 -0
- packages.txt +1 -0
- requirements.txt +13 -0
.gitignore
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
.ipynb_checkpoints
|
2 |
+
__pycache__
|
CrossEncoder/cross_encoder.py
ADDED
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2022, Lawrence Livermore National Security, LLC.
|
2 |
+
# All rights reserved.
|
3 |
+
# See the top-level LICENSE and NOTICE files for details.
|
4 |
+
# LLNL-CODE-838964
|
5 |
+
|
6 |
+
# SPDX-License-Identifier: Apache-2.0-with-LLVM-exception
|
7 |
+
|
8 |
+
from sentence_transformers.cross_encoder import CrossEncoder as CE
|
9 |
+
import numpy as np
|
10 |
+
from typing import List, Dict, Tuple
|
11 |
+
import json
|
12 |
+
from collections import defaultdict
|
13 |
+
import os
|
14 |
+
|
15 |
+
|
16 |
+
class CrossEncoder:
|
17 |
+
def __init__(self,
|
18 |
+
model_path: str = None,
|
19 |
+
max_length: int = None,
|
20 |
+
**kwargs):
|
21 |
+
|
22 |
+
if max_length != None:
|
23 |
+
self.model = CE(model_path, max_length = max_length, **kwargs)
|
24 |
+
|
25 |
+
self.model = CE(model_path, **kwargs)
|
26 |
+
|
27 |
+
|
28 |
+
def predict(self,
|
29 |
+
sentences: List[Tuple[str, str]],
|
30 |
+
batch_size: int = 32,
|
31 |
+
show_progress_bar: bool = False) -> List[float]:
|
32 |
+
|
33 |
+
return self.model.predict(sentences = sentences,
|
34 |
+
batch_size = batch_size,
|
35 |
+
show_progress_bar = show_progress_bar)
|
36 |
+
|
37 |
+
|
38 |
+
class CERank:
|
39 |
+
|
40 |
+
def __init__(self, model, batch_size: int =128, **kwargs):
|
41 |
+
self.cross_encoder = model
|
42 |
+
self.batch_size = batch_size
|
43 |
+
|
44 |
+
|
45 |
+
def flatten_examples(self, contexts: Dict[str, Dict], question: str):
|
46 |
+
|
47 |
+
text_pairs, pair_ids = [], []
|
48 |
+
for context_id, context in contexts.items():
|
49 |
+
pair_ids.append(['question_0', context_id])
|
50 |
+
text_pairs.append([question, context['text']])
|
51 |
+
|
52 |
+
return text_pairs, pair_ids
|
53 |
+
|
54 |
+
def group_questionrank(self, pair_ids, rank_scores):
|
55 |
+
|
56 |
+
unsorted = defaultdict(list)
|
57 |
+
for pair, score in zip(pair_ids, rank_scores):
|
58 |
+
query_id, paragraph_id = pair[0], pair[1]
|
59 |
+
unsorted[query_id].append((paragraph_id, score))
|
60 |
+
|
61 |
+
|
62 |
+
return unsorted
|
63 |
+
|
64 |
+
def get_rankings(self, pair_ids, rank_scores, text_pairs):
|
65 |
+
|
66 |
+
unsorted_ranks = self.group_questionrank(pair_ids, rank_scores)
|
67 |
+
rankings = defaultdict(dict)
|
68 |
+
|
69 |
+
for idx, (query_id, ranks) in enumerate(unsorted_ranks.items()):
|
70 |
+
sort_ranks = sorted(ranks, key = lambda item: item[1], reverse = True)
|
71 |
+
sorted_ranks, scores = list(zip(*sort_ranks))
|
72 |
+
rankings[query_id]['text'] = text_pairs[idx][0]
|
73 |
+
rankings[query_id]['scores'] = list(scores)
|
74 |
+
rankings[query_id]['ranks'] = list(sorted_ranks)
|
75 |
+
|
76 |
+
return rankings
|
77 |
+
|
78 |
+
|
79 |
+
def rank(self,
|
80 |
+
contexts: Dict[str, Dict],
|
81 |
+
question: str):
|
82 |
+
|
83 |
+
|
84 |
+
text_pairs, pair_ids = self.flatten_examples(contexts, question)
|
85 |
+
rank_scores = [float(score) for score in self.cross_encoder.predict(text_pairs, batch_size = self.batch_size)]
|
86 |
+
full_results = self.get_rankings(pair_ids, rank_scores, text_pairs)
|
87 |
+
|
88 |
+
return full_results
|
89 |
+
|
90 |
+
|
91 |
+
|
92 |
+
def get_ranked_contexts(context_json, question):
|
93 |
+
|
94 |
+
dirname = 'examples'
|
95 |
+
model_path = '/data/actici/pretrained_weights/ms-marco-electra-base'
|
96 |
+
max_length = 512
|
97 |
+
|
98 |
+
# Can't use use_fast (fast tokenizers) while gradio is running, causes conflict with tokenizer multiprocessing/parallelism.
|
99 |
+
cross_encoder = CrossEncoder(model_path, max_length, tokenizer_args={'use_fast':False})
|
100 |
+
ranker = CERank(cross_encoder)
|
101 |
+
|
102 |
+
with open(context_json, 'r') as fin:
|
103 |
+
contexts = json.load(fin)
|
104 |
+
|
105 |
+
rankings = ranker.rank(contexts, question)
|
106 |
+
|
107 |
+
with open('ranked_{0}.json'.format(context_json[:-5]), 'w') as fout:
|
108 |
+
json.dump(rankings, fout)
|
109 |
+
|
110 |
+
def get_ranked_contexts_in_memory(contexts, question):
|
111 |
+
|
112 |
+
dirname = 'examples'
|
113 |
+
model_path = '/data/actici/pretrained_weights/ms-marco-electra-base'
|
114 |
+
max_length = 512
|
115 |
+
|
116 |
+
# Can't use use_fast (fast tokenizers) while gradio is running, causes conflict with tokenizer multiprocessing/parallelism.
|
117 |
+
cross_encoder = CrossEncoder(model_path, max_length, tokenizer_args={'use_fast':False})
|
118 |
+
ranker = CERank(cross_encoder)
|
119 |
+
|
120 |
+
rankings = ranker.rank(contexts, question)
|
121 |
+
|
122 |
+
return rankings
|
CrossEncoder/cross_encoder_env.yml
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: cross_encoder_env
|
2 |
+
channels:
|
3 |
+
- defaults
|
4 |
+
dependencies:
|
5 |
+
- _libgcc_mutex=0.1=main
|
6 |
+
- _openmp_mutex=5.1=1_gnu
|
7 |
+
- ca-certificates=2022.4.26=h06a4308_0
|
8 |
+
- certifi=2022.6.15=py39h06a4308_0
|
9 |
+
- ld_impl_linux-64=2.38=h1181459_1
|
10 |
+
- libffi=3.3=he6710b0_2
|
11 |
+
- libgcc-ng=11.2.0=h1234567_1
|
12 |
+
- libgomp=11.2.0=h1234567_1
|
13 |
+
- libstdcxx-ng=11.2.0=h1234567_1
|
14 |
+
- ncurses=6.3=h7f8727e_2
|
15 |
+
- openssl=1.1.1o=h7f8727e_0
|
16 |
+
- pip=21.2.4=py39h06a4308_0
|
17 |
+
- python=3.9.12=h12debd9_1
|
18 |
+
- readline=8.1.2=h7f8727e_1
|
19 |
+
- setuptools=61.2.0=py39h06a4308_0
|
20 |
+
- sqlite=3.38.5=hc218d9a_0
|
21 |
+
- tk=8.6.12=h1ccaba5_0
|
22 |
+
- tzdata=2022a=hda174b7_0
|
23 |
+
- wheel=0.37.1=pyhd3eb1b0_0
|
24 |
+
- xz=5.2.5=h7f8727e_1
|
25 |
+
- zlib=1.2.12=h7f8727e_2
|
26 |
+
- pip:
|
27 |
+
- charset-normalizer==2.0.12
|
28 |
+
- click==8.1.3
|
29 |
+
- filelock==3.7.1
|
30 |
+
- huggingface-hub==0.8.1
|
31 |
+
- idna==3.3
|
32 |
+
- joblib==1.1.0
|
33 |
+
- nltk==3.7
|
34 |
+
- numpy==1.23.0
|
35 |
+
- packaging==21.3
|
36 |
+
- pillow==9.1.1
|
37 |
+
- pyparsing==3.0.9
|
38 |
+
- pyyaml==6.0
|
39 |
+
- regex==2022.6.2
|
40 |
+
- requests==2.28.0
|
41 |
+
- scikit-learn==1.1.1
|
42 |
+
- scipy==1.8.1
|
43 |
+
- sentence-transformers==2.2.2
|
44 |
+
- sentencepiece==0.1.96
|
45 |
+
- threadpoolctl==3.1.0
|
46 |
+
- tokenizers==0.12.1
|
47 |
+
- torch==1.11.0
|
48 |
+
- torchvision==0.12.0
|
49 |
+
- tqdm==4.64.0
|
50 |
+
- transformers==4.20.1
|
51 |
+
- typing-extensions==4.2.0
|
52 |
+
- urllib3==1.26.9
|
53 |
+
prefix: /home/ordonez2/miniconda3/envs/cross_encoder
|
DiT_Extractor/base_utils.py
ADDED
@@ -0,0 +1,378 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2022, Lawrence Livermore National Security, LLC.
|
2 |
+
# All rights reserved.
|
3 |
+
# See the top-level LICENSE and NOTICE files for details.
|
4 |
+
# LLNL-CODE-838964
|
5 |
+
|
6 |
+
# SPDX-License-Identifier: Apache-2.0-with-LLVM-exception
|
7 |
+
|
8 |
+
from pdfminer.pdfpage import PDFParser
|
9 |
+
from pdfminer.pdfpage import PDFDocument
|
10 |
+
from pdfminer.pdfpage import PDFPage
|
11 |
+
from pdfminer.layout import LTTextBoxHorizontal
|
12 |
+
from pdfminer.layout import LTTextLineHorizontal
|
13 |
+
from pdfminer.layout import LTChar
|
14 |
+
from pdfminer.layout import LAParams
|
15 |
+
from pdfminer.layout import LTRect
|
16 |
+
from pdfminer.layout import LTFigure
|
17 |
+
|
18 |
+
from pdfminer.converter import PDFPageAggregator
|
19 |
+
from pdfminer.pdfinterp import PDFResourceManager
|
20 |
+
from pdfminer.pdfinterp import PDFPageInterpreter
|
21 |
+
from pdfminer import pdfinterp
|
22 |
+
|
23 |
+
from collections.abc import Iterable
|
24 |
+
from collections import Counter
|
25 |
+
from collections import OrderedDict
|
26 |
+
|
27 |
+
import os
|
28 |
+
|
29 |
+
# This is use for highlighting in PDFs
|
30 |
+
from PyPDF2.generic import (
|
31 |
+
DictionaryObject,
|
32 |
+
NumberObject,
|
33 |
+
FloatObject,
|
34 |
+
NameObject,
|
35 |
+
TextStringObject,
|
36 |
+
ArrayObject
|
37 |
+
)
|
38 |
+
|
39 |
+
# Used to extract pages
|
40 |
+
from PyPDF2 import PdfFileReader, PdfFileWriter
|
41 |
+
|
42 |
+
def get_page_sizes(document):
|
43 |
+
parser = PDFParser(open(document, 'rb'))
|
44 |
+
doc = PDFDocument(parser)
|
45 |
+
pageSizesList = []
|
46 |
+
for page in PDFPage.create_pages(doc):
|
47 |
+
# the media box that is the page size as list of 4 integers x0 y0 x1 y1
|
48 |
+
pageSizesList.append(page.mediabox) # <- appending
|
49 |
+
return pageSizesList
|
50 |
+
|
51 |
+
def get_page_count(document):
|
52 |
+
# Is there a better way of getting the page count than doing this?
|
53 |
+
parser = PDFParser(document)
|
54 |
+
tmpdoc = PDFDocument(parser)
|
55 |
+
page_count = pdfinterp.resolve1(tmpdoc.catalog['Pages'])['Count']
|
56 |
+
return page_count
|
57 |
+
|
58 |
+
def get_pdf_page_count(filename):
|
59 |
+
with open(filename, 'rb') as document:
|
60 |
+
return get_page_count(document)
|
61 |
+
|
62 |
+
def get_pages(document, page_numbers = None):
|
63 |
+
#Create resource manager
|
64 |
+
rsrcmgr = PDFResourceManager()
|
65 |
+
# Set parameters for analysis.
|
66 |
+
laparams = LAParams()
|
67 |
+
# Create a PDF page aggregator object.
|
68 |
+
device = PDFPageAggregator(rsrcmgr, laparams=laparams)
|
69 |
+
interpreter = PDFPageInterpreter(rsrcmgr, device)
|
70 |
+
|
71 |
+
page_count = get_page_count(document)
|
72 |
+
|
73 |
+
if page_numbers is None:
|
74 |
+
page_numbers = range(page_count)
|
75 |
+
|
76 |
+
for page, page_number in zip(PDFPage.get_pages(document, page_numbers), page_numbers):
|
77 |
+
interpreter.process_page(page)
|
78 |
+
# receive the LTPage object for the page.
|
79 |
+
layout = device.get_result()
|
80 |
+
#print("Yield page:", page_number)
|
81 |
+
yield layout, page_number
|
82 |
+
|
83 |
+
def partial_overlaps(box, other):
|
84 |
+
"""
|
85 |
+
Determine if the two bounding boxes overlap eachother.
|
86 |
+
TODO: Really should just use a standard Python library for this.
|
87 |
+
|
88 |
+
box -- 2 coordinate bounding box (x1,y1,x2,y2)
|
89 |
+
other -- 2 coordinate bounding box (x1,y1,x2,y2)
|
90 |
+
"""
|
91 |
+
# a1 x1 a2 x2
|
92 |
+
# <------------------>
|
93 |
+
x_intersects = (other[0] < box[0] and other[2] > box[0]) or (
|
94 |
+
other[0] < box[2] and other[2] > box[2])
|
95 |
+
y_intersects = (other[1] < box[1] and other[3] > box[1]) or (
|
96 |
+
other[1] < box[3] and other[3] > box[3])
|
97 |
+
|
98 |
+
intersects = x_intersects or y_intersects
|
99 |
+
# TODO: Simplify?
|
100 |
+
return intersects and overlaps(box, other)
|
101 |
+
#return intersects
|
102 |
+
|
103 |
+
def overlaps(box, other):
|
104 |
+
"""
|
105 |
+
Determine if the two bounding boxes overlap eachother.
|
106 |
+
TODO: Really should just use a standard Python library for this.
|
107 |
+
|
108 |
+
box -- 2 coordinate bounding box (x1,y1,x2,y2)
|
109 |
+
other -- 2 coordinate bounding box (x1,y1,x2,y2)
|
110 |
+
"""
|
111 |
+
x_intersects = box[0] > other[2] or box[2] < other[0]
|
112 |
+
y_intersects = box[1] > other[3] or box[3] < other[1]
|
113 |
+
|
114 |
+
intersects = not (x_intersects or y_intersects)
|
115 |
+
return intersects
|
116 |
+
|
117 |
+
def union(src, other):
|
118 |
+
"""
|
119 |
+
Expand src by union of other bbox
|
120 |
+
|
121 |
+
src -- 2 coordinate bounding box (x1,y1,x2,y2)
|
122 |
+
other -- 2 coordinate bounding box (x1,y1,x2,y2)
|
123 |
+
|
124 |
+
returns union of src and other
|
125 |
+
"""
|
126 |
+
xmin = min(src[0], other[0])
|
127 |
+
ymin = min(src[1], other[1])
|
128 |
+
xmax = max(src[2], other[2])
|
129 |
+
ymax = max(src[3], other[3])
|
130 |
+
|
131 |
+
return [xmin, ymin, xmax, ymax]
|
132 |
+
|
133 |
+
|
134 |
+
|
135 |
+
# See: https://gist.github.com/agentcooper/4c55133f5d95866acdee5017cd318558#file-pypdf2highlight-py
|
136 |
+
# x1, y1 starts in bottom left corner
|
137 |
+
def createHighlight(x1, y1, x2, y2, meta, color = [1, 0, 0]):
|
138 |
+
newHighlight = DictionaryObject()
|
139 |
+
|
140 |
+
newHighlight.update({
|
141 |
+
NameObject("/F"): NumberObject(4),
|
142 |
+
NameObject("/Type"): NameObject("/Annot"),
|
143 |
+
NameObject("/Subtype"): NameObject("/Highlight"),
|
144 |
+
|
145 |
+
NameObject("/T"): TextStringObject(meta["author"]),
|
146 |
+
NameObject("/Contents"): TextStringObject(meta["contents"]),
|
147 |
+
|
148 |
+
NameObject("/C"): ArrayObject([FloatObject(c) for c in color]),
|
149 |
+
NameObject("/Rect"): ArrayObject([
|
150 |
+
FloatObject(x1),
|
151 |
+
FloatObject(y1),
|
152 |
+
FloatObject(x2),
|
153 |
+
FloatObject(y2)
|
154 |
+
]),
|
155 |
+
NameObject("/QuadPoints"): ArrayObject([
|
156 |
+
FloatObject(x1),
|
157 |
+
FloatObject(y2),
|
158 |
+
FloatObject(x2),
|
159 |
+
FloatObject(y2),
|
160 |
+
FloatObject(x1),
|
161 |
+
FloatObject(y1),
|
162 |
+
FloatObject(x2),
|
163 |
+
FloatObject(y1)
|
164 |
+
]),
|
165 |
+
})
|
166 |
+
|
167 |
+
return newHighlight
|
168 |
+
|
169 |
+
def addHighlightToPage(highlight, page, output):
|
170 |
+
highlight_ref = output._addObject(highlight);
|
171 |
+
|
172 |
+
if "/Annots" in page:
|
173 |
+
page[NameObject("/Annots")].append(highlight_ref)
|
174 |
+
else:
|
175 |
+
page[NameObject("/Annots")] = ArrayObject([highlight_ref])
|
176 |
+
|
177 |
+
def get_pdf_words(document, page_numbers=None):
|
178 |
+
"""
|
179 |
+
Get all words from LTChar or LTTextLineHorizontal objects from the document.
|
180 |
+
|
181 |
+
:param document: string path of the PDF file to process
|
182 |
+
:returns: A map of page #'s containing lists of coordinates and PDFMiner
|
183 |
+
objects. Ex.: {page_number: [[x1, y1, x2, y2, <LTTextLineHorizontal>],]}
|
184 |
+
"""
|
185 |
+
pdf_doc = open(document, 'rb')
|
186 |
+
|
187 |
+
bboxes = {}
|
188 |
+
for layout, page in get_pages(pdf_doc, page_numbers):
|
189 |
+
#print(element.get_text())
|
190 |
+
bboxes[page] = []
|
191 |
+
for element in layout:
|
192 |
+
if not isinstance(element, Iterable):
|
193 |
+
continue # not iterable
|
194 |
+
for subElement in element:
|
195 |
+
#print('Subelement type:', type(subElement))
|
196 |
+
if isinstance(subElement, LTChar):
|
197 |
+
if (subElement.get_text() == ' '):
|
198 |
+
pass # TODO: Handle word deliminator
|
199 |
+
# Print the character in this class
|
200 |
+
# print(subElement.get_text(), end='')
|
201 |
+
item = list(subElement.bbox)
|
202 |
+
item.append(subElement)
|
203 |
+
bboxes[page].append(item)
|
204 |
+
elif isinstance(subElement, LTTextLineHorizontal):
|
205 |
+
#print(subElement.bbox)
|
206 |
+
item = list(subElement.bbox)
|
207 |
+
item.append(subElement)
|
208 |
+
bboxes[page].append(item)
|
209 |
+
else:
|
210 |
+
pass
|
211 |
+
return bboxes
|
212 |
+
|
213 |
+
def get_paragraphs(words):
|
214 |
+
paragraph_tolerance = 0.1
|
215 |
+
max_height_diff = 1
|
216 |
+
paragraphs = []
|
217 |
+
|
218 |
+
for page, elements in words.items():
|
219 |
+
# Find nominal font size
|
220 |
+
# Round to int
|
221 |
+
freq = Counter()
|
222 |
+
for element in elements:
|
223 |
+
height = int(element[3] - element[1])
|
224 |
+
#print(height,end=' ')
|
225 |
+
freq[height] += 1
|
226 |
+
|
227 |
+
nominal_font = freq.most_common(1)[0][0]
|
228 |
+
print("Nominal font is:", nominal_font)
|
229 |
+
|
230 |
+
print("Page:", page)
|
231 |
+
x_offset_prev_line = None
|
232 |
+
prev_x_offset = None
|
233 |
+
prev_y_offset = None
|
234 |
+
paragraph_content = ""
|
235 |
+
#print("Element count:", len(elements))
|
236 |
+
first_line = False
|
237 |
+
processed_first_line = False
|
238 |
+
|
239 |
+
for element in elements:
|
240 |
+
x_offset = element[0]
|
241 |
+
y_offset = element[1]
|
242 |
+
height = int(element[3] - element[1])
|
243 |
+
text = element[4].get_text()
|
244 |
+
|
245 |
+
if x_offset_prev_line != None:
|
246 |
+
large_x_offset = (abs(x_offset_prev_line - x_offset) > paragraph_tolerance)
|
247 |
+
|
248 |
+
# Font size mismatch?
|
249 |
+
if abs(height - nominal_font) > max_height_diff:
|
250 |
+
if len(paragraph_content) > 0:
|
251 |
+
print("Content append:", len(paragraph_content))
|
252 |
+
paragraphs.append(paragraph_content)
|
253 |
+
paragraph_content = ""
|
254 |
+
print("Continue due to height != nominal_font")
|
255 |
+
continue
|
256 |
+
|
257 |
+
print("ELEMENT:", element[0:4], text[0:15])
|
258 |
+
if prev_y_offset is not None and len(paragraph_content) > 0:
|
259 |
+
if y_offset < prev_y_offset - height * 1.5:
|
260 |
+
print("Content append:", len(paragraph_content))
|
261 |
+
if len(paragraph_content) > 0:
|
262 |
+
paragraphs.append(paragraph_content)
|
263 |
+
paragraph_content = text
|
264 |
+
prev_y_offset = None
|
265 |
+
continue
|
266 |
+
|
267 |
+
prev_y_offset = y_offset
|
268 |
+
|
269 |
+
prev_y_offset = y_offset
|
270 |
+
#print("element:", element)
|
271 |
+
if not isinstance(element[4], LTTextLineHorizontal):
|
272 |
+
continue
|
273 |
+
|
274 |
+
#print("Running text:", text)
|
275 |
+
#print(f"x_offset_prev_line , x_offset]: {x_offset_prev_line, x_offset}")
|
276 |
+
|
277 |
+
|
278 |
+
# Find first paragraph
|
279 |
+
if x_offset_prev_line is None:
|
280 |
+
#print("x_offset_prev is none")
|
281 |
+
x_offset_prev_line = x_offset
|
282 |
+
if not processed_first_line:
|
283 |
+
first_line = True
|
284 |
+
processed_first_line = True
|
285 |
+
if height == nominal_font:
|
286 |
+
paragraph_content += text
|
287 |
+
#print("Continue due to x_offset_prev_line is none")
|
288 |
+
continue
|
289 |
+
|
290 |
+
|
291 |
+
|
292 |
+
# Check case if first line was indented
|
293 |
+
if x_offset_prev_line > x_offset and first_line:
|
294 |
+
#print("x_offset < element[0]")
|
295 |
+
first_line = False
|
296 |
+
paragraph_content += text
|
297 |
+
x_offset_prev_line = x_offset
|
298 |
+
#print("Continue due to x_offset_prev_line > x_offset and first_line")
|
299 |
+
continue
|
300 |
+
|
301 |
+
# is this indented?
|
302 |
+
# and ignore small changes
|
303 |
+
if x_offset_prev_line < x_offset and large_x_offset:
|
304 |
+
#print(f"x_offset_prev_line > x_offset: {x_offset_prev_line, x_offset}")
|
305 |
+
if height == nominal_font and len(paragraph_content) > 0:
|
306 |
+
paragraphs.append(paragraph_content)
|
307 |
+
|
308 |
+
paragraph_content = text
|
309 |
+
# Reset at next line read
|
310 |
+
# What if next paragraph is also indented???
|
311 |
+
x_offset_prev_line = None
|
312 |
+
#print("Continue due to x_offset_prev_line < x_offset and large_x_offset")
|
313 |
+
continue
|
314 |
+
|
315 |
+
#print(element[0:4])
|
316 |
+
if height == nominal_font:
|
317 |
+
paragraph_content += text
|
318 |
+
#print("End of loop")
|
319 |
+
|
320 |
+
# TODO: Remove redundant space
|
321 |
+
if paragraph_content != "":
|
322 |
+
paragraphs.append(paragraph_content)
|
323 |
+
|
324 |
+
# Find paragraph indexes
|
325 |
+
c = 0
|
326 |
+
indexes = []
|
327 |
+
for p in paragraphs:
|
328 |
+
c += len(p)
|
329 |
+
indexes.append(c)
|
330 |
+
|
331 |
+
return paragraphs, indexes
|
332 |
+
|
333 |
+
def get_pdf_elements(document, element_type, page_numbers=None):
|
334 |
+
pdf_doc = open(document, 'rb')
|
335 |
+
|
336 |
+
items = {}
|
337 |
+
for layout, page in get_pages(pdf_doc, page_numbers):
|
338 |
+
#print(element.get_text())
|
339 |
+
items[page] = []
|
340 |
+
for element in layout:
|
341 |
+
if isinstance(element, element_type):
|
342 |
+
item = list(element.bbox)
|
343 |
+
if hasattr(element, 'non_stroking_color'):
|
344 |
+
item.append(element.non_stroking_color)
|
345 |
+
items[page].append(item)
|
346 |
+
print(items)
|
347 |
+
return items
|
348 |
+
|
349 |
+
def get_large_colored_background_rectangles(document, page_numbers=None):
|
350 |
+
# Only include rectangles that are at least 4" x 1" in size
|
351 |
+
min_size = (288.0, 72.0)
|
352 |
+
|
353 |
+
elements = get_pdf_elements(document, LTRect, page_numbers)
|
354 |
+
rects_out = {}
|
355 |
+
for page, rects in elements.items():
|
356 |
+
print("Rects:", rects)
|
357 |
+
for rect in rects:
|
358 |
+
width = rect[2] - rect[0]
|
359 |
+
height = rect[3] - rect[1]
|
360 |
+
print("Dimensions:", width, height)
|
361 |
+
if (width > min_size[0] and
|
362 |
+
height > min_size[1]):
|
363 |
+
if not page in rects_out:
|
364 |
+
rects_out[page] = []
|
365 |
+
rects_out[page].append(rect)
|
366 |
+
return rects_out
|
367 |
+
|
368 |
+
def extract_pages(document, output, page_numbers=None):
|
369 |
+
pdf = PdfFileReader(document)
|
370 |
+
|
371 |
+
pdf_writer = PdfFileWriter()
|
372 |
+
for page in page_numbers:
|
373 |
+
current_page = pdf.getPage(page)
|
374 |
+
pdf_writer.addPage(current_page)
|
375 |
+
|
376 |
+
with open(output, "wb") as out:
|
377 |
+
pdf_writer.write(out)
|
378 |
+
|
DiT_Extractor/dit_object_detection/README.md
ADDED
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# DiT for Object Detection
|
2 |
+
|
3 |
+
This folder contains Mask R-CNN Cascade Mask R-CNN running instructions on top of [Detectron2](https://github.com/facebookresearch/detectron2) for PubLayNet and ICDAR 2019 cTDaR.
|
4 |
+
|
5 |
+
## Usage
|
6 |
+
|
7 |
+
### Inference
|
8 |
+
|
9 |
+
The quickest way to try out DiT for document layout analysis is the web demo: [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/nielsr/dit-document-layout-analysis).
|
10 |
+
|
11 |
+
One can run inference using the `inference.py` script. It can be run as follows (from the root of the unilm repository):
|
12 |
+
|
13 |
+
```
|
14 |
+
python ./dit/object_detection/inference.py \
|
15 |
+
--image_path ./dit/object_detection/publaynet_example.jpeg \
|
16 |
+
--output_file_name output.jpg \
|
17 |
+
--config ./dit/object_detection/publaynet_configs/maskrcnn/maskrcnn_dit_base.yaml \
|
18 |
+
--opts MODEL.WEIGHTS https://layoutlm.blob.core.windows.net/dit/dit-fts/publaynet_dit-b_mrcnn.pth \
|
19 |
+
```
|
20 |
+
|
21 |
+
Make sure that the configuration file (YAML) and PyTorch checkpoint match. The example above uses DiT-base with the Mask R-CNN framework fine-tuned on PubLayNet.
|
22 |
+
|
23 |
+
### Data Preparation
|
24 |
+
|
25 |
+
**PubLayNet**
|
26 |
+
|
27 |
+
Download the data from this [link](https://dax-cdn.cdn.appdomain.cloud/dax-publaynet/1.0.0/publaynet.tar.gz?_ga=2.218138265.1825957955.1646384196-1495010506.1633610665) (~96GB). Then extract it to `PATH-to-PubLayNet`.
|
28 |
+
|
29 |
+
A soft link needs to be created to make the data accessible for the program:`ln -s PATH-to-PubLayNet publaynet_data`.
|
30 |
+
|
31 |
+
**ICDAR 2019 cTDaR**
|
32 |
+
|
33 |
+
Download the data from this [link](https://github.com/cndplab-founder/ICDAR2019_cTDaR) (~4GB). Assume path to this repository is named as `PATH-to-ICDARrepo`.
|
34 |
+
|
35 |
+
Then run `python convert_to_coco_format.py --root_dir=PATH-to-ICDARrepo --target_dir=PATH-toICDAR`. Now the path to processed data is `PATH-to-ICDAR`.
|
36 |
+
|
37 |
+
Run the following command to get the adaptively binarized images for archival subset.
|
38 |
+
|
39 |
+
```
|
40 |
+
cp -r PATH-to-ICDAR/trackA_archival PATH-to-ICDAR/at_trackA_archival
|
41 |
+
python adaptive_binarize.py --root_dir PATH-to-ICDAR/at_trackA_archival
|
42 |
+
```
|
43 |
+
|
44 |
+
The binarized archival subset will be in `PATH-to-ICDAR/at_trackA_archival`.
|
45 |
+
|
46 |
+
According to the subset you want to evaluate/fine-tune, a soft link should be created:`ln -s PATH-to-ICDAR/trackA_modern data` or `ln -s PATH-to-ICDAR/at_trackA_archival data`.
|
47 |
+
|
48 |
+
### Evaluation
|
49 |
+
|
50 |
+
Following commands provide two examples to evaluate the fine-tuned checkpoints.
|
51 |
+
|
52 |
+
The config files can be found in `icdar19_configs` and `publaynet_configs`.
|
53 |
+
|
54 |
+
1) Evaluate the fine-tuned checkpoint of DiT-Base with Mask R-CNN on PublayNet:
|
55 |
+
```bash
|
56 |
+
python train_net.py --config-file publaynet_configs/maskrcnn/maskrcnn_dit_base.yaml --eval-only --num-gpus 8 MODEL.WEIGHTS <finetuned_checkpoint_file_path or link> OUTPUT_DIR <your_output_dir>
|
57 |
+
```
|
58 |
+
|
59 |
+
2) Evaluate the fine-tuned checkpoint of DiT-Large with Cascade Mask R-CNN on ICDAR 2019 cTDaR archival subset (make sure you have created a soft link from `PATH-to-ICDAR/at_trackA_archival` to `data`):
|
60 |
+
```bash
|
61 |
+
python train_net.py --config-file icdar19_configs/cascade/cascade_dit_large.yaml --eval-only --num-gpus 8 MODEL.WEIGHTS <finetuned_checkpoint_file_path or link> OUTPUT_DIR <your_output_dir>
|
62 |
+
```
|
63 |
+
|
64 |
+
**Note**: We have fixed the **bug** in the [ICDAR2019 measurement tool](https://github.com/cndplab-founder/ctdar_measurement_tool) during integrating the tool into our code. If you use the tool to get the evaluation score, please modify the [code](https://github.com/cndplab-founder/ctdar_measurement_tool/blob/738456d3164a838ffaeefe7d1b5e64f3a4368a0e/evaluate.py#L146
|
65 |
+
) as follows:
|
66 |
+
```bash
|
67 |
+
...
|
68 |
+
# print(each_file)
|
69 |
+
|
70 |
+
# for file in gt_file_lst:
|
71 |
+
# if file.split(".") != "xml":
|
72 |
+
# gt_file_lst.remove(file)
|
73 |
+
# # print(gt_file_lst)
|
74 |
+
|
75 |
+
# Comment the code above and add the code below
|
76 |
+
for i in range(len(gt_file_lst) - 1, -1, -1):
|
77 |
+
if gt_file_lst[i].split(".")[-1] != "xml":
|
78 |
+
del gt_file_lst[i]
|
79 |
+
|
80 |
+
if len(gt_file_lst) > 0:
|
81 |
+
...
|
82 |
+
```
|
83 |
+
|
84 |
+
### Training
|
85 |
+
The following commands provide two examples to train the Mask R-CNN/Cascade Mask R-CNN with DiT backbone on 8 32GB Nvidia V100 GPUs.
|
86 |
+
|
87 |
+
1) Fine-tune DiT-Base with Cascade Mask R-CNN on PublayNet:
|
88 |
+
```bash
|
89 |
+
python train_net.py --config-file publaynet_configs/cascade/cascade_dit_base.yaml --num-gpus 8 MODEL.WEIGHTS <DiT-Base_file_path or link> OUTPUT_DIR <your_output_dir>
|
90 |
+
```
|
91 |
+
|
92 |
+
|
93 |
+
2) Fine-tune DiT-Large with Mask R-CNN on ICDAR 2019 cTDaR modern:
|
94 |
+
```bash
|
95 |
+
python train_net.py --config-file icdar19_configs/markrcnn/maskrcnn_dit_large.yaml --num-gpus 8 MODEL.WEIGHTS <DiT-Large_file_path or link> OUTPUT_DIR <your_output_dir>
|
96 |
+
```
|
97 |
+
|
98 |
+
|
99 |
+
|
100 |
+
[Detectron2's document](https://detectron2.readthedocs.io/en/latest/tutorials/getting_started.html) may help you for more details.
|
101 |
+
|
102 |
+
|
103 |
+
## Citation
|
104 |
+
|
105 |
+
If you find this repository useful, please consider citing our work:
|
106 |
+
```
|
107 |
+
@misc{li2022dit,
|
108 |
+
title={DiT: Self-supervised Pre-training for Document Image Transformer},
|
109 |
+
author={Junlong Li and Yiheng Xu and Tengchao Lv and Lei Cui and Cha Zhang and Furu Wei},
|
110 |
+
year={2022},
|
111 |
+
eprint={2203.02378},
|
112 |
+
archivePrefix={arXiv},
|
113 |
+
primaryClass={cs.CV}
|
114 |
+
}
|
115 |
+
```
|
116 |
+
|
117 |
+
|
118 |
+
|
119 |
+
## Acknowledgment
|
120 |
+
Thanks to [Detectron2](https://github.com/facebookresearch/detectron2) for Mask R-CNN and Cascade Mask R-CNN implementation.
|
DiT_Extractor/dit_object_detection/ditod/__init__.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------------------------------
|
2 |
+
# MPViT: Multi-Path Vision Transformer for Dense Prediction
|
3 |
+
# Copyright (c) 2022 Electronics and Telecommunications Research Institute (ETRI).
|
4 |
+
# All Rights Reserved.
|
5 |
+
# Written by Youngwan Lee
|
6 |
+
# This source code is licensed(Dual License(GPL3.0 & Commercial)) under the license found in the
|
7 |
+
# LICENSE file in the root directory of this source tree.
|
8 |
+
# --------------------------------------------------------------------------------
|
9 |
+
|
10 |
+
from .config import add_vit_config
|
11 |
+
from .backbone import build_vit_fpn_backbone
|
DiT_Extractor/dit_object_detection/ditod/backbone.py
ADDED
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------------------------------
|
2 |
+
# VIT: Multi-Path Vision Transformer for Dense Prediction
|
3 |
+
# Copyright (c) 2022 Electronics and Telecommunications Research Institute (ETRI).
|
4 |
+
# All Rights Reserved.
|
5 |
+
# Written by Youngwan Lee
|
6 |
+
# This source code is licensed(Dual License(GPL3.0 & Commercial)) under the license found in the
|
7 |
+
# LICENSE file in the root directory of this source tree.
|
8 |
+
# --------------------------------------------------------------------------------
|
9 |
+
# References:
|
10 |
+
# timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm
|
11 |
+
# CoaT: https://github.com/mlpc-ucsd/CoaT
|
12 |
+
# --------------------------------------------------------------------------------
|
13 |
+
|
14 |
+
|
15 |
+
import torch
|
16 |
+
|
17 |
+
from detectron2.layers import (
|
18 |
+
ShapeSpec,
|
19 |
+
)
|
20 |
+
from detectron2.modeling import Backbone, BACKBONE_REGISTRY, FPN
|
21 |
+
from detectron2.modeling.backbone.fpn import LastLevelP6P7, LastLevelMaxPool
|
22 |
+
|
23 |
+
from .beit import beit_base_patch16, dit_base_patch16, dit_large_patch16, beit_large_patch16
|
24 |
+
from .deit import deit_base_patch16, mae_base_patch16
|
25 |
+
|
26 |
+
__all__ = [
|
27 |
+
"build_vit_fpn_backbone",
|
28 |
+
]
|
29 |
+
|
30 |
+
|
31 |
+
class VIT_Backbone(Backbone):
|
32 |
+
"""
|
33 |
+
Implement VIT backbone.
|
34 |
+
"""
|
35 |
+
|
36 |
+
def __init__(self, name, out_features, drop_path, img_size, pos_type, model_kwargs):
|
37 |
+
super().__init__()
|
38 |
+
self._out_features = out_features
|
39 |
+
if 'base' in name:
|
40 |
+
self._out_feature_strides = {"layer3": 4, "layer5": 8, "layer7": 16, "layer11": 32}
|
41 |
+
else:
|
42 |
+
self._out_feature_strides = {"layer7": 4, "layer11": 8, "layer15": 16, "layer23": 32}
|
43 |
+
|
44 |
+
if name == 'beit_base_patch16':
|
45 |
+
model_func = beit_base_patch16
|
46 |
+
self._out_feature_channels = {"layer3": 768, "layer5": 768, "layer7": 768, "layer11": 768}
|
47 |
+
elif name == 'dit_base_patch16':
|
48 |
+
model_func = dit_base_patch16
|
49 |
+
self._out_feature_channels = {"layer3": 768, "layer5": 768, "layer7": 768, "layer11": 768}
|
50 |
+
elif name == "deit_base_patch16":
|
51 |
+
model_func = deit_base_patch16
|
52 |
+
self._out_feature_channels = {"layer3": 768, "layer5": 768, "layer7": 768, "layer11": 768}
|
53 |
+
elif name == "mae_base_patch16":
|
54 |
+
model_func = mae_base_patch16
|
55 |
+
self._out_feature_channels = {"layer3": 768, "layer5": 768, "layer7": 768, "layer11": 768}
|
56 |
+
elif name == "dit_large_patch16":
|
57 |
+
model_func = dit_large_patch16
|
58 |
+
self._out_feature_channels = {"layer7": 1024, "layer11": 1024, "layer15": 1024, "layer23": 1024}
|
59 |
+
elif name == "beit_large_patch16":
|
60 |
+
model_func = beit_large_patch16
|
61 |
+
self._out_feature_channels = {"layer7": 1024, "layer11": 1024, "layer15": 1024, "layer23": 1024}
|
62 |
+
else:
|
63 |
+
raise ValueError("Unsupported VIT name yet.")
|
64 |
+
|
65 |
+
if 'beit' in name or 'dit' in name:
|
66 |
+
if pos_type == "abs":
|
67 |
+
self.backbone = model_func(img_size=img_size,
|
68 |
+
out_features=out_features,
|
69 |
+
drop_path_rate=drop_path,
|
70 |
+
use_abs_pos_emb=True,
|
71 |
+
**model_kwargs)
|
72 |
+
elif pos_type == "shared_rel":
|
73 |
+
self.backbone = model_func(img_size=img_size,
|
74 |
+
out_features=out_features,
|
75 |
+
drop_path_rate=drop_path,
|
76 |
+
use_shared_rel_pos_bias=True,
|
77 |
+
**model_kwargs)
|
78 |
+
elif pos_type == "rel":
|
79 |
+
self.backbone = model_func(img_size=img_size,
|
80 |
+
out_features=out_features,
|
81 |
+
drop_path_rate=drop_path,
|
82 |
+
use_rel_pos_bias=True,
|
83 |
+
**model_kwargs)
|
84 |
+
else:
|
85 |
+
raise ValueError()
|
86 |
+
else:
|
87 |
+
self.backbone = model_func(img_size=img_size,
|
88 |
+
out_features=out_features,
|
89 |
+
drop_path_rate=drop_path,
|
90 |
+
**model_kwargs)
|
91 |
+
|
92 |
+
def forward(self, x):
|
93 |
+
"""
|
94 |
+
Args:
|
95 |
+
x: Tensor of shape (N,C,H,W). H, W must be a multiple of ``self.size_divisibility``.
|
96 |
+
|
97 |
+
Returns:
|
98 |
+
dict[str->Tensor]: names and the corresponding features
|
99 |
+
"""
|
100 |
+
assert x.dim() == 4, f"VIT takes an input of shape (N, C, H, W). Got {x.shape} instead!"
|
101 |
+
return self.backbone.forward_features(x)
|
102 |
+
|
103 |
+
def output_shape(self):
|
104 |
+
return {
|
105 |
+
name: ShapeSpec(
|
106 |
+
channels=self._out_feature_channels[name], stride=self._out_feature_strides[name]
|
107 |
+
)
|
108 |
+
for name in self._out_features
|
109 |
+
}
|
110 |
+
|
111 |
+
|
112 |
+
def build_VIT_backbone(cfg):
|
113 |
+
"""
|
114 |
+
Create a VIT instance from config.
|
115 |
+
|
116 |
+
Args:
|
117 |
+
cfg: a detectron2 CfgNode
|
118 |
+
|
119 |
+
Returns:
|
120 |
+
A VIT backbone instance.
|
121 |
+
"""
|
122 |
+
# fmt: off
|
123 |
+
name = cfg.MODEL.VIT.NAME
|
124 |
+
out_features = cfg.MODEL.VIT.OUT_FEATURES
|
125 |
+
drop_path = cfg.MODEL.VIT.DROP_PATH
|
126 |
+
img_size = cfg.MODEL.VIT.IMG_SIZE
|
127 |
+
pos_type = cfg.MODEL.VIT.POS_TYPE
|
128 |
+
|
129 |
+
model_kwargs = eval(str(cfg.MODEL.VIT.MODEL_KWARGS).replace("`", ""))
|
130 |
+
|
131 |
+
return VIT_Backbone(name, out_features, drop_path, img_size, pos_type, model_kwargs)
|
132 |
+
|
133 |
+
|
134 |
+
@BACKBONE_REGISTRY.register()
|
135 |
+
def build_vit_fpn_backbone(cfg, input_shape: ShapeSpec):
|
136 |
+
"""
|
137 |
+
Create a VIT w/ FPN backbone.
|
138 |
+
|
139 |
+
Args:
|
140 |
+
cfg: a detectron2 CfgNode
|
141 |
+
|
142 |
+
Returns:
|
143 |
+
backbone (Backbone): backbone module, must be a subclass of :class:`Backbone`.
|
144 |
+
"""
|
145 |
+
bottom_up = build_VIT_backbone(cfg)
|
146 |
+
in_features = cfg.MODEL.FPN.IN_FEATURES
|
147 |
+
out_channels = cfg.MODEL.FPN.OUT_CHANNELS
|
148 |
+
backbone = FPN(
|
149 |
+
bottom_up=bottom_up,
|
150 |
+
in_features=in_features,
|
151 |
+
out_channels=out_channels,
|
152 |
+
norm=cfg.MODEL.FPN.NORM,
|
153 |
+
top_block=LastLevelMaxPool(),
|
154 |
+
fuse_type=cfg.MODEL.FPN.FUSE_TYPE,
|
155 |
+
)
|
156 |
+
return backbone
|
DiT_Extractor/dit_object_detection/ditod/beit.py
ADDED
@@ -0,0 +1,671 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" Vision Transformer (ViT) in PyTorch
|
2 |
+
|
3 |
+
A PyTorch implement of Vision Transformers as described in
|
4 |
+
'An Image Is Worth 16 x 16 Words: Transformers for Image Recognition at Scale' - https://arxiv.org/abs/2010.11929
|
5 |
+
|
6 |
+
The official jax code is released and available at https://github.com/google-research/vision_transformer
|
7 |
+
|
8 |
+
Status/TODO:
|
9 |
+
* Models updated to be compatible with official impl. Args added to support backward compat for old PyTorch weights.
|
10 |
+
* Weights ported from official jax impl for 384x384 base and small models, 16x16 and 32x32 patches.
|
11 |
+
* Trained (supervised on ImageNet-1k) my custom 'small' patch model to 77.9, 'base' to 79.4 top-1 with this code.
|
12 |
+
* Hopefully find time and GPUs for SSL or unsupervised pretraining on OpenImages w/ ImageNet fine-tune in future.
|
13 |
+
|
14 |
+
Acknowledgments:
|
15 |
+
* The paper authors for releasing code and weights, thanks!
|
16 |
+
* I fixed my class token impl based on Phil Wang's https://github.com/lucidrains/vit-pytorch ... check it out
|
17 |
+
for some einops/einsum fun
|
18 |
+
* Simple transformer style inspired by Andrej Karpathy's https://github.com/karpathy/minGPT
|
19 |
+
* Bert reference code checks against Huggingface Transformers and Tensorflow Bert
|
20 |
+
|
21 |
+
Hacked together by / Copyright 2020 Ross Wightman
|
22 |
+
"""
|
23 |
+
import warnings
|
24 |
+
import math
|
25 |
+
import torch
|
26 |
+
from functools import partial
|
27 |
+
import torch.nn as nn
|
28 |
+
import torch.nn.functional as F
|
29 |
+
import torch.utils.checkpoint as checkpoint
|
30 |
+
from timm.models.layers import drop_path, to_2tuple, trunc_normal_
|
31 |
+
|
32 |
+
|
33 |
+
def _cfg(url='', **kwargs):
|
34 |
+
return {
|
35 |
+
'url': url,
|
36 |
+
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
|
37 |
+
'crop_pct': .9, 'interpolation': 'bicubic',
|
38 |
+
'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5),
|
39 |
+
**kwargs
|
40 |
+
}
|
41 |
+
|
42 |
+
|
43 |
+
class DropPath(nn.Module):
|
44 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
45 |
+
"""
|
46 |
+
|
47 |
+
def __init__(self, drop_prob=None):
|
48 |
+
super(DropPath, self).__init__()
|
49 |
+
self.drop_prob = drop_prob
|
50 |
+
|
51 |
+
def forward(self, x):
|
52 |
+
return drop_path(x, self.drop_prob, self.training)
|
53 |
+
|
54 |
+
def extra_repr(self) -> str:
|
55 |
+
return 'p={}'.format(self.drop_prob)
|
56 |
+
|
57 |
+
|
58 |
+
class Mlp(nn.Module):
|
59 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
60 |
+
super().__init__()
|
61 |
+
out_features = out_features or in_features
|
62 |
+
hidden_features = hidden_features or in_features
|
63 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
64 |
+
self.act = act_layer()
|
65 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
66 |
+
self.drop = nn.Dropout(drop)
|
67 |
+
|
68 |
+
def forward(self, x):
|
69 |
+
x = self.fc1(x)
|
70 |
+
x = self.act(x)
|
71 |
+
# x = self.drop(x)
|
72 |
+
# commit this for the orignal BERT implement
|
73 |
+
x = self.fc2(x)
|
74 |
+
x = self.drop(x)
|
75 |
+
return x
|
76 |
+
|
77 |
+
|
78 |
+
class Attention(nn.Module):
|
79 |
+
def __init__(
|
80 |
+
self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
|
81 |
+
proj_drop=0., window_size=None, attn_head_dim=None):
|
82 |
+
super().__init__()
|
83 |
+
self.num_heads = num_heads
|
84 |
+
head_dim = dim // num_heads
|
85 |
+
if attn_head_dim is not None:
|
86 |
+
head_dim = attn_head_dim
|
87 |
+
all_head_dim = head_dim * self.num_heads
|
88 |
+
# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
|
89 |
+
self.scale = qk_scale or head_dim ** -0.5
|
90 |
+
|
91 |
+
self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
|
92 |
+
if qkv_bias:
|
93 |
+
self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
|
94 |
+
self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
|
95 |
+
else:
|
96 |
+
self.q_bias = None
|
97 |
+
self.v_bias = None
|
98 |
+
|
99 |
+
if window_size:
|
100 |
+
self.window_size = window_size
|
101 |
+
self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
|
102 |
+
self.relative_position_bias_table = nn.Parameter(
|
103 |
+
torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
|
104 |
+
# cls to token & token 2 cls & cls to cls
|
105 |
+
|
106 |
+
# get pair-wise relative position index for each token inside the window
|
107 |
+
coords_h = torch.arange(window_size[0])
|
108 |
+
coords_w = torch.arange(window_size[1])
|
109 |
+
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
|
110 |
+
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
111 |
+
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
|
112 |
+
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
|
113 |
+
relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
|
114 |
+
relative_coords[:, :, 1] += window_size[1] - 1
|
115 |
+
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
|
116 |
+
relative_position_index = \
|
117 |
+
torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype)
|
118 |
+
relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
119 |
+
relative_position_index[0, 0:] = self.num_relative_distance - 3
|
120 |
+
relative_position_index[0:, 0] = self.num_relative_distance - 2
|
121 |
+
relative_position_index[0, 0] = self.num_relative_distance - 1
|
122 |
+
|
123 |
+
self.register_buffer("relative_position_index", relative_position_index)
|
124 |
+
|
125 |
+
# trunc_normal_(self.relative_position_bias_table, std=.0)
|
126 |
+
else:
|
127 |
+
self.window_size = None
|
128 |
+
self.relative_position_bias_table = None
|
129 |
+
self.relative_position_index = None
|
130 |
+
|
131 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
132 |
+
self.proj = nn.Linear(all_head_dim, dim)
|
133 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
134 |
+
|
135 |
+
def forward(self, x, rel_pos_bias=None, training_window_size=None):
|
136 |
+
B, N, C = x.shape
|
137 |
+
qkv_bias = None
|
138 |
+
if self.q_bias is not None:
|
139 |
+
qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
|
140 |
+
# qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
141 |
+
qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
|
142 |
+
qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
143 |
+
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
|
144 |
+
|
145 |
+
q = q * self.scale
|
146 |
+
attn = (q @ k.transpose(-2, -1))
|
147 |
+
|
148 |
+
if self.relative_position_bias_table is not None:
|
149 |
+
if training_window_size == self.window_size:
|
150 |
+
relative_position_bias = \
|
151 |
+
self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
|
152 |
+
self.window_size[0] * self.window_size[1] + 1,
|
153 |
+
self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
|
154 |
+
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
155 |
+
attn = attn + relative_position_bias.unsqueeze(0)
|
156 |
+
else:
|
157 |
+
training_window_size = tuple(training_window_size.tolist())
|
158 |
+
new_num_relative_distance = (2 * training_window_size[0] - 1) * (2 * training_window_size[1] - 1) + 3
|
159 |
+
# new_num_relative_dis 为 所有可能的相对位置选项,包含cls-cls,tok-cls,与cls-tok
|
160 |
+
new_relative_position_bias_table = F.interpolate(
|
161 |
+
self.relative_position_bias_table[:-3, :].permute(1, 0).view(1, self.num_heads,
|
162 |
+
2 * self.window_size[0] - 1,
|
163 |
+
2 * self.window_size[1] - 1),
|
164 |
+
size=(2 * training_window_size[0] - 1, 2 * training_window_size[1] - 1), mode='bicubic',
|
165 |
+
align_corners=False)
|
166 |
+
new_relative_position_bias_table = new_relative_position_bias_table.view(self.num_heads,
|
167 |
+
new_num_relative_distance - 3).permute(
|
168 |
+
1, 0)
|
169 |
+
new_relative_position_bias_table = torch.cat(
|
170 |
+
[new_relative_position_bias_table, self.relative_position_bias_table[-3::]], dim=0)
|
171 |
+
|
172 |
+
# get pair-wise relative position index for each token inside the window
|
173 |
+
coords_h = torch.arange(training_window_size[0])
|
174 |
+
coords_w = torch.arange(training_window_size[1])
|
175 |
+
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
|
176 |
+
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
177 |
+
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
|
178 |
+
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
|
179 |
+
relative_coords[:, :, 0] += training_window_size[0] - 1 # shift to start from 0
|
180 |
+
relative_coords[:, :, 1] += training_window_size[1] - 1
|
181 |
+
relative_coords[:, :, 0] *= 2 * training_window_size[1] - 1
|
182 |
+
relative_position_index = \
|
183 |
+
torch.zeros(size=(training_window_size[0] * training_window_size[1] + 1,) * 2,
|
184 |
+
dtype=relative_coords.dtype)
|
185 |
+
relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
186 |
+
relative_position_index[0, 0:] = new_num_relative_distance - 3
|
187 |
+
relative_position_index[0:, 0] = new_num_relative_distance - 2
|
188 |
+
relative_position_index[0, 0] = new_num_relative_distance - 1
|
189 |
+
|
190 |
+
relative_position_bias = \
|
191 |
+
new_relative_position_bias_table[relative_position_index.view(-1)].view(
|
192 |
+
training_window_size[0] * training_window_size[1] + 1,
|
193 |
+
training_window_size[0] * training_window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
|
194 |
+
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
195 |
+
attn = attn + relative_position_bias.unsqueeze(0)
|
196 |
+
|
197 |
+
if rel_pos_bias is not None:
|
198 |
+
attn = attn + rel_pos_bias
|
199 |
+
|
200 |
+
attn = attn.softmax(dim=-1)
|
201 |
+
attn = self.attn_drop(attn)
|
202 |
+
|
203 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
|
204 |
+
x = self.proj(x)
|
205 |
+
x = self.proj_drop(x)
|
206 |
+
return x
|
207 |
+
|
208 |
+
|
209 |
+
class Block(nn.Module):
|
210 |
+
|
211 |
+
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
|
212 |
+
drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm,
|
213 |
+
window_size=None, attn_head_dim=None):
|
214 |
+
super().__init__()
|
215 |
+
self.norm1 = norm_layer(dim)
|
216 |
+
self.attn = Attention(
|
217 |
+
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
218 |
+
attn_drop=attn_drop, proj_drop=drop, window_size=window_size, attn_head_dim=attn_head_dim)
|
219 |
+
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
220 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
221 |
+
self.norm2 = norm_layer(dim)
|
222 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
223 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
224 |
+
|
225 |
+
if init_values is not None:
|
226 |
+
self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
|
227 |
+
self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
|
228 |
+
else:
|
229 |
+
self.gamma_1, self.gamma_2 = None, None
|
230 |
+
|
231 |
+
def forward(self, x, rel_pos_bias=None, training_window_size=None):
|
232 |
+
if self.gamma_1 is None:
|
233 |
+
x = x + self.drop_path(
|
234 |
+
self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias, training_window_size=training_window_size))
|
235 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
236 |
+
else:
|
237 |
+
x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias,
|
238 |
+
training_window_size=training_window_size))
|
239 |
+
x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
|
240 |
+
return x
|
241 |
+
|
242 |
+
|
243 |
+
class PatchEmbed(nn.Module):
|
244 |
+
""" Image to Patch Embedding
|
245 |
+
"""
|
246 |
+
|
247 |
+
def __init__(self, img_size=[224, 224], patch_size=16, in_chans=3, embed_dim=768):
|
248 |
+
super().__init__()
|
249 |
+
img_size = to_2tuple(img_size)
|
250 |
+
patch_size = to_2tuple(patch_size)
|
251 |
+
num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
|
252 |
+
self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
|
253 |
+
self.num_patches_w = self.patch_shape[0]
|
254 |
+
self.num_patches_h = self.patch_shape[1]
|
255 |
+
# the so-called patch_shape is the patch shape during pre-training
|
256 |
+
self.img_size = img_size
|
257 |
+
self.patch_size = patch_size
|
258 |
+
self.num_patches = num_patches
|
259 |
+
|
260 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
|
261 |
+
|
262 |
+
def forward(self, x, position_embedding=None, **kwargs):
|
263 |
+
# FIXME look at relaxing size constraints
|
264 |
+
# assert H == self.img_size[0] and W == self.img_size[1], \
|
265 |
+
# f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
|
266 |
+
x = self.proj(x)
|
267 |
+
Hp, Wp = x.shape[2], x.shape[3]
|
268 |
+
|
269 |
+
if position_embedding is not None:
|
270 |
+
# interpolate the position embedding to the corresponding size
|
271 |
+
position_embedding = position_embedding.view(1, self.patch_shape[0], self.patch_shape[1], -1).permute(0, 3,
|
272 |
+
1, 2)
|
273 |
+
position_embedding = F.interpolate(position_embedding, size=(Hp, Wp), mode='bicubic')
|
274 |
+
x = x + position_embedding
|
275 |
+
|
276 |
+
x = x.flatten(2).transpose(1, 2)
|
277 |
+
return x, (Hp, Wp)
|
278 |
+
|
279 |
+
|
280 |
+
class HybridEmbed(nn.Module):
|
281 |
+
""" CNN Feature Map Embedding
|
282 |
+
Extract feature map from CNN, flatten, project to embedding dim.
|
283 |
+
"""
|
284 |
+
|
285 |
+
def __init__(self, backbone, img_size=[224, 224], feature_size=None, in_chans=3, embed_dim=768):
|
286 |
+
super().__init__()
|
287 |
+
assert isinstance(backbone, nn.Module)
|
288 |
+
img_size = to_2tuple(img_size)
|
289 |
+
self.img_size = img_size
|
290 |
+
self.backbone = backbone
|
291 |
+
if feature_size is None:
|
292 |
+
with torch.no_grad():
|
293 |
+
# FIXME this is hacky, but most reliable way of determining the exact dim of the output feature
|
294 |
+
# map for all networks, the feature metadata has reliable channel and stride info, but using
|
295 |
+
# stride to calc feature dim requires info about padding of each stage that isn't captured.
|
296 |
+
training = backbone.training
|
297 |
+
if training:
|
298 |
+
backbone.eval()
|
299 |
+
o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))[-1]
|
300 |
+
feature_size = o.shape[-2:]
|
301 |
+
feature_dim = o.shape[1]
|
302 |
+
backbone.train(training)
|
303 |
+
else:
|
304 |
+
feature_size = to_2tuple(feature_size)
|
305 |
+
feature_dim = self.backbone.feature_info.channels()[-1]
|
306 |
+
self.num_patches = feature_size[0] * feature_size[1]
|
307 |
+
self.proj = nn.Linear(feature_dim, embed_dim)
|
308 |
+
|
309 |
+
def forward(self, x):
|
310 |
+
x = self.backbone(x)[-1]
|
311 |
+
x = x.flatten(2).transpose(1, 2)
|
312 |
+
x = self.proj(x)
|
313 |
+
return x
|
314 |
+
|
315 |
+
|
316 |
+
class RelativePositionBias(nn.Module):
|
317 |
+
|
318 |
+
def __init__(self, window_size, num_heads):
|
319 |
+
super().__init__()
|
320 |
+
self.window_size = window_size
|
321 |
+
self.num_heads = num_heads
|
322 |
+
self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
|
323 |
+
self.relative_position_bias_table = nn.Parameter(
|
324 |
+
torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
|
325 |
+
# cls to token & token 2 cls & cls to cls
|
326 |
+
|
327 |
+
# get pair-wise relative position index for each token inside the window
|
328 |
+
coords_h = torch.arange(window_size[0])
|
329 |
+
coords_w = torch.arange(window_size[1])
|
330 |
+
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
|
331 |
+
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
332 |
+
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
|
333 |
+
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
|
334 |
+
relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
|
335 |
+
relative_coords[:, :, 1] += window_size[1] - 1
|
336 |
+
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
|
337 |
+
relative_position_index = \
|
338 |
+
torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype)
|
339 |
+
relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
340 |
+
relative_position_index[0, 0:] = self.num_relative_distance - 3
|
341 |
+
relative_position_index[0:, 0] = self.num_relative_distance - 2
|
342 |
+
relative_position_index[0, 0] = self.num_relative_distance - 1
|
343 |
+
|
344 |
+
self.register_buffer("relative_position_index", relative_position_index)
|
345 |
+
|
346 |
+
# trunc_normal_(self.relative_position_bias_table, std=.02)
|
347 |
+
|
348 |
+
def forward(self, training_window_size):
|
349 |
+
if training_window_size == self.window_size:
|
350 |
+
relative_position_bias = \
|
351 |
+
self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
|
352 |
+
self.window_size[0] * self.window_size[1] + 1,
|
353 |
+
self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
|
354 |
+
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
355 |
+
else:
|
356 |
+
training_window_size = tuple(training_window_size.tolist())
|
357 |
+
new_num_relative_distance = (2 * training_window_size[0] - 1) * (2 * training_window_size[1] - 1) + 3
|
358 |
+
# new_num_relative_dis 为 所有可能的相对位置选项,包含cls-cls,tok-cls,与cls-tok
|
359 |
+
new_relative_position_bias_table = F.interpolate(
|
360 |
+
self.relative_position_bias_table[:-3, :].permute(1, 0).view(1, self.num_heads,
|
361 |
+
2 * self.window_size[0] - 1,
|
362 |
+
2 * self.window_size[1] - 1),
|
363 |
+
size=(2 * training_window_size[0] - 1, 2 * training_window_size[1] - 1), mode='bicubic',
|
364 |
+
align_corners=False)
|
365 |
+
new_relative_position_bias_table = new_relative_position_bias_table.view(self.num_heads,
|
366 |
+
new_num_relative_distance - 3).permute(
|
367 |
+
1, 0)
|
368 |
+
new_relative_position_bias_table = torch.cat(
|
369 |
+
[new_relative_position_bias_table, self.relative_position_bias_table[-3::]], dim=0)
|
370 |
+
|
371 |
+
# get pair-wise relative position index for each token inside the window
|
372 |
+
coords_h = torch.arange(training_window_size[0])
|
373 |
+
coords_w = torch.arange(training_window_size[1])
|
374 |
+
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
|
375 |
+
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
376 |
+
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
|
377 |
+
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
|
378 |
+
relative_coords[:, :, 0] += training_window_size[0] - 1 # shift to start from 0
|
379 |
+
relative_coords[:, :, 1] += training_window_size[1] - 1
|
380 |
+
relative_coords[:, :, 0] *= 2 * training_window_size[1] - 1
|
381 |
+
relative_position_index = \
|
382 |
+
torch.zeros(size=(training_window_size[0] * training_window_size[1] + 1,) * 2,
|
383 |
+
dtype=relative_coords.dtype)
|
384 |
+
relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
385 |
+
relative_position_index[0, 0:] = new_num_relative_distance - 3
|
386 |
+
relative_position_index[0:, 0] = new_num_relative_distance - 2
|
387 |
+
relative_position_index[0, 0] = new_num_relative_distance - 1
|
388 |
+
|
389 |
+
relative_position_bias = \
|
390 |
+
new_relative_position_bias_table[relative_position_index.view(-1)].view(
|
391 |
+
training_window_size[0] * training_window_size[1] + 1,
|
392 |
+
training_window_size[0] * training_window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
|
393 |
+
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
394 |
+
|
395 |
+
return relative_position_bias
|
396 |
+
|
397 |
+
|
398 |
+
class BEiT(nn.Module):
|
399 |
+
""" Vision Transformer with support for patch or hybrid CNN input stage
|
400 |
+
"""
|
401 |
+
|
402 |
+
def __init__(self,
|
403 |
+
img_size=[224, 224],
|
404 |
+
patch_size=16,
|
405 |
+
in_chans=3,
|
406 |
+
num_classes=80,
|
407 |
+
embed_dim=768,
|
408 |
+
depth=12,
|
409 |
+
num_heads=12,
|
410 |
+
mlp_ratio=4.,
|
411 |
+
qkv_bias=False,
|
412 |
+
qk_scale=None,
|
413 |
+
drop_rate=0.,
|
414 |
+
attn_drop_rate=0.,
|
415 |
+
drop_path_rate=0.,
|
416 |
+
hybrid_backbone=None,
|
417 |
+
norm_layer=None,
|
418 |
+
init_values=None,
|
419 |
+
use_abs_pos_emb=False,
|
420 |
+
use_rel_pos_bias=False,
|
421 |
+
use_shared_rel_pos_bias=False,
|
422 |
+
use_checkpoint=True,
|
423 |
+
pretrained=None,
|
424 |
+
out_features=None,
|
425 |
+
):
|
426 |
+
|
427 |
+
super(BEiT, self).__init__()
|
428 |
+
|
429 |
+
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
|
430 |
+
self.num_classes = num_classes
|
431 |
+
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
432 |
+
self.use_checkpoint = use_checkpoint
|
433 |
+
|
434 |
+
if hybrid_backbone is not None:
|
435 |
+
self.patch_embed = HybridEmbed(
|
436 |
+
hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim)
|
437 |
+
else:
|
438 |
+
self.patch_embed = PatchEmbed(
|
439 |
+
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
|
440 |
+
num_patches = self.patch_embed.num_patches
|
441 |
+
self.out_features = out_features
|
442 |
+
self.out_indices = [int(name[5:]) for name in out_features]
|
443 |
+
|
444 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
445 |
+
# self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
446 |
+
if use_abs_pos_emb:
|
447 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
|
448 |
+
else:
|
449 |
+
self.pos_embed = None
|
450 |
+
self.pos_drop = nn.Dropout(p=drop_rate)
|
451 |
+
|
452 |
+
self.use_shared_rel_pos_bias = use_shared_rel_pos_bias
|
453 |
+
if use_shared_rel_pos_bias:
|
454 |
+
self.rel_pos_bias = RelativePositionBias(window_size=self.patch_embed.patch_shape, num_heads=num_heads)
|
455 |
+
else:
|
456 |
+
self.rel_pos_bias = None
|
457 |
+
|
458 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
459 |
+
self.use_rel_pos_bias = use_rel_pos_bias
|
460 |
+
self.blocks = nn.ModuleList([
|
461 |
+
Block(
|
462 |
+
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
463 |
+
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
|
464 |
+
init_values=init_values, window_size=self.patch_embed.patch_shape if use_rel_pos_bias else None)
|
465 |
+
for i in range(depth)])
|
466 |
+
|
467 |
+
# trunc_normal_(self.mask_token, std=.02)
|
468 |
+
|
469 |
+
if patch_size == 16:
|
470 |
+
self.fpn1 = nn.Sequential(
|
471 |
+
nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
|
472 |
+
# nn.SyncBatchNorm(embed_dim),
|
473 |
+
nn.BatchNorm2d(embed_dim),
|
474 |
+
nn.GELU(),
|
475 |
+
nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
|
476 |
+
)
|
477 |
+
|
478 |
+
self.fpn2 = nn.Sequential(
|
479 |
+
nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
|
480 |
+
)
|
481 |
+
|
482 |
+
self.fpn3 = nn.Identity()
|
483 |
+
|
484 |
+
self.fpn4 = nn.MaxPool2d(kernel_size=2, stride=2)
|
485 |
+
elif patch_size == 8:
|
486 |
+
self.fpn1 = nn.Sequential(
|
487 |
+
nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
|
488 |
+
)
|
489 |
+
|
490 |
+
self.fpn2 = nn.Identity()
|
491 |
+
|
492 |
+
self.fpn3 = nn.Sequential(
|
493 |
+
nn.MaxPool2d(kernel_size=2, stride=2),
|
494 |
+
)
|
495 |
+
|
496 |
+
self.fpn4 = nn.Sequential(
|
497 |
+
nn.MaxPool2d(kernel_size=4, stride=4),
|
498 |
+
)
|
499 |
+
|
500 |
+
if self.pos_embed is not None:
|
501 |
+
trunc_normal_(self.pos_embed, std=.02)
|
502 |
+
trunc_normal_(self.cls_token, std=.02)
|
503 |
+
self.apply(self._init_weights)
|
504 |
+
self.fix_init_weight()
|
505 |
+
|
506 |
+
def fix_init_weight(self):
|
507 |
+
def rescale(param, layer_id):
|
508 |
+
param.div_(math.sqrt(2.0 * layer_id))
|
509 |
+
|
510 |
+
for layer_id, layer in enumerate(self.blocks):
|
511 |
+
rescale(layer.attn.proj.weight.data, layer_id + 1)
|
512 |
+
rescale(layer.mlp.fc2.weight.data, layer_id + 1)
|
513 |
+
|
514 |
+
def _init_weights(self, m):
|
515 |
+
if isinstance(m, nn.Linear):
|
516 |
+
trunc_normal_(m.weight, std=.02)
|
517 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
518 |
+
nn.init.constant_(m.bias, 0)
|
519 |
+
elif isinstance(m, nn.LayerNorm):
|
520 |
+
nn.init.constant_(m.bias, 0)
|
521 |
+
nn.init.constant_(m.weight, 1.0)
|
522 |
+
|
523 |
+
'''
|
524 |
+
def init_weights(self):
|
525 |
+
"""Initialize the weights in backbone.
|
526 |
+
|
527 |
+
Args:
|
528 |
+
pretrained (str, optional): Path to pre-trained weights.
|
529 |
+
Defaults to None.
|
530 |
+
"""
|
531 |
+
logger = get_root_logger()
|
532 |
+
|
533 |
+
if self.pos_embed is not None:
|
534 |
+
trunc_normal_(self.pos_embed, std=.02)
|
535 |
+
trunc_normal_(self.cls_token, std=.02)
|
536 |
+
self.apply(self._init_weights)
|
537 |
+
self.fix_init_weight()
|
538 |
+
|
539 |
+
if self.init_cfg is None:
|
540 |
+
logger.warn(f'No pre-trained weights for '
|
541 |
+
f'{self.__class__.__name__}, '
|
542 |
+
f'training start from scratch')
|
543 |
+
else:
|
544 |
+
assert 'checkpoint' in self.init_cfg, f'Only support ' \
|
545 |
+
f'specify `Pretrained` in ' \
|
546 |
+
f'`init_cfg` in ' \
|
547 |
+
f'{self.__class__.__name__} '
|
548 |
+
logger.info(f"Will load ckpt from {self.init_cfg['checkpoint']}")
|
549 |
+
load_checkpoint(self,
|
550 |
+
filename=self.init_cfg['checkpoint'],
|
551 |
+
strict=False,
|
552 |
+
logger=logger,
|
553 |
+
beit_spec_expand_rel_pos = self.use_rel_pos_bias,
|
554 |
+
)
|
555 |
+
'''
|
556 |
+
|
557 |
+
def get_num_layers(self):
|
558 |
+
return len(self.blocks)
|
559 |
+
|
560 |
+
@torch.jit.ignore
|
561 |
+
def no_weight_decay(self):
|
562 |
+
return {'pos_embed', 'cls_token'}
|
563 |
+
|
564 |
+
def forward_features(self, x):
|
565 |
+
B, C, H, W = x.shape
|
566 |
+
x, (Hp, Wp) = self.patch_embed(x, self.pos_embed[:, 1:, :] if self.pos_embed is not None else None)
|
567 |
+
# Hp, Wp are HW for patches
|
568 |
+
batch_size, seq_len, _ = x.size()
|
569 |
+
|
570 |
+
cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
|
571 |
+
if self.pos_embed is not None:
|
572 |
+
cls_tokens = cls_tokens + self.pos_embed[:, :1, :]
|
573 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
574 |
+
x = self.pos_drop(x)
|
575 |
+
|
576 |
+
features = []
|
577 |
+
training_window_size = torch.tensor([Hp, Wp])
|
578 |
+
|
579 |
+
rel_pos_bias = self.rel_pos_bias(training_window_size) if self.rel_pos_bias is not None else None
|
580 |
+
|
581 |
+
for i, blk in enumerate(self.blocks):
|
582 |
+
if self.use_checkpoint:
|
583 |
+
x = checkpoint.checkpoint(blk, x, rel_pos_bias, training_window_size)
|
584 |
+
else:
|
585 |
+
x = blk(x, rel_pos_bias=rel_pos_bias, training_window_size=training_window_size)
|
586 |
+
if i in self.out_indices:
|
587 |
+
xp = x[:, 1:, :].permute(0, 2, 1).reshape(B, -1, Hp, Wp)
|
588 |
+
features.append(xp.contiguous())
|
589 |
+
|
590 |
+
ops = [self.fpn1, self.fpn2, self.fpn3, self.fpn4]
|
591 |
+
for i in range(len(features)):
|
592 |
+
features[i] = ops[i](features[i])
|
593 |
+
|
594 |
+
feat_out = {}
|
595 |
+
|
596 |
+
for name, value in zip(self.out_features, features):
|
597 |
+
feat_out[name] = value
|
598 |
+
|
599 |
+
return feat_out
|
600 |
+
|
601 |
+
def forward(self, x):
|
602 |
+
x = self.forward_features(x)
|
603 |
+
return x
|
604 |
+
|
605 |
+
|
606 |
+
def beit_base_patch16(pretrained=False, **kwargs):
|
607 |
+
model = BEiT(
|
608 |
+
patch_size=16,
|
609 |
+
embed_dim=768,
|
610 |
+
depth=12,
|
611 |
+
num_heads=12,
|
612 |
+
mlp_ratio=4,
|
613 |
+
qkv_bias=True,
|
614 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
615 |
+
init_values=None,
|
616 |
+
**kwargs)
|
617 |
+
model.default_cfg = _cfg()
|
618 |
+
return model
|
619 |
+
|
620 |
+
def beit_large_patch16(pretrained=False, **kwargs):
|
621 |
+
model = BEiT(
|
622 |
+
patch_size=16,
|
623 |
+
embed_dim=1024,
|
624 |
+
depth=24,
|
625 |
+
num_heads=16,
|
626 |
+
mlp_ratio=4,
|
627 |
+
qkv_bias=True,
|
628 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
629 |
+
init_values=None,
|
630 |
+
**kwargs)
|
631 |
+
model.default_cfg = _cfg()
|
632 |
+
return model
|
633 |
+
|
634 |
+
def dit_base_patch16(pretrained=False, **kwargs):
|
635 |
+
model = BEiT(
|
636 |
+
patch_size=16,
|
637 |
+
embed_dim=768,
|
638 |
+
depth=12,
|
639 |
+
num_heads=12,
|
640 |
+
mlp_ratio=4,
|
641 |
+
qkv_bias=True,
|
642 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
643 |
+
init_values=0.1,
|
644 |
+
**kwargs)
|
645 |
+
model.default_cfg = _cfg()
|
646 |
+
return model
|
647 |
+
|
648 |
+
def dit_large_patch16(pretrained=False, **kwargs):
|
649 |
+
model = BEiT(
|
650 |
+
patch_size=16,
|
651 |
+
embed_dim=1024,
|
652 |
+
depth=24,
|
653 |
+
num_heads=16,
|
654 |
+
mlp_ratio=4,
|
655 |
+
qkv_bias=True,
|
656 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
657 |
+
init_values=1e-5,
|
658 |
+
**kwargs)
|
659 |
+
model.default_cfg = _cfg()
|
660 |
+
return model
|
661 |
+
|
662 |
+
if __name__ == '__main__':
|
663 |
+
model = BEiT(use_checkpoint=True, use_shared_rel_pos_bias=True)
|
664 |
+
model = model.to("cuda:0")
|
665 |
+
input1 = torch.rand(2, 3, 512, 762).to("cuda:0")
|
666 |
+
input2 = torch.rand(2, 3, 800, 1200).to("cuda:0")
|
667 |
+
input3 = torch.rand(2, 3, 720, 1000).to("cuda:0")
|
668 |
+
output1 = model(input1)
|
669 |
+
output2 = model(input2)
|
670 |
+
output3 = model(input3)
|
671 |
+
print("all done")
|
DiT_Extractor/dit_object_detection/ditod/config.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from detectron2.config import CfgNode as CN
|
2 |
+
|
3 |
+
|
4 |
+
def add_vit_config(cfg):
|
5 |
+
"""
|
6 |
+
Add config for VIT.
|
7 |
+
"""
|
8 |
+
_C = cfg
|
9 |
+
|
10 |
+
_C.MODEL.VIT = CN()
|
11 |
+
|
12 |
+
# CoaT model name.
|
13 |
+
_C.MODEL.VIT.NAME = ""
|
14 |
+
|
15 |
+
# Output features from CoaT backbone.
|
16 |
+
_C.MODEL.VIT.OUT_FEATURES = ["layer3", "layer5", "layer7", "layer11"]
|
17 |
+
|
18 |
+
_C.MODEL.VIT.IMG_SIZE = [224, 224]
|
19 |
+
|
20 |
+
_C.MODEL.VIT.POS_TYPE = "shared_rel"
|
21 |
+
|
22 |
+
_C.MODEL.VIT.DROP_PATH = 0.
|
23 |
+
|
24 |
+
_C.MODEL.VIT.MODEL_KWARGS = "{}"
|
25 |
+
|
26 |
+
_C.SOLVER.OPTIMIZER = "ADAMW"
|
27 |
+
|
28 |
+
_C.SOLVER.BACKBONE_MULTIPLIER = 1.0
|
29 |
+
|
30 |
+
_C.AUG = CN()
|
31 |
+
|
32 |
+
_C.AUG.DETR = False
|
DiT_Extractor/dit_object_detection/ditod/deit.py
ADDED
@@ -0,0 +1,476 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Mostly copy-paste from DINO and timm library:
|
3 |
+
https://github.com/facebookresearch/dino
|
4 |
+
https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
|
5 |
+
"""
|
6 |
+
import warnings
|
7 |
+
|
8 |
+
import math
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
import torch.utils.checkpoint as checkpoint
|
12 |
+
from timm.models.layers import trunc_normal_, drop_path, to_2tuple
|
13 |
+
from functools import partial
|
14 |
+
|
15 |
+
def _cfg(url='', **kwargs):
|
16 |
+
return {
|
17 |
+
'url': url,
|
18 |
+
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
|
19 |
+
'crop_pct': .9, 'interpolation': 'bicubic',
|
20 |
+
'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5),
|
21 |
+
**kwargs
|
22 |
+
}
|
23 |
+
|
24 |
+
class DropPath(nn.Module):
|
25 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
26 |
+
"""
|
27 |
+
|
28 |
+
def __init__(self, drop_prob=None):
|
29 |
+
super(DropPath, self).__init__()
|
30 |
+
self.drop_prob = drop_prob
|
31 |
+
|
32 |
+
def forward(self, x):
|
33 |
+
return drop_path(x, self.drop_prob, self.training)
|
34 |
+
|
35 |
+
def extra_repr(self) -> str:
|
36 |
+
return 'p={}'.format(self.drop_prob)
|
37 |
+
|
38 |
+
|
39 |
+
class Mlp(nn.Module):
|
40 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
41 |
+
super().__init__()
|
42 |
+
out_features = out_features or in_features
|
43 |
+
hidden_features = hidden_features or in_features
|
44 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
45 |
+
self.act = act_layer()
|
46 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
47 |
+
self.drop = nn.Dropout(drop)
|
48 |
+
|
49 |
+
def forward(self, x):
|
50 |
+
x = self.fc1(x)
|
51 |
+
x = self.act(x)
|
52 |
+
x = self.drop(x)
|
53 |
+
x = self.fc2(x)
|
54 |
+
x = self.drop(x)
|
55 |
+
return x
|
56 |
+
|
57 |
+
|
58 |
+
class Attention(nn.Module):
|
59 |
+
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
|
60 |
+
super().__init__()
|
61 |
+
self.num_heads = num_heads
|
62 |
+
head_dim = dim // num_heads
|
63 |
+
# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
|
64 |
+
self.scale = qk_scale or head_dim ** -0.5
|
65 |
+
|
66 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
67 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
68 |
+
self.proj = nn.Linear(dim, dim)
|
69 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
70 |
+
|
71 |
+
def forward(self, x):
|
72 |
+
B, N, C = x.shape
|
73 |
+
q, k, v = self.qkv(x).reshape(B, N, 3, self.num_heads,
|
74 |
+
C // self.num_heads).permute(2, 0, 3, 1, 4)
|
75 |
+
|
76 |
+
attn = (q @ k.transpose(-2, -1)) * self.scale
|
77 |
+
attn = attn.softmax(dim=-1)
|
78 |
+
attn = self.attn_drop(attn)
|
79 |
+
|
80 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
81 |
+
x = self.proj(x)
|
82 |
+
x = self.proj_drop(x)
|
83 |
+
return x
|
84 |
+
|
85 |
+
|
86 |
+
class Block(nn.Module):
|
87 |
+
|
88 |
+
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
|
89 |
+
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
|
90 |
+
super().__init__()
|
91 |
+
self.norm1 = norm_layer(dim)
|
92 |
+
self.attn = Attention(
|
93 |
+
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
|
94 |
+
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
95 |
+
self.drop_path = DropPath(
|
96 |
+
drop_path) if drop_path > 0. else nn.Identity()
|
97 |
+
self.norm2 = norm_layer(dim)
|
98 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
99 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim,
|
100 |
+
act_layer=act_layer, drop=drop)
|
101 |
+
|
102 |
+
def forward(self, x):
|
103 |
+
x = x + self.drop_path(self.attn(self.norm1(x)))
|
104 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
105 |
+
return x
|
106 |
+
|
107 |
+
|
108 |
+
class PatchEmbed(nn.Module):
|
109 |
+
""" Image to Patch Embedding
|
110 |
+
"""
|
111 |
+
|
112 |
+
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
|
113 |
+
super().__init__()
|
114 |
+
img_size = to_2tuple(img_size)
|
115 |
+
patch_size = to_2tuple(patch_size)
|
116 |
+
|
117 |
+
self.window_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
|
118 |
+
|
119 |
+
self.num_patches_w, self.num_patches_h = self.window_size
|
120 |
+
|
121 |
+
self.num_patches = self.window_size[0] * self.window_size[1]
|
122 |
+
self.img_size = img_size
|
123 |
+
self.patch_size = patch_size
|
124 |
+
|
125 |
+
self.proj = nn.Conv2d(in_chans, embed_dim,
|
126 |
+
kernel_size=patch_size, stride=patch_size)
|
127 |
+
|
128 |
+
def forward(self, x):
|
129 |
+
x = self.proj(x)
|
130 |
+
return x
|
131 |
+
|
132 |
+
|
133 |
+
class HybridEmbed(nn.Module):
|
134 |
+
""" CNN Feature Map Embedding
|
135 |
+
Extract feature map from CNN, flatten, project to embedding dim.
|
136 |
+
"""
|
137 |
+
|
138 |
+
def __init__(self, backbone, img_size=224, feature_size=None, in_chans=3, embed_dim=768):
|
139 |
+
super().__init__()
|
140 |
+
assert isinstance(backbone, nn.Module)
|
141 |
+
img_size = to_2tuple(img_size)
|
142 |
+
self.img_size = img_size
|
143 |
+
self.backbone = backbone
|
144 |
+
if feature_size is None:
|
145 |
+
with torch.no_grad():
|
146 |
+
# FIXME this is hacky, but most reliable way of determining the exact dim of the output feature
|
147 |
+
# map for all networks, the feature metadata has reliable channel and stride info, but using
|
148 |
+
# stride to calc feature dim requires info about padding of each stage that isn't captured.
|
149 |
+
training = backbone.training
|
150 |
+
if training:
|
151 |
+
backbone.eval()
|
152 |
+
o = self.backbone(torch.zeros(
|
153 |
+
1, in_chans, img_size[0], img_size[1]))[-1]
|
154 |
+
feature_size = o.shape[-2:]
|
155 |
+
feature_dim = o.shape[1]
|
156 |
+
backbone.train(training)
|
157 |
+
else:
|
158 |
+
feature_size = to_2tuple(feature_size)
|
159 |
+
feature_dim = self.backbone.feature_info.channels()[-1]
|
160 |
+
self.num_patches = feature_size[0] * feature_size[1]
|
161 |
+
self.proj = nn.Linear(feature_dim, embed_dim)
|
162 |
+
|
163 |
+
def forward(self, x):
|
164 |
+
x = self.backbone(x)[-1]
|
165 |
+
x = x.flatten(2).transpose(1, 2)
|
166 |
+
x = self.proj(x)
|
167 |
+
return x
|
168 |
+
|
169 |
+
|
170 |
+
class ViT(nn.Module):
|
171 |
+
""" Vision Transformer with support for patch or hybrid CNN input stage
|
172 |
+
"""
|
173 |
+
|
174 |
+
def __init__(self,
|
175 |
+
model_name='vit_base_patch16_224',
|
176 |
+
img_size=384,
|
177 |
+
patch_size=16,
|
178 |
+
in_chans=3,
|
179 |
+
embed_dim=1024,
|
180 |
+
depth=24,
|
181 |
+
num_heads=16,
|
182 |
+
num_classes=19,
|
183 |
+
mlp_ratio=4.,
|
184 |
+
qkv_bias=True,
|
185 |
+
qk_scale=None,
|
186 |
+
drop_rate=0.1,
|
187 |
+
attn_drop_rate=0.,
|
188 |
+
drop_path_rate=0.,
|
189 |
+
hybrid_backbone=None,
|
190 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
191 |
+
norm_cfg=None,
|
192 |
+
pos_embed_interp=False,
|
193 |
+
random_init=False,
|
194 |
+
align_corners=False,
|
195 |
+
use_checkpoint=False,
|
196 |
+
num_extra_tokens=1,
|
197 |
+
out_features=None,
|
198 |
+
**kwargs,
|
199 |
+
):
|
200 |
+
|
201 |
+
super(ViT, self).__init__()
|
202 |
+
self.model_name = model_name
|
203 |
+
self.img_size = img_size
|
204 |
+
self.patch_size = patch_size
|
205 |
+
self.in_chans = in_chans
|
206 |
+
self.embed_dim = embed_dim
|
207 |
+
self.depth = depth
|
208 |
+
self.num_heads = num_heads
|
209 |
+
self.num_classes = num_classes
|
210 |
+
self.mlp_ratio = mlp_ratio
|
211 |
+
self.qkv_bias = qkv_bias
|
212 |
+
self.qk_scale = qk_scale
|
213 |
+
self.drop_rate = drop_rate
|
214 |
+
self.attn_drop_rate = attn_drop_rate
|
215 |
+
self.drop_path_rate = drop_path_rate
|
216 |
+
self.hybrid_backbone = hybrid_backbone
|
217 |
+
self.norm_layer = norm_layer
|
218 |
+
self.norm_cfg = norm_cfg
|
219 |
+
self.pos_embed_interp = pos_embed_interp
|
220 |
+
self.random_init = random_init
|
221 |
+
self.align_corners = align_corners
|
222 |
+
self.use_checkpoint = use_checkpoint
|
223 |
+
self.num_extra_tokens = num_extra_tokens
|
224 |
+
self.out_features = out_features
|
225 |
+
self.out_indices = [int(name[5:]) for name in out_features]
|
226 |
+
|
227 |
+
# self.num_stages = self.depth
|
228 |
+
# self.out_indices = tuple(range(self.num_stages))
|
229 |
+
|
230 |
+
if self.hybrid_backbone is not None:
|
231 |
+
self.patch_embed = HybridEmbed(
|
232 |
+
self.hybrid_backbone, img_size=self.img_size, in_chans=self.in_chans, embed_dim=self.embed_dim)
|
233 |
+
else:
|
234 |
+
self.patch_embed = PatchEmbed(
|
235 |
+
img_size=self.img_size, patch_size=self.patch_size, in_chans=self.in_chans, embed_dim=self.embed_dim)
|
236 |
+
self.num_patches = self.patch_embed.num_patches
|
237 |
+
|
238 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
|
239 |
+
|
240 |
+
if self.num_extra_tokens == 2:
|
241 |
+
self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
|
242 |
+
|
243 |
+
self.pos_embed = nn.Parameter(torch.zeros(
|
244 |
+
1, self.num_patches + self.num_extra_tokens, self.embed_dim))
|
245 |
+
self.pos_drop = nn.Dropout(p=self.drop_rate)
|
246 |
+
|
247 |
+
# self.num_extra_tokens = self.pos_embed.shape[-2] - self.num_patches
|
248 |
+
dpr = [x.item() for x in torch.linspace(0, self.drop_path_rate,
|
249 |
+
self.depth)] # stochastic depth decay rule
|
250 |
+
self.blocks = nn.ModuleList([
|
251 |
+
Block(
|
252 |
+
dim=self.embed_dim, num_heads=self.num_heads, mlp_ratio=self.mlp_ratio, qkv_bias=self.qkv_bias,
|
253 |
+
qk_scale=self.qk_scale,
|
254 |
+
drop=self.drop_rate, attn_drop=self.attn_drop_rate, drop_path=dpr[i], norm_layer=self.norm_layer)
|
255 |
+
for i in range(self.depth)])
|
256 |
+
|
257 |
+
# NOTE as per official impl, we could have a pre-logits representation dense layer + tanh here
|
258 |
+
# self.repr = nn.Linear(embed_dim, representation_size)
|
259 |
+
# self.repr_act = nn.Tanh()
|
260 |
+
|
261 |
+
if patch_size == 16:
|
262 |
+
self.fpn1 = nn.Sequential(
|
263 |
+
nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
|
264 |
+
nn.SyncBatchNorm(embed_dim),
|
265 |
+
nn.GELU(),
|
266 |
+
nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
|
267 |
+
)
|
268 |
+
|
269 |
+
self.fpn2 = nn.Sequential(
|
270 |
+
nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
|
271 |
+
)
|
272 |
+
|
273 |
+
self.fpn3 = nn.Identity()
|
274 |
+
|
275 |
+
self.fpn4 = nn.MaxPool2d(kernel_size=2, stride=2)
|
276 |
+
elif patch_size == 8:
|
277 |
+
self.fpn1 = nn.Sequential(
|
278 |
+
nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
|
279 |
+
)
|
280 |
+
|
281 |
+
self.fpn2 = nn.Identity()
|
282 |
+
|
283 |
+
self.fpn3 = nn.Sequential(
|
284 |
+
nn.MaxPool2d(kernel_size=2, stride=2),
|
285 |
+
)
|
286 |
+
|
287 |
+
self.fpn4 = nn.Sequential(
|
288 |
+
nn.MaxPool2d(kernel_size=4, stride=4),
|
289 |
+
)
|
290 |
+
|
291 |
+
trunc_normal_(self.pos_embed, std=.02)
|
292 |
+
trunc_normal_(self.cls_token, std=.02)
|
293 |
+
if self.num_extra_tokens==2:
|
294 |
+
trunc_normal_(self.dist_token, std=0.2)
|
295 |
+
self.apply(self._init_weights)
|
296 |
+
# self.fix_init_weight()
|
297 |
+
|
298 |
+
def fix_init_weight(self):
|
299 |
+
def rescale(param, layer_id):
|
300 |
+
param.div_(math.sqrt(2.0 * layer_id))
|
301 |
+
|
302 |
+
for layer_id, layer in enumerate(self.blocks):
|
303 |
+
rescale(layer.attn.proj.weight.data, layer_id + 1)
|
304 |
+
rescale(layer.mlp.fc2.weight.data, layer_id + 1)
|
305 |
+
|
306 |
+
def _init_weights(self, m):
|
307 |
+
if isinstance(m, nn.Linear):
|
308 |
+
trunc_normal_(m.weight, std=.02)
|
309 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
310 |
+
nn.init.constant_(m.bias, 0)
|
311 |
+
elif isinstance(m, nn.LayerNorm):
|
312 |
+
nn.init.constant_(m.bias, 0)
|
313 |
+
nn.init.constant_(m.weight, 1.0)
|
314 |
+
|
315 |
+
'''
|
316 |
+
def init_weights(self):
|
317 |
+
logger = get_root_logger()
|
318 |
+
|
319 |
+
trunc_normal_(self.pos_embed, std=.02)
|
320 |
+
trunc_normal_(self.cls_token, std=.02)
|
321 |
+
self.apply(self._init_weights)
|
322 |
+
|
323 |
+
if self.init_cfg is None:
|
324 |
+
logger.warn(f'No pre-trained weights for '
|
325 |
+
f'{self.__class__.__name__}, '
|
326 |
+
f'training start from scratch')
|
327 |
+
else:
|
328 |
+
assert 'checkpoint' in self.init_cfg, f'Only support ' \
|
329 |
+
f'specify `Pretrained` in ' \
|
330 |
+
f'`init_cfg` in ' \
|
331 |
+
f'{self.__class__.__name__} '
|
332 |
+
logger.info(f"Will load ckpt from {self.init_cfg['checkpoint']}")
|
333 |
+
load_checkpoint(self, filename=self.init_cfg['checkpoint'], strict=False, logger=logger)
|
334 |
+
'''
|
335 |
+
|
336 |
+
def get_num_layers(self):
|
337 |
+
return len(self.blocks)
|
338 |
+
|
339 |
+
@torch.jit.ignore
|
340 |
+
def no_weight_decay(self):
|
341 |
+
return {'pos_embed', 'cls_token'}
|
342 |
+
|
343 |
+
def _conv_filter(self, state_dict, patch_size=16):
|
344 |
+
""" convert patch embedding weight from manual patchify + linear proj to conv"""
|
345 |
+
out_dict = {}
|
346 |
+
for k, v in state_dict.items():
|
347 |
+
if 'patch_embed.proj.weight' in k:
|
348 |
+
v = v.reshape((v.shape[0], 3, patch_size, patch_size))
|
349 |
+
out_dict[k] = v
|
350 |
+
return out_dict
|
351 |
+
|
352 |
+
def to_2D(self, x):
|
353 |
+
n, hw, c = x.shape
|
354 |
+
h = w = int(math.sqrt(hw))
|
355 |
+
x = x.transpose(1, 2).reshape(n, c, h, w)
|
356 |
+
return x
|
357 |
+
|
358 |
+
def to_1D(self, x):
|
359 |
+
n, c, h, w = x.shape
|
360 |
+
x = x.reshape(n, c, -1).transpose(1, 2)
|
361 |
+
return x
|
362 |
+
|
363 |
+
def interpolate_pos_encoding(self, x, w, h):
|
364 |
+
npatch = x.shape[1] - self.num_extra_tokens
|
365 |
+
N = self.pos_embed.shape[1] - self.num_extra_tokens
|
366 |
+
if npatch == N and w == h:
|
367 |
+
return self.pos_embed
|
368 |
+
|
369 |
+
class_ORdist_pos_embed = self.pos_embed[:, 0:self.num_extra_tokens]
|
370 |
+
|
371 |
+
patch_pos_embed = self.pos_embed[:, self.num_extra_tokens:]
|
372 |
+
|
373 |
+
dim = x.shape[-1]
|
374 |
+
w0 = w // self.patch_embed.patch_size[0]
|
375 |
+
h0 = h // self.patch_embed.patch_size[1]
|
376 |
+
# we add a small number to avoid floating point error in the interpolation
|
377 |
+
# see discussion at https://github.com/facebookresearch/dino/issues/8
|
378 |
+
w0, h0 = w0 + 0.1, h0 + 0.1
|
379 |
+
patch_pos_embed = nn.functional.interpolate(
|
380 |
+
patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
|
381 |
+
scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
|
382 |
+
mode='bicubic',
|
383 |
+
)
|
384 |
+
assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1]
|
385 |
+
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
386 |
+
|
387 |
+
return torch.cat((class_ORdist_pos_embed, patch_pos_embed), dim=1)
|
388 |
+
|
389 |
+
def prepare_tokens(self, x, mask=None):
|
390 |
+
B, nc, w, h = x.shape
|
391 |
+
# patch linear embedding
|
392 |
+
x = self.patch_embed(x)
|
393 |
+
|
394 |
+
# mask image modeling
|
395 |
+
if mask is not None:
|
396 |
+
x = self.mask_model(x, mask)
|
397 |
+
x = x.flatten(2).transpose(1, 2)
|
398 |
+
|
399 |
+
# add the [CLS] token to the embed patch tokens
|
400 |
+
all_tokens = [self.cls_token.expand(B, -1, -1)]
|
401 |
+
|
402 |
+
if self.num_extra_tokens == 2:
|
403 |
+
dist_tokens = self.dist_token.expand(B, -1, -1)
|
404 |
+
all_tokens.append(dist_tokens)
|
405 |
+
all_tokens.append(x)
|
406 |
+
|
407 |
+
x = torch.cat(all_tokens, dim=1)
|
408 |
+
|
409 |
+
# add positional encoding to each token
|
410 |
+
x = x + self.interpolate_pos_encoding(x, w, h)
|
411 |
+
|
412 |
+
return self.pos_drop(x)
|
413 |
+
|
414 |
+
def forward_features(self, x):
|
415 |
+
# print(f"==========shape of x is {x.shape}==========")
|
416 |
+
B, _, H, W = x.shape
|
417 |
+
Hp, Wp = H // self.patch_size, W // self.patch_size
|
418 |
+
x = self.prepare_tokens(x)
|
419 |
+
|
420 |
+
features = []
|
421 |
+
for i, blk in enumerate(self.blocks):
|
422 |
+
if self.use_checkpoint:
|
423 |
+
x = checkpoint.checkpoint(blk, x)
|
424 |
+
else:
|
425 |
+
x = blk(x)
|
426 |
+
if i in self.out_indices:
|
427 |
+
xp = x[:, self.num_extra_tokens:, :].permute(0, 2, 1).reshape(B, -1, Hp, Wp)
|
428 |
+
features.append(xp.contiguous())
|
429 |
+
|
430 |
+
ops = [self.fpn1, self.fpn2, self.fpn3, self.fpn4]
|
431 |
+
for i in range(len(features)):
|
432 |
+
features[i] = ops[i](features[i])
|
433 |
+
|
434 |
+
feat_out = {}
|
435 |
+
|
436 |
+
for name, value in zip(self.out_features, features):
|
437 |
+
feat_out[name] = value
|
438 |
+
|
439 |
+
return feat_out
|
440 |
+
|
441 |
+
def forward(self, x):
|
442 |
+
x = self.forward_features(x)
|
443 |
+
return x
|
444 |
+
|
445 |
+
|
446 |
+
def deit_base_patch16(pretrained=False, **kwargs):
|
447 |
+
model = ViT(
|
448 |
+
patch_size=16,
|
449 |
+
drop_rate=0.,
|
450 |
+
embed_dim=768,
|
451 |
+
depth=12,
|
452 |
+
num_heads=12,
|
453 |
+
num_classes=1000,
|
454 |
+
mlp_ratio=4.,
|
455 |
+
qkv_bias=True,
|
456 |
+
use_checkpoint=True,
|
457 |
+
num_extra_tokens=2,
|
458 |
+
**kwargs)
|
459 |
+
model.default_cfg = _cfg()
|
460 |
+
return model
|
461 |
+
|
462 |
+
def mae_base_patch16(pretrained=False, **kwargs):
|
463 |
+
model = ViT(
|
464 |
+
patch_size=16,
|
465 |
+
drop_rate=0.,
|
466 |
+
embed_dim=768,
|
467 |
+
depth=12,
|
468 |
+
num_heads=12,
|
469 |
+
num_classes=1000,
|
470 |
+
mlp_ratio=4.,
|
471 |
+
qkv_bias=True,
|
472 |
+
use_checkpoint=True,
|
473 |
+
num_extra_tokens=1,
|
474 |
+
**kwargs)
|
475 |
+
model.default_cfg = _cfg()
|
476 |
+
return model
|
DiT_Extractor/dit_object_detection/publaynet_configs/Base-RCNN-FPN.yaml
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MODEL:
|
2 |
+
MASK_ON: True
|
3 |
+
META_ARCHITECTURE: "GeneralizedRCNN"
|
4 |
+
PIXEL_MEAN: [123.675, 116.280, 103.530]
|
5 |
+
PIXEL_STD: [58.395, 57.120, 57.375]
|
6 |
+
BACKBONE:
|
7 |
+
NAME: "build_vit_fpn_backbone"
|
8 |
+
VIT:
|
9 |
+
OUT_FEATURES: ["layer3", "layer5", "layer7", "layer11"]
|
10 |
+
DROP_PATH: 0.1
|
11 |
+
IMG_SIZE: [224,224]
|
12 |
+
POS_TYPE: "abs"
|
13 |
+
FPN:
|
14 |
+
IN_FEATURES: ["layer3", "layer5", "layer7", "layer11"]
|
15 |
+
ANCHOR_GENERATOR:
|
16 |
+
SIZES: [[32], [64], [128], [256], [512]] # One size for each in feature map
|
17 |
+
ASPECT_RATIOS: [[0.5, 1.0, 2.0]] # Three aspect ratios (same for all in feature maps)
|
18 |
+
RPN:
|
19 |
+
IN_FEATURES: ["p2", "p3", "p4", "p5", "p6"]
|
20 |
+
PRE_NMS_TOPK_TRAIN: 2000 # Per FPN level
|
21 |
+
PRE_NMS_TOPK_TEST: 1000 # Per FPN level
|
22 |
+
# Detectron1 uses 2000 proposals per-batch,
|
23 |
+
# (See "modeling/rpn/rpn_outputs.py" for details of this legacy issue)
|
24 |
+
# which is approximately 1000 proposals per-image since the default batch size for FPN is 2.
|
25 |
+
POST_NMS_TOPK_TRAIN: 1000
|
26 |
+
POST_NMS_TOPK_TEST: 1000
|
27 |
+
ROI_HEADS:
|
28 |
+
NAME: "StandardROIHeads"
|
29 |
+
IN_FEATURES: ["p2", "p3", "p4", "p5"]
|
30 |
+
NUM_CLASSES: 5
|
31 |
+
ROI_BOX_HEAD:
|
32 |
+
NAME: "FastRCNNConvFCHead"
|
33 |
+
NUM_FC: 2
|
34 |
+
POOLER_RESOLUTION: 7
|
35 |
+
ROI_MASK_HEAD:
|
36 |
+
NAME: "MaskRCNNConvUpsampleHead"
|
37 |
+
NUM_CONV: 4
|
38 |
+
POOLER_RESOLUTION: 14
|
39 |
+
DATASETS:
|
40 |
+
TRAIN: ("publaynet_train",)
|
41 |
+
TEST: ("publaynet_val",)
|
42 |
+
SOLVER:
|
43 |
+
LR_SCHEDULER_NAME: "WarmupCosineLR"
|
44 |
+
AMP:
|
45 |
+
ENABLED: True
|
46 |
+
OPTIMIZER: "ADAMW"
|
47 |
+
BACKBONE_MULTIPLIER: 1.0
|
48 |
+
CLIP_GRADIENTS:
|
49 |
+
ENABLED: True
|
50 |
+
CLIP_TYPE: "full_model"
|
51 |
+
CLIP_VALUE: 1.0
|
52 |
+
NORM_TYPE: 2.0
|
53 |
+
WARMUP_FACTOR: 0.01
|
54 |
+
BASE_LR: 0.0004
|
55 |
+
WEIGHT_DECAY: 0.05
|
56 |
+
IMS_PER_BATCH: 32
|
57 |
+
INPUT:
|
58 |
+
CROP:
|
59 |
+
ENABLED: True
|
60 |
+
TYPE: "absolute_range"
|
61 |
+
SIZE: (384, 600)
|
62 |
+
MIN_SIZE_TRAIN: (480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800)
|
63 |
+
FORMAT: "RGB"
|
64 |
+
DATALOADER:
|
65 |
+
FILTER_EMPTY_ANNOTATIONS: False
|
66 |
+
VERSION: 2
|
67 |
+
AUG:
|
68 |
+
DETR: True
|
69 |
+
SEED: 42
|
DiT_Extractor/dit_object_detection/publaynet_configs/cascade/cascade_dit_base.yaml
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_BASE_: "../Base-RCNN-FPN.yaml"
|
2 |
+
MODEL:
|
3 |
+
PIXEL_MEAN: [ 127.5, 127.5, 127.5 ]
|
4 |
+
PIXEL_STD: [ 127.5, 127.5, 127.5 ]
|
5 |
+
WEIGHTS: "https://layoutlm.blob.core.windows.net/dit/dit-pts/dit-base-224-p16-500k-62d53a.pth"
|
6 |
+
VIT:
|
7 |
+
NAME: "dit_base_patch16"
|
8 |
+
ROI_HEADS:
|
9 |
+
NAME: CascadeROIHeads
|
10 |
+
ROI_BOX_HEAD:
|
11 |
+
CLS_AGNOSTIC_BBOX_REG: True
|
12 |
+
RPN:
|
13 |
+
POST_NMS_TOPK_TRAIN: 2000
|
14 |
+
SOLVER:
|
15 |
+
WARMUP_ITERS: 1000
|
16 |
+
IMS_PER_BATCH: 16
|
17 |
+
MAX_ITER: 60000
|
18 |
+
CHECKPOINT_PERIOD: 2000
|
19 |
+
TEST:
|
20 |
+
EVAL_PERIOD: 2000
|
DiT_Extractor/dit_object_detection/publaynet_configs/cascade/cascade_dit_large.yaml
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_BASE_: "../Base-RCNN-FPN.yaml"
|
2 |
+
MODEL:
|
3 |
+
PIXEL_MEAN: [ 127.5, 127.5, 127.5 ]
|
4 |
+
PIXEL_STD: [ 127.5, 127.5, 127.5 ]
|
5 |
+
WEIGHTS: "https://layoutlm.blob.core.windows.net/dit/dit-pts/dit-large-224-p16-500k-d7a2fb.pth"
|
6 |
+
VIT:
|
7 |
+
NAME: "dit_large_patch16"
|
8 |
+
OUT_FEATURES: [ "layer7", "layer11", "layer15", "layer23" ]
|
9 |
+
DROP_PATH: 0.2
|
10 |
+
FPN:
|
11 |
+
IN_FEATURES: [ "layer7", "layer11", "layer15", "layer23" ]
|
12 |
+
ROI_HEADS:
|
13 |
+
NAME: CascadeROIHeads
|
14 |
+
ROI_BOX_HEAD:
|
15 |
+
CLS_AGNOSTIC_BBOX_REG: True
|
16 |
+
RPN:
|
17 |
+
POST_NMS_TOPK_TRAIN: 2000
|
18 |
+
SOLVER:
|
19 |
+
WARMUP_ITERS: 1000
|
20 |
+
IMS_PER_BATCH: 16
|
21 |
+
MAX_ITER: 60000
|
22 |
+
CHECKPOINT_PERIOD: 2000
|
23 |
+
BASE_LR: 0.0001
|
24 |
+
STEPS: (40000, 53333)
|
25 |
+
AMP:
|
26 |
+
ENABLED: False
|
27 |
+
TEST:
|
28 |
+
EVAL_PERIOD: 2000
|
DiT_Extractor/dit_object_detection/publaynet_configs/maskrcnn/maskrcnn_dit_base.yaml
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_BASE_: "../Base-RCNN-FPN.yaml"
|
2 |
+
MODEL:
|
3 |
+
PIXEL_MEAN: [ 127.5, 127.5, 127.5 ]
|
4 |
+
PIXEL_STD: [ 127.5, 127.5, 127.5 ]
|
5 |
+
WEIGHTS: "https://layoutlm.blob.core.windows.net/dit/dit-pts/dit-base-224-p16-500k-62d53a.pth"
|
6 |
+
VIT:
|
7 |
+
NAME: "dit_base_patch16"
|
8 |
+
SOLVER:
|
9 |
+
WARMUP_ITERS: 1000
|
10 |
+
IMS_PER_BATCH: 16
|
11 |
+
MAX_ITER: 60000
|
12 |
+
CHECKPOINT_PERIOD: 2000
|
13 |
+
TEST:
|
14 |
+
EVAL_PERIOD: 2000
|
15 |
+
OUTPUT_DIR: $AMLT_OUTPUT_DIR
|
DiT_Extractor/dit_object_detection/publaynet_configs/maskrcnn/maskrcnn_dit_large.yaml
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_BASE_: "../Base-RCNN-FPN.yaml"
|
2 |
+
MODEL:
|
3 |
+
PIXEL_MEAN: [ 127.5, 127.5, 127.5 ]
|
4 |
+
PIXEL_STD: [ 127.5, 127.5, 127.5 ]
|
5 |
+
WEIGHTS: "https://layoutlm.blob.core.windows.net/dit/dit-pts/dit-large-224-p16-500k-d7a2fb.pth"
|
6 |
+
VIT:
|
7 |
+
NAME: "dit_large_patch16"
|
8 |
+
OUT_FEATURES: [ "layer7", "layer11", "layer15", "layer23" ]
|
9 |
+
DROP_PATH: 0.2
|
10 |
+
FPN:
|
11 |
+
IN_FEATURES: [ "layer7", "layer11", "layer15", "layer23" ]
|
12 |
+
SOLVER:
|
13 |
+
WARMUP_ITERS: 1000
|
14 |
+
IMS_PER_BATCH: 16
|
15 |
+
MAX_ITER: 60000
|
16 |
+
CHECKPOINT_PERIOD: 2000
|
17 |
+
BASE_LR: 0.0001
|
18 |
+
AMP:
|
19 |
+
ENABLED: False
|
20 |
+
TEST:
|
21 |
+
EVAL_PERIOD: 2000
|
22 |
+
OUTPUT_DIR: "output/publaynet/mask_rcnn/dit_base_multistep_3x_ms"
|
DiT_Extractor/dit_runner.py
ADDED
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2022, Lawrence Livermore National Security, LLC.
|
2 |
+
# All rights reserved.
|
3 |
+
# See the top-level LICENSE and NOTICE files for details.
|
4 |
+
# LLNL-CODE-838964
|
5 |
+
|
6 |
+
# SPDX-License-Identifier: Apache-2.0-with-LLVM-exception
|
7 |
+
|
8 |
+
import cv2
|
9 |
+
from pathlib import Path
|
10 |
+
import torch
|
11 |
+
import json
|
12 |
+
|
13 |
+
from detectron2.config import CfgNode as CN
|
14 |
+
from detectron2.config import get_cfg
|
15 |
+
from detectron2.utils.visualizer import ColorMode, Visualizer
|
16 |
+
from detectron2.data import MetadataCatalog
|
17 |
+
from detectron2.engine import DefaultPredictor
|
18 |
+
|
19 |
+
from pdf2image import convert_from_path
|
20 |
+
|
21 |
+
from PIL import Image
|
22 |
+
import numpy as np
|
23 |
+
|
24 |
+
from dit_object_detection.ditod import add_vit_config
|
25 |
+
import base_utils
|
26 |
+
from pdfminer.layout import LTTextLineHorizontal, LTTextBoxHorizontal, LTAnno, LTChar
|
27 |
+
|
28 |
+
from tokenizers.pre_tokenizers import Whitespace
|
29 |
+
|
30 |
+
import warnings
|
31 |
+
warnings.filterwarnings("ignore")
|
32 |
+
|
33 |
+
dit_path = Path('DiT_Extractor/dit_object_detection')
|
34 |
+
|
35 |
+
cfg = get_cfg()
|
36 |
+
add_vit_config(cfg)
|
37 |
+
cfg.merge_from_file(dit_path / "publaynet_configs/cascade/cascade_dit_base.yaml")
|
38 |
+
|
39 |
+
cfg.MODEL.WEIGHTS = "https://layoutlm.blob.core.windows.net/dit/dit-fts/publaynet_dit-b_cascade.pth"
|
40 |
+
cfg.MODEL.DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
41 |
+
|
42 |
+
predictor = DefaultPredictor(cfg)
|
43 |
+
|
44 |
+
thing_classes = ["text","title","list","table","figure"]
|
45 |
+
thing_map = dict(map(reversed, enumerate(thing_classes)))
|
46 |
+
md = MetadataCatalog.get(cfg.DATASETS.TEST[0])
|
47 |
+
md.set(thing_classes=thing_classes)
|
48 |
+
|
49 |
+
|
50 |
+
def get_pdf_image(pdf_file, page):
|
51 |
+
image = convert_from_path(pdf_file, dpi=200, first_page=page, last_page=page)
|
52 |
+
return image
|
53 |
+
|
54 |
+
def get_characters(subelement):
|
55 |
+
all_chars = []
|
56 |
+
if isinstance(subelement, LTTextLineHorizontal):
|
57 |
+
for char in subelement:
|
58 |
+
if isinstance(char, LTChar):
|
59 |
+
all_chars.append((char.bbox, char.get_text()))
|
60 |
+
if isinstance(char, LTAnno):
|
61 |
+
# No bbox, just a space, so make a thin slice after previous text
|
62 |
+
bbox = all_chars[-1][0]
|
63 |
+
bbox = (bbox[2],bbox[1],bbox[2],bbox[3])
|
64 |
+
all_chars.append((bbox, char.get_text()))
|
65 |
+
return all_chars
|
66 |
+
|
67 |
+
|
68 |
+
def get_dit_preds(pdf, score_threshold=0.5):
|
69 |
+
|
70 |
+
page_count = base_utils.get_pdf_page_count(pdf)
|
71 |
+
|
72 |
+
# Input is numpy array of PIL image
|
73 |
+
page_sizes = base_utils.get_page_sizes(pdf)
|
74 |
+
|
75 |
+
sections = {}
|
76 |
+
viz_images = []
|
77 |
+
page_words = base_utils.get_pdf_words(pdf)
|
78 |
+
for page in range(1, page_count+1): #range(2, page_count + 1):
|
79 |
+
image = get_pdf_image(pdf, page)
|
80 |
+
image = np.array(image[0])
|
81 |
+
# Get prediction
|
82 |
+
output = predictor(image)["instances"]
|
83 |
+
output = output.to('cpu')
|
84 |
+
|
85 |
+
# Visualize predictions
|
86 |
+
v = Visualizer(image[:, :, ::-1],
|
87 |
+
md,
|
88 |
+
scale=1.0,
|
89 |
+
instance_mode=ColorMode.SEGMENTATION)
|
90 |
+
result = v.draw_instance_predictions(output)
|
91 |
+
result_image = result.get_image()[:, :, ::-1]
|
92 |
+
viz_img = Image.fromarray(result_image)
|
93 |
+
viz_images.append(viz_img)
|
94 |
+
|
95 |
+
words = page_words[page-1]
|
96 |
+
|
97 |
+
# Convert from image_size to page size
|
98 |
+
pdf_dimensions = page_sizes[page-1][2:]
|
99 |
+
# Swap height/width
|
100 |
+
pdf_image_size = (output.image_size[1], output.image_size[0])
|
101 |
+
|
102 |
+
scale = np.array(pdf_dimensions) / np.array(pdf_image_size)
|
103 |
+
scale_box = np.hstack((scale,scale))
|
104 |
+
# Words are in page coordinates
|
105 |
+
|
106 |
+
id = 0
|
107 |
+
sections[page-1] = []
|
108 |
+
draw = image.copy()
|
109 |
+
for box_t, clazz, score in zip(output.get('pred_boxes'), output.get('pred_classes'), output.get('scores')):
|
110 |
+
|
111 |
+
if score < score_threshold:
|
112 |
+
continue
|
113 |
+
|
114 |
+
box = box_t.numpy()
|
115 |
+
# Flip along Y axis
|
116 |
+
box[1] = pdf_image_size[1] - box[1]
|
117 |
+
box[3] = pdf_image_size[1] - box[3]
|
118 |
+
# Scale
|
119 |
+
scaled = box * scale_box
|
120 |
+
# This is the correct order
|
121 |
+
scaled = [scaled[0], scaled[3], scaled[2], scaled[1]]
|
122 |
+
if clazz != thing_map['text']:
|
123 |
+
continue
|
124 |
+
|
125 |
+
start = box[0:2].tolist()
|
126 |
+
end = box[2:4].tolist()
|
127 |
+
start = [int(x) for x in start]
|
128 |
+
end = [int(x) for x in end]
|
129 |
+
|
130 |
+
out = {}
|
131 |
+
|
132 |
+
for word in words.copy():
|
133 |
+
if base_utils.partial_overlaps(word[0:4], scaled):
|
134 |
+
if out == {}:
|
135 |
+
id += 1
|
136 |
+
out['coord'] = word[0:4]
|
137 |
+
out['subelements'] = []
|
138 |
+
out['type'] = 'content_block'
|
139 |
+
out['id']= id
|
140 |
+
out['text'] = ''
|
141 |
+
|
142 |
+
out['coord'] = base_utils.union(out['coord'], word[0:4])
|
143 |
+
out['text'] = out['text'] + word[4].get_text()
|
144 |
+
|
145 |
+
characters = get_characters(word[4])
|
146 |
+
out['subelements'].append(characters)
|
147 |
+
words.remove(word)
|
148 |
+
|
149 |
+
if len(out) != 0:
|
150 |
+
sections[page-1].append(out)
|
151 |
+
|
152 |
+
# Write final annotation
|
153 |
+
|
154 |
+
out_name = Path(pdf).name[:-4] + ".json"
|
155 |
+
with open(out_name, 'w', encoding='utf8') as json_out:
|
156 |
+
json.dump(sections, json_out, ensure_ascii=False, indent=4)
|
157 |
+
|
158 |
+
return viz_images
|
DiT_Extractor/sentence_extractor.py
ADDED
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2022, Lawrence Livermore National Security, LLC.
|
2 |
+
# All rights reserved.
|
3 |
+
# See the top-level LICENSE and NOTICE files for details.
|
4 |
+
# LLNL-CODE-838964
|
5 |
+
|
6 |
+
# SPDX-License-Identifier: Apache-2.0-with-LLVM-exception
|
7 |
+
|
8 |
+
import json
|
9 |
+
from tokenizers.pre_tokenizers import Whitespace
|
10 |
+
import base_utils
|
11 |
+
import spacy
|
12 |
+
|
13 |
+
def guess_sentences(tokens, text):
|
14 |
+
sentence_delems = ('.', '?', ').', '!')
|
15 |
+
sentences = []
|
16 |
+
sentence = []
|
17 |
+
maybe_delem = None
|
18 |
+
for token in tokens:
|
19 |
+
# check next token to see if there is space after prev delem
|
20 |
+
if maybe_delem != None:
|
21 |
+
if maybe_delem[1][1] < token[1][0]:
|
22 |
+
sentences.append(sentence)
|
23 |
+
sentence = []
|
24 |
+
maybe_delem = None
|
25 |
+
|
26 |
+
sentence.append(token)
|
27 |
+
if token[0] in sentence_delems:
|
28 |
+
maybe_delem = token
|
29 |
+
if sentence != []:
|
30 |
+
sentences.append(sentence)
|
31 |
+
return sentences
|
32 |
+
|
33 |
+
def spacey_sentences(text):
|
34 |
+
nlp = spacy.blank('en')
|
35 |
+
nlp.add_pipe('sentencizer')
|
36 |
+
sentences = [s.text for s in nlp(text).sents]
|
37 |
+
return sentences
|
38 |
+
|
39 |
+
def add_coords(sentences, all_coords):
|
40 |
+
sentences_out = []
|
41 |
+
for sentence in sentences:
|
42 |
+
new_sentence = []
|
43 |
+
for token in sentence:
|
44 |
+
indexes = token[1]
|
45 |
+
bbox = all_coords[indexes[0]]
|
46 |
+
for i in range(indexes[0]+1, indexes[1]):
|
47 |
+
bbox = base_utils.union(bbox, all_coords[i])
|
48 |
+
new_sentence.append((token[0],token[1],bbox))
|
49 |
+
sentences_out.append(new_sentence)
|
50 |
+
return sentences_out
|
51 |
+
|
52 |
+
def sentence_extract(document):
|
53 |
+
"""
|
54 |
+
Convert extract .PDF result .pkl into tokens with max length of 384 tokens, seperated
|
55 |
+
on sentence delimiter boundaries such as .!?
|
56 |
+
"""
|
57 |
+
max_tokens = 384
|
58 |
+
document_tree = json.load(open(document,'r'))
|
59 |
+
sections_per_page = {}
|
60 |
+
for page_num, page in document_tree.items():
|
61 |
+
# Tokenize per section (rectangular block that was detected by DIT)
|
62 |
+
word_sections = []
|
63 |
+
text_sections = []
|
64 |
+
for section in page:
|
65 |
+
text_sections.append(section['text'])
|
66 |
+
all_text = ''
|
67 |
+
all_coord = []
|
68 |
+
if 'subelements' not in section:
|
69 |
+
continue
|
70 |
+
for subelement in section['subelements']:
|
71 |
+
for char in subelement:
|
72 |
+
all_text += char[1]
|
73 |
+
all_coord.append(char[0])
|
74 |
+
# check for weird characters, e.g. "(cid:206)", "ff", "fi", etc
|
75 |
+
# if string isn't just 1 character, it's an irregular LTChar (character) from pdfminer.
|
76 |
+
# instead of skipping them, we can just create extra duplicate coordinates for the additional characters.
|
77 |
+
if len(char[1]) > 1:
|
78 |
+
bad_char_len = len(char[1])
|
79 |
+
dupe_coord_amt = (bad_char_len - 1)
|
80 |
+
for dupe_i in range(dupe_coord_amt):
|
81 |
+
all_coord.append(char[0])
|
82 |
+
|
83 |
+
pre_tokenizer = Whitespace()
|
84 |
+
|
85 |
+
sentences_pre_tok = spacey_sentences(all_text)
|
86 |
+
sentences = []
|
87 |
+
for sentence in sentences_pre_tok:
|
88 |
+
tokenized = pre_tokenizer.pre_tokenize_str(sentence)
|
89 |
+
sentences.append(tokenized)
|
90 |
+
|
91 |
+
sentences = add_coords(sentences, all_coord)
|
92 |
+
|
93 |
+
word_section = []
|
94 |
+
t = 0
|
95 |
+
for sentence in sentences:
|
96 |
+
t += len(sentence)
|
97 |
+
if t <= max_tokens:
|
98 |
+
word_section += sentence
|
99 |
+
else:
|
100 |
+
word_sections.append(word_section)
|
101 |
+
word_section = sentence
|
102 |
+
t = len(sentence)
|
103 |
+
word_sections.append(word_section)
|
104 |
+
sections = {'text_sections':text_sections, 'word_sections':word_sections}
|
105 |
+
sections_per_page[page_num] = sections
|
106 |
+
return sections_per_page
|
107 |
+
|
108 |
+
def format_output_contexts(sections_per_page):
|
109 |
+
|
110 |
+
all_contexts = {}
|
111 |
+
|
112 |
+
for page_idx in sections_per_page.keys():
|
113 |
+
|
114 |
+
text_sections = sections_per_page[page_idx]['text_sections']
|
115 |
+
word_sections = sections_per_page[page_idx]['word_sections']
|
116 |
+
|
117 |
+
for text_section, word_section in zip(text_sections, word_sections):
|
118 |
+
whitespaced_text = ' '.join([word[0] for word in word_section])
|
119 |
+
words_info = []
|
120 |
+
for word in word_section:
|
121 |
+
words_info.append({'word_text:':word[0], 'char_indices':word[1], 'word_bbox':word[2]})
|
122 |
+
|
123 |
+
context_row = {'text':text_section, 'whitespaced_text':whitespaced_text, 'page_idx':int(page_idx), 'words_info':words_info}
|
124 |
+
context_id = 'context_{0}'.format(len(all_contexts))
|
125 |
+
all_contexts[context_id] = context_row
|
126 |
+
|
127 |
+
return all_contexts
|
128 |
+
|
129 |
+
def get_contexts(json_input):
|
130 |
+
json_output = 'contexts_{0}'.format(json_input)
|
131 |
+
sections_per_page = sentence_extract(json_input)
|
132 |
+
|
133 |
+
all_contexts = format_output_contexts(sections_per_page)
|
134 |
+
|
135 |
+
with open(json_output, 'w', encoding='utf8') as json_out:
|
136 |
+
json.dump(all_contexts, json_out, ensure_ascii=False, indent=4)
|
LICENSE
ADDED
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, August 2022
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright 2018, Lawrence Livermore National Security, LLC
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
202 |
+
|
203 |
+
---- LLVM Exceptions to the Apache 2.0 License ----
|
204 |
+
|
205 |
+
As an exception, if, as a result of your compiling your source code, portions of this Software are embedded into an Object form of such source code, you may redistribute such embedded portions in such Object form without complying with the conditions of Sections 4(a), 4(b) and 4(d) of the License.
|
206 |
+
|
207 |
+
In addition, if you combine or link compiled forms of this Software with software that is licensed under the GPLv2 ("Combined Software") and if a court of competent jurisdiction determines that the patent provision (Section 3), the indemnity provision (Section 9) or other Section of the License conflicts with the conditions of the GPLv2, you may retroactively and prospectively choose to deem waived or otherwise exclude such Section(s) of the License, but only in their entirety and only with respect to the Combined Software.
|
NOTICE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
This work was produced under the auspices of the U.S. Department of
|
2 |
+
Energy by Lawrence Livermore National Laboratory under Contract
|
3 |
+
DE-AC52-07NA27344.
|
4 |
+
|
5 |
+
This work was prepared as an account of work sponsored by an agency of
|
6 |
+
the United States Government. Neither the United States Government nor
|
7 |
+
Lawrence Livermore National Security, LLC, nor any of their employees
|
8 |
+
makes any warranty, expressed or implied, or assumes any legal liability
|
9 |
+
or responsibility for the accuracy, completeness, or usefulness of any
|
10 |
+
information, apparatus, product, or process disclosed, or represents that
|
11 |
+
its use would not infringe privately owned rights.
|
12 |
+
|
13 |
+
Reference herein to any specific commercial product, process, or service
|
14 |
+
by trade name, trademark, manufacturer, or otherwise does not necessarily
|
15 |
+
constitute or imply its endorsement, recommendation, or favoring by the
|
16 |
+
United States Government or Lawrence Livermore National Security, LLC.
|
17 |
+
|
18 |
+
The views and opinions of authors expressed herein do not necessarily
|
19 |
+
state or reflect those of the United States Government or Lawrence
|
20 |
+
Livermore National Security, LLC, and shall not be used for advertising
|
21 |
+
or product endorsement purposes.
|
README.md
CHANGED
@@ -1,8 +1,8 @@
|
|
1 |
---
|
2 |
title: Detect Retrieve Comprehend
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
sdk: gradio
|
7 |
sdk_version: 3.1.7
|
8 |
app_file: app.py
|
@@ -10,4 +10,14 @@ pinned: false
|
|
10 |
license: apache-2.0
|
11 |
---
|
12 |
|
13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
title: Detect Retrieve Comprehend
|
3 |
+
emoji: 📚
|
4 |
+
colorFrom: green
|
5 |
+
colorTo: pink
|
6 |
sdk: gradio
|
7 |
sdk_version: 3.1.7
|
8 |
app_file: app.py
|
|
|
10 |
license: apache-2.0
|
11 |
---
|
12 |
|
13 |
+
# Release
|
14 |
+
|
15 |
+
---
|
16 |
+
|
17 |
+
**Detect, Retrieve, Comprehend** is distributed under the terms of Apache 2.0 license with LLVM exception.
|
18 |
+
|
19 |
+
See [LICENSE]() and [NOTICE]() for details.
|
20 |
+
|
21 |
+
SPDX-License-Identifier: Apache-2.0-with-LLVM-exception
|
22 |
+
|
23 |
+
LLNL-CODE-838964
|
UnifiedQA/demo_QA.py
ADDED
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2022, Lawrence Livermore National Security, LLC.
|
2 |
+
# All rights reserved.
|
3 |
+
# See the top-level LICENSE and NOTICE files for details.
|
4 |
+
# LLNL-CODE-838964
|
5 |
+
|
6 |
+
# SPDX-License-Identifier: Apache-2.0-with-LLVM-exception
|
7 |
+
|
8 |
+
import sys
|
9 |
+
import json
|
10 |
+
from math import ceil
|
11 |
+
|
12 |
+
import torch
|
13 |
+
import numpy as np
|
14 |
+
from torch import tensor
|
15 |
+
from torch.nn.functional import log_softmax
|
16 |
+
from torch.distributions.categorical import Categorical
|
17 |
+
from transformers import T5Tokenizer, T5ForConditionalGeneration
|
18 |
+
|
19 |
+
# load UnifiedQA onto device
|
20 |
+
model_name = "allenai/unifiedqa-v2-t5-large-1363200"
|
21 |
+
tokenizer = T5Tokenizer.from_pretrained(model_name)
|
22 |
+
model = T5ForConditionalGeneration.from_pretrained(model_name)
|
23 |
+
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
24 |
+
model.to(device)
|
25 |
+
|
26 |
+
def get_inputs(contexts_json, ranked_contexts_json):
|
27 |
+
with open(contexts_json, 'rt') as fp:
|
28 |
+
contexts = json.load(fp)
|
29 |
+
|
30 |
+
with open(ranked_contexts_json, 'rt') as fp:
|
31 |
+
ranked_contexts = json.load(fp)
|
32 |
+
|
33 |
+
question_id = list(ranked_contexts.keys())[0]
|
34 |
+
# assert len(questions) == 1, f'JSON should only have 1 question but found {len(questions)}: {questions}'
|
35 |
+
question = ranked_contexts[question_id]['text']
|
36 |
+
context_ids_sorted = ranked_contexts[question_id]['ranks']
|
37 |
+
context_scores = ranked_contexts[question_id]['scores']
|
38 |
+
contexts = [contexts[context_id]['text'] for context_id in context_ids_sorted]
|
39 |
+
|
40 |
+
# returns the question (str) and its contexts (sequence)
|
41 |
+
return question, contexts, context_scores
|
42 |
+
|
43 |
+
def get_tokens(text, tokenizer, max_tokens):
|
44 |
+
return tokenizer.encode_plus(text, return_tensors='pt', max_length=max_tokens, padding='max_length', truncation=True)['input_ids']
|
45 |
+
|
46 |
+
def prepare_inputs(tokenizer, max_tokens, context, question):
|
47 |
+
input_str = f'{question} \\n {context}'
|
48 |
+
inputs = get_tokens(input_str, tokenizer, max_tokens)
|
49 |
+
return inputs
|
50 |
+
|
51 |
+
def get_outputs(model, tokenizer, input_tokens, max_tokens):
|
52 |
+
output_dict = model.generate(input_tokens, output_scores=True, return_dict_in_generate=True, **{'max_length': max_tokens})
|
53 |
+
pred_tokens = output_dict['sequences'].squeeze().tolist()
|
54 |
+
|
55 |
+
# initialize metrics
|
56 |
+
logit_entropy = []
|
57 |
+
sentence_probs = []
|
58 |
+
|
59 |
+
# accumulate metrics over logit_sequence
|
60 |
+
logit_sequence = output_dict['scores'][:-1] # discard end token
|
61 |
+
for logit in logit_sequence:
|
62 |
+
log_probs = log_softmax(logit, dim=-1)
|
63 |
+
|
64 |
+
# update metrics
|
65 |
+
logit_entropy.append(Categorical(log_probs.exp()).entropy())
|
66 |
+
sentence_probs.append(log_probs.max())
|
67 |
+
|
68 |
+
# finish metrics calculation
|
69 |
+
logit_entropy = tensor(logit_entropy)
|
70 |
+
sentence_probs = tensor(sentence_probs)
|
71 |
+
entropy = logit_entropy.mean()
|
72 |
+
sentence_std = 0 if len(logit_sequence) == 1 else sentence_probs.std(unbiased=True).exp()
|
73 |
+
|
74 |
+
# use entropy * sentence_std as uncertainty
|
75 |
+
uncertainty = (entropy * sentence_std).item()
|
76 |
+
|
77 |
+
# convert answer tokens to str
|
78 |
+
pred_str = tokenizer.decode(pred_tokens, skip_special_tokens=True).lower()
|
79 |
+
|
80 |
+
return pred_str, uncertainty
|
81 |
+
|
82 |
+
# k_percent: percentage of contexts to use, cannot be less than min_k or greater than max_k
|
83 |
+
# min_k: minimum number of contexts to use, if possible. Setting this too small reduces recall
|
84 |
+
# max_k: maximum number of contexts to use. Setting this too big reduces precision
|
85 |
+
# recommended uncertainty thresholds are 2,3,4, and 5. The lower the threshold, the more aggressive the filtering
|
86 |
+
def run_model(model, tokenizer, device, question, contexts, context_scores, k_percent=0.1, min_k=10, max_k=25, uncertainty_thresh=3):
|
87 |
+
k = min(max(ceil(k_percent * len(contexts)), min_k), max_k)
|
88 |
+
contexts = contexts[:k]
|
89 |
+
context_scores = context_scores[:k]
|
90 |
+
|
91 |
+
# iterate through top-k contexts
|
92 |
+
answers = []
|
93 |
+
uncertainty = []
|
94 |
+
for context in contexts:
|
95 |
+
input_tokens = prepare_inputs(tokenizer, 512, context, question).to(device)
|
96 |
+
pred_str, uncertainty_1 = get_outputs(model, tokenizer, input_tokens, 512)
|
97 |
+
answers.append(pred_str)
|
98 |
+
uncertainty.append(uncertainty_1)
|
99 |
+
|
100 |
+
# contexts = np.array(contexts)
|
101 |
+
# answers = np.array(answers)
|
102 |
+
# uncertainty = np.array(uncertainty)
|
103 |
+
|
104 |
+
# sort by uncertainty, ascending order
|
105 |
+
# order = np.argsort(uncertainty)
|
106 |
+
# contexts = contexts[order]
|
107 |
+
# answers = answers[order]
|
108 |
+
# uncertainty = uncertainty[order]
|
109 |
+
|
110 |
+
# init lists for threshed answers
|
111 |
+
# weak_contexts = []
|
112 |
+
# weak_answers = []
|
113 |
+
# weak_uncertainty = []
|
114 |
+
|
115 |
+
# filter by uncertainty
|
116 |
+
# if len(answers) > min_k:
|
117 |
+
# weak = np.argwhere(uncertainty > uncertainty_thresh) # exceeds threshold
|
118 |
+
# weak_contexts = contexts[weak].tolist()
|
119 |
+
# weak_answers = answers[weak].tolist()
|
120 |
+
# weak_uncertainty = uncertainty[weak].tolist()
|
121 |
+
|
122 |
+
# strong = np.argwhere(uncertainty <= uncertainty_thresh) # within threshold
|
123 |
+
# contexts = contexts[strong]
|
124 |
+
# answers = answers[strong]
|
125 |
+
# uncertainty = uncertainty[strong]
|
126 |
+
|
127 |
+
# contexts = contexts.tolist()
|
128 |
+
# answers = answers.tolist()
|
129 |
+
# uncertainty = uncertainty.tolist()
|
130 |
+
|
131 |
+
# return {'contexts': contexts, 'answers': answers, 'uncertainty': uncertainty}, \
|
132 |
+
# {'contexts': weak_contexts, 'answers': weak_answers, 'uncertainty': weak_uncertainty}
|
133 |
+
|
134 |
+
return {'contexts': contexts, 'answers': answers, 'context_scores':context_scores, 'uncertainty': uncertainty}
|
135 |
+
|
136 |
+
def get_qa_results(contexts_json, ranked_contexts_json, topk):
|
137 |
+
|
138 |
+
# extract question and contexts from json
|
139 |
+
question, contexts, context_scores = get_inputs(contexts_json, ranked_contexts_json)
|
140 |
+
|
141 |
+
# infer answers
|
142 |
+
with torch.inference_mode(True):
|
143 |
+
# strong_answers, weak_answers = run_model(model, tokenizer, device, question, contexts, k_percent=k_percent)
|
144 |
+
qa_results = run_model(model, tokenizer, device, question, contexts, context_scores, k_percent=1.0, min_k=1, max_k=topk)
|
145 |
+
|
146 |
+
return qa_results
|
147 |
+
|
148 |
+
def get_qa_results_in_memory(contexts, ranked_contexts, topk):
|
149 |
+
|
150 |
+
question_id = list(ranked_contexts.keys())[0]
|
151 |
+
# assert len(questions) == 1, f'JSON should only have 1 question but found {len(questions)}: {questions}'
|
152 |
+
question = ranked_contexts[question_id]['text']
|
153 |
+
context_ids_sorted = ranked_contexts[question_id]['ranks']
|
154 |
+
context_scores = ranked_contexts[question_id]['scores']
|
155 |
+
contexts = [contexts[context_id]['text'] for context_id in context_ids_sorted]
|
156 |
+
|
157 |
+
# infer answers
|
158 |
+
with torch.inference_mode(True):
|
159 |
+
# strong_answers, weak_answers = run_model(model, tokenizer, device, question, contexts, k_percent=k_percent)
|
160 |
+
qa_results = run_model(model, tokenizer, device, question, contexts, context_scores, k_percent=1.0, min_k=1, max_k=topk)
|
161 |
+
|
162 |
+
return qa_results
|
163 |
+
|
164 |
+
def load_custom_model(finetuned_model_path):
|
165 |
+
global tokenizer
|
166 |
+
global model
|
167 |
+
|
168 |
+
# load UnifiedQA onto device
|
169 |
+
tokenizer = T5Tokenizer.from_pretrained(finetuned_model_path)
|
170 |
+
model = T5ForConditionalGeneration.from_pretrained(finetuned_model_path)
|
171 |
+
model.to(device)
|
172 |
+
|
173 |
+
def get_qa_results_in_memory_finetuned_unifiedqa(question, context_scores, contexts, topk):
|
174 |
+
|
175 |
+
# infer answers
|
176 |
+
with torch.inference_mode(True):
|
177 |
+
# strong_answers, weak_answers = run_model(model, tokenizer, device, question, contexts, k_percent=k_percent)
|
178 |
+
qa_results = run_model(model, tokenizer, device, question, contexts, context_scores, k_percent=1.0, min_k=1, max_k=topk)
|
179 |
+
|
180 |
+
return qa_results
|
app.py
ADDED
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2022, Lawrence Livermore National Security, LLC.
|
2 |
+
# All rights reserved.
|
3 |
+
# See the top-level LICENSE and NOTICE files for details.
|
4 |
+
# LLNL-CODE-838964
|
5 |
+
|
6 |
+
# SPDX-License-Identifier: Apache-2.0-with-LLVM-exception
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import gradio as gr
|
10 |
+
from pathlib import Path
|
11 |
+
|
12 |
+
from torchvision.transforms import ToPILImage, ToTensor
|
13 |
+
tensor_to_image = ToPILImage()
|
14 |
+
image_to_tensor = ToTensor()
|
15 |
+
|
16 |
+
import sys
|
17 |
+
sys.path.append('DiT_Extractor/')
|
18 |
+
sys.path.append('CrossEncoder/')
|
19 |
+
sys.path.append('UnifiedQA/')
|
20 |
+
|
21 |
+
import dit_runner
|
22 |
+
import sentence_extractor
|
23 |
+
import cross_encoder
|
24 |
+
import demo_QA
|
25 |
+
|
26 |
+
from torchvision.transforms import ToPILImage
|
27 |
+
tensor_to_image = ToPILImage()
|
28 |
+
|
29 |
+
def run_fn(pdf_file_obj, question_text, input_topk):
|
30 |
+
|
31 |
+
pdf = pdf_file_obj.name
|
32 |
+
viz_images = dit_runner.get_dit_preds(pdf, score_threshold=0.5)
|
33 |
+
entity_json = '{0}.json'.format(Path(pdf).name[:-4])
|
34 |
+
|
35 |
+
sentence_extractor.get_contexts(entity_json)
|
36 |
+
|
37 |
+
contexts_json = 'contexts_{0}'.format(entity_json)
|
38 |
+
# contexts_json = 'contexts_2105u2iwiwxh.03011.json'
|
39 |
+
|
40 |
+
cross_encoder.get_ranked_contexts(contexts_json, question_text)
|
41 |
+
|
42 |
+
ranked_contexts_json = 'ranked_{0}'.format(contexts_json)
|
43 |
+
# ranked_contexts_json = 'ranked_contexts_2105u2iwiwxh.03011.json'
|
44 |
+
|
45 |
+
input_topk = int(input_topk)
|
46 |
+
|
47 |
+
# viz_images = [tensor_to_image(x) for x in torch.randn(4, 3, 256, 256)]
|
48 |
+
|
49 |
+
qa_results = demo_QA.get_qa_results(contexts_json, ranked_contexts_json, input_topk)
|
50 |
+
|
51 |
+
history = [('<<< [Retrieval Score: {0:.02f}] >>> {1}'.format(s, c), a) for c, s, a in zip(qa_results['contexts'], qa_results['context_scores'], qa_results['answers'])]
|
52 |
+
|
53 |
+
# Show in ascending order of score, since results box is already scrolled down.
|
54 |
+
history = history[::-1]
|
55 |
+
|
56 |
+
return viz_images, contexts_json, ranked_contexts_json, history
|
57 |
+
|
58 |
+
demo = gr.Blocks()
|
59 |
+
|
60 |
+
with demo:
|
61 |
+
|
62 |
+
gr.Markdown("<h1><center>Document-based Question Answering</center></h1>")
|
63 |
+
gr.Markdown("<center>This is a supplemental demo for our publication, [Document-based Question Answering](https://www.google.com). In this system, our input is a PDF file with a specific question of interest. The output is a set of most probable answers. There are 4 main components in our deployed pipeline: (1) DiT Layout Analysis (2) Context Extraction (3) Cross-Encoder Retrieval (4) UnifiedQA. See below for example uses with further explanation.</center>")
|
64 |
+
|
65 |
+
with gr.Row():
|
66 |
+
with gr.Column():
|
67 |
+
with gr.Row():
|
68 |
+
input_pdf_file = gr.File(file_count='single', label='PDF File')
|
69 |
+
with gr.Row():
|
70 |
+
input_question_text = gr.Textbox(label='Question')
|
71 |
+
with gr.Row():
|
72 |
+
input_k_percent = gr.Slider(minimum=1, maximum=24, step=1, value=8, label='Top K')
|
73 |
+
with gr.Row():
|
74 |
+
button_run = gr.Button('Run QA on Document')
|
75 |
+
|
76 |
+
gr.Markdown("<h3><center>Summary</center></h3>")
|
77 |
+
with gr.Row():
|
78 |
+
gr.Markdown('''
|
79 |
+
- <u>**DiT - Document Image Transformer**</u>: PDF -> converted into a list of images -> each image receives Entity Predictions
|
80 |
+
- Note that using this computer vision approach allows us to ignore things like *page numbers, footnotes, references*, etc
|
81 |
+
- <u>**Paragraph-based Text Extraction**</u>: DiT Bounding Boxes -> Convert into PDF-Space Coordinates -> Text Extraction using PDFMiner6 -> Tokenize & Sentence Split if tokenizer max length is exceeded
|
82 |
+
- <u>**CrossEncoder Context Retrieval**</u>: All Contexts + Question -> Top K Relevant Contexts best suited for answering question
|
83 |
+
- <u>**UnifiedQA**</u>: Most Relevant Contexts + Supplied Question -> Predict Set of Probable Answers
|
84 |
+
''')
|
85 |
+
|
86 |
+
with gr.Column():
|
87 |
+
with gr.Row():
|
88 |
+
output_gallery = gr.Gallery(label='DiT Predicted Entities')
|
89 |
+
with gr.Row():
|
90 |
+
gr.Markdown('''
|
91 |
+
- The `DiT predicted Entities` output box is scrollable! Scroll to see different page predictions. Note that predictions with confidence scores < 0.5 are not passed forward for text extraction.
|
92 |
+
- If an image is clicked, the output box will switch to a gallery view. To view these outputs in much higher resolution, right-click and choose "open image in new tab"
|
93 |
+
''')
|
94 |
+
with gr.Row():
|
95 |
+
output_contexts = gr.File(label='Detected Contexts', interactive=False)
|
96 |
+
output_ranked_contexts = gr.File(label='CrossEncoder Ranked Contexts', interactive=False)
|
97 |
+
with gr.Row():
|
98 |
+
output_qa_results = gr.Chatbot(color_map=['blue', 'green'], label='UnifiedQA Results').style()
|
99 |
+
|
100 |
+
gr.Markdown("<h3><center>Related Work & Code</center></h3>")
|
101 |
+
gr.Markdown("<center>DiT (Document Image Transformer) - <a href=https://arxiv.org/abs/2203.02378>Arxiv Page</a> | <a href=https://github.com/microsoft/unilm/tree/master/dit>Github Repo</a></center>")
|
102 |
+
gr.Markdown("<center>CrossEncoder - <a href=https://arxiv.org/abs/2203.02378>Arxiv Page</a> | <a href=https://github.com/microsoft/unilm/tree/master/dit>Github Repo</a></center>")
|
103 |
+
gr.Markdown("<center>UnifiedQA - <a href=https://arxiv.org/abs/2005.00700>Arxiv Page</a> | <a href=https://github.com/allenai/unifiedqa>Github Repo</a></center>")
|
104 |
+
|
105 |
+
button_run.click(fn=run_fn, inputs=[input_pdf_file, input_question_text, input_k_percent], outputs=[output_gallery, output_contexts, output_ranked_contexts, output_qa_results])
|
106 |
+
|
107 |
+
examples = [
|
108 |
+
['examples/1909.00694.pdf', 'What is the seed lexicon?', 5],
|
109 |
+
['examples/1909.00694.pdf', 'How big is seed lexicon used for training?', 5],
|
110 |
+
['examples/1810.04805.pdf', 'What is this paper about?', 5],
|
111 |
+
['examples/1810.04805.pdf', 'What is the model size?', 5],
|
112 |
+
['examples/2105.03011.pdf', 'How many questions are in this dataset?', 5],
|
113 |
+
['examples/1909.00694.pdf', 'How are relations used to propagate polarity?', 5],
|
114 |
+
|
115 |
+
]
|
116 |
+
gr.Examples(examples=examples,
|
117 |
+
inputs=[input_pdf_file, input_question_text, input_k_percent])
|
118 |
+
|
119 |
+
# examples = gr.Dataset(components=[input_pdf_file, input_question_text], samples=[[open('examples/1810.04805.pdf', mode='rb'), 'How many parameters are in the model?']])
|
120 |
+
demo.launch(enable_queue=True)
|
env_setup.sh
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
conda create --name llnl_actici_env python=3.9
|
2 |
+
conda activate llnl_actici_env
|
3 |
+
|
4 |
+
conda install pytorch=1.10 torchvision torchaudio cudatoolkit=11.3 -c pytorch
|
5 |
+
|
6 |
+
# For DiT
|
7 |
+
python -m pip install detectron2 -f https://dl.fbaipublicfiles.com/detectron2/wheels/cu113/torch1.10/index.html
|
8 |
+
|
9 |
+
|
10 |
+
# For DiT
|
11 |
+
pip install opencv-python
|
12 |
+
pip install timm
|
13 |
+
pip install pdfminer.six
|
14 |
+
conda install -c conda-forge poppler
|
15 |
+
pip install pdf2image
|
16 |
+
pip install pypdf2
|
17 |
+
pip install spacy
|
18 |
+
# pytesseract, in case we need in future
|
19 |
+
pip install pytesseract
|
20 |
+
|
21 |
+
# For Retrieval & QA
|
22 |
+
pip install transformers==4.20
|
23 |
+
pip install sentence-transformers
|
24 |
+
|
25 |
+
# For Demo
|
26 |
+
pip install gradio
|
27 |
+
|
28 |
+
# If Jupyter is allowed
|
29 |
+
pip install jupyter
|
30 |
+
|
31 |
+
# (Optional, adding this custom env to the base environment's jupyter)
|
32 |
+
python -m ipykernel install --user --name llnl_actici_env --display-name "Python (llnl_actici_env)"
|
examples/1810.04805.pdf
ADDED
Binary file (775 kB). View file
|
|
examples/1909.00694.pdf
ADDED
Binary file (540 kB). View file
|
|
examples/2105.03011.pdf
ADDED
Binary file (507 kB). View file
|
|
ms-marco-electra-base/CEBinaryClassificationEvaluator_MS-Marco_results.csv
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
epoch,steps,Accuracy,Accuracy_Threshold,F1,F1_Threshold,Precision,Recall,Average_Precision
|
2 |
+
0,5000,0.9297070292970703,0.25256121158599854,0.8307839388145314,0.19771124422550201,0.7957875457875457,0.869,0.8904110467492587
|
3 |
+
0,10000,0.939006099390061,0.5306986570358276,0.8460807600950118,0.28808051347732544,0.8058823529411765,0.8905,0.910544278892506
|
4 |
+
0,15000,0.9393060693930607,0.5750397443771362,0.8566081871345029,0.48249387741088867,0.8048351648351648,0.9155,0.9132147986720082
|
5 |
+
0,20000,0.9405059494050595,0.591253936290741,0.8546298558514537,0.570050835609436,0.8356426182513139,0.8745,0.9073685536522613
|
6 |
+
0,25000,0.9436056394360564,0.5074090957641602,0.8603960396039605,0.5057582855224609,0.8519607843137255,0.869,0.9167379821993755
|
7 |
+
0,30000,0.9396060393960604,0.8262588381767273,0.8542471042471043,0.7406325340270996,0.8255597014925373,0.885,0.8979176130668384
|
8 |
+
0,35000,0.9425057494250575,0.46686679124832153,0.8596070915189268,0.28302955627441406,0.8252069917203312,0.897,0.9163289965092976
|
9 |
+
0,40000,0.9417058294170583,0.6763133406639099,0.8575602629656682,0.6603987216949463,0.8357854769814903,0.8805,0.9173776247925393
|
10 |
+
0,45000,0.9426057394260574,0.4643915295600891,0.8605042016806723,0.29147765040397644,0.8277136258660508,0.896,0.9120726077810245
|
11 |
+
0,50000,0.945005499450055,0.5493776798248291,0.8624535315985131,0.4713650643825531,0.855036855036855,0.87,0.9209400105864155
|
12 |
+
0,55000,0.9454054594540546,0.6156725287437439,0.864585893339887,0.5604670643806458,0.8501691638472693,0.8795,0.9206262233464874
|
13 |
+
0,60000,0.9421057894210579,0.39554399251937866,0.8605827112930412,0.3811936378479004,0.8300046446818393,0.8935,0.9193948306076224
|
14 |
+
0,65000,0.9428057194280572,0.5363738536834717,0.8629682313892841,0.32784485816955566,0.8205590622182146,0.91,0.9227492855045069
|
15 |
+
0,70000,0.9438056194380562,0.38333064317703247,0.8628501827040195,0.3524332344532013,0.8413301662707838,0.8855,0.9236299441431376
|
16 |
+
0,75000,0.9468053194680532,0.48936331272125244,0.8696717295443409,0.48936331272125244,0.8525456292026897,0.8875,0.9254413650794524
|
17 |
+
0,80000,0.9454054594540546,0.3127445578575134,0.8651851851851852,0.3127445578575134,0.8546341463414634,0.876,0.9213706944185774
|
18 |
+
0,85000,0.9443055694430557,0.31547677516937256,0.8655280250180418,0.21403872966766357,0.8340287436254057,0.8995,0.9237103419372517
|
19 |
+
0,90000,0.9465053494650535,0.3857932686805725,0.8702401164200824,0.3761560022830963,0.8450306170513424,0.897,0.9258501989030058
|
20 |
+
0,95000,0.9453054694530547,0.3604514002799988,0.8669713735867213,0.29048818349838257,0.8354195642095503,0.901,0.9226658871253511
|
21 |
+
0,100000,0.9453054694530547,0.6748594045639038,0.8686288585786074,0.4552273154258728,0.8329508949059201,0.9075,0.9252677323330876
|
22 |
+
0,105000,0.9435056494350565,0.40062007308006287,0.8639551192145862,0.1210024282336235,0.8112379280070237,0.924,0.9237990563267019
|
23 |
+
0,110000,0.944905509449055,0.4197750985622406,0.8656429942418427,0.27975988388061523,0.8321033210332104,0.902,0.9247201058651281
|
24 |
+
0,115000,0.9464053594640536,0.4172205924987793,0.8698167791706846,0.2961992919445038,0.839851024208566,0.902,0.927117403879296
|
25 |
+
0,120000,0.9474052594740526,0.44686269760131836,0.8712047012732614,0.4383932948112488,0.8536468330134357,0.8895,0.9279628711835812
|
26 |
+
0,125000,0.945005499450055,0.4358792304992676,0.8655339805825243,0.28539055585861206,0.8410377358490566,0.8915,0.9268525722856882
|
27 |
+
0,130000,0.9462053794620537,0.21194982528686523,0.8703747911195989,0.16292141377925873,0.8328003654636821,0.9115,0.925512309638313
|
28 |
+
0,135000,0.9454054594540546,0.2292814701795578,0.8678621991505427,0.11477036774158478,0.82171581769437,0.9195,0.9268551457216524
|
29 |
+
0,140000,0.9482051794820517,0.31556186079978943,0.8758076094759513,0.26744428277015686,0.8398347865993575,0.915,0.9275073681003255
|
30 |
+
0,145000,0.9478052194780522,0.3485147953033447,0.8719556305763203,0.12995882332324982,0.8421052631578947,0.904,0.9278250006342896
|
31 |
+
0,150000,0.9483051694830517,0.32228657603263855,0.8726037369570493,0.21710461378097534,0.8477133427628477,0.899,0.9259328370035781
|
32 |
+
0,155000,0.9474052594740526,0.1903868019580841,0.8731307284129282,0.18298938870429993,0.8434296365330848,0.905,0.9261096325445609
|
33 |
+
0,160000,0.9473052694730527,0.5740681886672974,0.872194660996929,0.17134147882461548,0.8266905508284819,0.923,0.927973529121574
|
34 |
+
0,165000,0.9495050494950505,0.38968273997306824,0.87591956841589,0.34622055292129517,0.8594802694898941,0.893,0.9241440163389828
|
35 |
+
0,170000,0.9459054094590541,0.47478723526000977,0.8706669854171647,0.11328981816768646,0.8341731562070546,0.9105,0.9289979858500923
|
36 |
+
0,175000,0.9473052694730527,0.5903739929199219,0.8703747911195989,0.15506823360919952,0.8328003654636821,0.9115,0.9305074303915251
|
37 |
+
0,180000,0.9463053694630537,0.23235449194908142,0.8702585165498912,0.23235449194908142,0.841982234689107,0.9005,0.9291547676197442
|
38 |
+
0,185000,0.9478052194780522,0.174373060464859,0.8734852157052836,0.171615868806839,0.8476011288805269,0.901,0.9280170204346545
|
39 |
+
0,190000,0.949005099490051,0.5715193748474121,0.8747241971071341,0.5108739137649536,0.8581048581048581,0.892,0.9271410745170057
|
40 |
+
0,195000,0.9461053894610539,0.5194154977798462,0.8679334916864608,0.170893132686615,0.8266968325791855,0.9135,0.9271023702066649
|
41 |
+
0,200000,0.9468053194680532,0.3094758987426758,0.8707931277947754,0.11578939855098724,0.82258781680747,0.925,0.9290083868621436
|
42 |
+
0,205000,0.9461053894610539,0.6028298139572144,0.8679067577113257,0.13052904605865479,0.8202047174009791,0.9215,0.9276186176796931
|
43 |
+
0,210000,0.9459054094590541,0.49049288034439087,0.8694616484040019,0.16249723732471466,0.8303002729754322,0.9125,0.9285170114050436
|
ms-marco-electra-base/README.md
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
license: apache-2.0
|
3 |
+
---
|
4 |
+
# Cross-Encoder for MS Marco
|
5 |
+
|
6 |
+
This model was trained on the [MS Marco Passage Ranking](https://github.com/microsoft/MSMARCO-Passage-Ranking) task.
|
7 |
+
|
8 |
+
The model can be used for Information Retrieval: Given a query, encode the query will all possible passages (e.g. retrieved with ElasticSearch). Then sort the passages in a decreasing order. See [SBERT.net Retrieve & Re-rank](https://www.sbert.net/examples/applications/retrieve_rerank/README.html) for more details. The training code is available here: [SBERT.net Training MS Marco](https://github.com/UKPLab/sentence-transformers/tree/master/examples/training/ms_marco)
|
9 |
+
|
10 |
+
|
11 |
+
## Usage with Transformers
|
12 |
+
|
13 |
+
```python
|
14 |
+
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
15 |
+
import torch
|
16 |
+
|
17 |
+
model = AutoModelForSequenceClassification.from_pretrained('model_name')
|
18 |
+
tokenizer = AutoTokenizer.from_pretrained('model_name')
|
19 |
+
|
20 |
+
features = tokenizer(['How many people live in Berlin?', 'How many people live in Berlin?'], ['Berlin has a population of 3,520,031 registered inhabitants in an area of 891.82 square kilometers.', 'New York City is famous for the Metropolitan Museum of Art.'], padding=True, truncation=True, return_tensors="pt")
|
21 |
+
|
22 |
+
model.eval()
|
23 |
+
with torch.no_grad():
|
24 |
+
scores = model(**features).logits
|
25 |
+
print(scores)
|
26 |
+
```
|
27 |
+
|
28 |
+
|
29 |
+
## Usage with SentenceTransformers
|
30 |
+
|
31 |
+
The usage becomes easier when you have [SentenceTransformers](https://www.sbert.net/) installed. Then, you can use the pre-trained models like this:
|
32 |
+
```python
|
33 |
+
from sentence_transformers import CrossEncoder
|
34 |
+
model = CrossEncoder('model_name', max_length=512)
|
35 |
+
scores = model.predict([('Query', 'Paragraph1'), ('Query', 'Paragraph2') , ('Query', 'Paragraph3')])
|
36 |
+
```
|
37 |
+
|
38 |
+
|
39 |
+
## Performance
|
40 |
+
In the following table, we provide various pre-trained Cross-Encoders together with their performance on the [TREC Deep Learning 2019](https://microsoft.github.io/TREC-2019-Deep-Learning/) and the [MS Marco Passage Reranking](https://github.com/microsoft/MSMARCO-Passage-Ranking/) dataset.
|
41 |
+
|
42 |
+
|
43 |
+
| Model-Name | NDCG@10 (TREC DL 19) | MRR@10 (MS Marco Dev) | Docs / Sec |
|
44 |
+
| ------------- |:-------------| -----| --- |
|
45 |
+
| **Version 2 models** | | |
|
46 |
+
| cross-encoder/ms-marco-TinyBERT-L-2-v2 | 69.84 | 32.56 | 9000
|
47 |
+
| cross-encoder/ms-marco-MiniLM-L-2-v2 | 71.01 | 34.85 | 4100
|
48 |
+
| cross-encoder/ms-marco-MiniLM-L-4-v2 | 73.04 | 37.70 | 2500
|
49 |
+
| cross-encoder/ms-marco-MiniLM-L-6-v2 | 74.30 | 39.01 | 1800
|
50 |
+
| cross-encoder/ms-marco-MiniLM-L-12-v2 | 74.31 | 39.02 | 960
|
51 |
+
| **Version 1 models** | | |
|
52 |
+
| cross-encoder/ms-marco-TinyBERT-L-2 | 67.43 | 30.15 | 9000
|
53 |
+
| cross-encoder/ms-marco-TinyBERT-L-4 | 68.09 | 34.50 | 2900
|
54 |
+
| cross-encoder/ms-marco-TinyBERT-L-6 | 69.57 | 36.13 | 680
|
55 |
+
| cross-encoder/ms-marco-electra-base | 71.99 | 36.41 | 340
|
56 |
+
| **Other models** | | |
|
57 |
+
| nboost/pt-tinybert-msmarco | 63.63 | 28.80 | 2900
|
58 |
+
| nboost/pt-bert-base-uncased-msmarco | 70.94 | 34.75 | 340
|
59 |
+
| nboost/pt-bert-large-msmarco | 73.36 | 36.48 | 100
|
60 |
+
| Capreolus/electra-base-msmarco | 71.23 | 36.89 | 340
|
61 |
+
| amberoad/bert-multilingual-passage-reranking-msmarco | 68.40 | 35.54 | 330
|
62 |
+
| sebastian-hofstaetter/distilbert-cat-margin_mse-T2-msmarco | 72.82 | 37.88 | 720
|
63 |
+
|
64 |
+
Note: Runtime was computed on a V100 GPU.
|
ms-marco-electra-base/config.json
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_name_or_path": "google/electra-base-discriminator",
|
3 |
+
"architectures": [
|
4 |
+
"ElectraForSequenceClassification"
|
5 |
+
],
|
6 |
+
"attention_probs_dropout_prob": 0.1,
|
7 |
+
"embedding_size": 768,
|
8 |
+
"hidden_act": "gelu",
|
9 |
+
"hidden_dropout_prob": 0.1,
|
10 |
+
"hidden_size": 768,
|
11 |
+
"id2label": {
|
12 |
+
"0": "LABEL_0"
|
13 |
+
},
|
14 |
+
"initializer_range": 0.02,
|
15 |
+
"intermediate_size": 3072,
|
16 |
+
"label2id": {
|
17 |
+
"LABEL_0": 0
|
18 |
+
},
|
19 |
+
"layer_norm_eps": 1e-12,
|
20 |
+
"max_position_embeddings": 512,
|
21 |
+
"model_type": "electra",
|
22 |
+
"num_attention_heads": 12,
|
23 |
+
"num_hidden_layers": 12,
|
24 |
+
"pad_token_id": 0,
|
25 |
+
"summary_activation": "gelu",
|
26 |
+
"summary_last_dropout": 0.1,
|
27 |
+
"summary_type": "first",
|
28 |
+
"summary_use_proj": true,
|
29 |
+
"type_vocab_size": 2,
|
30 |
+
"vocab_size": 30522
|
31 |
+
}
|
ms-marco-electra-base/pytorch_model.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c554473d61458bf2969566b1bb464eb280ef7de9cacb6ec787b4fe7f0a9a80d9
|
3 |
+
size 438022601
|
ms-marco-electra-base/special_tokens_map.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"unk_token": "[UNK]", "sep_token": "[SEP]", "pad_token": "[PAD]", "cls_token": "[CLS]", "mask_token": "[MASK]"}
|
ms-marco-electra-base/tokenizer_config.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"do_lower_case": true, "do_basic_tokenize": true, "never_split": null, "unk_token": "[UNK]", "sep_token": "[SEP]", "pad_token": "[PAD]", "cls_token": "[CLS]", "mask_token": "[MASK]", "tokenize_chinese_chars": true, "strip_accents": null, "model_max_length": 512, "name_or_path": "google/electra-base-discriminator"}
|
ms-marco-electra-base/vocab.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
packages.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
poppler-utils
|
requirements.txt
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch==1.10.0
|
2 |
+
torchvision
|
3 |
+
opencv-python
|
4 |
+
timm
|
5 |
+
pdfminer.six
|
6 |
+
pdf2image
|
7 |
+
pypdf2
|
8 |
+
spacy
|
9 |
+
pytesseract
|
10 |
+
transformers==4.20
|
11 |
+
sentence-transformers
|
12 |
+
https://dl.fbaipublicfiles.com/detectron2/wheels/cpu/torch1.10/detectron2-0.6%2Bcpu-cp38-cp38-linux_x86_64.whl
|
13 |
+
gradio
|