Spaces:
Runtime error
Runtime error
Commit
·
1212b6f
1
Parent(s):
cb08135
Upload 2 files
Browse files- app.py +102 -0
- requirements.txt +19 -0
app.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import requests
|
2 |
+
|
3 |
+
import gradio as gr
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from transformers import ViTFeatureExtractor, AutoTokenizer, CLIPFeatureExtractor, AutoModel, AutoModelForCausalLM
|
7 |
+
from transformers.models.auto.configuration_auto import AutoConfig
|
8 |
+
from src.vision_encoder_decoder import SmallCap, SmallCapConfig
|
9 |
+
from src.gpt2 import ThisGPT2Config, ThisGPT2LMHeadModel
|
10 |
+
from src.utils import prep_strings, postprocess_preds
|
11 |
+
import json
|
12 |
+
|
13 |
+
from src.retrieve_caps import *
|
14 |
+
from PIL import Image
|
15 |
+
from torchvision import transforms
|
16 |
+
|
17 |
+
from src.opt import ThisOPTConfig, ThisOPTForCausalLM
|
18 |
+
|
19 |
+
|
20 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
21 |
+
|
22 |
+
# load feature extractor
|
23 |
+
feature_extractor = CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32")
|
24 |
+
|
25 |
+
# load and configure tokenizer
|
26 |
+
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125M")
|
27 |
+
tokenizer.pad_token = '!'
|
28 |
+
tokenizer.eos_token = '.'
|
29 |
+
|
30 |
+
# load model
|
31 |
+
# AutoConfig.register("this_gpt2", ThisGPT2Config)
|
32 |
+
# AutoModel.register(ThisGPT2Config, ThisGPT2LMHeadModel)
|
33 |
+
# AutoModelForCausalLM.register(ThisGPT2Config, ThisGPT2LMHeadModel)
|
34 |
+
# AutoConfig.register("smallcap", SmallCapConfig)
|
35 |
+
# AutoModel.register(SmallCapConfig, SmallCap)
|
36 |
+
# model = AutoModel.from_pretrained("Yova/SmallCap7M")
|
37 |
+
|
38 |
+
|
39 |
+
AutoConfig.register("this_opt", ThisOPTConfig)
|
40 |
+
AutoModel.register(ThisOPTConfig, ThisOPTForCausalLM)
|
41 |
+
AutoModelForCausalLM.register(ThisOPTConfig, ThisOPTForCausalLM)
|
42 |
+
AutoConfig.register("smallcap", SmallCapConfig)
|
43 |
+
AutoModel.register(SmallCapConfig, SmallCap)
|
44 |
+
model = AutoModel.from_pretrained("Yova/SmallCapOPT7M")
|
45 |
+
|
46 |
+
model= model.to(device)
|
47 |
+
|
48 |
+
template = open('src/template.txt').read().strip() + ' '
|
49 |
+
|
50 |
+
# precompute captions for retrieval
|
51 |
+
captions = json.load(open('datastore/coco_index_captions.json'))
|
52 |
+
retrieval_model, feature_extractor_retrieval = clip.load("RN50x64", device=device)
|
53 |
+
retrieval_index = faiss.read_index('datastore/coco_index')
|
54 |
+
#res = faiss.StandardGpuResources()
|
55 |
+
#retrieval_index = faiss.index_cpu_to_gpu(res, 0, retrieval_index)
|
56 |
+
|
57 |
+
# Download human-readable labels for ImageNet.
|
58 |
+
response = requests.get("https://git.io/JJkYN")
|
59 |
+
labels = response.text.split("\n")
|
60 |
+
|
61 |
+
|
62 |
+
def retrieve_caps(image_embedding, index, k=4):
|
63 |
+
xq = image_embedding.astype(np.float32)
|
64 |
+
faiss.normalize_L2(xq)
|
65 |
+
D, I = index.search(xq, k)
|
66 |
+
return I
|
67 |
+
|
68 |
+
def classify_image(image):
|
69 |
+
inp = transforms.ToTensor()(image)
|
70 |
+
|
71 |
+
pixel_values_retrieval = feature_extractor_retrieval(image).to(device)
|
72 |
+
with torch.no_grad():
|
73 |
+
image_embedding = retrieval_model.encode_image(pixel_values_retrieval.unsqueeze(0)).cpu().numpy()
|
74 |
+
|
75 |
+
nns = retrieve_caps(image_embedding, retrieval_index)[0]
|
76 |
+
caps = [captions[i] for i in nns][:4]
|
77 |
+
|
78 |
+
# prepare prompt
|
79 |
+
decoder_input_ids = prep_strings('', tokenizer, template=template, retrieved_caps=caps, k=4, is_test=True)
|
80 |
+
|
81 |
+
# generate caption
|
82 |
+
pixel_values = feature_extractor(image, return_tensors="pt").pixel_values
|
83 |
+
with torch.no_grad():
|
84 |
+
pred = model.generate(pixel_values.to(device),
|
85 |
+
decoder_input_ids=torch.tensor([decoder_input_ids]).to(device),
|
86 |
+
max_new_tokens=25, no_repeat_ngram_size=0, length_penalty=0,
|
87 |
+
min_length=1, num_beams=3, eos_token_id=tokenizer.eos_token_id)
|
88 |
+
#inp = tf.keras.applications.mobilenet_v2.preprocess_input(inp)
|
89 |
+
#prediction = inception_net.predict(inp).flatten()
|
90 |
+
retrieved_caps="Retrieved captions: \n{}\n{}\n{}\n{}".format(*caps)
|
91 |
+
return retrieved_caps + "\n\n\n Generated caption:\n" + str(postprocess_preds(tokenizer.decode(pred[0]), tokenizer))
|
92 |
+
|
93 |
+
|
94 |
+
image = gr.Image(type="pil")
|
95 |
+
|
96 |
+
textbox = gr.Textbox(placeholder="Retrieved captions and generated caption...", lines=4)
|
97 |
+
|
98 |
+
|
99 |
+
title = "SmallCap Demo"
|
100 |
+
gr.Interface(
|
101 |
+
fn=classify_image, inputs=image, outputs=textbox, title=title
|
102 |
+
).launch(share=True)
|
requirements.txt
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
datasets==2.4.0
|
2 |
+
faiss-gpu==1.7.2
|
3 |
+
h5py==3.7.0
|
4 |
+
huggingface-hub==0.8.1
|
5 |
+
pandas==1.4.3
|
6 |
+
Pillow==9.2.0
|
7 |
+
pyarrow==9.0.0
|
8 |
+
pyparsing==3.0.9
|
9 |
+
PyYAML==6.0
|
10 |
+
tokenizers==0.12.1
|
11 |
+
torch==1.12.1
|
12 |
+
torchaudio==0.12.1
|
13 |
+
torchvision==0.13.1
|
14 |
+
tqdm==4.64.0
|
15 |
+
transformers==4.21.1
|
16 |
+
ftfy
|
17 |
+
regex
|
18 |
+
tqdm
|
19 |
+
git+https://github.com/openai/CLIP.git
|