Cristiants commited on
Commit
9b57564
β€’
1 Parent(s): b9cb95c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -49
app.py CHANGED
@@ -1,26 +1,13 @@
1
- import sys
2
- # if 'google.colab' in sys.modules:
3
- # print('Running in Colab.')
4
- # !pip3 install transformers==4.15.0 timm==0.4.12 fairscale==0.4.4
5
- # !git clone https://github.com/salesforce/BLIP
6
- # %cd BLIP
7
- import gradio as gr
8
- import torch
9
  import requests
10
- from torchvision import transforms
11
  from PIL import Image
12
- import requests
13
  import torch
14
- from torchvision import transforms
15
- from torchvision.transforms.functional import InterpolationMode
16
-
17
 
18
- #@title
19
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
20
 
21
- model = torch.hub.load('pytorch/vision:v0.6.0', 'resnet18', pretrained=True).eval()
22
- response = requests.get("https://git.io/JJkYN")
23
- labels = response.text.split("\n")
24
 
25
  def predict(inp):
26
  inp = transforms.ToTensor()(inp).unsqueeze(0)
@@ -34,38 +21,11 @@ demo = gr.Interface(fn=predict,
34
  outputs=gr.outputs.Label(num_top_classes=3)
35
  )
36
 
37
- def load_demo_image(image_size,device,imageurl):
38
- img_url = imageurl
39
- raw_image = Image.open(requests.get(img_url, stream=True).raw).convert('RGB')
40
-
41
- w,h = raw_image.size
42
- display(raw_image.resize((w//5,h//5)))
43
-
44
- transform = transforms.Compose([
45
- transforms.Resize((image_size,image_size),interpolation=InterpolationMode.BICUBIC),
46
- transforms.ToTensor(),
47
- transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
48
- ])
49
- image = transform(raw_image).unsqueeze(0).to(device)
50
- return image
51
- from models.blip import blip_decoder
52
-
53
  def predict(imageurl):
54
- image_size = 384
55
- image = load_demo_image(image_size=image_size, device=device,imageurl=imageurl)
56
-
57
- model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_capfilt_large.pth'
58
-
59
- model = blip_decoder(pretrained=model_url, image_size=image_size, vit='base')
60
- model.eval()
61
- model = model.to(device)
62
-
63
- with torch.no_grad():
64
- # beam search
65
- caption = model.generate(image, sample=False, num_beams=3, max_length=20, min_length=5)
66
- # nucleus sampling
67
- # caption = model.generate(image, sample=True, top_p=0.9, max_length=20, min_length=5)
68
- return('caption: '+caption[0])
69
 
70
  demo = gr.Interface(fn=predict,
71
  inputs="text",
 
 
 
 
 
 
 
 
 
1
  import requests
 
2
  from PIL import Image
3
+ from transformers import AutoProcessor, Blip2ForConditionalGeneration
4
  import torch
 
 
 
5
 
6
+ processor = AutoProcessor.from_pretrained("Salesforce/blip2-opt-2.7b")
7
+ model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b", torch_dtype=torch.float16)
8
 
9
+ device = "cuda" if torch.cuda.is_available() else "cpu"
10
+ model.to(device)
 
11
 
12
  def predict(inp):
13
  inp = transforms.ToTensor()(inp).unsqueeze(0)
 
21
  outputs=gr.outputs.Label(num_top_classes=3)
22
  )
23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  def predict(imageurl):
25
+ inputs = processor(image, return_tensors="pt")
26
+ generated_ids = model.generate(**inputs, max_new_tokens=20)
27
+ generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
28
+ return('caption: '+generated_text)
 
 
 
 
 
 
 
 
 
 
 
29
 
30
  demo = gr.Interface(fn=predict,
31
  inputs="text",