schirrmacher commited on
Commit
913742e
1 Parent(s): 94305f2

Upload folder using huggingface_hub

Browse files
Files changed (2) hide show
  1. app.py +1 -1
  2. ormbg.py +11 -0
app.py CHANGED
@@ -3,7 +3,7 @@ import numpy as np
3
  import torch
4
  import torch.nn.functional as F
5
  import gradio as gr
6
- from ormbg.models.ormbg import ORMBG
7
  from PIL import Image
8
 
9
  model_path = "models/ormbg.pth"
 
3
  import torch
4
  import torch.nn.functional as F
5
  import gradio as gr
6
+ from ormbg import ORMBG
7
  from PIL import Image
8
 
9
  model_path = "models/ormbg.pth"
ormbg.py CHANGED
@@ -357,6 +357,9 @@ class myrebnconv(nn.Module):
357
  return self.rl(self.bn(self.conv(x)))
358
 
359
 
 
 
 
360
  class ORMBG(nn.Module):
361
 
362
  def __init__(self, in_ch=3, out_ch=1):
@@ -398,6 +401,14 @@ class ORMBG(nn.Module):
398
 
399
  # self.outconv = nn.Conv2d(6*out_ch,out_ch,1)
400
 
 
 
 
 
 
 
 
 
401
  def forward(self, x):
402
 
403
  hx = x
 
357
  return self.rl(self.bn(self.conv(x)))
358
 
359
 
360
+ bce_loss = nn.BCELoss(size_average=True)
361
+
362
+
363
  class ORMBG(nn.Module):
364
 
365
  def __init__(self, in_ch=3, out_ch=1):
 
401
 
402
  # self.outconv = nn.Conv2d(6*out_ch,out_ch,1)
403
 
404
+ def compute_loss(self, predictions, ground_truth):
405
+ loss0, loss = 0.0, 0.0
406
+ for i in range(0, len(predictions)):
407
+ loss = loss + bce_loss(predictions[i], ground_truth)
408
+ if i == 0:
409
+ loss0 = loss
410
+ return loss0, loss
411
+
412
  def forward(self, x):
413
 
414
  hx = x