pain commited on
Commit
58f7258
·
verified ·
1 Parent(s): 1664c3c

Upload 7 files

Browse files
Files changed (5) hide show
  1. .gitignore +16 -0
  2. app.py +46 -37
  3. logo_araclip.png +0 -0
  4. requirements.txt +95 -3
  5. utils.py +9 -21
.gitignore ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ cashed_pickles/*
2
+ photos/*
3
+ .env/*
4
+ */__pycache__/*
5
+ .gradio/*
6
+ */.ipynb_checkpoints/*
7
+ */.vscode/*
8
+ */.git/*
9
+ */.gitignore
10
+ */.gitattributes
11
+ */.gitmodules
12
+ */.gitkeep
13
+ */.gitlab-ci.yml
14
+ */.gitlab/*
15
+ */.github/*
16
+ */
app.py CHANGED
@@ -12,8 +12,8 @@ with gr.Blocks() as demo_araclip:
12
 
13
  gr.Markdown("## Input parameters")
14
 
15
- txt = gr.Textbox(label="Text Query (Caption)")
16
- num = gr.Slider(label="Number of retrieved image", value=1, minimum=1)
17
 
18
 
19
  with gr.Row():
@@ -22,26 +22,15 @@ with gr.Blocks() as demo_araclip:
22
  gr.Markdown("## Retrieved Images")
23
 
24
  gallery = gr.Gallery(
25
- label="Generated images", show_label=True, elem_id="gallery"
26
  , columns=[5], rows=[1], object_fit="contain", height="auto")
27
 
28
 
29
  with gr.Row():
30
- lables = gr.Label(label="Text image similarity")
31
 
32
- with gr.Row():
33
-
34
- with gr.Column(scale=1):
35
- gr.Markdown("<div style='text-align: center; font-size: 24px; font-weight: bold;'>Data Retrieved based on Images Similarity</div>")
36
-
37
- json_output = gr.JSON()
38
-
39
- with gr.Column(scale=1):
40
- gr.Markdown("<div style='text-align: center; font-size: 24px; font-weight: bold;'>Data Retrieved based on Text similarity</div>")
41
- json_text = gr.JSON()
42
-
43
-
44
- btn.click(utils.predict, inputs=[txt, num, dadtaset_select], outputs=[gallery,lables, json_output, json_text])
45
 
46
 
47
  gr.Examples(
@@ -49,7 +38,7 @@ with gr.Blocks() as demo_araclip:
49
  ["وقوف قطة بمخالبها على فأرة حاسوب على المكتب", 10],
50
  ["صحن به شوربة صينية بالخضار، وإلى جانبه بطاطس مقلية وزجاجة ماء", 7]],
51
  inputs=[txt, num, dadtaset_select],
52
- outputs=[gallery,lables, json_output, json_text],
53
  fn=utils.predict,
54
  cache_examples=False,
55
  )
@@ -64,8 +53,8 @@ with gr.Blocks() as demo_mclip:
64
 
65
  gr.Markdown("## Input parameters")
66
 
67
- txt = gr.Textbox(label="Text Query (Caption)")
68
- num = gr.Slider(label="Number of retrieved image", value=1, minimum=1)
69
 
70
  with gr.Row():
71
  btn = gr.Button("Retrieve images", scale=1)
@@ -79,37 +68,57 @@ with gr.Blocks() as demo_mclip:
79
 
80
  lables = gr.Label()
81
 
82
- with gr.Row():
83
-
84
- with gr.Column(scale=1):
85
- gr.Markdown("## Images Retrieved")
86
- json_output = gr.JSON()
87
-
88
- with gr.Column(scale=1):
89
- gr.Markdown("## Text Retrieved")
90
- json_text = gr.JSON()
91
-
92
- btn.click(utils.predict_mclip, inputs=[txt, num, dadtaset_select], outputs=[gallery,lables, json_output, json_text])
93
 
