dustalov commited on
Commit
ee1c835
·
verified ·
1 Parent(s): c4d6746

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -3
app.py CHANGED
@@ -46,7 +46,7 @@ def aggregate(wins: npt.NDArray[np.int64], ties: npt.NDArray[np.int64],
46
  )
47
 
48
  v = v_numerator / v_denominator
49
- v = np.nan_to_num(v, copy=False)
50
 
51
  pi_old = pi.copy()
52
 
@@ -63,14 +63,18 @@ def aggregate(wins: npt.NDArray[np.int64], ties: npt.NDArray[np.int64],
63
  )
64
 
65
  pi = pi_numerator / pi_denominator
66
- pi = np.nan_to_num(pi, copy=False)
67
 
68
- converged = bool(np.all(np.abs(pi / (pi + 1) - pi_old / (pi_old + 1)) < tolerance)) or (iterations >= limit)
 
69
 
70
  return pi
71
 
72
 
73
  def handler(file: typing.IO[bytes], seed: int) -> pd.DataFrame:
 
 
 
74
  try:
75
  df = pd.read_csv(file.name, dtype=str)
76
  except ValueError as e:
 
46
  )
47
 
48
  v = v_numerator / v_denominator
49
+ v = np.nan_to_num(v, copy=False, nan=tolerance)
50
 
51
  pi_old = pi.copy()
52
 
 
63
  )
64
 
65
  pi = pi_numerator / pi_denominator
66
+ pi = np.nan_to_num(pi, copy=False, nan=tolerance)
67
 
68
+ converged = np.allclose(pi / (pi + 1), pi_old / (pi_old + 1),
69
+ rtol=tolerance, atol=tolerance) or (iterations >= limit)
70
 
71
  return pi
72
 
73
 
74
  def handler(file: typing.IO[bytes], seed: int) -> pd.DataFrame:
75
+ if file is None:
76
+ raise gr.Error('File must be uploaded')
77
+
78
  try:
79
  df = pd.read_csv(file.name, dtype=str)
80
  except ValueError as e: