martinkropf commited on
Commit
aad5fe7
1 Parent(s): c99985e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +100 -28
app.py CHANGED
@@ -1,30 +1,102 @@
 
 
 
 
 
 
 
 
 
 
1
 
 
 
 
2
 
3
- from turtle import title
4
- import gradio as gr
5
- from transformers import pipeline
6
- import numpy as np
7
- from PIL import Image
8
-
9
-
10
- pipe = pipeline("zero-shot-image-classification", model="mkaichristensen/echo-clip")
11
- images="dog.jpg"
12
-
13
- def shot(image, labels_text):
14
- PIL_image = Image.fromarray(np.uint8(image)).convert('RGB')
15
- labels = labels_text.split(",")
16
- res = pipe(images=PIL_image,
17
- candidate_labels=labels,
18
- hypothesis_template= "This is a photo of a {}")
19
- return {dic["label"]: dic["score"] for dic in res}
20
-
21
- iface = gr.Interface(shot,
22
- ["image", "text"],
23
- "label",
24
- examples=[["dog.jpg", "dog,cat,bird"],
25
- ["germany.jpg", "germany,belgium,colombia"],
26
- ["colombia.jpg", "germany,belgium,colombia"]],
27
- description="Add a picture and a list of labels separated by commas",
28
- title="Zero-shot Image Classification")
29
-
30
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from open_clip import tokenize, create_model_and_transforms
2
+ import torchvision.transforms as T
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from utils import (
6
+ zero_shot_prompts,
7
+ compute_binary_metric,
8
+ compute_regression_metric,
9
+ read_avi,
10
+ )
11
 
12
+ # You'll need to log in to the HuggingFace hub CLI to download the models
13
+ # You can do this with the terminal command "huggingface-cli login"
14
+ # You'll be asked to paste your HuggingFace API token, which you can find at https://huggingface.co/settings/token
15
 
16
+ # Use EchoCLIP for zero-shot tasks like ejection fraction prediction
17
+ # or pacemaker detection. It has a short context window because it
18
+ # uses the CLIP BPE tokenizer, so it can't process an entire report at once.
19
+ echo_clip, _, preprocess_val = create_model_and_transforms(
20
+ "hf-hub:mkaichristensen/echo-clip", precision="bf16"
21
+ )
22
+
23
+ # We'll use random noise in the shape of a 10-frame video in this example, but you can use any image
24
+ # We'll load a sample echo video and preprocess its frames.
25
+ test_video = read_avi(
26
+ "example_video.avi",
27
+ (224, 224),
28
+ )
29
+ test_video = torch.stack(
30
+ [preprocess_val(T.ToPILImage()(frame)) for frame in test_video], dim=0
31
+ )
32
+ test_video = test_video.cpu()
33
+ test_video = test_video.to(torch.bfloat16)
34
+
35
+ # Be sure to normalize the CLIP embedding after calculating it to make
36
+ # cosine similarity between embeddings easier to calculate.
37
+ test_video_embedding = F.normalize(echo_clip.encode_image(test_video), dim=-1)
38
+
39
+ # Add in a batch dimension because the zero-shot functions expect one
40
+ test_video_embedding = test_video_embedding.unsqueeze(0)
41
+
42
+
43
+ # To perform zero-shot prediction on our "echo" image, we'll need
44
+ # prompts that describe the task we want to perform. For example,
45
+ # to zero-shot detect pacemakers, we'll use the following prompts
46
+ pacemaker_prompts = zero_shot_prompts["pacemaker"]
47
+ print(pacemaker_prompts)
48
+
49
+ # We'll use the CLIP BPE tokenizer to tokenize the prompts
50
+ pacemaker_prompts = tokenize(pacemaker_prompts).cpu()
51
+ print(pacemaker_prompts)
52
+
53
+ # Now we can encode the prompts into embeddings
54
+ pacemaker_prompt_embeddings = F.normalize(
55
+ echo_clip.encode_text(pacemaker_prompts), dim=-1
56
+ )
57
+ print(pacemaker_prompt_embeddings.shape)
58
+
59
+ # Now we can compute the similarity between the video and the prompts
60
+ # to get a prediction for whether the video contains a pacemaker. It's
61
+ # important to note that this prediction is not calibrated, and can
62
+ # range from -1 to 1.
63
+ pacemaker_predictions = compute_binary_metric(
64
+ test_video_embedding, pacemaker_prompt_embeddings
65
+ )
66
+
67
+ # If we use a pacemaker detection threshold calibrated using its F1 score on
68
+ # our test set, we can get a proper true/false prediction prediction.
69
+ f1_calibrated_threshold = 0.298
70
+ print(f"Pacemaker detected: {pacemaker_predictions.item() > f1_calibrated_threshold}")
71
+
72
+
73
+ # We can also do the same thing for predicting continuous values,
74
+ # like ejection fraction. We'll use the following prompts for
75
+ # zero-shot ejection fraction prediction:
76
+ ejection_fraction_prompts = zero_shot_prompts["ejection_fraction"]
77
+ print(ejection_fraction_prompts)
78
+
79
+ # However, since ejection fraction can range between 0 and 100,
80
+ # we'll need to make 100 versions of each prompt.
81
+ prompts = []
82
+ prompt_values = []
83
+
84
+ for prompt in ejection_fraction_prompts:
85
+ for i in range(101):
86
+ prompts.append(prompt.replace("<#>", str(i)))
87
+ prompt_values.append(i)
88
+
89
+ ejection_fraction_prompts = prompts
90
+
91
+ # We'll once again tokenize and embed the prompts
92
+ ejection_fraction_prompts = tokenize(ejection_fraction_prompts).cpu()
93
+ ejection_fraction_embeddings = F.normalize(
94
+ echo_clip.encode_text(ejection_fraction_prompts), dim=-1
95
+ )
96
+
97
+ # And we'll compute the similarity between the image and the prompts
98
+ # to get a prediction for the ejection fraction.
99
+ ejection_fraction_predictions = compute_regression_metric(
100
+ test_video_embedding, ejection_fraction_embeddings, prompt_values
101
+ )
102
+ print(f"Predicted ejection fraction is {ejection_fraction_predictions.item():.1f}%")