rashmi commited on
Commit
60f558e
1 Parent(s): 9cdd2d7

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +68 -0
app.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.system('pip install detectron2 -f https://dl.fbaipublicfiles.com/detectron2/wheels/cu102/torch1.9/index.html')
3
+ os.system('pip install torch==1.9.0 torchvision==0.10.0')
4
+
5
+ import gradio as gr
6
+ # check pytorch installation:
7
+ import torch, torchvision
8
+ print(torch.__version__, torch.cuda.is_available())
9
+ assert torch.__version__.startswith("1.9") # please manually install torch 1.9 if Colab changes its default version
10
+ # Some basic setup:
11
+ # Setup detectron2 logger
12
+ import detectron2
13
+ from detectron2.utils.logger import setup_logger
14
+ # import some common libraries
15
+ import numpy as np
16
+ import os, json, cv2, random
17
+ # import some common detectron2 utilities
18
+ from detectron2 import model_zoo
19
+ from detectron2.engine import DefaultPredictor
20
+ from detectron2.config import get_cfg
21
+ from detectron2.utils.visualizer import Visualizer, ColorMode
22
+ from detectron2.data import MetadataCatalog, DatasetCatalog
23
+ from PIL import Image
24
+ from pathlib import Path
25
+ from detectron2.data.datasets import register_coco_instances
26
+ from matplotlib import pyplot as plt
27
+
28
+
29
+ cfg = get_cfg()
30
+ cfg.MODEL.DEVICE='cpu'
31
+ # add project-specific config (e.g., TensorMask) here if you're not running a model in detectron2's core library
32
+ cfg.INPUT.MASK_FORMAT='bitmask'
33
+ cfg.MODEL.ROI_HEADS.NUM_CLASSES = 3
34
+ cfg.TEST.DETECTIONS_PER_IMAGE = 1000
35
+ cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml"))
36
+ cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5 # set threshold for this model
37
+ # Find a model from detectron2's model zoo. You can use the https://dl.fbaipublicfiles... url as well
38
+ cfg.MODEL.WEIGHTS = "model_final.pth"
39
+
40
+ predictor = DefaultPredictor(cfg)
41
+
42
+
43
+ def inference(img):
44
+ # im = cv2.imread(img.name)
45
+ im = cv2.imread(img)
46
+ outputs = predictor(im)
47
+
48
+ take = outputs['instances'].scores >= 0.5 #Threshold
49
+ pred_masks = outputs['instances'].pred_masks[take].cpu().numpy()
50
+
51
+ mask = np.stack(pred_masks)
52
+ mask = np.any(mask == 1, axis=0)
53
+
54
+ p = plt.imshow(im,cmap='gray')
55
+ p1 = plt.imshow(mask, alpha=0.4)
56
+
57
+ return plt
58
+
59
+
60
+
61
+ title = "Sartorius Cell Instance Segmentation"
62
+ description = "Sartorius Cell Instance Segmentation Demo: Current Kaggle competition - kaggle.com/c/sartorius-cell-instance-segmentation"
63
+ article = "<p style='text-align: center'><a href='https://ai.facebook.com/blog/-detectron2-a-pytorch-based-modular-object-detection-library-/' target='_blank'>Detectron2: A PyTorch-based modular object detection library</a> | <a href='https://github.com/facebookresearch/detectron2' target='_blank'>Github Repo</a></p>"
64
+ examples = [['0030fd0e6378.png']]
65
+ gr.Interface(inference, inputs=gr.inputs.Image(type="filepath"), outputs=gr.outputs.Image('plot') ,enable_queue=True, title=title,
66
+ description=description,
67
+ article=article,
68
+ examples=examples).launch(debug=False)