SaiBrahmam
commited on
Commit
•
a2a2c87
1
Parent(s):
3cd9f43
Upload Untitled16.ipynb
Browse files- Untitled16.ipynb +1 -0
Untitled16.ipynb
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[],"authorship_tag":"ABX9TyMyRM8leYh4swXBQGUCusd+"},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"code","execution_count":null,"metadata":{"id":"OVqlrt1P7QNE"},"outputs":[],"source":["# install requirements\n","import sys\n","if 'google.colab' in sys.modules:\n"," print('Running in Colab.')\n"," !pip install transformers timm fairscale\n"," !git clone https://github.com/salesforce/BLIP\n"," %cd BLIP\n","\n","from PIL import Image\n","import requests\n","import torch\n","from torchvision import transforms\n","from torchvision.transforms.functional import InterpolationMode\n","\n","device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n","\n","def load_demo_image(image_size,device):\n"," img_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/demo.jpg' \n"," raw_image = Image.open(requests.get(img_url, stream=True).raw).convert('RGB') \n","\n"," w,h = raw_image.size\n"," display(raw_image.resize((w//5,h//5)))\n"," \n"," transform = transforms.Compose([\n"," transforms.Resize((image_size,image_size),interpolation=InterpolationMode.BICUBIC),\n"," transforms.ToTensor(),\n"," transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))\n"," ]) \n"," image = transform(raw_image).unsqueeze(0).to(device) \n"," return image\n","\n","from models.blip import blip_decoder\n","\n","image_size = 384\n","image = load_demo_image(image_size=image_size, device=device)\n","\n","model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_capfilt_large.pth'\n"," \n","model = blip_decoder(pretrained=model_url, image_size=image_size, vit='base')\n","model.eval()\n","model = model.to(device)\n","\n","with torch.no_grad():\n","\n"," # beam search\n"," #captions = model.generate(image, sample=False, num_beams=3, max_length=20, min_length=5, num_return_sequences=3) \n"," # nucleus sampling\n"," num_captions = 3\n"," captions = []\n"," for i in range(num_captions):\n"," caption = model.generate(image, sample=True, top_p=0.9, max_length=20, min_length=5)\n"," captions.append(caption[0])\n"," for i, caption in enumerate(captions):\n"," print(f'caption {i+1}: {caption}') "]}]}
|