not-lain commited on
Commit
92d4bc4
·
1 Parent(s): a41f85a

fix corner cases

Browse files
Files changed (1) hide show
  1. app.py +94 -46
app.py CHANGED
@@ -16,6 +16,12 @@ t2 = torch.arange({n2}).view({dim2})
16
 
17
  """
18
 
 
 
 
 
 
 
19
 
20
  def generate_example(dim1: list, dim2: list):
21
  n1 = 1
@@ -58,25 +64,35 @@ def sanitize_dimention(dim):
58
  return out
59
 
60
 
61
- def create_row(dim,is_dim=None,checks=None):
62
  out = "| "
63
  n_dim = len(dim)
64
  for i in range(n_dim):
65
- # infered last dims
66
- if (is_dim ==1 and i == n_dim-2) or (is_dim ==2 and i ==n_dim-1):
67
- color = "green"
68
- out += f"<strong style='color: {color}'> {dim[i]} </strong>| "
69
- # check every normal dimension
70
- elif (is_dim ==1 and i != n_dim-1) or (is_dim ==2 and i ==n_dim-1):
71
- color = "green" if checks[i] == "V" else "red"
72
- out += f"<strong style='color: {color}'> {dim[i]} </strong>| "
73
- # checks last 2 dims
74
- elif (is_dim ==1 and i == n_dim-1) or (is_dim ==2 and i ==n_dim-2):
75
- color = "blue" if checks[i] == "V" else "yellow"
76
- out += f"<strong style='color: {color}'> {dim[i]} </strong>| "
77
- # when using this function without checks
78
- else :
79
- out+= f"{dim[i]} | "
 
 
 
 
 
 
 
 
 
 
80
  return out + "\n"
81
 
82
 
@@ -89,20 +105,20 @@ def create_header(n_dim, checks=None):
89
  return out
90
 
91
 
92
- def generate_table(dim1, dim2, checks=None):
93
  n_dim = len(dim1)
94
  table = create_header(n_dim, checks)
95
  # tensor 1
96
  if not checks :
97
  table += create_row(dim1)
98
  else :
99
- table += create_row(dim1,1,checks)
100
 
101
  # tensor 2
102
  if not checks :
103
  table += create_row(dim2)
104
  else :
105
- table += create_row(dim2,2,checks)
106
  return table
107
 
108
 
@@ -122,8 +138,6 @@ def alignment_and_fill_with_ones(dim1, dim2):
122
  return dim1, dim2
123
 
124
  def check_validity(dim1,dim2):
125
- if len(dim1) < 2:
126
- return ["WIP"] * len(dim1)
127
  out = []
128
  for i in range(len(dim1)-2):
129
  if dim1[i] == dim2[i]:
@@ -138,8 +152,9 @@ def check_validity(dim1,dim2):
138
  return out
139
 
140
 
141
- def substitute_ones_with_concat(dim1,dim2):
142
- for i in range(len(dim1)-1):
 
143
  dim1[i] = dim2[i] if dim1[i] == 1 else dim1[i]
144
  dim2[i] = dim1[i] if dim2[i] == 1 else dim2[i]
145
  return dim1, dim2
@@ -147,30 +162,63 @@ def substitute_ones_with_concat(dim1,dim2):
147
  def predict(dim1, dim2):
148
  dim1 = sanitize_dimention(dim1)
149
  dim2 = sanitize_dimention(dim2)
150
- dim1, dim2, code = generate_example(dim1, dim2)
 
151
  # TODO
152
- # fix for dims if one or both have dimensions is 1
153
- # Table 1
154
- dim1, dim2 = alignment_and_fill_with_ones(dim1, dim2)
155
- table1 = generate_table(dim1, dim2)
156
- # Table 2
157
- dim1, dim2 = substitute_ones_with_concat(dim1,dim2)
158
- table2 = generate_table(dim1, dim2)
159
- # Table 3
160
- checks = check_validity(dim1,dim2)
161
- table3 = generate_table(dim1,dim2,checks)
162
-
163
- out = code
164
- out += "\n# Step1 (alignment and pre_append with ones)\n" + table1
165
- out += "\n# Step2 (susbtitute columns that have 1 with concat)\nexcept for last 2 dimensions\n" + table2
166
- out += "\n# Step3 (check if matrix multiplication is valid)\n"
167
- out += "* last dimension of dim1 should equal before last dimension of dim2 (blue or yellow colors)\n"
168
- out += "* all the other dimensions should be equal to one another (green or red colors)\n\n" + table3
169
- if "X" not in checks :
170
- dim1[-1] = dim2[-1]
171
- out += "\n# Final dimension\n"
172
- out+="as highlighted in <strong style='color:green'> green </strong> \n\n"
173
- out+= f"`output.shape = {dim1}`"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
174
  return out
175
 
176
 
 
16
 
17
  """
18
 
19
+ matrix_loop = """```python
20
+ out = 0
21
+ for i, j in zip(t1, t2):
22
+ out += i * j
23
+ ```
24
+ """
25
 
26
  def generate_example(dim1: list, dim2: list):
27
  n1 = 1
 
64
  return out
65
 
66
 
67
+ def create_row(dim,is_dim=None,checks=None,version=1):
68
  out = "| "
69
  n_dim = len(dim)
70
  for i in range(n_dim):
