austinmw commited on
Commit
2b2a664
1 Parent(s): c9bff96

Upload tool

Browse files
Files changed (4) hide show
  1. app.py +4 -0
  2. blip_tool.py +79 -0
  3. requirements.txt +5 -0
  4. tool_config.json +5 -0
app.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from transformers import launch_gradio_demo
2
+ from blip_tool import InstructBLIPImageQuestionAnsweringTool
3
+
4
+ launch_gradio_demo(InstructBLIPImageQuestionAnsweringTool)
blip_tool.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import requests
3
+
4
+ os.environ['TRANSFORMERS_CACHE'] = "/home/ec2-user/SageMaker/blip/cache"
5
+
6
+ from PIL import Image
7
+
8
+ import torch
9
+ from transformers import AutoModelForVision2Seq, AutoProcessor
10
+ from transformers import InstructBlipProcessor, InstructBlipForConditionalGeneration
11
+ from transformers.tools import PipelineTool
12
+ from transformers.tools.base import get_default_device
13
+ from transformers.utils import requires_backends
14
+
15
+ class InstructBLIPImageQuestionAnsweringTool(PipelineTool):
16
+ #default_checkpoint = "Salesforce/blip2-opt-2.7b"
17
+ #default_checkpoint = "Salesforce/instructblip-flan-t5-xl"
18
+ default_checkpoint = "Salesforce/instructblip-vicuna-7b"
19
+ #default_checkpoint = "Salesforce/instructblip-vicuna-13b"
20
+
21
+ description = (
22
+ "This is a tool that answers a question about an image. It takes an input named `image` which should be the "
23
+ "image containing the information, as well as a `question` which should be the question in English. It "
24
+ "returns a text that is the answer to the question."
25
+ )
26
+ name = "image_qa"
27
+ pre_processor_class = AutoProcessor
28
+ model_class = AutoModelForVision2Seq
29
+ inputs = ["image", "text"]
30
+ outputs = ["text"]
31
+
32
+ def __init__(self, *args, **kwargs):
33
+ requires_backends(self, ["vision"])
34
+ super().__init__(*args, **kwargs)
35
+
36
+ def setup(self):
37
+ """
38
+ Instantiates the `pre_processor`, `model` and `post_processor` if necessary.
39
+ """
40
+ if isinstance(self.pre_processor, str):
41
+ self.pre_processor = self.pre_processor_class.from_pretrained(self.pre_processor, **self.hub_kwargs)
42
+
43
+ if isinstance(self.model, str):
44
+ self.model = self.model_class.from_pretrained(self.model, **self.model_kwargs, **self.hub_kwargs, load_in_4bit=True, torch_dtype=torch.float16)
45
+
46
+ if self.post_processor is None:
47
+ self.post_processor = self.pre_processor
48
+ elif isinstance(self.post_processor, str):
49
+ self.post_processor = self.post_processor_class.from_pretrained(self.post_processor, **self.hub_kwargs)
50
+
51
+ if self.device is None:
52
+ if self.device_map is not None:
53
+ self.device = list(self.model.hf_device_map.values())[0]
54
+ else:
55
+ self.device = get_default_device()
56
+
57
+ # if self.device_map is None:
58
+ # self.model.to(self.device)
59
+ self.is_initialized = True
60
+
61
+ def encode(self, image, question: str):
62
+ return self.pre_processor(images=image, text=question, return_tensors="pt").to(device="cuda", dtype=torch.float16)
63
+
64
+ def forward(self, inputs):
65
+ outputs = self.model.generate(
66
+ **inputs,
67
+ #max_new_tokens=50,
68
+ num_beams=5,
69
+ max_new_tokens=256,
70
+ min_length=1,
71
+ top_p=0.9,
72
+ repetition_penalty=1.5,
73
+ length_penalty=1.0,
74
+ temperature=1,
75
+ )
76
+ return outputs
77
+
78
+ def decode(self, outputs):
79
+ return self.pre_processor.batch_decode(outputs, skip_special_tokens=True)[0].strip()
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ requests
2
+ torch
3
+ transformers
4
+ os
5
+ PIL
tool_config.json ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ {
2
+ "description": "This is a tool that answers a question about an image. It takes an input named `image` which should be the image containing the information, as well as a `question` which should be the question in English. It returns a text that is the answer to the question.",
3
+ "name": "image_qa",
4
+ "tool_class": "blip_tool.InstructBLIPImageQuestionAnsweringTool"
5
+ }