94
  gr.Examples(
95
  examples=[["تخطي لاعب فريق بيتسبرج بايرتس منطقة اللوحة الرئيسية في مباراة بدوري البيسبول", 5],
96
  ["وقوف قطة بمخالبها على فأرة حاسوب على المكتب", 10],
97
  ["صحن به شوربة صينية بالخضار، وإلى جانبه بطاطس مقلية وزجاجة ماء", 7]],
98
  inputs=[txt, num, dadtaset_select],
99
- outputs=[gallery,lables, json_output, json_text],
100
  fn=utils.predict_mclip,
101
  cache_examples=False,
102
  )
103
 
104
 
 
 
 
 
 
 
 
 
105
  # Group the demos in a TabbedInterface
106
  with gr.Blocks() as demo:
107
 
108
- gr.Markdown("<font color=red size=10><center>AraClip: Arabic Image Retrieval Application</center></font>")
109
-
110
- gr.TabbedInterface([demo_araclip, demo_mclip], ["Our Model", "Mclip model"])
111
-
112
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
  if __name__ == "__main__":
114
 
115
  demo.launch()
 
12
 
13
  gr.Markdown("## Input parameters")
14
 
15
+ txt = gr.Textbox(label="Text Query")
16
+ num = gr.Slider(label="Number of retrieved image", value=1, minimum=1, step=1)
17
 
18
 
19
  with gr.Row():
 
22
  gr.Markdown("## Retrieved Images")
23
 
24
  gallery = gr.Gallery(
25
+ show_label=False, elem_id="gallery"
26
  , columns=[5], rows=[1], object_fit="contain", height="auto")
27
 
28
 
29
  with gr.Row():
30
+ lables = gr.Label(label="Text-image similarity")
31
 
32
+
33
+ btn.click(utils.predict, inputs=[txt, num, dadtaset_select], outputs=[gallery,lables])
 
 
 
 
 
 
 
 
 
 
 
34
 
35
 
36
  gr.Examples(
 
38
  ["وقوف قطة بمخالبها على فأرة حاسوب على المكتب", 10],
39
  ["صحن به شوربة صينية بالخضار، وإلى جانبه بطاطس مقلية وزجاجة ماء", 7]],
40
  inputs=[txt, num, dadtaset_select],
41
+ outputs=[gallery,lables],
42
  fn=utils.predict,
43
  cache_examples=False,
44
  )
 
53
 
54
  gr.Markdown("## Input parameters")
55
 
56
+ txt = gr.Textbox(label="Text Query")
57
+ num = gr.Slider(label="Number of retrieved image", value=1, minimum=1, step=1)
58
 
59
  with gr.Row():
60
  btn = gr.Button("Retrieve images", scale=1)
 
68
 
69
  lables = gr.Label()
70
 
71
+ btn.click(utils.predict_mclip, inputs=[txt, num, dadtaset_select], outputs=[gallery,lables])
 
 
 
 
 
 
 
 
 
 
72
 
73
  gr.Examples(
74
  examples=[["تخطي لاعب فريق بيتسبرج بايرتس منطقة اللوحة الرئيسية في مباراة بدوري البيسبول", 5],
75
  ["وقوف قطة بمخالبها على فأرة حاسوب على المكتب", 10],
76
  ["صحن به شوربة صينية بالخضار، وإلى جانبه بطاطس مقلية وزجاجة ماء", 7]],
77
  inputs=[txt, num, dadtaset_select],
78
+ outputs=[gallery,lables],
79
  fn=utils.predict_mclip,
80
  cache_examples=False,
81
  )
82
 
83
 
84
+ # Define custom CSS to increase the size of the tabs
85
+ custom_css = """
86
+ .gr-tabbed-interface .gr-tab {
87
+ font-size: 50px; /* Increase the font size */
88
+ padding: 10px; /* Increase the padding */
89
+ }
90
+ """
91
+
92
  # Group the demos in a TabbedInterface
93
  with gr.Blocks() as demo:
94
 
