pdich2085 commited on
Commit
55095dc
1 Parent(s): c3839c3

update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +68 -34
handler.py CHANGED
@@ -1,54 +1,88 @@
1
- from typing import Dict, Any, List
2
  from PIL import Image
3
  import torch
 
4
  from io import BytesIO
5
- from transformers import VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer
6
 
 
7
 
8
- # Source: https://www.philschmid.de/custom-inference-handler
 
 
 
 
 
 
9
 
 
 
 
10
 
11
- class EndpointHandler:
12
- def __init__(self, path="nlpconnect/vit-gpt2-image-captioning"):
13
- self.model = VisionEncoderDecoderModel.from_pretrained(path)
14
 
15
- # Using ViTImageProcessor instead of ViTFeatureExtractor
16
- self.feature_extractor = ViTImageProcessor.from_pretrained(path)
17
 
18
- self.tokenizer = AutoTokenizer.from_pretrained(path)
 
 
 
 
19
 
20
- self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
- self.model.to(self.device)
22
 
23
- self.max_length = 16
24
- self.num_beams = 4
25
 
26
- def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
27
- """
28
- Args:
29
- data (:obj:):
30
- includes the input image data.
31
- Return:
32
- A :obj:`dict` with the caption.
33
- """
34
- image_bytes = data.get("inputs", None)
35
 
36
- # Convert image bytes to PIL Image
37
- image = Image.open(BytesIO(image_bytes))
38
- if image.mode != "RGB":
39
- image = image.convert(mode="RGB")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
- pixel_values = self.feature_extractor(
42
- images=image, return_tensors="pt"
43
- ).pixel_values
44
- pixel_values = pixel_values.to(self.device)
 
45
 
46
- gen_kwargs = {"max_length": self.max_length, "num_beams": self.num_beams}
47
- output_ids = self.model.generate(pixel_values, **gen_kwargs)
 
 
 
 
 
 
 
48
 
49
- caption = self.tokenizer.decode(output_ids[0], skip_special_tokens=True).strip()
 
 
 
50
 
51
- return {"caption": caption}
 
 
 
 
 
 
 
 
 
 
52
 
53
 
54
  # from typing import Dict, Any, List
 
 
1
  from PIL import Image
2
  import torch
3
+ import base64
4
  from io import BytesIO
5
+ from transformers import BlipForConditionalGeneration, BlipProcessor
6
 
7
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
8
 
9
+ class EndpointHandler():
10
+ def __init__(self):
11
+ self.processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
12
+ self.model = BlipForConditionalGeneration.from_pretrained(
13
+ "Salesforce/blip-image-captioning-large"
14
+ ).to(device)
15
+ self.model.eval()
16
 
17
+ def __call__(self, image_data: str) -> dict:
18
+ try:
19
+ raw_image = Image.open(BytesIO(base64.b64decode(image_data))).convert("RGB")
20
 
21
+ processed_input = self.processor(raw_image, return_tensors="pt").to(device)
 
 
22
 
23
+ with torch.no_grad():
24
+ out = self.model.generate(**processed_input)
25
 
26
+ caption = self.processor.batch_decode(out, skip_special_tokens=True)[0]
27
+ return {"caption": caption}
28
+ except Exception as e:
29
+ print(f"Error during processing: {str(e)}")
30
+ return {"caption": "", "error": str(e)}
31
 
 
 
32
 
 
 
33
 
 
 
 
 
 
 
 
 
 
34
 
35
+ # from typing import Dict, Any, List
36
+ # from PIL import Image
37
+ # import torch
38
+ # from io import BytesIO
39
+ # from transformers import VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer
40
+
41
+
42
+ # # Source: https://www.philschmid.de/custom-inference-handler
43
+
44
+
45
+ # class EndpointHandler:
46
+ # def __init__(self, path="nlpconnect/vit-gpt2-image-captioning"):
47
+ # self.model = VisionEncoderDecoderModel.from_pretrained(path)
48
+
49
+ # # Using ViTImageProcessor instead of ViTFeatureExtractor
50
+ # self.feature_extractor = ViTImageProcessor.from_pretrained(path)
51
+
52
+ # self.tokenizer = AutoTokenizer.from_pretrained(path)
53
 
54
+ # self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
55
+ # self.model.to(self.device)
56
+
57
+ # self.max_length = 16
58
+ # self.num_beams = 4
59
 
60
+ # def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
61
+ # """
62
+ # Args:
63
+ # data (:obj:):
64
+ # includes the input image data.
65
+ # Return:
66
+ # A :obj:`dict` with the caption.
67
+ # """
68
+ # image_bytes = data.get("inputs", None)
69
 
70
+ # # Convert image bytes to PIL Image
71
+ # image = Image.open(BytesIO(image_bytes))
72
+ # if image.mode != "RGB":
73
+ # image = image.convert(mode="RGB")
74
 
75
+ # pixel_values = self.feature_extractor(
76
+ # images=image, return_tensors="pt"
77
+ # ).pixel_values
78
+ # pixel_values = pixel_values.to(self.device)
79
+
80
+ # gen_kwargs = {"max_length": self.max_length, "num_beams": self.num_beams}
81
+ # output_ids = self.model.generate(pixel_values, **gen_kwargs)
82
+
83
+ # caption = self.tokenizer.decode(output_ids[0], skip_special_tokens=True).strip()
84
+
85
+ # return {"caption": caption}
86
 
87
 
88
  # from typing import Dict, Any, List