refactor: add imports
Browse files- modeling_clip.py +3 -2
modeling_clip.py
CHANGED
@@ -5,8 +5,9 @@
|
|
5 |
# and adjusted for Jina CLIP
|
6 |
|
7 |
from functools import partial
|
8 |
-
from typing import Optional, Tuple, Union
|
9 |
|
|
|
10 |
import torch
|
11 |
import torch.nn.functional as f
|
12 |
import torch.utils.checkpoint
|
@@ -258,7 +259,7 @@ class JinaCLIPModel(JinaCLIPPreTrainedModel):
|
|
258 |
device: Optional[torch.device] = None,
|
259 |
normalize_embeddings: bool = False,
|
260 |
**tokenizer_kwargs,
|
261 |
-
) -> Union[
|
262 |
|
263 |
self.eval()
|
264 |
|
|
|
5 |
# and adjusted for Jina CLIP
|
6 |
|
7 |
from functools import partial
|
8 |
+
from typing import Optional, Tuple, Union, List
|
9 |
|
10 |
+
import numpy as np
|
11 |
import torch
|
12 |
import torch.nn.functional as f
|
13 |
import torch.utils.checkpoint
|
|
|
259 |
device: Optional[torch.device] = None,
|
260 |
normalize_embeddings: bool = False,
|
261 |
**tokenizer_kwargs,
|
262 |
+
) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]::
|
263 |
|
264 |
self.eval()
|
265 |
|