[yesha vyas] commited on
Commit
f1f8d48
β€’
1 Parent(s): 444fa81

App modified

Browse files
Files changed (1) hide show
  1. app.py +96 -25
app.py CHANGED
@@ -1,27 +1,98 @@
1
- from transformers import AutoTokenizer, AutoModelForSequenceClassification
 
2
  import gradio as gr
 
 
 
 
 
3
 
4
- # Load the fine-tuned model and tokenizer from Hugging Face
5
- model_name = "yeshavyas27/moondream-ft"
6
- tokenizer = AutoTokenizer.from_pretrained(model_name)
7
- model = AutoModelForSequenceClassification.from_pretrained(model_name, trust_remote_code=True)
8
-
9
- # Define the inference function
10
- def predict(input_text):
11
- inputs = tokenizer(input_text, return_tensors="pt", truncation=True, padding=True)
12
- outputs = model(**inputs)
13
- predictions = outputs.logits.argmax(dim=-1)
14
- return predictions.item()
15
-
16
- # Set up the Gradio interface
17
- iface = gr.Interface(
18
- fn=predict,
19
- inputs=gr.inputs.Textbox(lines=2, placeholder="Enter text here..."),
20
- outputs="text",
21
- title="Moondream Model Inference",
22
- description="Enter some text to get predictions from the fine-tuned Moondream model."
23
- )
24
-
25
- # Run the app
26
- if __name__ == "__main__":
27
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import torch
3
  import gradio as gr
4
+ from threading import Thread
5
+ from transformers import TextIteratorStreamer, AutoTokenizer, AutoModelForCausalLM
6
+ from PIL import ImageDraw
7
+ import re
8
+ from torchvision.transforms.v2 import Resize
9
 
10
+ parser = argparse.ArgumentParser()
11
+ parser.add_argument("--cpu", action="store_true")
12
+ args = parser.parse_args()
13
+
14
+ DEVICE = "cuda"
15
+ DTYPE = torch.float32 if DEVICE == "cpu" else torch.float16 # CPU doesn't support float16
16
+ LATEST_REVISION = "2024-05-20"
17
+
18
+ model_id = "vikhyatk/moondream2"
19
+ tokenizer = AutoTokenizer.from_pretrained(model_id, revision=LATEST_REVISION)
20
+ moondream = AutoModelForCausalLM.from_pretrained(
21
+ model_id, trust_remote_code=True, revision=LATEST_REVISION, torch_dtype=DTYPE
22
+ ).to(device=DEVICE)
23
+
24
+ moondream.eval()
25
+
26
+
27
+ def answer_question(img, prompt):
28
+ image_embeds = moondream.encode_image(img)
29
+ streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True)
30
+ thread = Thread(
31
+ target=moondream.answer_question,
32
+ kwargs={
33
+ "image_embeds": image_embeds,
34
+ "question": prompt,
35
+ "tokenizer": tokenizer,
36
+ "streamer": streamer,
37
+ },
38
+ )
39
+ thread.start()
40
+
41
+ buffer = ""
42
+ for new_text in streamer:
43
+ buffer += new_text
44
+ yield buffer
45
+
46
+
47
+ def extract_floats(text):
48
+ # Regular expression to match an array of four floating point numbers
49
+ pattern = r"\[\s*(-?\d+\.\d+)\s*,\s*(-?\d+\.\d+)\s*,\s*(-?\d+\.\d+)\s*,\s*(-?\d+\.\d+)\s*\]"
50
+ match = re.search(pattern, text)
51
+ if match:
52
+ # Extract the numbers and convert them to floats
53
+ return [float(num) for num in match.groups()]
54
+ return None # Return None if no match is found
55
+
56
+
57
+ def extract_bbox(text):
58
+ bbox = None
59
+ if extract_floats(text) is not None:
60
+ x1, y1, x2, y2 = extract_floats(text)
61
+ bbox = (x1, y1, x2, y2)
62
+ return bbox
63
+
64
+
65
+ def process_answer(img, answer):
66
+ if extract_bbox(answer) is not None:
67
+ x1, y1, x2, y2 = extract_bbox(answer)
68
+ draw_image = Resize(768)(img)
69
+ width, height = draw_image.size
70
+ x1, x2 = int(x1 * width), int(x2 * width)
71
+ y1, y2 = int(y1 * height), int(y2 * height)
72
+ bbox = (x1, y1, x2, y2)
73
+ ImageDraw.Draw(draw_image).rectangle(bbox, outline="red", width=3)
74
+ return gr.update(visible=True, value=draw_image)
75
+
76
+ return gr.update(visible=False, value=None)
77
+
78
+
79
+ with gr.Blocks() as demo:
80
+ gr.Markdown(
81
+ """
82
+ # πŸŒ” moondream
83
+ """
84
+ )
85
+ with gr.Row():
86
+ prompt = gr.Textbox(label="Input Prompt", placeholder="Type here...", scale=4)
87
+ submit = gr.Button("Submit")
88
+ with gr.Row():
89
+ img = gr.Image(type="pil", label="Upload an Image")
90
+ with gr.Column():
91
+ output = gr.Markdown(label="Response")
92
+ ann = gr.Image(visible=False, label="Annotated Image")
93
+
94
+ submit.click(answer_question, [img, prompt], output)
95
+ prompt.submit(answer_question, [img, prompt], output)
96
+ output.change(process_answer, [img, output], ann, show_progress=False)
97
+
98
+ demo.queue().launch(debug=True)