themanas021 commited on
Commit
365a9d8
·
1 Parent(s): ebef8da

Create model.py

Browse files
Files changed (1) hide show
  1. model.py +82 -0
model.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os, shutil
3
+ import random
4
+
5
+
6
+ from PIL import Image
7
+ import jax
8
+ from transformers import FlaxVisionEncoderDecoderModel, ViTFeatureExtractor, AutoTokenizer
9
+ from huggingface_hub import hf_hub_download
10
+
11
+
12
+ # create target model directory
13
+ model_dir = './models/'
14
+ os.makedirs(model_dir, exist_ok=True)
15
+
16
+ files_to_download = [
17
+ "config.json",
18
+ "flax_model.msgpack",
19
+ "merges.txt",
20
+ "special_tokens_map.json",
21
+ "tokenizer.json",
22
+ "tokenizer_config.json",
23
+ "vocab.json",
24
+ "preprocessor_config.json",
25
+ ]
26
+
27
+ # copy files from checkpoint hub:
28
+ for fn in files_to_download:
29
+ file_path = hf_hub_download("ydshieh/vit-gpt2-coco-en-ckpts", f"ckpt_epoch_3_step_6900/{fn}")
30
+ shutil.copyfile(file_path, os.path.join(model_dir, fn))
31
+
32
+ model = FlaxVisionEncoderDecoderModel.from_pretrained(model_dir)
33
+ feature_extractor = ViTFeatureExtractor.from_pretrained(model_dir)
34
+ tokenizer = AutoTokenizer.from_pretrained(model_dir)
35
+
36
+ max_length = 16
37
+ num_beams = 4
38
+ gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
39
+
40
+
41
+ @jax.jit
42
+ def generate(pixel_values):
43
+ output_ids = model.generate(pixel_values, **gen_kwargs).sequences
44
+ return output_ids
45
+
46
+
47
+ def predict(image):
48
+
49
+ if image.mode != "RGB":
50
+ image = image.convert(mode="RGB")
51
+
52
+ pixel_values = feature_extractor(images=image, return_tensors="np").pixel_values
53
+
54
+ output_ids = generate(pixel_values)
55
+ preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
56
+ preds = [pred.strip() for pred in preds]
57
+
58
+ return preds[0]
59
+
60
+
61
+ def _compile():
62
+
63
+ image_path = 'samples/val_000000039769.jpg'
64
+ image = Image.open(image_path)
65
+ predict(image)
66
+ image.close()
67
+
68
+
69
+ _compile()
70
+
71
+
72
+ sample_dir = './samples/'
73
+ sample_image_ids = tuple(["None"] + [int(f.replace('COCO_val2017_', '').replace('.jpg', '')) for f in os.listdir(sample_dir) if f.startswith('COCO_val2017_')])
74
+
75
+ with open(os.path.join(sample_dir, "coco-val2017-img-ids.json"), "r", encoding="UTF-8") as fp:
76
+ coco_2017_val_image_ids = json.load(fp)
77
+
78
+
79
+ def get_random_image_id():
80
+
81
+ image_id = random.sample(coco_2017_val_image_ids, k=1)[0]
82
+ return image_id