95
+ # gr.Image("statics/logo_araclip.png")
96
+ gr.Markdown("""
97
+ <center> <img src="/file=statics/logo_araclip.png" alt="Imgur" style="width:200px"></center>
98
+ """)
99
+ gr.Markdown("<center> <font color=red size=10>AraClip: Arabic Image Retrieval Application</font></center>")
100
+
101
+ gr.Markdown("""
102
+ <font size=4> To run the demo 🤗, please select the model, then the dataset you would like to search in, enter a text query, and specify the number of retrieved images.</font>
103
+
104
+ """)
105
+
106
+
107
+
108
+ gr.TabbedInterface([demo_araclip, demo_mclip], ["Our Model", "Mclip model"], css=custom_css)
109
+
110
+ gr.Markdown(
111
+ """
112
+ If you find this work helpful, please help us to ⭐ the repositories in <a href='https://github.com/Arabic-Clip' target='_blank'>Github Organization</a>. Thank you!
113
+
114
+ ---
115
+ 📝 **Citation**
116
+
117
+ To be shared soon.
118
+
119
+ 📋 **License**
120
+ """
121
+ )
122
  if __name__ == "__main__":
123
 
124
  demo.launch()
logo_araclip.png ADDED
requirements.txt CHANGED
@@ -1,5 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  open-clip-torch==2.23.0
2
- transformers==4.36.1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  torch==2.1.1
4
- gradio==4.9.0
5
- multilingual-clip==1.0.10
 
 
 
 
 
 
 
 
 
 
 
1
+ aiofiles==23.2.1
2
+ altair==5.2.0
3
+ annotated-types==0.6.0
4
+ anyio==3.7.1
5
+ attrs==23.1.0
6
+ certifi==2023.11.17
7
+ charset-normalizer==3.3.2
8
+ click==8.1.7
9
+ colorama==0.4.6
10
+ contourpy==1.1.1
11
+ cycler==0.12.1
12
+ exceptiongroup==1.2.0
13
+ fastapi==0.105.0
14
+ ffmpy==0.3.1
15
+ filelock==3.13.1
16
+ fonttools==4.46.0
17
+ fsspec==2023.12.2
18
+ ftfy==6.1.3
19
+ gradio==4.38.1
20
+ gradio-client==1.1.0
21
+ h11==0.14.0
22
+ httpcore==1.0.5
23
+ httpx==0.27.0
24
+ huggingface-hub==0.19.4
25
+ idna==3.6
26
+ importlib-resources==6.1.1
27
+ Jinja2==3.1.2
28
+ jsonschema==4.20.0
29
+ jsonschema-specifications==2023.11.2
30
+ kiwisolver==1.4.5
31
+ markdown-it-py==3.0.0
32
+ MarkupSafe==2.1.3
33
+ matplotlib==3.7.4
34
+ mdurl==0.1.2
35
+ mpmath==1.3.0
36
+ multilingual-clip==1.0.10
37
+ networkx==3.1
38
+ numpy==1.24.4
39
+ nvidia-cublas-cu12==12.1.3.1
40
+ nvidia-cuda-cupti-cu12==12.1.105
41
+ nvidia-cuda-nvrtc-cu12==12.1.105
42
+ nvidia-cuda-runtime-cu12==12.1.105
43
+ nvidia-cudnn-cu12==8.9.2.26
44
+ nvidia-cufft-cu12==11.0.2.54
45
+ nvidia-curand-cu12==10.3.2.106
46
+ nvidia-cusolver-cu12==11.4.5.107
47
+ nvidia-cusparse-cu12==12.1.0.106
48
+ nvidia-nccl-cu12==2.18.1
49
+ nvidia-nvjitlink-cu12==12.3.101
50
+ nvidia-nvtx-cu12==12.1.105
51
  open-clip-torch==2.23.0
52
+ orjson==3.9.10
53
+ packaging==23.2
54
+ pandas==2.0.3
55
+ Pillow==10.1.0
56
+ pkgutil-resolve-name==1.3.10
57
+ protobuf==4.25.1
58
+ pydantic==2.5.2
59
+ pydantic-core==2.14.5
60
+ pydub==0.25.1
61
+ pygments==2.17.2
62
+ pyparsing==3.1.1
63
+ python-dateutil==2.8.2
64
+ python-multipart==0.0.9
65
+ pytz==2023.3.post1
66
+ PyYAML==6.0.1
67
+ referencing==0.32.0
68
+ regex==2023.10.3
69
+ requests==2.31.0
70
+ rich==13.7.0
71
+ rpds-py==0.13.2
72
+ ruff==0.5.4
73
+ safetensors==0.4.1
74
+ semantic-version==2.10.0
75
+ sentencepiece==0.1.99
76
+ shellingham==1.5.4
77
+ six==1.16.0
78
+ sniffio==1.3.0
79
+ starlette==0.27.0
80
+ sympy==1.12
81
+ timm==0.9.12
82
+ tokenizers==0.15.0
83
+ tomlkit==0.12.0
84
+ toolz==0.12.0
85
  torch==2.1.1
