visheratin commited on
Commit
8e8fbc1
1 Parent(s): 122865a

Update new model

Browse files
Files changed (1) hide show
  1. processing_llava.py +58 -12
processing_llava.py CHANGED
@@ -31,27 +31,72 @@ from transformers.utils import TensorType
31
  import torch
32
  from open_clip.transform import PreprocessCfg, image_transform_v2
33
  from modeling_llava import LlavaForConditionalGeneration
 
 
34
 
35
 
36
  class OpenCLIPImageProcessor:
37
- def __init__(self, config):
38
  cfg = PreprocessCfg(**config)
39
  transform = image_transform_v2(cfg=cfg, is_train=False)
40
  self.transform = transform
 
 
41
 
42
- def __call__(self, image, return_tensors):
43
- if isinstance(image, list):
44
- outputs = []
45
- for item in image:
46
- outputs.append(self.transform(item))
47
- return {
48
- "pixel_values": torch.tensor(outputs),
49
- }
50
- output = self.transform(image)
51
  return {
52
- "pixel_values": output.unsqueeze(0),
53
  }
54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  @property
56
  def model_input_names(self):
57
  return ["pixel_values"]
@@ -75,12 +120,13 @@ class LlavaProcessor:
75
  return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH,
76
  ) -> BatchFeature:
77
  if images is not None:
78
- pixel_values = self.image_processor(images, return_tensors=return_tensors)[
79
  "pixel_values"
80
  ]
81
  pixel_values = pixel_values.to(model.device).to(model.dtype)
82
  image_outputs = model.vision_model(pixel_values)
83
  image_features = model.multi_modal_projector(image_outputs)
 
84
  else:
85
  image_features = None
86
  text_inputs = self.tokenizer(
 
31
  import torch
32
  from open_clip.transform import PreprocessCfg, image_transform_v2
33
  from modeling_llava import LlavaForConditionalGeneration
34
+ from PIL import Image
35
+ import math
36
 
37
 
38
  class OpenCLIPImageProcessor:
39
+ def __init__(self, config, crop_size=384, max_tokens=100):
40
  cfg = PreprocessCfg(**config)
41
  transform = image_transform_v2(cfg=cfg, is_train=False)
42
  self.transform = transform
43
+ self.crop_size = crop_size
44
+ self.max_tokens = max_tokens
45
 
46
+ def __call__(self, image: Image.Image):
47
+ output = self.transform_func(image)
 
 
 
 
 
 
 
48
  return {
49
+ "pixel_values": output,
50
  }
51
 
52
+ def transform_func(self, image: Image.Image):
53
+ outputs = []
54
+ outputs.append(self.transform(image))
55
+ width, height = image.size
56
+ crop_size = self.crop_size
57
+ if width <= crop_size and height <= crop_size:
58
+ outputs = torch.stack(outputs, dim=0)
59
+ return outputs
60
+ total_tokens = math.inf
61
+ while total_tokens > self.max_tokens:
62
+ total_tokens = math.floor(
63
+ (2 * width - crop_size)
64
+ / crop_size
65
+ * (2 * height - crop_size)
66
+ / crop_size
67
+ )
68
+ if total_tokens > self.max_tokens:
69
+ crop_size += 10
70
+ stride = crop_size // 2
71
+ x_steps = int(round((2 * width - crop_size) / crop_size))
72
+ if x_steps < 1:
73
+ x_steps = 1
74
+ y_steps = int(round((2 * height - crop_size) / crop_size))
75
+ if y_steps < 1:
76
+ y_steps = 1
77
+ x_coords = []
78
+ y_coords = []
79
+ for i in range(x_steps):
80
+ x_coords.append([i * stride, i * stride + crop_size])
81
+ if x_coords[-1][1] != width:
82
+ x_coords[-1][1] = width
83
+ for i in range(y_steps):
84
+ y_coords.append([i * stride, i * stride + crop_size])
85
+ if y_coords[-1][1] != height:
86
+ y_coords[-1][1] = height
87
+ image_parts = []
88
+ for i in range(len(x_coords)):
89
+ for j in range(len(y_coords)):
90
+ image_parts.append(
91
+ image.crop(
92
+ (x_coords[i][0], y_coords[j][0], x_coords[i][1], y_coords[j][1])
93
+ )
94
+ )
95
+ for image_part in image_parts:
96
+ outputs.append(self.transform(image_part))
97
+ outputs = torch.stack(outputs, dim=0)
98
+ return outputs
99
+
100
  @property
101
  def model_input_names(self):
102
  return ["pixel_values"]
 
120
  return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH,
121
  ) -> BatchFeature:
122
  if images is not None:
123
+ pixel_values = self.image_processor(images)[
124
  "pixel_values"
125
  ]
126
  pixel_values = pixel_values.to(model.device).to(model.dtype)
127
  image_outputs = model.vision_model(pixel_values)
128
  image_features = model.multi_modal_projector(image_outputs)
129
+ image_features = image_features.unsqueeze(0)
130
  else:
131
  image_features = None
132
  text_inputs = self.tokenizer(