saylee-m commited on
Commit
f50a20b
1 Parent(s): 59e7947

added more comments

Browse files
Files changed (1) hide show
  1. app.py +7 -2
app.py CHANGED
@@ -26,8 +26,8 @@ def load_donut_model():
26
  return model, processor
27
 
28
  def load_paligemma_docvqa():
29
- model_id = "google/paligemma-3b-ft-docvqa-896"
30
- # model_id = "google/paligemma-3b-mix-448"
31
  processor = AutoProcessor.from_pretrained(model_id)
32
  model = PaliGemmaForConditionalGeneration.from_pretrained(model_id)
33
  model.to(device)
@@ -53,6 +53,7 @@ def load_models():
53
  }
54
 
55
  loaded_models = load_models()
 
56
 
57
  def base64_encoded_image(image_array):
58
  im = Image.fromarray(image_array)
@@ -108,10 +109,14 @@ def process_document_donut(image_array, question):
108
  return op
109
 
110
  def process_document_pg(image_array, question):
 
111
  model, processor = loaded_models.get("paligemma")
112
 
 
113
  inputs = processor(images=image_array, text=question, return_tensors="pt").to(device)
 
114
  predictions = model.generate(**inputs, max_new_tokens=100)
 
115
  return processor.decode(predictions[0], skip_special_tokens=True)[len(question):].lstrip("\n")
116
 
117
  def process_document_idf(image_array, question):
 
26
  return model, processor
27
 
28
  def load_paligemma_docvqa():
29
+ # model_id = "google/paligemma-3b-ft-docvqa-896"
30
+ model_id = "google/paligemma-3b-mix-448"
31
  processor = AutoProcessor.from_pretrained(model_id)
32
  model = PaliGemmaForConditionalGeneration.from_pretrained(model_id)
33
  model.to(device)
 
53
  }
54
 
55
  loaded_models = load_models()
56
+ print("models loaded")
57
 
58
  def base64_encoded_image(image_array):
59
  im = Image.fromarray(image_array)
 
109
  return op
110
 
111
  def process_document_pg(image_array, question):
112
+ print("called loaded model")
113
  model, processor = loaded_models.get("paligemma")
114
 
115
+ print("converting inputs")
116
  inputs = processor(images=image_array, text=question, return_tensors="pt").to(device)
117
+ print("get predictions")
118
  predictions = model.generate(**inputs, max_new_tokens=100)
119
+ print("returning decoding")
120
  return processor.decode(predictions[0], skip_special_tokens=True)[len(question):].lstrip("\n")
121
 
122
  def process_document_idf(image_array, question):