71
+ if version == 1 :
72
+ # infered last dims
73
+ if (is_dim ==1 and i == n_dim-2) or (is_dim ==2 and i ==n_dim-1):
74
+ color = "green"
75
+ out += f"<strong style='color: {color}'> {dim[i]} </strong>| "
76
+ # check every normal dimension
77
+ elif (is_dim ==1 and i != n_dim-1) or (is_dim ==2 and i ==n_dim-1):
78
+ color = "green" if checks[i] == "V" else "red"
79
+ out += f"<strong style='color: {color}'> {dim[i]} </strong>| "
80
+ # checks last 2 dims
81
+ elif (is_dim ==1 and i == n_dim-1) or (is_dim ==2 and i ==n_dim-2):
82
+ color = "blue" if checks[i] == "V" else "yellow"
83
+ out += f"<strong style='color: {color}'> {dim[i]} </strong>| "
84
+ # when using this function without checks
85
+ else :
86
+ out+= f"{dim[i]} | "
87
+ if version == 2 :
88
+ if (is_dim == 1 and i != n_dim-1) :
89
+ out += f"<strong style='color: green'> {dim[i]} </strong>| "
90
+ elif i == n_dim-1 :
91
+ color = "blue" if checks[i] == "V" else "yellow"
92
+ out += f"<strong style='color: {color}'> {dim[i]} </strong>| "
93
+ else :
94
+ out += f"{dim[i]} | "
95
+
96
  return out + "\n"
97
 
98
 
 
105
  return out
106
 
107
 
108
+ def generate_table(dim1, dim2, checks=None,version=1):
109
  n_dim = len(dim1)
110
  table = create_header(n_dim, checks)
111
  # tensor 1
112
  if not checks :
113
  table += create_row(dim1)
114
  else :
115
+ table += create_row(dim1,1,checks,version)
116
 
117
  # tensor 2
118
  if not checks :
119
  table += create_row(dim2)
120
  else :
121
+ table += create_row(dim2,2,checks,version)
122
  return table
123
 
124
 
 
138
  return dim1, dim2
139
 
140
  def check_validity(dim1,dim2):
 
 
141
  out = []
142
  for i in range(len(dim1)-2):
143
  if dim1[i] == dim2[i]:
 
152
  return out
153
 
154
 
155
+ def substitute_ones_with_concat(dim1,dim2,version = 1):
156
+ n = len(dim1)-2 if version ==1 else len(dim1)-1
157
+ for i in range(n):
158
  dim1[i] = dim2[i] if dim1[i] == 1 else dim1[i]
159
  dim2[i] = dim1[i] if dim2[i] == 1 else dim2[i]
160
  return dim1, dim2
 
162
  def predict(dim1, dim2):
163
  dim1 = sanitize_dimention(dim1)
164
  dim2 = sanitize_dimention(dim2)
165
+ n1 , n2 = len(dim1) , len(dim2)
166
+ dim1, dim2, out = generate_example(dim1, dim2)
167
  # TODO
168
+ if n1 >1 and n2 > 1 :
169
+ # Table 1
170
+ dim1, dim2 = alignment_and_fill_with_ones(dim1, dim2)
171
+ table1 = generate_table(dim1, dim2)
172
+ # Table 2
173
+ dim1, dim2 = substitute_ones_with_concat(dim1,dim2)
174
+ table2 = generate_table(dim1, dim2)
175
+ # Table 3
176
+ checks = check_validity(dim1,dim2)
177
+ table3 = generate_table(dim1,dim2,checks)
178
+
179
+ out += "\n# Step1 (alignment and pre_append with ones)\n" + table1
180
+ out += "\n# Step2 (susbtitute columns that have 1 with concat)\nexcept for last 2 dimensions\n" + table2
181
+ out += "\n# Step3 (check if matrix multiplication is valid)\n"
182
+ out += "* last dimension of dim1 should equal before last dimension of dim2 (blue or yellow colors)\n"
183
+ out += "* all the other dimensions should be equal to one another (green or red colors)\n\n" + table3
184
+ if "X" not in checks :
185
+ dim1[-1] = dim2[-1]
186
+ out += "\n# Final dimension\n"
187
+ out+="as highlighted in <strong style='color:green'> green </strong> \n\n"
188
+ out+= f"`output.shape = {dim1}`"
189
+ # case single dims
190
+ elif n1 == 1 and n2 == 1 :
191
+ out += "# Single Dimensional Cases\n"
192
+ out += "When both matricies have only single dims they should both have the same number of values in the first dimension\n"
193
+ out += "meaning that `t1.shape == t2.shape`\n"
194
+ out += "the output is a single value, think : \n"
195
+ out += matrix_loop
196
+ else :
197
+ out += "# One of the dimensions is a single dimension\n"
198
+ out += "In this case we need to assert that the last dimension of `t1` "
199
+ out += "is equal to the last dimension of `t2`\n"
200
+ out += "Once the assertion is valid then we get rid of the last dimension and keep the rest\n"
201
+ out += "# Step 1 (alignment and fill with ones)\n"
202
+ dim1, dim2 = alignment_and_fill_with_ones(dim1, dim2)
203
+ table = generate_table(dim1, dim2)
204
+ out += table
205
+ out += "\n# Step2 (susbtitute columns that have 1 with concat)\n"
206
+ out += ""
207
+ dim1, dim2 = substitute_ones_with_concat(dim1,dim2,2)
208
+ checks = ["V"] * (len(dim1)-1)
209
+ if dim1[-1] == dim2[-1] :
210
+ checks.append("V")
211
+ else :
212
+ checks.append("X")
213
+ table = generate_table(dim1, dim2,checks,2)
214
+ out+= table
215
+ if "X" not in checks :
216
+ out += "\n#Final dimension"
217
+ out += "The final dimension is everything colored in <strong style='color:green'> green </strong> \n"
218
+ out += f"\nfinal dimension = `{dim1[:-1]}` "
219
+
220
+
221
+
222
  return out
223
 
224