Update README.md
Browse files
README.md
CHANGED
@@ -39,7 +39,6 @@ import requests
|
|
39 |
from PIL import Image
|
40 |
import torch
|
41 |
|
42 |
-
device = torch.device('cuda')
|
43 |
image_urls = [
|
44 |
"https://farm4.staticflickr.com/3395/3428278415_81c3e27f15_z.jpg",
|
45 |
"http://images.cocodataset.org/val2017/000000039769.jpg"]
|
@@ -49,13 +48,13 @@ texts = [
|
|
49 |
images = [Image.open(requests.get(url, stream=True).raw) for url in image_urls]
|
50 |
|
51 |
processor = BridgeTowerProcessor.from_pretrained("BridgeTower/bridgetower-large-itm-mlm")
|
52 |
-
model = BridgeTowerForContrastiveLearning.from_pretrained("BridgeTower/bridgetower-large-itm-mlm-itc")
|
53 |
|
54 |
-
inputs = processor(images, texts, padding=True, return_tensors="pt")
|
55 |
outputs = model(**inputs, labels=torch.ones(2,device=device))
|
56 |
|
57 |
-
inputs = processor(images, texts[::-1], padding=True, return_tensors="pt")
|
58 |
-
outputs_swapped = model(**inputs, labels=torch.ones(2
|
59 |
|
60 |
print('Loss', outputs.loss.item())
|
61 |
print('Loss with swapped images', outputs_swapped.loss.item())
|
|
|
39 |
from PIL import Image
|
40 |
import torch
|
41 |
|
|
|
42 |
image_urls = [
|
43 |
"https://farm4.staticflickr.com/3395/3428278415_81c3e27f15_z.jpg",
|
44 |
"http://images.cocodataset.org/val2017/000000039769.jpg"]
|
|
|
48 |
images = [Image.open(requests.get(url, stream=True).raw) for url in image_urls]
|
49 |
|
50 |
processor = BridgeTowerProcessor.from_pretrained("BridgeTower/bridgetower-large-itm-mlm")
|
51 |
+
model = BridgeTowerForContrastiveLearning.from_pretrained("BridgeTower/bridgetower-large-itm-mlm-itc")
|
52 |
|
53 |
+
inputs = processor(images, texts, padding=True, return_tensors="pt")
|
54 |
outputs = model(**inputs, labels=torch.ones(2,device=device))
|
55 |
|
56 |
+
inputs = processor(images, texts[::-1], padding=True, return_tensors="pt")
|
57 |
+
outputs_swapped = model(**inputs, labels=torch.ones(2))
|
58 |
|
59 |
print('Loss', outputs.loss.item())
|
60 |
print('Loss with swapped images', outputs_swapped.loss.item())
|