File size: 7,844 Bytes
331412c 039cd66 331412c 92d4bc4 331412c b1045a7 331412c 9c7dc56 331412c a951e4b 331412c a951e4b 331412c 039cd66 b1045a7 9c7dc56 e64e782 b1045a7 92d4bc4 b1045a7 92d4bc4 b1045a7 92d4bc4 b1045a7 92d4bc4 b1045a7 92d4bc4 b1045a7 92d4bc4 b1045a7 92d4bc4 b1045a7 9c7dc56 b1045a7 9c7dc56 039cd66 b1045a7 e64e782 b1045a7 e64e782 039cd66 b1045a7 e64e782 b1045a7 9c7dc56 b1045a7 9c7dc56 b1045a7 9c7dc56 b1045a7 9c7dc56 b1045a7 9c7dc56 b1045a7 92d4bc4 9c7dc56 331412c b1045a7 331412c a951e4b b1045a7 92d4bc4 b1045a7 92d4bc4 b1045a7 92d4bc4 b1045a7 92d4bc4 b1045a7 80a7762 b1045a7 92d4bc4 b1045a7 92d4bc4 b1045a7 92d4bc4 b1045a7 92d4bc4 b1045a7 fe7995a 92d4bc4 00210e6 b1045a7 92d4bc4 b1045a7 92d4bc4 b1045a7 92d4bc4 039cd66 331412c b1045a7 182bc8c 7ff9ff1 331412c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 |
import gradio as gr
import torch
EXAMPLE_MD = """
```python
import torch
t1 = torch.arange({n1}).view({dim1})
t2 = torch.arange({n2}).view({dim2})
(t1 @ t2).shape = {out_shape}
```
"""
matrix_loop = """```python
out = 0
for i, j in zip(t1, t2):
out += i * j
```
"""
def generate_example(dim1: list, dim2: list):
n1 = 1
n2 = 1
for i in dim1:
n1 *= i
for i in dim2:
n2 *= i
t1 = torch.arange(n1).view(dim1)
t2 = torch.arange(n2).view(dim2)
try:
out_shape = list((t1 @ t2).shape)
except RuntimeError:
out_shape = "error"
code = EXAMPLE_MD.format(
n1=str(n1), dim1=str(dim1), n2=str(n2), dim2=str(dim2), out_shape=str(out_shape)
)
return dim1, dim2, code
def sanitize_dimension(dim):
if dim is None:
gr.Error("one of the dimensions is empty, please fill it")
if "[" in dim:
dim = dim.replace("[", "")
if "]" in dim:
dim = dim.replace("]", "")
if "," in dim:
dim = dim.replace(",", " ").strip()
out = [int(i.strip()) for i in dim.split()]
else:
out = [int(dim.strip())]
if 0 in out:
gr.Error(
"Found the number 0 in one of the dimensions which is not allowed, consider using 1 instead"
)
return out
def create_row(dim, is_dim=None, checks=None, version=1):
out = "| "
n_dim = len(dim)
for i in range(n_dim):
if version == 1:
# infered last dims
if (is_dim == 1 and i == n_dim - 2) or (is_dim == 2 and i == n_dim - 1):
color = "green"
out += f"<strong style='color: {color}'> {dim[i]} </strong>| "
# check every normal dimension
elif (is_dim == 1 and i != n_dim - 1) or (is_dim == 2 and i == n_dim - 1):
color = "green" if checks[i] == "V" else "red"
out += f"<strong style='color: {color}'> {dim[i]} </strong>| "
# checks last 2 dims
elif (is_dim == 1 and i == n_dim - 1) or (is_dim == 2 and i == n_dim - 2):
color = "blue" if checks[i] == "V" else "yellow"
out += f"<strong style='color: {color}'> {dim[i]} </strong>| "
# when using this function without checks
else:
out += f"{dim[i]} | "
if version == 2:
if is_dim == 1 and i != n_dim - 1:
out += f"<strong style='color: green'> {dim[i]} </strong>| "
elif i == n_dim - 1:
color = "blue" if checks[i] == "V" else "yellow"
out += f"<strong style='color: {color}'> {dim[i]} </strong>| "
else:
out += f"{dim[i]} | "
return out + "\n"
def create_header(n_dim, checks=None):
checks = ["<!-- -->"] * n_dim if checks is None else checks
out = "| "
for i in checks:
out = out + i + " | "
out += "\n" + "|---" * n_dim + "|\n"
return out
def generate_table(dim1, dim2, checks=None, version=1):
n_dim = len(dim1)
table = create_header(n_dim, checks)
# tensor 1
if not checks:
table += create_row(dim1)
else:
table += create_row(dim1, 1, checks, version)
# tensor 2
if not checks:
table += create_row(dim2)
else:
table += create_row(dim2, 2, checks, version)
return table
def alignment_and_fill_with_ones(dim1, dim2):
n_dim = max(len(dim1), len(dim2))
if len(dim1) == len(dim2):
pass
elif len(dim1) < len(dim2):
placeholder = [1] * (n_dim - len(dim1))
placeholder.extend(dim1)
dim1 = placeholder
else:
placeholder = [1] * (n_dim - len(dim2))
placeholder.extend(dim2)
dim2 = placeholder
return dim1, dim2
def check_validity(dim1, dim2):
out = []
for i in range(len(dim1) - 2):
if dim1[i] == dim2[i]:
out.append("V")
else:
out.append("X")
# final dims
if dim1[-1] == dim2[-2]:
out.extend(["V", "V"])
else:
out.extend(["X", "X"])
return out
def substitute_ones_with_concat(dim1, dim2, version=1):
n = len(dim1) - 2 if version == 1 else len(dim1) - 1
for i in range(n):
dim1[i] = dim2[i] if dim1[i] == 1 else dim1[i]
dim2[i] = dim1[i] if dim2[i] == 1 else dim2[i]
return dim1, dim2
def predict(dim1, dim2):
dim1 = sanitize_dimension(dim1)
dim2 = sanitize_dimension(dim2)
n1, n2 = len(dim1), len(dim2)
dim1, dim2, out = generate_example(dim1, dim2)
# TODO
if n1 > 1 and n2 > 1:
# Table 1
dim1, dim2 = alignment_and_fill_with_ones(dim1, dim2)
table1 = generate_table(dim1, dim2)
# Table 2
dim1, dim2 = substitute_ones_with_concat(dim1, dim2)
table2 = generate_table(dim1, dim2)
# Table 3
checks = check_validity(dim1, dim2)
table3 = generate_table(dim1, dim2, checks)
out += "\n# Step1 (alignment and pre_append with ones)\n" + table1
out += (
"\n# Step2 (substitute columns that have 1 with concat)\nexcept for last 2 dimensions\n"
+ table2
)
out += "\n# Step3 (check if matrix multiplication is valid)\n"
out += "* last dimension of dim1 should equal before last dimension of dim2 (blue or yellow colors)\n"
out += (
"* all the other dimensions should be equal to one another (green or red colors)\n\n"
+ table3
)
if "X" not in checks:
dim1[-1] = dim2[-1]
out += "\n# Final dimension\n"
out += "as highlighted in <strong style='color:green'> green </strong> \n\n"
out += f"`output.shape = {dim1}`"
# case single dims
elif n1 == 1 and n2 == 1:
out += "# Single Dimensional Cases\n"
out += "When both matricies have only single dims they should both have the same number of values in the first dimension\n"
out += "meaning that `t1.shape == t2.shape`\n"
out += "the output is a single value, think : \n"
out += matrix_loop
else:
out += "# One of the tensors has a single dimension\n"
out += "In this case we need to assert that the last dimension of `t1` "
out += "is equal to the last dimension of `t2`\n"
out += "Once the assertion is valid then we get rid of the last dimension and keep the rest\n"
out += "# Step 1 (alignment and fill with ones)\n"
dim1, dim2 = alignment_and_fill_with_ones(dim1, dim2)
table = generate_table(dim1, dim2)
out += table
out += "\n# Step2 (susbtitute columns that have 1 with concat)\n"
out += "fill all previous columns with ones\n"
dim1, dim2 = substitute_ones_with_concat(dim1, dim2, 2)
checks = ["V"] * (len(dim1) - 1)
if dim1[-1] == dim2[-1]:
checks.append("V")
else:
checks.append("X")
table = generate_table(dim1, dim2, checks, 2)
out += table
if "X" not in checks:
out += "\n#Final dimension"
out += "The final dimension is everything colored in <strong style='color:green'> green </strong> \n"
out += f"\nfinal dimension = `{dim1[:-1]}` "
return out
demo = gr.Interface(
predict,
inputs=["text", "text"],
outputs=["markdown"],
examples=[
["9,2,1,3,3", "5,3,7"],
["7,4,2,3", "5,2,7"],
["4,5,6,7", "7"],
["7,5,3", "4"],
["5", "5"],
["8", "2"],
],
title= "Pytorch Matrix Multiplication",
description= """There are 3 cases which are covered in the examples:
* Both matricies have dimensions bigger than 1
* One of the matracies have a single dimension
* Both Matracies have a single dimension
""",
)
demo.launch(debug=True)
|