sergeipetrov commited on
Commit
a712455
·
verified ·
1 Parent(s): e6b188b

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +19 -33
handler.py CHANGED
@@ -1,5 +1,5 @@
1
  from typing import Dict, List, Any
2
- from transformers import DonutProcessor, VisionEncoderDecoderModel
3
  import torch
4
 
5
 
@@ -10,37 +10,23 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
  class EndpointHandler:
11
  def __init__(self, path=""):
12
  # load the model
13
- self.processor = DonutProcessor.from_pretrained(path)
14
- self.model = VisionEncoderDecoderModel.from_pretrained(path)
15
  # move model to device
16
  self.model.to(device)
17
- self.decoder_input_ids = self.processor.tokenizer(
18
- "<s_cord-v2>", add_special_tokens=False, return_tensors="pt"
19
- ).input_ids
20
-
21
- def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
22
-
23
- inputs = data.pop("inputs", data)
24
-
25
-
26
- # preprocess the input
27
- pixel_values = self.processor(inputs, return_tensors="pt").pixel_values
28
-
29
- # forward pass
30
- outputs = self.model.generate(
31
- pixel_values.to(device),
32
- decoder_input_ids=self.decoder_input_ids.to(device),
33
- max_length=self.model.decoder.config.max_position_embeddings,
34
- early_stopping=True,
35
- pad_token_id=self.processor.tokenizer.pad_token_id,
36
- eos_token_id=self.processor.tokenizer.eos_token_id,
37
- use_cache=True,
38
- num_beams=1,
39
- bad_words_ids=[[self.processor.tokenizer.unk_token_id]],
40
- return_dict_in_generate=True,
41
- )
42
- # process output
43
- prediction = self.processor.batch_decode(outputs.sequences)[0]
44
- prediction = self.processor.token2json(prediction)
45
-
46
- return prediction
 
1
  from typing import Dict, List, Any
2
+ from transformers import AutoImageProcessor, Swin2SRForImageSuperResolution
3
  import torch
4
 
5
 
 
10
  class EndpointHandler:
11
  def __init__(self, path=""):
12
  # load the model
13
+ self.processor = AutoImageProcessor.from_pretrained("caidas/swin2SR-classical-sr-x2-64")
14
+ self.model = Swin2SRForImageSuperResolution.from_pretrained("caidas/swin2SR-classical-sr-x2-64")
15
  # move model to device
16
  self.model.to(device)
17
+
18
+ def __call__(self, image: Any) -> List[List[Dict[str, float]]]:
19
+
20
+ inputs = self.processor(image, return_tensors="pt")
21
+ outputs = self.model(**inputs)
22
+
23
+ output = outputs.reconstruction.data.squeeze().float().cpu().clamp_(0, 1).numpy()
24
+ output = np.moveaxis(output, source=0, destination=-1)
25
+ output = (output * 255.0).round().astype(np.uint8)
26
+
27
+ img = Image.fromarray(output)
28
+ buffered = BytesIO()
29
+ img.save(buffered, format="JPEG")
30
+ img_str = base64.b64encode(buffered.getvalue())
31
+
32
+ return img_str