aliabd HF staff commited on
Commit
a461fb8
·
1 Parent(s): bc3aef8

Upload with huggingface_hub

Browse files
Files changed (5) hide show
  1. README.md +5 -6
  2. img1.jpg +0 -0
  3. img2.jpg +0 -0
  4. requirements.txt +8 -0
  5. run.py +91 -0
README.md CHANGED
@@ -1,12 +1,11 @@
 
1
  ---
2
- title: Echocardiogram-Segmentation Main
3
- emoji: 👀
4
- colorFrom: green
5
  colorTo: indigo
6
  sdk: gradio
7
  sdk_version: 3.6
8
- app_file: app.py
9
  pinned: false
10
  ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
+
2
  ---
3
+ title: Echocardiogram-Segmentation_main
4
+ emoji: 🔥
5
+ colorFrom: indigo
6
  colorTo: indigo
7
  sdk: gradio
8
  sdk_version: 3.6
9
+ app_file: run.py
10
  pinned: false
11
  ---
 
 
img1.jpg ADDED
img2.jpg ADDED
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ -f https://download.pytorch.org/whl/torch_stable.html
2
+ numpy
3
+ matplotlib
4
+ wget
5
+ torch==1.6.0+cpu
6
+ torchvision==0.7.0+cpu
7
+
8
+ https://gradio-main-build.s3.amazonaws.com/c3bec6153737855510542e8154391f328ac72606/gradio-3.6-py3-none-any.whl
run.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from os.path import splitext
3
+ import numpy as np
4
+ import sys
5
+ import matplotlib.pyplot as plt
6
+ import torch
7
+ import torchvision
8
+ import wget
9
+
10
+
11
+ destination_folder = "output"
12
+ destination_for_weights = "weights"
13
+
14
+ if os.path.exists(destination_for_weights):
15
+ print("The weights are at", destination_for_weights)
16
+ else:
17
+ print("Creating folder at ", destination_for_weights, " to store weights")
18
+ os.mkdir(destination_for_weights)
19
+
20
+ segmentationWeightsURL = 'https://github.com/douyang/EchoNetDynamic/releases/download/v1.0.0/deeplabv3_resnet50_random.pt'
21
+
22
+ if not os.path.exists(os.path.join(destination_for_weights, os.path.basename(segmentationWeightsURL))):
23
+ print("Downloading Segmentation Weights, ", segmentationWeightsURL," to ",os.path.join(destination_for_weights, os.path.basename(segmentationWeightsURL)))
24
+ filename = wget.download(segmentationWeightsURL, out = destination_for_weights)
25
+ else:
26
+ print("Segmentation Weights already present")
27
+
28
+ torch.cuda.empty_cache()
29
+
30
+ def collate_fn(x):
31
+ x, f = zip(*x)
32
+ i = list(map(lambda t: t.shape[1], x))
33
+ x = torch.as_tensor(np.swapaxes(np.concatenate(x, 1), 0, 1))
34
+ return x, f, i
35
+
36
+ model = torchvision.models.segmentation.deeplabv3_resnet50(pretrained=False, aux_loss=False)
37
+ model.classifier[-1] = torch.nn.Conv2d(model.classifier[-1].in_channels, 1, kernel_size=model.classifier[-1].kernel_size)
38
+
39
+ print("loading weights from ", os.path.join(destination_for_weights, "deeplabv3_resnet50_random"))
40
+
41
+ if torch.cuda.is_available():
42
+ print("cuda is available, original weights")
43
+ device = torch.device("cuda")
44
+ model = torch.nn.DataParallel(model)
45
+ model.to(device)
46
+ checkpoint = torch.load(os.path.join(destination_for_weights, os.path.basename(segmentationWeightsURL)))
47
+ model.load_state_dict(checkpoint['state_dict'])
48
+ else:
49
+ print("cuda is not available, cpu weights")
50
+ device = torch.device("cpu")
51
+ checkpoint = torch.load(os.path.join(destination_for_weights, os.path.basename(segmentationWeightsURL)), map_location = "cpu")
52
+ state_dict_cpu = {k[7:]: v for (k, v) in checkpoint['state_dict'].items()}
53
+ model.load_state_dict(state_dict_cpu)
54
+
55
+ model.eval()
56
+
57
+ def segment(input):
58
+ inp = input
59
+ x = inp.transpose([2, 0, 1]) # channels-first
60
+ x = np.expand_dims(x, axis=0) # adding a batch dimension
61
+
62
+ mean = x.mean(axis=(0, 2, 3))
63
+ std = x.std(axis=(0, 2, 3))
64
+ x = x - mean.reshape(1, 3, 1, 1)
65
+ x = x / std.reshape(1, 3, 1, 1)
66
+
67
+ with torch.no_grad():
68
+ x = torch.from_numpy(x).type('torch.FloatTensor').to(device)
69
+ output = model(x)
70
+
71
+ y = output['out'].numpy()
72
+ y = y.squeeze()
73
+
74
+ out = y>0
75
+
76
+ mask = inp.copy()
77
+ mask[out] = np.array([0, 0, 255])
78
+
79
+ return mask
80
+
81
+ import gradio as gr
82
+
83
+ i = gr.inputs.Image(shape=(112, 112), label="Echocardiogram")
84
+ o = gr.outputs.Image(label="Segmentation Mask")
85
+
86
+ examples = [["img1.jpg"], ["img2.jpg"]]
87
+ title = None #"Left Ventricle Segmentation"
88
+ description = "This semantic segmentation model identifies the left ventricle in echocardiogram images."
89
+ # videos. Accurate evaluation of the motion and size of the left ventricle is crucial for the assessment of cardiac function and ejection fraction. In this interface, the user inputs apical-4-chamber images from echocardiography videos and the model will output a prediction of the localization of the left ventricle in blue. This model was trained on the publicly released EchoNet-Dynamic dataset of 10k echocardiogram videos with 20k expert annotations of the left ventricle and published as part of ‘Video-based AI for beat-to-beat assessment of cardiac function’ by Ouyang et al. in Nature, 2020."
90
+ thumbnail = "https://raw.githubusercontent.com/gradio-app/hub-echonet/master/thumbnail.png"
91
+ gr.Interface(segment, i, o, examples=examples, allow_flagging=False, analytics_enabled=False, thumbnail=thumbnail, cache_examples=False).launch()