86
+ torchvision==0.16.1
87
+ tqdm==4.66.1
88
+ transformers==4.36.1
89
+ triton==2.1.0
90
+ typer==0.12.3
91
+ typing-extensions==4.9.0
92
+ tzdata==2023.3
93
+ urllib3==2.1.0
94
+ uvicorn==0.24.0.post1
95
+ wcwidth==0.2.12
96
+ websockets==11.0.3
97
+ zipp==3.17.0
utils.py CHANGED
@@ -106,32 +106,20 @@ def find_image(language_model,clip_model, text_query, dataset, image_features, t
106
  probs = txt_logits.softmax(dim=-1).cpu().detach().numpy().T
107
 
108
  file_paths = []
109
- labels, json_data = {}, {}
110
 
111
  for i in range(1, num+1):
112
  idx = np.argsort(probs, axis=0)[-i, 0]
113
  path = images_path + dataset.get_image_name(idx)
114
 
115
- path_l = (path,f"{sorted_data[idx]['caption_ar']}")
116
 
117
  labels[f" Image # {i}"] = probs[idx]
118
- json_data[f" Image # {i}"] = sorted_data[idx]
119
 
120
  file_paths.append(path_l)
121
 
122
 
123
- json_text = {}
124
-
125
- for _, txt_logits_full in text_logits.items():
126
-
127
- probs_text = txt_logits_full.softmax(dim=-1).cpu().detach().numpy().T
128
-
129
- for j in range(1, num+1):
130
-
131
- idx = np.argsort(probs_text, axis=0)[-j, 0]
132
- json_text[f" Text # {j}"] = sorted_data[idx]
133
-
134
- return file_paths, labels, json_data, json_text
135
 
136
 
137
 
@@ -163,12 +151,12 @@ araclip = AraClip()
163
  def predict(text, num, dadtaset_select):
164
 
165
  if dadtaset_select == "XTD dataset":
166
- image_paths, labels, json_data, json_text = find_image(araclip.language_model,araclip.clip_model, text, araclip.load_xtd_dataset(), araclip.load_pickle_file("cashed_pickles/XTD_pickles/araclip/image_features_XTD_1000_images_arabert_siglib_best_model.pickle") , araclip.load_pickle_file("cashed_pickles/XTD_pickles/araclip/image_features_XTD_1000_images_arabert_siglib_best_model.pickle"), araclip.sorted_data_xtd, 'photos/XTD10_dataset/', num=int(num))
167
 
168
  else:
169
- image_paths, labels, json_data, json_text = find_image(araclip.language_model,araclip.clip_model, text, araclip.load_flicker8k_dataset(), araclip.load_pickle_file("cashed_pickles/flicker_8k/araclip/image_features_flicker_8k_images_arabert_siglib_best_model.pickle") , araclip.load_pickle_file("cashed_pickles/flicker_8k/araclip/text_features_flicker_8k_images_arabert_siglib_best_model.pickle"), araclip.sorted_data_flicker8k, "photos/Flicker8k_Dataset/", num=int(num))
170
 
171
- return image_paths, labels, json_data, json_text
172
 
173
 
174
  class Mclip():
@@ -203,10 +191,10 @@ def predict_mclip(text, num, dadtaset_select):
203
 
204
 
205
  if dadtaset_select == "XTD dataset":
206
- image_paths, labels, json_data, json_text = find_image(mclip.language_model_mclip,mclip.clip_model_mclip, text, mclip.load_xtd_dataset() , mclip.load_pickle_file("cashed_pickles/XTD_pickles/mclip/image_features_XTD_1000_images_XLM_Roberta_Large_Vit_B_16Plus_ar.pickle") , mclip.load_pickle_file("cashed_pickles/XTD_pickles/mclip/text_features_XTD_1000_images_XLM_Roberta_Large_Vit_B_16Plus_ar.pickle") , mclip.sorted_data_xtd , 'photos/XTD10_dataset/', num=int(num))
207
 
208
  else:
209
- image_paths, labels, json_data, json_text = find_image(mclip.language_model_mclip,mclip.clip_model_mclip, text, mclip.load_flicker8k_dataset() , mclip.load_pickle_file("cashed_pickles/flicker_8k/mclip/image_features_flicker_8k_images_XLM_Roberta_Large_Vit_B_16Plus_ar.pickle") , mclip.load_pickle_file("cashed_pickles/flicker_8k/mclip/text_features_flicker_8k_images_XLM_Roberta_Large_Vit_B_16Plus_ar.pickle") , mclip.sorted_data_flicker8k , 'photos/Flicker8k_Dataset/', num=int(num))
210
 
211
 
212
- return image_paths, labels, json_data, json_text
 
106
  probs = txt_logits.softmax(dim=-1).cpu().detach().numpy().T
107
 
108
  file_paths = []
109
+ labels = {}
110
 
111
  for i in range(1, num+1):
112
  idx = np.argsort(probs, axis=0)[-i, 0]
113
  path = images_path + dataset.get_image_name(idx)
114
 
115
+ path_l = (path, "")
116
 
117
  labels[f" Image # {i}"] = probs[idx]
 
118
 
119
  file_paths.append(path_l)
120
 
121
 
122
+ return file_paths, labels
 
 
 
 
 
 
 
 
 
 
 
123
 
124
 
125
 
 
151
  def predict(text, num, dadtaset_select):
152
 
153
  if dadtaset_select == "XTD dataset":
154
+ image_paths, labels = find_image(araclip.language_model,araclip.clip_model, text, araclip.load_xtd_dataset(), araclip.load_pickle_file("cashed_pickles/XTD_pickles/araclip/image_features_XTD_1000_images_arabert_siglib_best_model.pickle") , araclip.load_pickle_file("cashed_pickles/XTD_pickles/araclip/image_features_XTD_1000_images_arabert_siglib_best_model.pickle"), araclip.sorted_data_xtd, 'photos/XTD10_dataset/', num=int(num))
155
 
156
  else:
157
+ image_paths, labels = find_image(araclip.language_model,araclip.clip_model, text, araclip.load_flicker8k_dataset(), araclip.load_pickle_file("cashed_pickles/flicker_8k/araclip/image_features_flicker_8k_images_arabert_siglib_best_model.pickle") , araclip.load_pickle_file("cashed_pickles/flicker_8k/araclip/text_features_flicker_8k_images_arabert_siglib_best_model.pickle"), araclip.sorted_data_flicker8k, "photos/Flicker8k_Dataset/", num=int(num))
158
 
159
+ return image_paths, labels
160
 
161
 
162
  class Mclip():
 
191
 
192
 
193
  if dadtaset_select == "XTD dataset":
194
+ image_paths, labels = find_image(mclip.language_model_mclip,mclip.clip_model_mclip, text, mclip.load_xtd_dataset() , mclip.load_pickle_file("cashed_pickles/XTD_pickles/mclip/image_features_XTD_1000_images_XLM_Roberta_Large_Vit_B_16Plus_ar.pickle") , mclip.load_pickle_file("cashed_pickles/XTD_pickles/mclip/text_features_XTD_1000_images_XLM_Roberta_Large_Vit_B_16Plus_ar.pickle") , mclip.sorted_data_xtd , 'photos/XTD10_dataset/', num=int(num))
195
 
196
  else:
197
+ image_paths, labels = find_image(mclip.language_model_mclip,mclip.clip_model_mclip, text, mclip.load_flicker8k_dataset() , mclip.load_pickle_file("cashed_pickles/flicker_8k/mclip/image_features_flicker_8k_images_XLM_Roberta_Large_Vit_B_16Plus_ar.pickle") , mclip.load_pickle_file("cashed_pickles/flicker_8k/mclip/text_features_flicker_8k_images_XLM_Roberta_Large_Vit_B_16Plus_ar.pickle") , mclip.sorted_data_flicker8k , 'photos/Flicker8k_Dataset/', num=int(num))
198
 
199
 
200
+ return image_paths, labels