|
|
|
|
|
|
|
import numpy as np |
|
|
|
class ConfusionMatrix(object): |
|
|
|
def __init__(self, n_classes): |
|
self.n_classes = n_classes |
|
|
|
|
|
self.confusion_matrix = np.zeros((n_classes, n_classes)) |
|
|
|
def _fast_hist(self, label_true, label_pred, n_class): |
|
hist = np.zeros((n_class, n_class)) |
|
hist[label_pred, label_true] += 1 |
|
|
|
return hist |
|
|
|
def update(self, label_trues, label_preds): |
|
for lt, lp in zip(label_trues, label_preds): |
|
tmp = self._fast_hist(lt.item(), lp.item(), self.n_classes) |
|
self.confusion_matrix += tmp |
|
|
|
def get_scores(self): |
|
"""Returns accuracy score evaluation result. |
|
- overall accuracy |
|
- mean accuracy |
|
- mean IU |
|
- fwavacc |
|
""" |
|
hist = self.confusion_matrix |
|
|
|
|
|
|
|
if sum(hist.sum(axis=1)) != 0: |
|
acc = sum(np.diag(hist)) / sum(hist.sum(axis=1)) |
|
else: |
|
acc = 0.0 |
|
|
|
return acc |
|
|
|
def plotcm(self): |
|
print(self.confusion_matrix) |
|
|
|
def reset(self): |
|
self.confusion_matrix = np.zeros((self.n_classes, self.n_classes)) |