chats-bug commited on
Commit
0d08077
1 Parent(s): f4c7af7

Initial model with number of captions control

Browse files
Files changed (3) hide show
  1. app.py +54 -0
  2. model.py +47 -0
  3. requirements.txt +5 -0
app.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from PIL import Image
3
+
4
+ from model import GitBaseCocoModel
5
+
6
+
7
+ def generate_captions(
8
+ image: Image,
9
+ max_len: int = 50,
10
+ num_captions: int = 1,
11
+ ):
12
+ """
13
+ Generates captions for the given image.
14
+
15
+ -----
16
+ Parameters:
17
+ image: PIL.Image
18
+ The image to generate captions for.
19
+ max_len: int
20
+ The maximum length of the caption.
21
+ num_captions: int
22
+ The number of captions to generate.
23
+
24
+ -----
25
+ Returns:
26
+ list[str]
27
+ """
28
+
29
+ device = "cuda" if gradio.use_gpu else "cpu"
30
+ checkpoint = "microsoft/git-base-coco"
31
+
32
+ model = GitBaseCocoModel(device, checkpoint)
33
+
34
+ caption = model.generate(image, max_len, num_captions)
35
+ return caption
36
+
37
+
38
+ inputs = [
39
+ gr.inputs.Image(type="pil", label="Image"),
40
+ gr.inputs.Number(default=50, label="Maximum Caption Length"),
41
+ gr.inputs.Number(default=1, label="Number of Captions to Generate"),
42
+ ]
43
+ outputs = gr.outputs.Textbox()
44
+
45
+ title = "Git-Base-COCO Image Captioning"
46
+ description = "A model for generating captions for images."
47
+
48
+ gr.Interface(
49
+ fn=generate_captions,
50
+ inputs=inputs,
51
+ outputs=outputs,
52
+ title=title,
53
+ description=description,
54
+ ).launch()
model.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoProcessor, AutoModelForCausalLM
2
+
3
+ class GitBaseCocoModel:
4
+ def __init__(self, device, checkpoint="microsoft/git-base-coco"):
5
+ """
6
+ A wrapper class for the Git-Base-COCO model. It is a pretrained model for image captioning.
7
+
8
+ -----
9
+ Parameters:
10
+ device: torch.device
11
+ The device to run the model on.
12
+ checkpoint: str
13
+ The checkpoint to load the model from.
14
+
15
+ -----
16
+ Returns:
17
+ None
18
+ """
19
+ self.checkpoint = checkpoint
20
+ self.device = device
21
+ self.processor = AutoProcessor.from_pretrained(self.checkpoint)
22
+ self.model = AutoModelForCausalLM.from_pretrained(self.checkpoint).to(self.device)
23
+
24
+ def generate(self, image, max_len=50, num_captions=1):
25
+ """
26
+ Generates captions for the given image.
27
+
28
+ -----
29
+ Parameters:
30
+ image: PIL.Image
31
+ The image to generate captions for.
32
+ max_len: int
33
+ The maximum length of the caption.
34
+ num_captions: int
35
+ The number of captions to generate.
36
+ """
37
+ pixel_values = self.processor(
38
+ images=image, return_tensors="pt"
39
+ ).pixel_values.to(self.device)
40
+ generated_ids = self.model.generate(
41
+ pixel_values=pixel_values,
42
+ max_length=max_len,
43
+ num_beams=num_captions,
44
+ num_return_sequences=num_captions,
45
+ )
46
+ return self.processor.batch_decode(generated_ids, skip_special_tokens=True)
47
+
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ torch
2
+ open_clip_torch
3
+ accelerate
4
+ bitsandbytes
5
+ git+https://github.com/huggingface/transformers.git@main