bhavitvyamalik commited on
Commit
adb2f7c
1 Parent(s): 965885f

force download

Browse files
Files changed (1) hide show
  1. apps/mic.py +2 -2
apps/mic.py CHANGED
@@ -44,7 +44,7 @@ def app(state):
44
 
45
  @st.cache
46
  def load_model(ckpt):
47
- return FlaxCLIPVisionMBartForConditionalGeneration.from_pretrained(ckpt)
48
 
49
  @st.cache
50
  def generate_sequence(pixel_values, lang_code, num_beams, temperature, top_p, do_sample, top_k, max_length):
@@ -54,7 +54,7 @@ def app(state):
54
  output_sequence = tokenizer.batch_decode(output_ids[0], skip_special_tokens=True, max_length=max_length)
55
  return output_sequence
56
 
57
- mic_checkpoints = ["./ckpts/ckpt-51999"] # TODO: Maybe add more checkpoints?
58
  dummy_data = pd.read_csv("reference.tsv", sep="\t")
59
 
60
  first_index = 25
 
44
 
45
  @st.cache
46
  def load_model(ckpt):
47
+ return FlaxCLIPVisionMBartForConditionalGeneration.from_pretrained(ckpt, cache_dir="./", force_download=True)
48
 
49
  @st.cache
50
  def generate_sequence(pixel_values, lang_code, num_beams, temperature, top_p, do_sample, top_k, max_length):
 
54
  output_sequence = tokenizer.batch_decode(output_ids[0], skip_special_tokens=True, max_length=max_length)
55
  return output_sequence
56
 
57
+ mic_checkpoints = ["flax-community/clip-vit-base-patch32_mbart-large-50"] # TODO: Maybe add more checkpoints?
58
  dummy_data = pd.read_csv("reference.tsv", sep="\t")
59
 
60
  first_index = 25