ydshieh commited on
Commit
943681e
1 Parent(s): 1a9cd94

try init model

Browse files
Files changed (2) hide show
  1. app.py +1 -1
  2. model.py +38 -1
app.py CHANGED
@@ -44,7 +44,7 @@ st.sidebar.write('\n')
44
 
45
  with st.spinner('Generating image caption ...'):
46
 
47
- caption, tokens, token_ids = predict(image)
48
 
49
  st.success(f'caption: {caption}')
50
  st.success(f'tokens: {tokens}')
 
44
 
45
  with st.spinner('Generating image caption ...'):
46
 
47
+ caption, tokens, token_ids = predict_dummy(image)
48
 
49
  st.success(f'caption: {caption}')
50
  st.success(f'tokens: {tokens}')
model.py CHANGED
@@ -25,9 +25,46 @@ shutil.copyfile(filepath, os.path.join(model_dir, 'flax_model.msgpack'))
25
 
26
  flax_vit_gpt2_lm = FlaxViTGPT2LMForConditionalGeneration.from_pretrained(model_dir)
27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  def predict(image):
29
- return 'dummy caption!', ['dummy', 'caption', '!'], [1, 2, 3]
30
 
 
 
 
31
 
 
 
32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
 
 
25
 
26
  flax_vit_gpt2_lm = FlaxViTGPT2LMForConditionalGeneration.from_pretrained(model_dir)
27
 
28
+ vit_model_name = 'google/vit-base-patch16-224-in21k'
29
+ feature_extractor = ViTFeatureExtractor.from_pretrained(vit_model_name)
30
+
31
+ gpt2_model_name = 'asi/gpt-fr-cased-small'
32
+ tokenizer = GPT2Tokenizer.from_pretrained(gpt2_model_name)
33
+
34
+ max_length = 32
35
+ num_beams = 8
36
+ gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
37
+
38
+
39
+ @jax.jit
40
+ def predict_fn(pixel_values):
41
+
42
+ return flax_vit_gpt2_lm.generate(pixel_values, **gen_kwargs)
43
+
44
  def predict(image):
 
45
 
46
+ # batch dim is added automatically
47
+ encoder_inputs = feature_extractor(images=image, return_tensors="jax")
48
+ pixel_values = encoder_inputs.pixel_values
49
 
50
+ # generation
51
+ generation = predict_fn(pixel_values)
52
 
53
+ token_ids = np.array(generation.sequences)[0]
54
+ caption = tokenizer.decode(token_ids)
55
+
56
+ return caption, token_ids
57
+
58
+ def init():
59
+
60
+ image_path = 'samples/val_000000039769.jpg'
61
+ image = Image.open(image_path)
62
+
63
+ caption, token_ids = predict(image)
64
+ image.close()
65
+
66
+ def predict_dummy(image):
67
+
68
+ return 'dummy caption!', ['dummy', 'caption', '!'], [1, 2, 3]
69
 
70
+ init()