Abubakar Abid commited on
Commit
b9e6b57
·
1 Parent(s): a10ea79
Files changed (7) hide show
  1. .gitattributes +0 -16
  2. .gitignore +1 -0
  3. app.py +90 -0
  4. img1.jpg +0 -0
  5. img2.jpg +0 -0
  6. requirements.txt +7 -0
  7. thumbnail.png +0 -0
.gitattributes DELETED
@@ -1,16 +0,0 @@
1
- *.bin.* filter=lfs diff=lfs merge=lfs -text
2
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.h5 filter=lfs diff=lfs merge=lfs -text
5
- *.tflite filter=lfs diff=lfs merge=lfs -text
6
- *.tar.gz filter=lfs diff=lfs merge=lfs -text
7
- *.ot filter=lfs diff=lfs merge=lfs -text
8
- *.onnx filter=lfs diff=lfs merge=lfs -text
9
- *.arrow filter=lfs diff=lfs merge=lfs -text
10
- *.ftz filter=lfs diff=lfs merge=lfs -text
11
- *.joblib filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.pb filter=lfs diff=lfs merge=lfs -text
15
- *.pt filter=lfs diff=lfs merge=lfs -text
16
- *.pth filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ weights/*
app.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, os.path
2
+ from os.path import splitext
3
+ import numpy as np
4
+ import sys
5
+ import matplotlib.pyplot as plt
6
+ import torch
7
+ import torchvision
8
+ import wget
9
+
10
+ destination_folder = "output"
11
+ destination_for_weights = "weights"
12
+
13
+ if os.path.exists(destination_for_weights):
14
+ print("The weights are at", destination_for_weights)
15
+ else:
16
+ print("Creating folder at ", destination_for_weights, " to store weights")
17
+ os.mkdir(destination_for_weights)
18
+
19
+ segmentationWeightsURL = 'https://github.com/douyang/EchoNetDynamic/releases/download/v1.0.0/deeplabv3_resnet50_random.pt'
20
+
21
+ if not os.path.exists(os.path.join(destination_for_weights, os.path.basename(segmentationWeightsURL))):
22
+ print("Downloading Segmentation Weights, ", segmentationWeightsURL," to ",os.path.join(destination_for_weights, os.path.basename(segmentationWeightsURL)))
23
+ filename = wget.download(segmentationWeightsURL, out = destination_for_weights)
24
+ else:
25
+ print("Segmentation Weights already present")
26
+
27
+ torch.cuda.empty_cache()
28
+
29
+ def collate_fn(x):
30
+ x, f = zip(*x)
31
+ i = list(map(lambda t: t.shape[1], x))
32
+ x = torch.as_tensor(np.swapaxes(np.concatenate(x, 1), 0, 1))
33
+ return x, f, i
34
+
35
+ model = torchvision.models.segmentation.deeplabv3_resnet50(pretrained=False, aux_loss=False)
36
+ model.classifier[-1] = torch.nn.Conv2d(model.classifier[-1].in_channels, 1, kernel_size=model.classifier[-1].kernel_size)
37
+
38
+ print("loading weights from ", os.path.join(destination_for_weights, "deeplabv3_resnet50_random"))
39
+
40
+ if torch.cuda.is_available():
41
+ print("cuda is available, original weights")
42
+ device = torch.device("cuda")
43
+ model = torch.nn.DataParallel(model)
44
+ model.to(device)
45
+ checkpoint = torch.load(os.path.join(destination_for_weights, os.path.basename(segmentationWeightsURL)))
46
+ model.load_state_dict(checkpoint['state_dict'])
47
+ else:
48
+ print("cuda is not available, cpu weights")
49
+ device = torch.device("cpu")
50
+ checkpoint = torch.load(os.path.join(destination_for_weights, os.path.basename(segmentationWeightsURL)), map_location = "cpu")
51
+ state_dict_cpu = {k[7:]: v for (k, v) in checkpoint['state_dict'].items()}
52
+ model.load_state_dict(state_dict_cpu)
53
+
54
+ model.eval()
55
+
56
+ def segment(inp):
57
+ x = inp.transpose([2, 0, 1]) # channels-first
58
+ x = np.expand_dims(x, axis=0) # adding a batch dimension
59
+
60
+ mean = x.mean(axis=(0, 2, 3))
61
+ std = x.std(axis=(0, 2, 3))
62
+ x = x - mean.reshape(1, 3, 1, 1)
63
+ x = x / std.reshape(1, 3, 1, 1)
64
+
65
+ with torch.no_grad():
66
+ x = torch.from_numpy(x).type('torch.FloatTensor').to(device)
67
+ output = model(x)
68
+
69
+ y = output['out'].numpy()
70
+ y = y.squeeze()
71
+
72
+ out = y>0
73
+
74
+ mask = inp.copy()
75
+ mask[out] = np.array([0, 0, 255])
76
+
77
+ return mask
78
+
79
+ import gradio as gr
80
+
81
+ i = gr.inputs.Image(shape=(112, 112))
82
+ o = gr.outputs.Image()
83
+
84
+ examples = [["img1.jpg"], ["img2.jpg"]]
85
+ title = "Left Ventricle Segmentation"
86
+ description = "This semantic segmentation model identifies the left ventricle in echocardiogram videos. Accurate evaluation of the motion and size of the left ventricle is crucial for the assessment of cardiac function and ejection fraction. In this interface, the user inputs apical-4-chamber images from echocardiography videos and the model will output a prediction of the localization of the left ventricle in blue. This model was trained on the publicly released EchoNet-Dynamic dataset of 10k echocardiogram videos with 20k expert annotations of the left ventricle and published as part of ‘Video-based AI for beat-to-beat assessment of cardiac function’ by Ouyang et al. in Nature, 2020."
87
+ thumbnail = "https://raw.githubusercontent.com/gradio-app/hub-echonet/master/thumbnail.png"
88
+
89
+ gr.Interface(segment, i, o, examples=examples, allow_flagging=False, analytics_enabled=False,
90
+ title=title, description=description, thumbnail=thumbnail).launch()
img1.jpg ADDED
img2.jpg ADDED
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ -f https://download.pytorch.org/whl/torch_stable.html
2
+ numpy
3
+ matplotlib
4
+ wget
5
+ torch==1.6.0+cpu
6
+ torchvision==0.7.0+cpu
7
+
thumbnail.png ADDED