rjiang12 commited on
Commit
7a72735
1 Parent(s): 72f18ac

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -31
app.py CHANGED
@@ -31,40 +31,26 @@ git_model_base.to(device)
31
  # vilt_model.to(device)
32
 
33
  def generate_answer_git(processor, model, image, question):
34
- # # prepare image
35
- # pixel_values = processor(images=image, return_tensors="pt").pixel_values
36
 
37
- # # prepare question
38
- # input_ids = processor(text=question, add_special_tokens=False).input_ids
39
- # input_ids = [processor.tokenizer.cls_token_id] + input_ids
40
- # input_ids = torch.tensor(input_ids).unsqueeze(0)
41
 
42
- # generated_ids = model.generate(pixel_values=pixel_values, input_ids=input_ids, max_length=50, return_dict_in_generate=True, output_scores=True)
43
- # print('scores:')
44
- # print(generated_ids.scores)
45
- # idx = generated_ids.scores[0].argmax(-1).item()
46
- # idx1 = generated_ids.scores[1].argmax(-1).item()
47
- # print(idx, idx1)
48
- # print(model.config.id2label)
49
- # ans = model.config.id2label[idx]
50
- # ans1 = model.config.id2label[idx1]
51
- # print(ans, ans1)
52
- # print('sequences:')
53
- # print(generated_ids.sequences)
54
- # print(generated_ids)
55
- # generated_answer = processor.batch_decode(generated_ids.sequences, skip_special_tokens=True)
56
- # print(generated_answer)
57
-
58
- encoding = processor(images=image, text=question, return_tensors="pt")
59
-
60
- with torch.no_grad():
61
- outputs = model(**encoding)
62
- print(outputs.logits)
63
- predicted_class_idx = outputs.logits[0].argmax(-1).item()
64
- # return model.config.id2label[predicted_class_idx]
65
- print(predicted_class_idx)
66
  print(model.config.id2label)
67
- print(model.config.id2label[predicted_class_idx])
 
 
 
 
68
 
69
 
70
  return 'haha'
 
31
  # vilt_model.to(device)
32
 
33
  def generate_answer_git(processor, model, image, question):
34
+ # prepare image
35
+ pixel_values = processor(images=image, return_tensors="pt").pixel_values
36
 
37
+ # prepare question
38
+ input_ids = processor(text=question, add_special_tokens=False).input_ids
39
+ input_ids = [processor.tokenizer.cls_token_id] + input_ids
40
+ input_ids = torch.tensor(input_ids).unsqueeze(0)
41
 
42
+ generated_ids = model.generate(pixel_values=pixel_values, input_ids=input_ids, max_length=50, return_dict_in_generate=True, output_scores=True)
43
+ print('scores:')
44
+ print(generated_ids.scores)
45
+ idx = generated_ids.scores[0].argmax(-1).item()
46
+ idx1 = generated_ids.scores[1].argmax(-1).item()
47
+ print(idx, idx1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  print(model.config.id2label)
49
+ print('sequences:')
50
+ print(generated_ids.sequences)
51
+ print(generated_ids)
52
+ generated_answer = processor.batch_decode(generated_ids.sequences, skip_special_tokens=True)
53
+ print(generated_answer)
54
 
55
 
56
  return 'haha'