hongyiyang commited on
Commit
02212ce
1 Parent(s): 9c503ae

use LlavaForConditionalGeneration to load instead

Browse files
Files changed (1) hide show
  1. handler.py +14 -4
handler.py CHANGED
@@ -3,9 +3,17 @@ from transformers import pipeline
3
  from PIL import Image
4
  import requests
5
 
 
 
 
6
  class EndpointHandler():
7
  def __init__(self, path="."):
8
- self.pipeline = pipeline("image-to-text", model=path)
 
 
 
 
 
9
 
10
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
11
  """
@@ -24,8 +32,10 @@ class EndpointHandler():
24
  image = Image.open(requests.get(url, stream=True).raw)
25
 
26
 
 
 
27
  # run normal prediction
28
- outputs = self.pipeline(image, prompt=prompt, generate_kwargs={"max_new_tokens": 200})
29
- print(outputs)
30
 
31
- return outputs
 
3
  from PIL import Image
4
  import requests
5
 
6
+ import torch
7
+ from transformers import AutoProcessor, LlavaForConditionalGeneration
8
+
9
  class EndpointHandler():
10
  def __init__(self, path="."):
11
+ self.model = LlavaForConditionalGeneration.from_pretrained(
12
+ path,
13
+ torch_dtype=torch.float16,
14
+ low_cpu_mem_usage=True,
15
+ ).to(0)
16
+ self.processor = AutoProcessor.from_pretrained(path)
17
 
18
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
19
  """
 
32
  image = Image.open(requests.get(url, stream=True).raw)
33
 
34
 
35
+ inputs = self.processor(prompt, image, return_tensors='pt').to(0, torch.float16)
36
+
37
  # run normal prediction
38
+ output = self.model.generate(**inputs, max_new_tokens=200, do_sample=False)
39
+ print(output)
40
 
41
+ return output