Update README.md
Browse files
README.md
CHANGED
@@ -249,209 +249,27 @@ Based on AltCLIP, we have also developed the AltDiffusion model, visualized as f
|
|
249 |
![](https://raw.githubusercontent.com/920232796/test/master/image7.png)
|
250 |
|
251 |
## 模型推理 Inference
|
252 |
-
|
253 |
```python
|
254 |
-
import torch
|
255 |
from PIL import Image
|
256 |
-
|
257 |
-
|
258 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
259 |
-
## 一行代码直接自动下载权重到'./checkpoints/clip-xlmr-large',并自动加载CLIP模型权重
|
260 |
-
## modelhub地址: Modelhub(https://model.baai.ac.cn/models)
|
261 |
-
loader = AutoLoader(
|
262 |
-
task_name="txt_img_matching",
|
263 |
-
model_dir="./checkpoints",
|
264 |
-
model_name="AltCLIP-XLMR-L"
|
265 |
-
)
|
266 |
-
## 获取加载好的模型
|
267 |
-
model = loader.get_model()
|
268 |
-
## 获取tokenizer
|
269 |
-
tokenizer = loader.get_tokenizer()
|
270 |
-
## 获取transform用来处理图像
|
271 |
-
transform = loader.get_transform()
|
272 |
-
|
273 |
-
model.eval()
|
274 |
-
model.to(device)
|
275 |
-
|
276 |
-
## 推理过程,图像与文本匹配
|
277 |
-
image = Image.open("./dog.jpeg")
|
278 |
-
image = transform(image)
|
279 |
-
image = torch.tensor(image["pixel_values"]).to(device)
|
280 |
-
text = tokenizer(["a rat", "a dog", "a cat"])["input_ids"]
|
281 |
-
|
282 |
-
text = torch.tensor(text).to(device)
|
283 |
-
|
284 |
-
with torch.no_grad():
|
285 |
-
image_features = model.get_image_features(image)
|
286 |
-
text_features = model.get_text_features(text)
|
287 |
-
text_probs = (image_features @ text_features.T).softmax(dim=-1)
|
288 |
-
|
289 |
-
print(text_probs.cpu().numpy()[0].tolist())
|
290 |
-
```
|
291 |
-
|
292 |
-
## CLIP微调 Finetuning
|
293 |
-
|
294 |
-
微调采用cifar10数据集,并使用FlagAI的Trainer快速开始训练过程。
|
295 |
-
|
296 |
-
Fine-tuning was done using the cifar10 dataset and using FlagAI's Trainer to quickly start the training process.
|
297 |
-
|
298 |
-
```python
|
299 |
-
# Copyright © 2022 BAAI. All rights reserved.
|
300 |
-
#
|
301 |
-
# Licensed under the Apache License, Version 2.0 (the "License")
|
302 |
-
import torch
|
303 |
-
from flagai.auto_model.auto_loader import AutoLoader
|
304 |
-
import os
|
305 |
-
from flagai.trainer import Trainer
|
306 |
-
from torchvision.datasets import (
|
307 |
-
CIFAR10
|
308 |
-
)
|
309 |
-
|
310 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
311 |
-
dataset_root = "./clip_benchmark_datasets"
|
312 |
-
dataset_name = "cifar10"
|
313 |
-
|
314 |
-
batch_size = 4
|
315 |
-
classes = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
|
316 |
-
|
317 |
-
auto_loader = AutoLoader(
|
318 |
-
task_name="txt_img_matching",
|
319 |
-
model_dir="./checkpoints/",
|
320 |
-
model_name="AltCLIP-XLMR-L" # Load the checkpoints from Modelhub(model.baai.ac.cn/models)
|
321 |
-
)
|
322 |
-
|
323 |
-
model = auto_loader.get_model()
|
324 |
-
model.to(device)
|
325 |
-
model.eval()
|
326 |
-
tokenizer = auto_loader.get_tokenizer()
|
327 |
-
transform = auto_loader.get_transform()
|
328 |
-
|
329 |
-
trainer = Trainer(env_type="pytorch",
|
330 |
-
pytorch_device=device,
|
331 |
-
experiment_name="clip_finetuning",
|
332 |
-
batch_size=4,
|
333 |
-
lr=1e-4,
|
334 |
-
epochs=10,
|
335 |
-
log_interval=10)
|
336 |
-
|
337 |
-
dataset = CIFAR10(root=os.path.join(dataset_root, dataset_name),
|
338 |
-
transform=transform,
|
339 |
-
download=True)
|
340 |
-
|
341 |
-
def cifar10_collate_fn(batch):
|
342 |
-
# image shape is (batch, 3, 224, 224)
|
343 |
-
images = torch.tensor([b[0]["pixel_values"][0] for b in batch])
|
344 |
-
# text_id shape is (batch, n)
|
345 |
-
input_ids = torch.tensor([tokenizer(f"a photo of a {b[1]}",padding=True,truncation=True,max_length=77)["input_ids"] for b in batch])
|
346 |
-
|
347 |
-
return {
|
348 |
-
"pixel_values": images,
|
349 |
-
"input_ids": input_ids
|
350 |
-
}
|
351 |
-
|
352 |
-
if __name__ == "__main__":
|
353 |
-
trainer.train(model=model, train_dataset=dataset, collate_fn=cifar10_collate_fn)
|
354 |
-
```
|
355 |
-
|
356 |
|
|
|
|
|
|
|
357 |
|
358 |
-
|
|
|
|
|
359 |
|
360 |
-
|
|
|
361 |
|
362 |
-
|
363 |
-
|
364 |
-
We provide validation scripts that can be run directly on the cifar10 dataset.
|
365 |
-
|
366 |
-
```python
|
367 |
-
# Copyright © 2022 BAAI. All rights reserved.
|
368 |
-
#
|
369 |
-
# Licensed under the Apache License, Version 2.0 (the "License")
|
370 |
-
import torch
|
371 |
-
from flagai.auto_model.auto_loader import AutoLoader
|
372 |
-
from metrics import zeroshot_classification
|
373 |
-
import json
|
374 |
-
import os
|
375 |
-
from torchvision.datasets import CIFAR10
|
376 |
-
|
377 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
378 |
-
maxlen = 256
|
379 |
-
|
380 |
-
dataset_root = "./clip_benchmark_datasets"
|
381 |
-
dataset_name = "cifar10"
|
382 |
-
|
383 |
-
auto_loader = AutoLoader(
|
384 |
-
task_name="txt_img_matching",
|
385 |
-
model_dir="./checkpoints/",
|
386 |
-
model_name="AltCLIP-XLMR-L"
|
387 |
-
)
|
388 |
-
|
389 |
-
model = auto_loader.get_model()
|
390 |
-
model.to(device)
|
391 |
-
model.eval()
|
392 |
-
tokenizer = auto_loader.get_tokenizer()
|
393 |
-
transform = auto_loader.get_transform()
|
394 |
-
|
395 |
-
dataset = CIFAR10(root=os.path.join(dataset_root, dataset_name),
|
396 |
-
transform=transform,
|
397 |
-
download=True)
|
398 |
-
batch_size = 128
|
399 |
-
num_workers = 4
|
400 |
-
|
401 |
-
template = {"cifar10": [
|
402 |
-
"a photo of a {c}.",
|
403 |
-
"a blurry photo of a {c}.",
|
404 |
-
"a black and white photo of a {c}.",
|
405 |
-
"a low contrast photo of a {c}.",
|
406 |
-
"a high contrast photo of a {c}.",
|
407 |
-
"a bad photo of a {c}.",
|
408 |
-
"a good photo of a {c}.",
|
409 |
-
"a photo of a small {c}.",
|
410 |
-
"a photo of a big {c}.",
|
411 |
-
"a photo of the {c}.",
|
412 |
-
"a blurry photo of the {c}.",
|
413 |
-
"a black and white photo of the {c}.",
|
414 |
-
"a low contrast photo of the {c}.",
|
415 |
-
"a high contrast photo of the {c}.",
|
416 |
-
"a bad photo of the {c}.",
|
417 |
-
"a good photo of the {c}.",
|
418 |
-
"a photo of the small {c}.",
|
419 |
-
"a photo of the big {c}."
|
420 |
-
],
|
421 |
-
}
|
422 |
-
def evaluate():
|
423 |
-
if dataset:
|
424 |
-
dataloader = torch.utils.data.DataLoader(
|
425 |
-
dataset,
|
426 |
-
batch_size=batch_size,
|
427 |
-
shuffle=False,
|
428 |
-
num_workers=num_workers,
|
429 |
-
)
|
430 |
-
classnames = dataset.classes if hasattr(dataset, "classes") else None
|
431 |
-
|
432 |
-
zeroshot_templates = template["cifar10"]
|
433 |
-
metrics = zeroshot_classification.evaluate(
|
434 |
-
model,
|
435 |
-
dataloader,
|
436 |
-
tokenizer,
|
437 |
-
classnames,
|
438 |
-
zeroshot_templates,
|
439 |
-
device=device,
|
440 |
-
amp=True,
|
441 |
-
)
|
442 |
-
|
443 |
-
dump = {
|
444 |
-
"dataset": dataset_name,
|
445 |
-
"metrics": metrics
|
446 |
-
}
|
447 |
-
|
448 |
-
print(dump)
|
449 |
-
with open("./result.txt", "w") as f:
|
450 |
-
json.dump(dump, f)
|
451 |
-
return metrics
|
452 |
-
|
453 |
-
if __name__ == "__main__":
|
454 |
-
evaluate()
|
455 |
|
|
|
|
|
|
|
456 |
```
|
457 |
|
|
|
|
249 |
![](https://raw.githubusercontent.com/920232796/test/master/image7.png)
|
250 |
|
251 |
## 模型推理 Inference
|
252 |
+
Please download the code from [FlagAI AltCLIP](https://github.com/FlagAI-Open/FlagAI/tree/master/examples/AltCLIP)
|
253 |
```python
|
|
|
254 |
from PIL import Image
|
255 |
+
import requests
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
256 |
|
257 |
+
# transformers version >= 4.21.0
|
258 |
+
from modeling_altclip import AltCLIP
|
259 |
+
from processing_altclip import AltCLIPProcessor
|
260 |
|
261 |
+
# now our repo's in private, so we need `use_auth_token=True`
|
262 |
+
model = AltCLIP.from_pretrained("BAAI/AltCLIP")
|
263 |
+
processor = AltCLIPProcessor.from_pretrained("BAAI/AltCLIP")
|
264 |
|
265 |
+
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
266 |
+
image = Image.open(requests.get(url, stream=True).raw)
|
267 |
|
268 |
+
inputs = processor(text=["a photo of a cat", "a photo of a dog"], images=image, return_tensors="pt", padding=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
269 |
|
270 |
+
outputs = model(**inputs)
|
271 |
+
logits_per_image = outputs.logits_per_image # this is the image-text similarity score
|
272 |
+
probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities
|
273 |
```
|
274 |
|
275 |
+
|