File size: 96,653 Bytes
239ee43 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207 1208 1209 1210 1211 1212 1213 1214 1215 1216 1217 1218 1219 1220 1221 1222 1223 1224 1225 1226 1227 1228 1229 1230 1231 1232 1233 1234 1235 1236 1237 1238 1239 1240 1241 1242 1243 1244 1245 1246 1247 1248 1249 1250 1251 1252 1253 1254 1255 1256 1257 1258 1259 1260 1261 1262 1263 1264 1265 1266 1267 1268 1269 1270 1271 1272 1273 1274 1275 1276 1277 1278 1279 1280 1281 1282 1283 1284 1285 1286 1287 1288 1289 1290 1291 1292 1293 1294 1295 1296 1297 1298 1299 1300 1301 1302 1303 1304 1305 1306 1307 1308 1309 1310 1311 1312 1313 1314 1315 1316 1317 1318 1319 1320 1321 1322 1323 1324 1325 1326 1327 1328 1329 1330 1331 1332 1333 1334 1335 1336 1337 1338 1339 1340 1341 1342 1343 1344 1345 1346 1347 1348 1349 1350 1351 1352 1353 1354 1355 1356 1357 1358 1359 1360 1361 1362 1363 1364 1365 1366 1367 1368 1369 1370 1371 1372 1373 1374 1375 1376 1377 1378 1379 1380 1381 1382 1383 1384 1385 1386 1387 1388 1389 1390 1391 1392 1393 1394 1395 1396 1397 1398 1399 1400 1401 1402 1403 1404 1405 1406 1407 1408 1409 1410 1411 1412 1413 1414 1415 1416 1417 1418 1419 1420 1421 1422 1423 1424 1425 1426 1427 1428 1429 1430 1431 1432 1433 1434 1435 1436 1437 1438 1439 1440 1441 1442 1443 1444 1445 1446 1447 1448 1449 1450 1451 1452 1453 1454 1455 1456 1457 1458 1459 1460 1461 1462 1463 1464 1465 1466 1467 1468 1469 1470 1471 1472 1473 1474 1475 1476 1477 1478 1479 1480 1481 1482 1483 1484 1485 1486 1487 1488 1489 1490 1491 1492 1493 1494 1495 1496 1497 1498 1499 1500 1501 1502 1503 1504 1505 1506 1507 1508 1509 1510 1511 1512 1513 1514 1515 1516 1517 1518 1519 1520 1521 1522 1523 1524 1525 1526 1527 1528 1529 1530 1531 1532 1533 1534 1535 1536 1537 1538 1539 1540 1541 1542 1543 1544 1545 1546 1547 1548 1549 1550 1551 1552 1553 1554 1555 1556 1557 1558 1559 1560 1561 1562 1563 1564 1565 1566 1567 1568 1569 1570 1571 1572 1573 1574 1575 1576 1577 1578 1579 1580 1581 1582 1583 1584 1585 1586 1587 1588 1589 1590 1591 1592 1593 1594 1595 1596 1597 1598 1599 1600 1601 1602 1603 1604 1605 1606 1607 1608 1609 1610 1611 1612 1613 1614 1615 1616 1617 1618 1619 1620 1621 1622 1623 1624 1625 1626 1627 1628 1629 1630 1631 1632 1633 1634 1635 1636 1637 1638 1639 1640 1641 1642 1643 1644 1645 1646 1647 1648 1649 1650 1651 1652 1653 1654 1655 1656 1657 1658 1659 1660 1661 1662 1663 1664 1665 1666 1667 1668 1669 1670 1671 1672 1673 1674 1675 1676 1677 1678 1679 1680 1681 1682 1683 1684 1685 1686 1687 1688 1689 1690 1691 1692 1693 1694 1695 1696 1697 1698 1699 1700 1701 1702 1703 1704 1705 1706 1707 1708 1709 1710 1711 1712 1713 1714 1715 1716 1717 1718 1719 1720 1721 1722 1723 1724 1725 1726 1727 1728 1729 1730 1731 1732 1733 1734 1735 1736 1737 1738 1739 1740 1741 1742 1743 1744 1745 1746 1747 1748 1749 1750 1751 1752 1753 1754 1755 1756 1757 1758 1759 1760 1761 1762 1763 1764 1765 1766 1767 1768 1769 1770 1771 1772 1773 1774 1775 1776 1777 1778 1779 1780 1781 1782 1783 1784 1785 1786 1787 1788 1789 1790 1791 1792 1793 1794 1795 1796 1797 1798 1799 1800 1801 1802 1803 1804 1805 1806 1807 1808 1809 1810 1811 1812 1813 1814 1815 1816 1817 1818 1819 1820 1821 1822 1823 1824 1825 1826 1827 1828 1829 1830 1831 1832 1833 1834 1835 1836 1837 1838 1839 1840 1841 1842 1843 1844 1845 1846 1847 1848 1849 1850 1851 1852 1853 1854 1855 1856 1857 1858 1859 1860 1861 1862 1863 1864 1865 1866 1867 1868 1869 1870 1871 1872 1873 1874 1875 1876 1877 1878 1879 1880 1881 1882 1883 1884 1885 1886 1887 1888 1889 1890 1891 1892 1893 1894 1895 1896 1897 1898 1899 1900 1901 1902 1903 1904 1905 1906 1907 1908 1909 1910 1911 1912 1913 1914 1915 1916 1917 1918 1919 1920 1921 1922 1923 1924 1925 1926 1927 1928 1929 1930 1931 1932 1933 1934 1935 1936 1937 1938 1939 1940 1941 1942 1943 1944 1945 1946 1947 1948 1949 1950 1951 1952 1953 1954 1955 1956 1957 1958 1959 1960 1961 1962 1963 1964 1965 1966 1967 1968 1969 1970 1971 1972 1973 1974 1975 1976 1977 1978 1979 1980 1981 1982 1983 1984 1985 1986 1987 1988 1989 1990 1991 1992 1993 1994 1995 1996 1997 1998 1999 2000 2001 2002 2003 2004 2005 2006 2007 2008 2009 2010 2011 2012 2013 2014 2015 2016 2017 2018 2019 2020 2021 2022 2023 2024 2025 2026 2027 2028 2029 2030 2031 2032 2033 2034 2035 2036 2037 2038 2039 2040 2041 2042 2043 2044 2045 2046 2047 2048 2049 2050 2051 2052 2053 2054 2055 2056 2057 2058 2059 2060 2061 2062 2063 2064 2065 2066 2067 2068 2069 2070 2071 2072 2073 2074 2075 2076 2077 2078 2079 2080 2081 2082 2083 2084 2085 2086 2087 2088 2089 2090 2091 2092 2093 2094 2095 2096 2097 2098 2099 2100 2101 2102 2103 2104 2105 2106 2107 2108 2109 2110 2111 2112 2113 2114 2115 2116 2117 2118 2119 2120 2121 2122 2123 2124 2125 2126 2127 2128 2129 2130 2131 2132 2133 2134 2135 2136 2137 2138 2139 2140 2141 2142 2143 2144 2145 2146 2147 2148 2149 2150 2151 2152 2153 2154 2155 2156 2157 2158 2159 2160 2161 2162 2163 2164 2165 2166 2167 2168 2169 2170 2171 2172 2173 2174 2175 2176 2177 2178 2179 2180 2181 2182 2183 2184 2185 2186 2187 2188 2189 2190 2191 2192 2193 2194 2195 2196 2197 2198 2199 2200 2201 2202 2203 2204 2205 2206 2207 2208 2209 2210 2211 2212 2213 2214 2215 2216 2217 2218 2219 2220 2221 2222 2223 2224 2225 2226 2227 2228 2229 2230 2231 2232 2233 2234 2235 2236 2237 2238 2239 2240 2241 2242 2243 2244 2245 2246 2247 2248 2249 2250 2251 2252 2253 2254 2255 2256 2257 2258 2259 2260 2261 2262 2263 2264 2265 2266 2267 2268 2269 2270 2271 2272 2273 2274 2275 2276 2277 2278 2279 2280 2281 2282 2283 2284 2285 2286 2287 2288 2289 2290 2291 2292 2293 2294 2295 2296 2297 2298 2299 2300 2301 2302 2303 2304 2305 2306 2307 2308 2309 2310 2311 2312 2313 2314 2315 2316 2317 2318 2319 2320 2321 2322 2323 2324 2325 2326 2327 2328 2329 2330 2331 2332 2333 2334 2335 2336 2337 2338 2339 2340 2341 2342 2343 2344 2345 2346 2347 2348 2349 2350 2351 2352 2353 2354 2355 2356 2357 2358 2359 2360 2361 2362 2363 2364 2365 2366 2367 2368 2369 2370 2371 2372 2373 2374 2375 2376 2377 2378 2379 2380 2381 2382 2383 2384 2385 2386 2387 2388 2389 2390 2391 2392 2393 2394 2395 2396 2397 2398 2399 2400 2401 2402 2403 2404 2405 2406 2407 2408 2409 2410 2411 2412 2413 2414 2415 2416 2417 2418 2419 2420 2421 2422 2423 2424 2425 2426 2427 2428 2429 2430 2431 2432 2433 2434 2435 2436 2437 2438 2439 2440 2441 2442 2443 2444 2445 2446 2447 2448 2449 2450 2451 2452 2453 2454 2455 2456 2457 2458 2459 2460 2461 2462 2463 2464 2465 2466 2467 2468 2469 2470 2471 2472 2473 2474 2475 2476 2477 2478 2479 2480 2481 2482 2483 2484 2485 2486 2487 2488 2489 2490 2491 2492 2493 2494 2495 2496 2497 2498 2499 2500 2501 2502 2503 2504 2505 2506 2507 2508 2509 2510 2511 2512 2513 2514 2515 2516 2517 2518 2519 2520 2521 2522 2523 2524 2525 2526 2527 2528 2529 2530 2531 2532 2533 2534 2535 2536 2537 2538 2539 2540 2541 2542 2543 2544 2545 2546 2547 2548 2549 2550 2551 2552 2553 2554 2555 2556 2557 2558 2559 2560 2561 2562 2563 2564 2565 2566 2567 2568 2569 2570 2571 2572 2573 2574 2575 2576 2577 2578 2579 2580 2581 2582 2583 2584 2585 2586 2587 2588 2589 2590 2591 2592 2593 2594 2595 2596 2597 2598 2599 2600 2601 2602 2603 2604 2605 2606 2607 2608 2609 2610 2611 2612 2613 2614 2615 2616 2617 2618 2619 2620 2621 2622 2623 2624 2625 2626 2627 2628 2629 2630 2631 2632 2633 2634 2635 2636 2637 2638 2639 2640 2641 2642 2643 2644 2645 2646 2647 2648 2649 2650 2651 2652 2653 2654 2655 2656 2657 2658 2659 2660 2661 2662 2663 2664 2665 2666 2667 2668 2669 2670 2671 2672 2673 2674 2675 2676 2677 2678 2679 2680 2681 2682 2683 2684 2685 2686 2687 2688 2689 2690 2691 2692 2693 2694 2695 2696 2697 2698 2699 2700 2701 2702 2703 2704 2705 2706 2707 2708 2709 2710 2711 2712 2713 2714 2715 2716 2717 2718 2719 2720 2721 2722 2723 2724 2725 2726 2727 2728 2729 2730 2731 2732 |
import math
import copy
from random import random
from beartype.typing import List, Union
from beartype import beartype
from tqdm.auto import tqdm
from functools import partial, wraps
from contextlib import contextmanager, nullcontext
from collections import namedtuple
from pathlib import Path
import torch
import torch.nn.functional as F
from torch.nn.parallel import DistributedDataParallel
from torch import nn, einsum
from torch.cuda.amp import autocast
from torch.special import expm1
import torchvision.transforms as T
import kornia.augmentation as K
from einops import rearrange, repeat, reduce, pack, unpack
from einops.layers.torch import Rearrange, Reduce
from imagen_pytorch.t5 import t5_encode_text, get_encoded_dim, DEFAULT_T5_NAME
from imagen_pytorch.imagen_video import Unet3D, resize_video_to, scale_video_time
# helper functions
def exists(val):
return val is not None
def identity(t, *args, **kwargs):
return t
def divisible_by(numer, denom):
return (numer % denom) == 0
def first(arr, d = None):
if len(arr) == 0:
return d
return arr[0]
def maybe(fn):
@wraps(fn)
def inner(x):
if not exists(x):
return x
return fn(x)
return inner
def once(fn):
called = False
@wraps(fn)
def inner(x):
nonlocal called
if called:
return
called = True
return fn(x)
return inner
print_once = once(print)
def default(val, d):
if exists(val):
return val
return d() if callable(d) else d
def cast_tuple(val, length = None):
if isinstance(val, list):
val = tuple(val)
output = val if isinstance(val, tuple) else ((val,) * default(length, 1))
if exists(length):
assert len(output) == length
return output
def compact(input_dict):
return {key: value for key, value in input_dict.items() if exists(value)}
def maybe_transform_dict_key(input_dict, key, fn):
if key not in input_dict:
return input_dict
copied_dict = input_dict.copy()
copied_dict[key] = fn(copied_dict[key])
return copied_dict
def cast_uint8_images_to_float(images):
if not images.dtype == torch.uint8:
return images
return images / 255
def module_device(module):
return next(module.parameters()).device
def zero_init_(m):
nn.init.zeros_(m.weight)
if exists(m.bias):
nn.init.zeros_(m.bias)
def eval_decorator(fn):
def inner(model, *args, **kwargs):
was_training = model.training
model.eval()
out = fn(model, *args, **kwargs)
model.train(was_training)
return out
return inner
def pad_tuple_to_length(t, length, fillvalue = None):
remain_length = length - len(t)
if remain_length <= 0:
return t
return (*t, *((fillvalue,) * remain_length))
# helper classes
class Identity(nn.Module):
def __init__(self, *args, **kwargs):
super().__init__()
def forward(self, x, *args, **kwargs):
return x
# tensor helpers
def log(t, eps: float = 1e-12):
return torch.log(t.clamp(min = eps))
def l2norm(t):
return F.normalize(t, dim = -1)
def right_pad_dims_to(x, t):
padding_dims = x.ndim - t.ndim
if padding_dims <= 0:
return t
return t.view(*t.shape, *((1,) * padding_dims))
def masked_mean(t, *, dim, mask = None):
if not exists(mask):
return t.mean(dim = dim)
denom = mask.sum(dim = dim, keepdim = True)
mask = rearrange(mask, 'b n -> b n 1')
masked_t = t.masked_fill(~mask, 0.)
return masked_t.sum(dim = dim) / denom.clamp(min = 1e-5)
def resize_image_to(
image,
target_image_size,
clamp_range = None,
mode = 'nearest'
):
orig_image_size = image.shape[-1]
if orig_image_size == target_image_size:
return image
out = F.interpolate(image, target_image_size, mode = mode)
if exists(clamp_range):
out = out.clamp(*clamp_range)
return out
def calc_all_frame_dims(
downsample_factors: List[int],
frames
):
if not exists(frames):
return (tuple(),) * len(downsample_factors)
all_frame_dims = []
for divisor in downsample_factors:
assert divisible_by(frames, divisor)
all_frame_dims.append((frames // divisor,))
return all_frame_dims
def safe_get_tuple_index(tup, index, default = None):
if len(tup) <= index:
return default
return tup[index]
# image normalization functions
# ddpms expect images to be in the range of -1 to 1
def normalize_neg_one_to_one(img):
return img * 2 - 1
def unnormalize_zero_to_one(normed_img):
return (normed_img + 1) * 0.5
# classifier free guidance functions
def prob_mask_like(shape, prob, device):
if prob == 1:
return torch.ones(shape, device = device, dtype = torch.bool)
elif prob == 0:
return torch.zeros(shape, device = device, dtype = torch.bool)
else:
return torch.zeros(shape, device = device).float().uniform_(0, 1) < prob
# gaussian diffusion with continuous time helper functions and classes
# large part of this was thanks to @crowsonkb at https://github.com/crowsonkb/v-diffusion-jax/blob/master/diffusion/utils.py
@torch.jit.script
def beta_linear_log_snr(t):
return -torch.log(expm1(1e-4 + 10 * (t ** 2)))
@torch.jit.script
def alpha_cosine_log_snr(t, s: float = 0.008):
return -log((torch.cos((t + s) / (1 + s) * math.pi * 0.5) ** -2) - 1, eps = 1e-5) # not sure if this accounts for beta being clipped to 0.999 in discrete version
def log_snr_to_alpha_sigma(log_snr):
return torch.sqrt(torch.sigmoid(log_snr)), torch.sqrt(torch.sigmoid(-log_snr))
class GaussianDiffusionContinuousTimes(nn.Module):
def __init__(self, *, noise_schedule, timesteps = 1000):
super().__init__()
if noise_schedule == "linear":
self.log_snr = beta_linear_log_snr
elif noise_schedule == "cosine":
self.log_snr = alpha_cosine_log_snr
else:
raise ValueError(f'invalid noise schedule {noise_schedule}')
self.num_timesteps = timesteps
def get_times(self, batch_size, noise_level, *, device):
return torch.full((batch_size,), noise_level, device = device, dtype = torch.float32)
def sample_random_times(self, batch_size, *, device):
return torch.zeros((batch_size,), device = device).float().uniform_(0, 1)
def get_condition(self, times):
return maybe(self.log_snr)(times)
def get_sampling_timesteps(self, batch, *, device):
times = torch.linspace(1., 0., self.num_timesteps + 1, device = device)
times = repeat(times, 't -> b t', b = batch)
times = torch.stack((times[:, :-1], times[:, 1:]), dim = 0)
times = times.unbind(dim = -1)
return times
def q_posterior(self, x_start, x_t, t, *, t_next = None):
t_next = default(t_next, lambda: (t - 1. / self.num_timesteps).clamp(min = 0.))
""" https://openreview.net/attachment?id=2LdBqxc1Yv&name=supplementary_material """
log_snr = self.log_snr(t)
log_snr_next = self.log_snr(t_next)
log_snr, log_snr_next = map(partial(right_pad_dims_to, x_t), (log_snr, log_snr_next))
alpha, sigma = log_snr_to_alpha_sigma(log_snr)
alpha_next, sigma_next = log_snr_to_alpha_sigma(log_snr_next)
# c - as defined near eq 33
c = -expm1(log_snr - log_snr_next)
posterior_mean = alpha_next * (x_t * (1 - c) / alpha + c * x_start)
# following (eq. 33)
posterior_variance = (sigma_next ** 2) * c
posterior_log_variance_clipped = log(posterior_variance, eps = 1e-20)
return posterior_mean, posterior_variance, posterior_log_variance_clipped
def q_sample(self, x_start, t, noise = None):
dtype = x_start.dtype
if isinstance(t, float):
batch = x_start.shape[0]
t = torch.full((batch,), t, device = x_start.device, dtype = dtype)
noise = default(noise, lambda: torch.randn_like(x_start))
log_snr = self.log_snr(t).type(dtype)
log_snr_padded_dim = right_pad_dims_to(x_start, log_snr)
alpha, sigma = log_snr_to_alpha_sigma(log_snr_padded_dim)
return alpha * x_start + sigma * noise, log_snr, alpha, sigma
def q_sample_from_to(self, x_from, from_t, to_t, noise = None):
shape, device, dtype = x_from.shape, x_from.device, x_from.dtype
batch = shape[0]
if isinstance(from_t, float):
from_t = torch.full((batch,), from_t, device = device, dtype = dtype)
if isinstance(to_t, float):
to_t = torch.full((batch,), to_t, device = device, dtype = dtype)
noise = default(noise, lambda: torch.randn_like(x_from))
log_snr = self.log_snr(from_t)
log_snr_padded_dim = right_pad_dims_to(x_from, log_snr)
alpha, sigma = log_snr_to_alpha_sigma(log_snr_padded_dim)
log_snr_to = self.log_snr(to_t)
log_snr_padded_dim_to = right_pad_dims_to(x_from, log_snr_to)
alpha_to, sigma_to = log_snr_to_alpha_sigma(log_snr_padded_dim_to)
return x_from * (alpha_to / alpha) + noise * (sigma_to * alpha - sigma * alpha_to) / alpha
def predict_start_from_v(self, x_t, t, v):
log_snr = self.log_snr(t)
log_snr = right_pad_dims_to(x_t, log_snr)
alpha, sigma = log_snr_to_alpha_sigma(log_snr)
return alpha * x_t - sigma * v
def predict_start_from_noise(self, x_t, t, noise):
log_snr = self.log_snr(t)
log_snr = right_pad_dims_to(x_t, log_snr)
alpha, sigma = log_snr_to_alpha_sigma(log_snr)
return (x_t - sigma * noise) / alpha.clamp(min = 1e-8)
# norms and residuals
class LayerNorm(nn.Module):
def __init__(self, feats, stable = False, dim = -1):
super().__init__()
self.stable = stable
self.dim = dim
self.g = nn.Parameter(torch.ones(feats, *((1,) * (-dim - 1))))
def forward(self, x):
dtype, dim = x.dtype, self.dim
if self.stable:
x = x / x.amax(dim = dim, keepdim = True).detach()
eps = 1e-5 if x.dtype == torch.float32 else 1e-3
var = torch.var(x, dim = dim, unbiased = False, keepdim = True)
mean = torch.mean(x, dim = dim, keepdim = True)
return (x - mean) * (var + eps).rsqrt().type(dtype) * self.g.type(dtype)
ChanLayerNorm = partial(LayerNorm, dim = -3)
class Always():
def __init__(self, val):
self.val = val
def __call__(self, *args, **kwargs):
return self.val
class Residual(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(x, **kwargs) + x
class Parallel(nn.Module):
def __init__(self, *fns):
super().__init__()
self.fns = nn.ModuleList(fns)
def forward(self, x):
outputs = [fn(x) for fn in self.fns]
return sum(outputs)
# attention pooling
class PerceiverAttention(nn.Module):
def __init__(
self,
*,
dim,
dim_head = 64,
heads = 8,
scale = 8
):
super().__init__()
self.scale = scale
self.heads = heads
inner_dim = dim_head * heads
self.norm = nn.LayerNorm(dim)
self.norm_latents = nn.LayerNorm(dim)
self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
self.q_scale = nn.Parameter(torch.ones(dim_head))
self.k_scale = nn.Parameter(torch.ones(dim_head))
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim, bias = False),
nn.LayerNorm(dim)
)
def forward(self, x, latents, mask = None):
x = self.norm(x)
latents = self.norm_latents(latents)
b, h = x.shape[0], self.heads
q = self.to_q(latents)
# the paper differs from Perceiver in which they also concat the key / values derived from the latents to be attended to
kv_input = torch.cat((x, latents), dim = -2)
k, v = self.to_kv(kv_input).chunk(2, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
# qk rmsnorm
q, k = map(l2norm, (q, k))
q = q * self.q_scale
k = k * self.k_scale
# similarities and masking
sim = einsum('... i d, ... j d -> ... i j', q, k) * self.scale
if exists(mask):
max_neg_value = -torch.finfo(sim.dtype).max
mask = F.pad(mask, (0, latents.shape[-2]), value = True)
mask = rearrange(mask, 'b j -> b 1 1 j')
sim = sim.masked_fill(~mask, max_neg_value)
# attention
attn = sim.softmax(dim = -1, dtype = torch.float32)
attn = attn.to(sim.dtype)
out = einsum('... i j, ... j d -> ... i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)', h = h)
return self.to_out(out)
class PerceiverResampler(nn.Module):
def __init__(
self,
*,
dim,
depth,
dim_head = 64,
heads = 8,
num_latents = 64,
num_latents_mean_pooled = 4, # number of latents derived from mean pooled representation of the sequence
max_seq_len = 512,
ff_mult = 4
):
super().__init__()
self.pos_emb = nn.Embedding(max_seq_len, dim)
self.latents = nn.Parameter(torch.randn(num_latents, dim))
self.to_latents_from_mean_pooled_seq = None
if num_latents_mean_pooled > 0:
self.to_latents_from_mean_pooled_seq = nn.Sequential(
LayerNorm(dim),
nn.Linear(dim, dim * num_latents_mean_pooled),
Rearrange('b (n d) -> b n d', n = num_latents_mean_pooled)
)
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
PerceiverAttention(dim = dim, dim_head = dim_head, heads = heads),
FeedForward(dim = dim, mult = ff_mult)
]))
def forward(self, x, mask = None):
n, device = x.shape[1], x.device
pos_emb = self.pos_emb(torch.arange(n, device = device))
x_with_pos = x + pos_emb
latents = repeat(self.latents, 'n d -> b n d', b = x.shape[0])
if exists(self.to_latents_from_mean_pooled_seq):
meanpooled_seq = masked_mean(x, dim = 1, mask = torch.ones(x.shape[:2], device = x.device, dtype = torch.bool))
meanpooled_latents = self.to_latents_from_mean_pooled_seq(meanpooled_seq)
latents = torch.cat((meanpooled_latents, latents), dim = -2)
for attn, ff in self.layers:
latents = attn(x_with_pos, latents, mask = mask) + latents
latents = ff(latents) + latents
return latents
# attention
class Attention(nn.Module):
def __init__(
self,
dim,
*,
dim_head = 64,
heads = 8,
context_dim = None,
scale = 8
):
super().__init__()
self.scale = scale
self.heads = heads
inner_dim = dim_head * heads
self.norm = LayerNorm(dim)
self.null_kv = nn.Parameter(torch.randn(2, dim_head))
self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(dim, dim_head * 2, bias = False)
self.q_scale = nn.Parameter(torch.ones(dim_head))
self.k_scale = nn.Parameter(torch.ones(dim_head))
self.to_context = nn.Sequential(nn.LayerNorm(context_dim), nn.Linear(context_dim, dim_head * 2)) if exists(context_dim) else None
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim, bias = False),
LayerNorm(dim)
)
def forward(self, x, context = None, mask = None, attn_bias = None):
b, n, device = *x.shape[:2], x.device
x = self.norm(x)
q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim = -1))
q = rearrange(q, 'b n (h d) -> b h n d', h = self.heads)
# add null key / value for classifier free guidance in prior net
nk, nv = map(lambda t: repeat(t, 'd -> b 1 d', b = b), self.null_kv.unbind(dim = -2))
k = torch.cat((nk, k), dim = -2)
v = torch.cat((nv, v), dim = -2)
# add text conditioning, if present
if exists(context):
assert exists(self.to_context)
ck, cv = self.to_context(context).chunk(2, dim = -1)
k = torch.cat((ck, k), dim = -2)
v = torch.cat((cv, v), dim = -2)
# qk rmsnorm
q, k = map(l2norm, (q, k))
q = q * self.q_scale
k = k * self.k_scale
# calculate query / key similarities
sim = einsum('b h i d, b j d -> b h i j', q, k) * self.scale
# relative positional encoding (T5 style)
if exists(attn_bias):
sim = sim + attn_bias
# masking
max_neg_value = -torch.finfo(sim.dtype).max
if exists(mask):
mask = F.pad(mask, (1, 0), value = True)
mask = rearrange(mask, 'b j -> b 1 1 j')
sim = sim.masked_fill(~mask, max_neg_value)
# attention
attn = sim.softmax(dim = -1, dtype = torch.float32)
attn = attn.to(sim.dtype)
# aggregate values
out = einsum('b h i j, b j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
# decoder
def Upsample(dim, dim_out = None):
dim_out = default(dim_out, dim)
return nn.Sequential(
nn.Upsample(scale_factor = 2, mode = 'nearest'),
nn.Conv2d(dim, dim_out, 3, padding = 1)
)
class PixelShuffleUpsample(nn.Module):
"""
code shared by @MalumaDev at DALLE2-pytorch for addressing checkboard artifacts
https://arxiv.org/ftp/arxiv/papers/1707/1707.02937.pdf
"""
def __init__(self, dim, dim_out = None):
super().__init__()
dim_out = default(dim_out, dim)
conv = nn.Conv2d(dim, dim_out * 4, 1)
self.net = nn.Sequential(
conv,
nn.SiLU(),
nn.PixelShuffle(2)
)
self.init_conv_(conv)
def init_conv_(self, conv):
o, i, h, w = conv.weight.shape
conv_weight = torch.empty(o // 4, i, h, w)
nn.init.kaiming_uniform_(conv_weight)
conv_weight = repeat(conv_weight, 'o ... -> (o 4) ...')
conv.weight.data.copy_(conv_weight)
nn.init.zeros_(conv.bias.data)
def forward(self, x):
return self.net(x)
def Downsample(dim, dim_out = None):
# https://arxiv.org/abs/2208.03641 shows this is the most optimal way to downsample
# named SP-conv in the paper, but basically a pixel unshuffle
dim_out = default(dim_out, dim)
return nn.Sequential(
Rearrange('b c (h s1) (w s2) -> b (c s1 s2) h w', s1 = 2, s2 = 2),
nn.Conv2d(dim * 4, dim_out, 1)
)
class SinusoidalPosEmb(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, x):
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device = x.device) * -emb)
emb = rearrange(x, 'i -> i 1') * rearrange(emb, 'j -> 1 j')
return torch.cat((emb.sin(), emb.cos()), dim = -1)
class LearnedSinusoidalPosEmb(nn.Module):
""" following @crowsonkb 's lead with learned sinusoidal pos emb """
""" https://github.com/crowsonkb/v-diffusion-jax/blob/master/diffusion/models/danbooru_128.py#L8 """
def __init__(self, dim):
super().__init__()
assert (dim % 2) == 0
half_dim = dim // 2
self.weights = nn.Parameter(torch.randn(half_dim))
def forward(self, x):
x = rearrange(x, 'b -> b 1')
freqs = x * rearrange(self.weights, 'd -> 1 d') * 2 * math.pi
fouriered = torch.cat((freqs.sin(), freqs.cos()), dim = -1)
fouriered = torch.cat((x, fouriered), dim = -1)
return fouriered
class Block(nn.Module):
def __init__(
self,
dim,
dim_out,
groups = 8,
norm = True
):
super().__init__()
self.groupnorm = nn.GroupNorm(groups, dim) if norm else Identity()
self.activation = nn.SiLU()
self.project = nn.Conv2d(dim, dim_out, 3, padding = 1)
def forward(self, x, scale_shift = None):
x = self.groupnorm(x)
if exists(scale_shift):
scale, shift = scale_shift
x = x * (scale + 1) + shift
x = self.activation(x)
return self.project(x)
class ResnetBlock(nn.Module):
def __init__(
self,
dim,
dim_out,
*,
cond_dim = None,
time_cond_dim = None,
groups = 8,
linear_attn = False,
use_gca = False,
squeeze_excite = False,
**attn_kwargs
):
super().__init__()
self.time_mlp = None
if exists(time_cond_dim):
self.time_mlp = nn.Sequential(
nn.SiLU(),
nn.Linear(time_cond_dim, dim_out * 2)
)
self.cross_attn = None
if exists(cond_dim):
attn_klass = CrossAttention if not linear_attn else LinearCrossAttention
self.cross_attn = attn_klass(
dim = dim_out,
context_dim = cond_dim,
**attn_kwargs
)
self.block1 = Block(dim, dim_out, groups = groups)
self.block2 = Block(dim_out, dim_out, groups = groups)
self.gca = GlobalContext(dim_in = dim_out, dim_out = dim_out) if use_gca else Always(1)
self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else Identity()
def forward(self, x, time_emb = None, cond = None):
scale_shift = None
if exists(self.time_mlp) and exists(time_emb):
time_emb = self.time_mlp(time_emb)
time_emb = rearrange(time_emb, 'b c -> b c 1 1')
scale_shift = time_emb.chunk(2, dim = 1)
h = self.block1(x)
if exists(self.cross_attn):
assert exists(cond)
h = rearrange(h, 'b c h w -> b h w c')
h, ps = pack([h], 'b * c')
h = self.cross_attn(h, context = cond) + h
h, = unpack(h, ps, 'b * c')
h = rearrange(h, 'b h w c -> b c h w')
h = self.block2(h, scale_shift = scale_shift)
h = h * self.gca(h)
return h + self.res_conv(x)
class CrossAttention(nn.Module):
def __init__(
self,
dim,
*,
context_dim = None,
dim_head = 64,
heads = 8,
norm_context = False,
scale = 8
):
super().__init__()
self.scale = scale
self.heads = heads
inner_dim = dim_head * heads
context_dim = default(context_dim, dim)
self.norm = LayerNorm(dim)
self.norm_context = LayerNorm(context_dim) if norm_context else Identity()
self.null_kv = nn.Parameter(torch.randn(2, dim_head))
self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias = False)
self.q_scale = nn.Parameter(torch.ones(dim_head))
self.k_scale = nn.Parameter(torch.ones(dim_head))
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim, bias = False),
LayerNorm(dim)
)
def forward(self, x, context, mask = None):
b, n, device = *x.shape[:2], x.device
x = self.norm(x)
context = self.norm_context(context)
q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1))
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), (q, k, v))
# add null key / value for classifier free guidance in prior net
nk, nv = map(lambda t: repeat(t, 'd -> b h 1 d', h = self.heads, b = b), self.null_kv.unbind(dim = -2))
k = torch.cat((nk, k), dim = -2)
v = torch.cat((nv, v), dim = -2)
# cosine sim attention
q, k = map(l2norm, (q, k))
q = q * self.q_scale
k = k * self.k_scale
# similarities
sim = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
# masking
max_neg_value = -torch.finfo(sim.dtype).max
if exists(mask):
mask = F.pad(mask, (1, 0), value = True)
mask = rearrange(mask, 'b j -> b 1 1 j')
sim = sim.masked_fill(~mask, max_neg_value)
attn = sim.softmax(dim = -1, dtype = torch.float32)
attn = attn.to(sim.dtype)
out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
class LinearCrossAttention(CrossAttention):
def forward(self, x, context, mask = None):
b, n, device = *x.shape[:2], x.device
x = self.norm(x)
context = self.norm_context(context)
q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1))
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = self.heads), (q, k, v))
# add null key / value for classifier free guidance in prior net
nk, nv = map(lambda t: repeat(t, 'd -> (b h) 1 d', h = self.heads, b = b), self.null_kv.unbind(dim = -2))
k = torch.cat((nk, k), dim = -2)
v = torch.cat((nv, v), dim = -2)
# masking
max_neg_value = -torch.finfo(x.dtype).max
if exists(mask):
mask = F.pad(mask, (1, 0), value = True)
mask = rearrange(mask, 'b n -> b n 1')
k = k.masked_fill(~mask, max_neg_value)
v = v.masked_fill(~mask, 0.)
# linear attention
q = q.softmax(dim = -1)
k = k.softmax(dim = -2)
q = q * self.scale
context = einsum('b n d, b n e -> b d e', k, v)
out = einsum('b n d, b d e -> b n e', q, context)
out = rearrange(out, '(b h) n d -> b n (h d)', h = self.heads)
return self.to_out(out)
class LinearAttention(nn.Module):
def __init__(
self,
dim,
dim_head = 32,
heads = 8,
dropout = 0.05,
context_dim = None,
**kwargs
):
super().__init__()
self.scale = dim_head ** -0.5
self.heads = heads
inner_dim = dim_head * heads
self.norm = ChanLayerNorm(dim)
self.nonlin = nn.SiLU()
self.to_q = nn.Sequential(
nn.Dropout(dropout),
nn.Conv2d(dim, inner_dim, 1, bias = False),
nn.Conv2d(inner_dim, inner_dim, 3, bias = False, padding = 1, groups = inner_dim)
)
self.to_k = nn.Sequential(
nn.Dropout(dropout),
nn.Conv2d(dim, inner_dim, 1, bias = False),
nn.Conv2d(inner_dim, inner_dim, 3, bias = False, padding = 1, groups = inner_dim)
)
self.to_v = nn.Sequential(
nn.Dropout(dropout),
nn.Conv2d(dim, inner_dim, 1, bias = False),
nn.Conv2d(inner_dim, inner_dim, 3, bias = False, padding = 1, groups = inner_dim)
)
self.to_context = nn.Sequential(nn.LayerNorm(context_dim), nn.Linear(context_dim, inner_dim * 2, bias = False)) if exists(context_dim) else None
self.to_out = nn.Sequential(
nn.Conv2d(inner_dim, dim, 1, bias = False),
ChanLayerNorm(dim)
)
def forward(self, fmap, context = None):
h, x, y = self.heads, *fmap.shape[-2:]
fmap = self.norm(fmap)
q, k, v = map(lambda fn: fn(fmap), (self.to_q, self.to_k, self.to_v))
q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> (b h) (x y) c', h = h), (q, k, v))
if exists(context):
assert exists(self.to_context)
ck, cv = self.to_context(context).chunk(2, dim = -1)
ck, cv = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), (ck, cv))
k = torch.cat((k, ck), dim = -2)
v = torch.cat((v, cv), dim = -2)
q = q.softmax(dim = -1)
k = k.softmax(dim = -2)
q = q * self.scale
context = einsum('b n d, b n e -> b d e', k, v)
out = einsum('b n d, b d e -> b n e', q, context)
out = rearrange(out, '(b h) (x y) d -> b (h d) x y', h = h, x = x, y = y)
out = self.nonlin(out)
return self.to_out(out)
class GlobalContext(nn.Module):
""" basically a superior form of squeeze-excitation that is attention-esque """
def __init__(
self,
*,
dim_in,
dim_out
):
super().__init__()
self.to_k = nn.Conv2d(dim_in, 1, 1)
hidden_dim = max(3, dim_out // 2)
self.net = nn.Sequential(
nn.Conv2d(dim_in, hidden_dim, 1),
nn.SiLU(),
nn.Conv2d(hidden_dim, dim_out, 1),
nn.Sigmoid()
)
def forward(self, x):
context = self.to_k(x)
x, context = map(lambda t: rearrange(t, 'b n ... -> b n (...)'), (x, context))
out = einsum('b i n, b c n -> b c i', context.softmax(dim = -1), x)
out = rearrange(out, '... -> ... 1')
return self.net(out)
def FeedForward(dim, mult = 2):
hidden_dim = int(dim * mult)
return nn.Sequential(
LayerNorm(dim),
nn.Linear(dim, hidden_dim, bias = False),
nn.GELU(),
LayerNorm(hidden_dim),
nn.Linear(hidden_dim, dim, bias = False)
)
def ChanFeedForward(dim, mult = 2): # in paper, it seems for self attention layers they did feedforwards with twice channel width
hidden_dim = int(dim * mult)
return nn.Sequential(
ChanLayerNorm(dim),
nn.Conv2d(dim, hidden_dim, 1, bias = False),
nn.GELU(),
ChanLayerNorm(hidden_dim),
nn.Conv2d(hidden_dim, dim, 1, bias = False)
)
class TransformerBlock(nn.Module):
def __init__(
self,
dim,
*,
depth = 1,
heads = 8,
dim_head = 32,
ff_mult = 2,
context_dim = None
):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
Attention(dim = dim, heads = heads, dim_head = dim_head, context_dim = context_dim),
FeedForward(dim = dim, mult = ff_mult)
]))
def forward(self, x, context = None):
x = rearrange(x, 'b c h w -> b h w c')
x, ps = pack([x], 'b * c')
for attn, ff in self.layers:
x = attn(x, context = context) + x
x = ff(x) + x
x, = unpack(x, ps, 'b * c')
x = rearrange(x, 'b h w c -> b c h w')
return x
class LinearAttentionTransformerBlock(nn.Module):
def __init__(
self,
dim,
*,
depth = 1,
heads = 8,
dim_head = 32,
ff_mult = 2,
context_dim = None,
**kwargs
):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
LinearAttention(dim = dim, heads = heads, dim_head = dim_head, context_dim = context_dim),
ChanFeedForward(dim = dim, mult = ff_mult)
]))
def forward(self, x, context = None):
for attn, ff in self.layers:
x = attn(x, context = context) + x
x = ff(x) + x
return x
class CrossEmbedLayer(nn.Module):
def __init__(
self,
dim_in,
kernel_sizes,
dim_out = None,
stride = 2
):
super().__init__()
assert all([*map(lambda t: (t % 2) == (stride % 2), kernel_sizes)])
dim_out = default(dim_out, dim_in)
kernel_sizes = sorted(kernel_sizes)
num_scales = len(kernel_sizes)
# calculate the dimension at each scale
dim_scales = [int(dim_out / (2 ** i)) for i in range(1, num_scales)]
dim_scales = [*dim_scales, dim_out - sum(dim_scales)]
self.convs = nn.ModuleList([])
for kernel, dim_scale in zip(kernel_sizes, dim_scales):
self.convs.append(nn.Conv2d(dim_in, dim_scale, kernel, stride = stride, padding = (kernel - stride) // 2))
def forward(self, x):
fmaps = tuple(map(lambda conv: conv(x), self.convs))
return torch.cat(fmaps, dim = 1)
class UpsampleCombiner(nn.Module):
def __init__(
self,
dim,
*,
enabled = False,
dim_ins = tuple(),
dim_outs = tuple()
):
super().__init__()
dim_outs = cast_tuple(dim_outs, len(dim_ins))
assert len(dim_ins) == len(dim_outs)
self.enabled = enabled
if not self.enabled:
self.dim_out = dim
return
self.fmap_convs = nn.ModuleList([Block(dim_in, dim_out) for dim_in, dim_out in zip(dim_ins, dim_outs)])
self.dim_out = dim + (sum(dim_outs) if len(dim_outs) > 0 else 0)
def forward(self, x, fmaps = None):
target_size = x.shape[-1]
fmaps = default(fmaps, tuple())
if not self.enabled or len(fmaps) == 0 or len(self.fmap_convs) == 0:
return x
fmaps = [resize_image_to(fmap, target_size) for fmap in fmaps]
outs = [conv(fmap) for fmap, conv in zip(fmaps, self.fmap_convs)]
return torch.cat((x, *outs), dim = 1)
class Unet(nn.Module):
def __init__(
self,
*,
dim,
text_embed_dim = get_encoded_dim(DEFAULT_T5_NAME),
num_resnet_blocks = 1,
cond_dim = None,
num_image_tokens = 4,
num_time_tokens = 2,
learned_sinu_pos_emb_dim = 16,
out_dim = None,
dim_mults=(1, 2, 4, 8),
cond_images_channels = 0,
channels = 3,
channels_out = None,
attn_dim_head = 64,
attn_heads = 8,
ff_mult = 2.,
lowres_cond = False, # for cascading diffusion - https://cascaded-diffusion.github.io/
layer_attns = True,
layer_attns_depth = 1,
layer_mid_attns_depth = 1,
layer_attns_add_text_cond = True, # whether to condition the self-attention blocks with the text embeddings, as described in Appendix D.3.1
attend_at_middle = True, # whether to have a layer of attention at the bottleneck (can turn off for higher resolution in cascading DDPM, before bringing in efficient attention)
layer_cross_attns = True,
use_linear_attn = False,
use_linear_cross_attn = False,
cond_on_text = True,
max_text_len = 256,
init_dim = None,
resnet_groups = 8,
init_conv_kernel_size = 7, # kernel size of initial conv, if not using cross embed
init_cross_embed = True,
init_cross_embed_kernel_sizes = (3, 7, 15),
cross_embed_downsample = False,
cross_embed_downsample_kernel_sizes = (2, 4),
attn_pool_text = True,
attn_pool_num_latents = 32,
dropout = 0.,
memory_efficient = False,
init_conv_to_final_conv_residual = False,
use_global_context_attn = True,
scale_skip_connection = True,
final_resnet_block = True,
final_conv_kernel_size = 3,
self_cond = False,
resize_mode = 'nearest',
combine_upsample_fmaps = False, # combine feature maps from all upsample blocks, used in unet squared successfully
pixel_shuffle_upsample = True, # may address checkboard artifacts
):
super().__init__()
# guide researchers
assert attn_heads > 1, 'you need to have more than 1 attention head, ideally at least 4 or 8'
if dim < 128:
print_once('The base dimension of your u-net should ideally be no smaller than 128, as recommended by a professional DDPM trainer https://nonint.com/2022/05/04/friends-dont-let-friends-train-small-diffusion-models/')
# save locals to take care of some hyperparameters for cascading DDPM
self._locals = locals()
self._locals.pop('self', None)
self._locals.pop('__class__', None)
# determine dimensions
self.channels = channels
self.channels_out = default(channels_out, channels)
# (1) in cascading diffusion, one concats the low resolution image, blurred, for conditioning the higher resolution synthesis
# (2) in self conditioning, one appends the predict x0 (x_start)
init_channels = channels * (1 + int(lowres_cond) + int(self_cond))
init_dim = default(init_dim, dim)
self.self_cond = self_cond
# optional image conditioning
self.has_cond_image = cond_images_channels > 0
self.cond_images_channels = cond_images_channels
init_channels += cond_images_channels
# initial convolution
self.init_conv = CrossEmbedLayer(init_channels, dim_out = init_dim, kernel_sizes = init_cross_embed_kernel_sizes, stride = 1) if init_cross_embed else nn.Conv2d(init_channels, init_dim, init_conv_kernel_size, padding = init_conv_kernel_size // 2)
dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
in_out = list(zip(dims[:-1], dims[1:]))
# time conditioning
cond_dim = default(cond_dim, dim)
time_cond_dim = dim * 4 * (2 if lowres_cond else 1)
# embedding time for log(snr) noise from continuous version
sinu_pos_emb = LearnedSinusoidalPosEmb(learned_sinu_pos_emb_dim)
sinu_pos_emb_input_dim = learned_sinu_pos_emb_dim + 1
self.to_time_hiddens = nn.Sequential(
sinu_pos_emb,
nn.Linear(sinu_pos_emb_input_dim, time_cond_dim),
nn.SiLU()
)
self.to_time_cond = nn.Sequential(
nn.Linear(time_cond_dim, time_cond_dim)
)
# project to time tokens as well as time hiddens
self.to_time_tokens = nn.Sequential(
nn.Linear(time_cond_dim, cond_dim * num_time_tokens),
Rearrange('b (r d) -> b r d', r = num_time_tokens)
)
# low res aug noise conditioning
self.lowres_cond = lowres_cond
if lowres_cond:
self.to_lowres_time_hiddens = nn.Sequential(
LearnedSinusoidalPosEmb(learned_sinu_pos_emb_dim),
nn.Linear(learned_sinu_pos_emb_dim + 1, time_cond_dim),
nn.SiLU()
)
self.to_lowres_time_cond = nn.Sequential(
nn.Linear(time_cond_dim, time_cond_dim)
)
self.to_lowres_time_tokens = nn.Sequential(
nn.Linear(time_cond_dim, cond_dim * num_time_tokens),
Rearrange('b (r d) -> b r d', r = num_time_tokens)
)
# normalizations
self.norm_cond = nn.LayerNorm(cond_dim)
# text encoding conditioning (optional)
self.text_to_cond = None
if cond_on_text:
assert exists(text_embed_dim), 'text_embed_dim must be given to the unet if cond_on_text is True'
self.text_to_cond = nn.Linear(text_embed_dim, cond_dim)
# finer control over whether to condition on text encodings
self.cond_on_text = cond_on_text
# attention pooling
self.attn_pool = PerceiverResampler(dim = cond_dim, depth = 2, dim_head = attn_dim_head, heads = attn_heads, num_latents = attn_pool_num_latents) if attn_pool_text else None
# for classifier free guidance
self.max_text_len = max_text_len
self.null_text_embed = nn.Parameter(torch.randn(1, max_text_len, cond_dim))
self.null_text_hidden = nn.Parameter(torch.randn(1, time_cond_dim))
# for non-attention based text conditioning at all points in the network where time is also conditioned
self.to_text_non_attn_cond = None
if cond_on_text:
self.to_text_non_attn_cond = nn.Sequential(
nn.LayerNorm(cond_dim),
nn.Linear(cond_dim, time_cond_dim),
nn.SiLU(),
nn.Linear(time_cond_dim, time_cond_dim)
)
# attention related params
attn_kwargs = dict(heads = attn_heads, dim_head = attn_dim_head)
num_layers = len(in_out)
# resnet block klass
num_resnet_blocks = cast_tuple(num_resnet_blocks, num_layers)
resnet_groups = cast_tuple(resnet_groups, num_layers)
resnet_klass = partial(ResnetBlock, **attn_kwargs)
layer_attns = cast_tuple(layer_attns, num_layers)
layer_attns_depth = cast_tuple(layer_attns_depth, num_layers)
layer_cross_attns = cast_tuple(layer_cross_attns, num_layers)
use_linear_attn = cast_tuple(use_linear_attn, num_layers)
use_linear_cross_attn = cast_tuple(use_linear_cross_attn, num_layers)
assert all([layers == num_layers for layers in list(map(len, (resnet_groups, layer_attns, layer_cross_attns)))])
# downsample klass
downsample_klass = Downsample
if cross_embed_downsample:
downsample_klass = partial(CrossEmbedLayer, kernel_sizes = cross_embed_downsample_kernel_sizes)
# initial resnet block (for memory efficient unet)
self.init_resnet_block = resnet_klass(init_dim, init_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[0], use_gca = use_global_context_attn) if memory_efficient else None
# scale for resnet skip connections
self.skip_connect_scale = 1. if not scale_skip_connection else (2 ** -0.5)
# layers
self.downs = nn.ModuleList([])
self.ups = nn.ModuleList([])
num_resolutions = len(in_out)
layer_params = [num_resnet_blocks, resnet_groups, layer_attns, layer_attns_depth, layer_cross_attns, use_linear_attn, use_linear_cross_attn]
reversed_layer_params = list(map(reversed, layer_params))
# downsampling layers
skip_connect_dims = [] # keep track of skip connection dimensions
for ind, ((dim_in, dim_out), layer_num_resnet_blocks, groups, layer_attn, layer_attn_depth, layer_cross_attn, layer_use_linear_attn, layer_use_linear_cross_attn) in enumerate(zip(in_out, *layer_params)):
is_last = ind >= (num_resolutions - 1)
layer_cond_dim = cond_dim if layer_cross_attn or layer_use_linear_cross_attn else None
if layer_attn:
transformer_block_klass = TransformerBlock
elif layer_use_linear_attn:
transformer_block_klass = LinearAttentionTransformerBlock
else:
transformer_block_klass = Identity
current_dim = dim_in
# whether to pre-downsample, from memory efficient unet
pre_downsample = None
if memory_efficient:
pre_downsample = downsample_klass(dim_in, dim_out)
current_dim = dim_out
skip_connect_dims.append(current_dim)
# whether to do post-downsample, for non-memory efficient unet
post_downsample = None
if not memory_efficient:
post_downsample = downsample_klass(current_dim, dim_out) if not is_last else Parallel(nn.Conv2d(dim_in, dim_out, 3, padding = 1), nn.Conv2d(dim_in, dim_out, 1))
self.downs.append(nn.ModuleList([
pre_downsample,
resnet_klass(current_dim, current_dim, cond_dim = layer_cond_dim, linear_attn = layer_use_linear_cross_attn, time_cond_dim = time_cond_dim, groups = groups),
nn.ModuleList([ResnetBlock(current_dim, current_dim, time_cond_dim = time_cond_dim, groups = groups, use_gca = use_global_context_attn) for _ in range(layer_num_resnet_blocks)]),
transformer_block_klass(dim = current_dim, depth = layer_attn_depth, ff_mult = ff_mult, context_dim = cond_dim, **attn_kwargs),
post_downsample
]))
# middle layers
mid_dim = dims[-1]
self.mid_block1 = ResnetBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[-1])
self.mid_attn = TransformerBlock(mid_dim, depth = layer_mid_attns_depth, **attn_kwargs) if attend_at_middle else None
self.mid_block2 = ResnetBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[-1])
# upsample klass
upsample_klass = Upsample if not pixel_shuffle_upsample else PixelShuffleUpsample
# upsampling layers
upsample_fmap_dims = []
for ind, ((dim_in, dim_out), layer_num_resnet_blocks, groups, layer_attn, layer_attn_depth, layer_cross_attn, layer_use_linear_attn, layer_use_linear_cross_attn) in enumerate(zip(reversed(in_out), *reversed_layer_params)):
is_last = ind == (len(in_out) - 1)
layer_cond_dim = cond_dim if layer_cross_attn or layer_use_linear_cross_attn else None
if layer_attn:
transformer_block_klass = TransformerBlock
elif layer_use_linear_attn:
transformer_block_klass = LinearAttentionTransformerBlock
else:
transformer_block_klass = Identity
skip_connect_dim = skip_connect_dims.pop()
upsample_fmap_dims.append(dim_out)
self.ups.append(nn.ModuleList([
resnet_klass(dim_out + skip_connect_dim, dim_out, cond_dim = layer_cond_dim, linear_attn = layer_use_linear_cross_attn, time_cond_dim = time_cond_dim, groups = groups),
nn.ModuleList([ResnetBlock(dim_out + skip_connect_dim, dim_out, time_cond_dim = time_cond_dim, groups = groups, use_gca = use_global_context_attn) for _ in range(layer_num_resnet_blocks)]),
transformer_block_klass(dim = dim_out, depth = layer_attn_depth, ff_mult = ff_mult, context_dim = cond_dim, **attn_kwargs),
upsample_klass(dim_out, dim_in) if not is_last or memory_efficient else Identity()
]))
# whether to combine feature maps from all upsample blocks before final resnet block out
self.upsample_combiner = UpsampleCombiner(
dim = dim,
enabled = combine_upsample_fmaps,
dim_ins = upsample_fmap_dims,
dim_outs = dim
)
# whether to do a final residual from initial conv to the final resnet block out
self.init_conv_to_final_conv_residual = init_conv_to_final_conv_residual
final_conv_dim = self.upsample_combiner.dim_out + (dim if init_conv_to_final_conv_residual else 0)
# final optional resnet block and convolution out
self.final_res_block = ResnetBlock(final_conv_dim, dim, time_cond_dim = time_cond_dim, groups = resnet_groups[0], use_gca = True) if final_resnet_block else None
final_conv_dim_in = dim if final_resnet_block else final_conv_dim
final_conv_dim_in += (channels if lowres_cond else 0)
self.final_conv = nn.Conv2d(final_conv_dim_in, self.channels_out, final_conv_kernel_size, padding = final_conv_kernel_size // 2)
zero_init_(self.final_conv)
# resize mode
self.resize_mode = resize_mode
# if the current settings for the unet are not correct
# for cascading DDPM, then reinit the unet with the right settings
def cast_model_parameters(
self,
*,
lowres_cond,
text_embed_dim,
channels,
channels_out,
cond_on_text
):
if lowres_cond == self.lowres_cond and \
channels == self.channels and \
cond_on_text == self.cond_on_text and \
text_embed_dim == self._locals['text_embed_dim'] and \
channels_out == self.channels_out:
return self
updated_kwargs = dict(
lowres_cond = lowres_cond,
text_embed_dim = text_embed_dim,
channels = channels,
channels_out = channels_out,
cond_on_text = cond_on_text
)
return self.__class__(**{**self._locals, **updated_kwargs})
# methods for returning the full unet config as well as its parameter state
def to_config_and_state_dict(self):
return self._locals, self.state_dict()
# class method for rehydrating the unet from its config and state dict
@classmethod
def from_config_and_state_dict(klass, config, state_dict):
unet = klass(**config)
unet.load_state_dict(state_dict)
return unet
# methods for persisting unet to disk
def persist_to_file(self, path):
path = Path(path)
path.parents[0].mkdir(exist_ok = True, parents = True)
config, state_dict = self.to_config_and_state_dict()
pkg = dict(config = config, state_dict = state_dict)
torch.save(pkg, str(path))
# class method for rehydrating the unet from file saved with `persist_to_file`
@classmethod
def hydrate_from_file(klass, path):
path = Path(path)
assert path.exists()
pkg = torch.load(str(path))
assert 'config' in pkg and 'state_dict' in pkg
config, state_dict = pkg['config'], pkg['state_dict']
return Unet.from_config_and_state_dict(config, state_dict)
# forward with classifier free guidance
def forward_with_cond_scale(
self,
*args,
cond_scale = 1.,
**kwargs
):
logits = self.forward(*args, **kwargs)
if cond_scale == 1:
return logits
null_logits = self.forward(*args, cond_drop_prob = 1., **kwargs)
return null_logits + (logits - null_logits) * cond_scale
def forward(
self,
x,
time,
*,
lowres_cond_img = None,
lowres_noise_times = None,
text_embeds = None,
text_mask = None,
cond_images = None,
self_cond = None,
cond_drop_prob = 0.
):
batch_size, device = x.shape[0], x.device
# condition on self
if self.self_cond:
self_cond = default(self_cond, lambda: torch.zeros_like(x))
x = torch.cat((x, self_cond), dim = 1)
# add low resolution conditioning, if present
assert not (self.lowres_cond and not exists(lowres_cond_img)), 'low resolution conditioning image must be present'
assert not (self.lowres_cond and not exists(lowres_noise_times)), 'low resolution conditioning noise time must be present'
if exists(lowres_cond_img):
x = torch.cat((x, lowres_cond_img), dim = 1)
# condition on input image
assert not (self.has_cond_image ^ exists(cond_images)), 'you either requested to condition on an image on the unet, but the conditioning image is not supplied, or vice versa'
if exists(cond_images):
assert cond_images.shape[1] == self.cond_images_channels, 'the number of channels on the conditioning image you are passing in does not match what you specified on initialiation of the unet'
cond_images = resize_image_to(cond_images, x.shape[-1], mode = self.resize_mode)
x = torch.cat((cond_images, x), dim = 1)
# initial convolution
x = self.init_conv(x)
# init conv residual
if self.init_conv_to_final_conv_residual:
init_conv_residual = x.clone()
# time conditioning
time_hiddens = self.to_time_hiddens(time)
# derive time tokens
time_tokens = self.to_time_tokens(time_hiddens)
t = self.to_time_cond(time_hiddens)
# add lowres time conditioning to time hiddens
# and add lowres time tokens along sequence dimension for attention
if self.lowres_cond:
lowres_time_hiddens = self.to_lowres_time_hiddens(lowres_noise_times)
lowres_time_tokens = self.to_lowres_time_tokens(lowres_time_hiddens)
lowres_t = self.to_lowres_time_cond(lowres_time_hiddens)
t = t + lowres_t
time_tokens = torch.cat((time_tokens, lowres_time_tokens), dim = -2)
# text conditioning
text_tokens = None
if exists(text_embeds) and self.cond_on_text:
# conditional dropout
text_keep_mask = prob_mask_like((batch_size,), 1 - cond_drop_prob, device = device)
text_keep_mask_embed = rearrange(text_keep_mask, 'b -> b 1 1')
text_keep_mask_hidden = rearrange(text_keep_mask, 'b -> b 1')
# calculate text embeds
text_tokens = self.text_to_cond(text_embeds)
text_tokens = text_tokens[:, :self.max_text_len]
if exists(text_mask):
text_mask = text_mask[:, :self.max_text_len]
text_tokens_len = text_tokens.shape[1]
remainder = self.max_text_len - text_tokens_len
if remainder > 0:
text_tokens = F.pad(text_tokens, (0, 0, 0, remainder))
if exists(text_mask):
if remainder > 0:
text_mask = F.pad(text_mask, (0, remainder), value = False)
text_mask = rearrange(text_mask, 'b n -> b n 1')
text_keep_mask_embed = text_mask & text_keep_mask_embed
null_text_embed = self.null_text_embed.to(text_tokens.dtype) # for some reason pytorch AMP not working
text_tokens = torch.where(
text_keep_mask_embed,
text_tokens,
null_text_embed
)
if exists(self.attn_pool):
text_tokens = self.attn_pool(text_tokens)
# extra non-attention conditioning by projecting and then summing text embeddings to time
# termed as text hiddens
mean_pooled_text_tokens = text_tokens.mean(dim = -2)
text_hiddens = self.to_text_non_attn_cond(mean_pooled_text_tokens)
null_text_hidden = self.null_text_hidden.to(t.dtype)
text_hiddens = torch.where(
text_keep_mask_hidden,
text_hiddens,
null_text_hidden
)
t = t + text_hiddens
# main conditioning tokens (c)
c = time_tokens if not exists(text_tokens) else torch.cat((time_tokens, text_tokens), dim = -2)
# normalize conditioning tokens
c = self.norm_cond(c)
# initial resnet block (for memory efficient unet)
if exists(self.init_resnet_block):
x = self.init_resnet_block(x, t)
# go through the layers of the unet, down and up
hiddens = []
for pre_downsample, init_block, resnet_blocks, attn_block, post_downsample in self.downs:
if exists(pre_downsample):
x = pre_downsample(x)
x = init_block(x, t, c)
for resnet_block in resnet_blocks:
x = resnet_block(x, t)
hiddens.append(x)
x = attn_block(x, c)
hiddens.append(x)
if exists(post_downsample):
x = post_downsample(x)
x = self.mid_block1(x, t, c)
if exists(self.mid_attn):
x = self.mid_attn(x)
x = self.mid_block2(x, t, c)
add_skip_connection = lambda x: torch.cat((x, hiddens.pop() * self.skip_connect_scale), dim = 1)
up_hiddens = []
for init_block, resnet_blocks, attn_block, upsample in self.ups:
x = add_skip_connection(x)
x = init_block(x, t, c)
for resnet_block in resnet_blocks:
x = add_skip_connection(x)
x = resnet_block(x, t)
x = attn_block(x, c)
up_hiddens.append(x.contiguous())
x = upsample(x)
# whether to combine all feature maps from upsample blocks
x = self.upsample_combiner(x, up_hiddens)
# final top-most residual if needed
if self.init_conv_to_final_conv_residual:
x = torch.cat((x, init_conv_residual), dim = 1)
if exists(self.final_res_block):
x = self.final_res_block(x, t)
if exists(lowres_cond_img):
x = torch.cat((x, lowres_cond_img), dim = 1)
return self.final_conv(x)
# null unet
class NullUnet(nn.Module):
def __init__(self, *args, **kwargs):
super().__init__()
self.lowres_cond = False
self.dummy_parameter = nn.Parameter(torch.tensor([0.]))
def cast_model_parameters(self, *args, **kwargs):
return self
def forward(self, x, *args, **kwargs):
return x
# predefined unets, with configs lining up with hyperparameters in appendix of paper
class BaseUnet64(Unet):
def __init__(self, *args, **kwargs):
default_kwargs = dict(
dim = 512,
dim_mults = (1, 2, 3, 4),
num_resnet_blocks = 3,
layer_attns = (False, True, True, True),
layer_cross_attns = (False, True, True, True),
attn_heads = 8,
ff_mult = 2.,
memory_efficient = False
)
super().__init__(*args, **{**default_kwargs, **kwargs})
class SRUnet256(Unet):
def __init__(self, *args, **kwargs):
default_kwargs = dict(
dim = 128,
dim_mults = (1, 2, 4, 8),
num_resnet_blocks = (2, 4, 8, 8),
layer_attns = (False, False, False, True),
layer_cross_attns = (False, False, False, True),
attn_heads = 8,
ff_mult = 2.,
memory_efficient = True
)
super().__init__(*args, **{**default_kwargs, **kwargs})
class SRUnet1024(Unet):
def __init__(self, *args, **kwargs):
default_kwargs = dict(
dim = 128,
dim_mults = (1, 2, 4, 8),
num_resnet_blocks = (2, 4, 8, 8),
layer_attns = False,
layer_cross_attns = (False, False, False, True),
attn_heads = 8,
ff_mult = 2.,
memory_efficient = True
)
super().__init__(*args, **{**default_kwargs, **kwargs})
# main imagen ddpm class, which is a cascading DDPM from Ho et al.
class Imagen(nn.Module):
def __init__(
self,
unets,
*,
image_sizes, # for cascading ddpm, image size at each stage
text_encoder_name = DEFAULT_T5_NAME,
text_embed_dim = None,
channels = 3,
timesteps = 1000,
cond_drop_prob = 0.1,
loss_type = 'l2',
noise_schedules = 'cosine',
pred_objectives = 'noise',
random_crop_sizes = None,
lowres_noise_schedule = 'linear',
lowres_sample_noise_level = 0.2, # in the paper, they present a new trick where they noise the lowres conditioning image, and at sample time, fix it to a certain level (0.1 or 0.3) - the unets are also made to be conditioned on this noise level
per_sample_random_aug_noise_level = False, # unclear when conditioning on augmentation noise level, whether each batch element receives a random aug noise value - turning off due to @marunine's find
condition_on_text = True,
auto_normalize_img = True, # whether to take care of normalizing the image from [0, 1] to [-1, 1] and back automatically - you can turn this off if you want to pass in the [-1, 1] ranged image yourself from the dataloader
dynamic_thresholding = True,
dynamic_thresholding_percentile = 0.95, # unsure what this was based on perusal of paper
only_train_unet_number = None,
temporal_downsample_factor = 1,
resize_cond_video_frames = True,
resize_mode = 'nearest',
min_snr_loss_weight = True, # https://arxiv.org/abs/2303.09556
min_snr_gamma = 5
):
super().__init__()
# loss
if loss_type == 'l1':
loss_fn = F.l1_loss
elif loss_type == 'l2':
loss_fn = F.mse_loss
elif loss_type == 'huber':
loss_fn = F.smooth_l1_loss
else:
raise NotImplementedError()
self.loss_type = loss_type
self.loss_fn = loss_fn
# conditioning hparams
self.condition_on_text = condition_on_text
self.unconditional = not condition_on_text
# channels
self.channels = channels
# automatically take care of ensuring that first unet is unconditional
# while the rest of the unets are conditioned on the low resolution image produced by previous unet
unets = cast_tuple(unets)
num_unets = len(unets)
# determine noise schedules per unet
timesteps = cast_tuple(timesteps, num_unets)
# make sure noise schedule defaults to 'cosine', 'cosine', and then 'linear' for rest of super-resoluting unets
noise_schedules = cast_tuple(noise_schedules)
noise_schedules = pad_tuple_to_length(noise_schedules, 2, 'cosine')
noise_schedules = pad_tuple_to_length(noise_schedules, num_unets, 'linear')
# construct noise schedulers
noise_scheduler_klass = GaussianDiffusionContinuousTimes
self.noise_schedulers = nn.ModuleList([])
for timestep, noise_schedule in zip(timesteps, noise_schedules):
noise_scheduler = noise_scheduler_klass(noise_schedule = noise_schedule, timesteps = timestep)
self.noise_schedulers.append(noise_scheduler)
# randomly cropping for upsampler training
self.random_crop_sizes = cast_tuple(random_crop_sizes, num_unets)
assert not exists(first(self.random_crop_sizes)), 'you should not need to randomly crop image during training for base unet, only for upsamplers - so pass in `random_crop_sizes = (None, 128, 256)` as example'
# lowres augmentation noise schedule
self.lowres_noise_schedule = GaussianDiffusionContinuousTimes(noise_schedule = lowres_noise_schedule)
# ddpm objectives - predicting noise by default
self.pred_objectives = cast_tuple(pred_objectives, num_unets)
# get text encoder
self.text_encoder_name = text_encoder_name
self.text_embed_dim = default(text_embed_dim, lambda: get_encoded_dim(text_encoder_name))
self.encode_text = partial(t5_encode_text, name = text_encoder_name)
# construct unets
self.unets = nn.ModuleList([])
self.unet_being_trained_index = -1 # keeps track of which unet is being trained at the moment
self.only_train_unet_number = only_train_unet_number
for ind, one_unet in enumerate(unets):
assert isinstance(one_unet, (Unet, Unet3D, NullUnet))
is_first = ind == 0
one_unet = one_unet.cast_model_parameters(
lowres_cond = not is_first,
cond_on_text = self.condition_on_text,
text_embed_dim = self.text_embed_dim if self.condition_on_text else None,
channels = self.channels,
channels_out = self.channels
)
self.unets.append(one_unet)
# unet image sizes
image_sizes = cast_tuple(image_sizes)
self.image_sizes = image_sizes
assert num_unets == len(image_sizes), f'you did not supply the correct number of u-nets ({len(unets)}) for resolutions {image_sizes}'
self.sample_channels = cast_tuple(self.channels, num_unets)
# determine whether we are training on images or video
is_video = any([isinstance(unet, Unet3D) for unet in self.unets])
self.is_video = is_video
self.right_pad_dims_to_datatype = partial(rearrange, pattern = ('b -> b 1 1 1' if not is_video else 'b -> b 1 1 1 1'))
self.resize_to = resize_video_to if is_video else resize_image_to
self.resize_to = partial(self.resize_to, mode = resize_mode)
# temporal interpolation
temporal_downsample_factor = cast_tuple(temporal_downsample_factor, num_unets)
self.temporal_downsample_factor = temporal_downsample_factor
self.resize_cond_video_frames = resize_cond_video_frames
self.temporal_downsample_divisor = temporal_downsample_factor[0]
assert temporal_downsample_factor[-1] == 1, 'downsample factor of last stage must be 1'
assert tuple(sorted(temporal_downsample_factor, reverse = True)) == temporal_downsample_factor, 'temporal downsample factor must be in order of descending'
# cascading ddpm related stuff
lowres_conditions = tuple(map(lambda t: t.lowres_cond, self.unets))
assert lowres_conditions == (False, *((True,) * (num_unets - 1))), 'the first unet must be unconditioned (by low resolution image), and the rest of the unets must have `lowres_cond` set to True'
self.lowres_sample_noise_level = lowres_sample_noise_level
self.per_sample_random_aug_noise_level = per_sample_random_aug_noise_level
# classifier free guidance
self.cond_drop_prob = cond_drop_prob
self.can_classifier_guidance = cond_drop_prob > 0.
# normalize and unnormalize image functions
self.normalize_img = normalize_neg_one_to_one if auto_normalize_img else identity
self.unnormalize_img = unnormalize_zero_to_one if auto_normalize_img else identity
self.input_image_range = (0. if auto_normalize_img else -1., 1.)
# dynamic thresholding
self.dynamic_thresholding = cast_tuple(dynamic_thresholding, num_unets)
self.dynamic_thresholding_percentile = dynamic_thresholding_percentile
# min snr loss weight
min_snr_loss_weight = cast_tuple(min_snr_loss_weight, num_unets)
min_snr_gamma = cast_tuple(min_snr_gamma, num_unets)
assert len(min_snr_loss_weight) == len(min_snr_gamma) == num_unets
self.min_snr_gamma = tuple((gamma if use_min_snr else None) for use_min_snr, gamma in zip(min_snr_loss_weight, min_snr_gamma))
# one temp parameter for keeping track of device
self.register_buffer('_temp', torch.tensor([0.]), persistent = False)
# default to device of unets passed in
self.to(next(self.unets.parameters()).device)
def force_unconditional_(self):
self.condition_on_text = False
self.unconditional = True
for unet in self.unets:
unet.cond_on_text = False
@property
def device(self):
return self._temp.device
def get_unet(self, unet_number):
assert 0 < unet_number <= len(self.unets)
index = unet_number - 1
if isinstance(self.unets, nn.ModuleList):
unets_list = [unet for unet in self.unets]
delattr(self, 'unets')
self.unets = unets_list
if index != self.unet_being_trained_index:
for unet_index, unet in enumerate(self.unets):
unet.to(self.device if unet_index == index else 'cpu')
self.unet_being_trained_index = index
return self.unets[index]
def reset_unets_all_one_device(self, device = None):
device = default(device, self.device)
self.unets = nn.ModuleList([*self.unets])
self.unets.to(device)
self.unet_being_trained_index = -1
@contextmanager
def one_unet_in_gpu(self, unet_number = None, unet = None):
assert exists(unet_number) ^ exists(unet)
if exists(unet_number):
unet = self.unets[unet_number - 1]
cpu = torch.device('cpu')
devices = [module_device(unet) for unet in self.unets]
self.unets.to(cpu)
unet.to(self.device)
yield
for unet, device in zip(self.unets, devices):
unet.to(device)
# overriding state dict functions
def state_dict(self, *args, **kwargs):
self.reset_unets_all_one_device()
return super().state_dict(*args, **kwargs)
def load_state_dict(self, *args, **kwargs):
self.reset_unets_all_one_device()
return super().load_state_dict(*args, **kwargs)
# gaussian diffusion methods
def p_mean_variance(
self,
unet,
x,
t,
*,
noise_scheduler,
text_embeds = None,
text_mask = None,
cond_images = None,
cond_video_frames = None,
post_cond_video_frames = None,
lowres_cond_img = None,
self_cond = None,
lowres_noise_times = None,
cond_scale = 1.,
model_output = None,
t_next = None,
pred_objective = 'noise',
dynamic_threshold = True
):
assert not (cond_scale != 1. and not self.can_classifier_guidance), 'imagen was not trained with conditional dropout, and thus one cannot use classifier free guidance (cond_scale anything other than 1)'
video_kwargs = dict()
if self.is_video:
video_kwargs = dict(
cond_video_frames = cond_video_frames,
post_cond_video_frames = post_cond_video_frames,
)
pred = default(model_output, lambda: unet.forward_with_cond_scale(
x,
noise_scheduler.get_condition(t),
text_embeds = text_embeds,
text_mask = text_mask,
cond_images = cond_images,
cond_scale = cond_scale,
lowres_cond_img = lowres_cond_img,
self_cond = self_cond,
lowres_noise_times = self.lowres_noise_schedule.get_condition(lowres_noise_times),
**video_kwargs
))
if pred_objective == 'noise':
x_start = noise_scheduler.predict_start_from_noise(x, t = t, noise = pred)
elif pred_objective == 'x_start':
x_start = pred
elif pred_objective == 'v':
x_start = noise_scheduler.predict_start_from_v(x, t = t, v = pred)
else:
raise ValueError(f'unknown objective {pred_objective}')
if dynamic_threshold:
# following pseudocode in appendix
# s is the dynamic threshold, determined by percentile of absolute values of reconstructed sample per batch element
s = torch.quantile(
rearrange(x_start, 'b ... -> b (...)').abs(),
self.dynamic_thresholding_percentile,
dim = -1
)
s.clamp_(min = 1.)
s = right_pad_dims_to(x_start, s)
x_start = x_start.clamp(-s, s) / s
else:
x_start.clamp_(-1., 1.)
mean_and_variance = noise_scheduler.q_posterior(x_start = x_start, x_t = x, t = t, t_next = t_next)
return mean_and_variance, x_start
@torch.no_grad()
def p_sample(
self,
unet,
x,
t,
*,
noise_scheduler,
t_next = None,
text_embeds = None,
text_mask = None,
cond_images = None,
cond_video_frames = None,
post_cond_video_frames = None,
cond_scale = 1.,
self_cond = None,
lowres_cond_img = None,
lowres_noise_times = None,
pred_objective = 'noise',
dynamic_threshold = True
):
b, *_, device = *x.shape, x.device
video_kwargs = dict()
if self.is_video:
video_kwargs = dict(
cond_video_frames = cond_video_frames,
post_cond_video_frames = post_cond_video_frames,
)
(model_mean, _, model_log_variance), x_start = self.p_mean_variance(
unet,
x = x,
t = t,
t_next = t_next,
noise_scheduler = noise_scheduler,
text_embeds = text_embeds,
text_mask = text_mask,
cond_images = cond_images,
cond_scale = cond_scale,
lowres_cond_img = lowres_cond_img,
self_cond = self_cond,
lowres_noise_times = lowres_noise_times,
pred_objective = pred_objective,
dynamic_threshold = dynamic_threshold,
**video_kwargs
)
noise = torch.randn_like(x)
# no noise when t == 0
is_last_sampling_timestep = (t_next == 0) if isinstance(noise_scheduler, GaussianDiffusionContinuousTimes) else (t == 0)
nonzero_mask = (1 - is_last_sampling_timestep.float()).reshape(b, *((1,) * (len(x.shape) - 1)))
pred = model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
return pred, x_start
@torch.no_grad()
def p_sample_loop(
self,
unet,
shape,
*,
noise_scheduler,
lowres_cond_img = None,
lowres_noise_times = None,
text_embeds = None,
text_mask = None,
cond_images = None,
cond_video_frames = None,
post_cond_video_frames = None,
inpaint_images = None,
inpaint_videos = None,
inpaint_masks = None,
inpaint_resample_times = 5,
init_images = None,
skip_steps = None,
cond_scale = 1,
pred_objective = 'noise',
dynamic_threshold = True,
use_tqdm = True
):
device = self.device
batch = shape[0]
img = torch.randn(shape, device = device)
# video
is_video = len(shape) == 5
frames = shape[-3] if is_video else None
resize_kwargs = dict(target_frames = frames) if exists(frames) else dict()
# for initialization with an image or video
if exists(init_images):
img += init_images
# keep track of x0, for self conditioning
x_start = None
# prepare inpainting
inpaint_images = default(inpaint_videos, inpaint_images)
has_inpainting = exists(inpaint_images) and exists(inpaint_masks)
resample_times = inpaint_resample_times if has_inpainting else 1
if has_inpainting:
inpaint_images = self.normalize_img(inpaint_images)
inpaint_images = self.resize_to(inpaint_images, shape[-1], **resize_kwargs)
inpaint_masks = self.resize_to(rearrange(inpaint_masks, 'b ... -> b 1 ...').float(), shape[-1], **resize_kwargs).bool()
# time
timesteps = noise_scheduler.get_sampling_timesteps(batch, device = device)
# whether to skip any steps
skip_steps = default(skip_steps, 0)
timesteps = timesteps[skip_steps:]
# video conditioning kwargs
video_kwargs = dict()
if self.is_video:
video_kwargs = dict(
cond_video_frames = cond_video_frames,
post_cond_video_frames = post_cond_video_frames,
)
for times, times_next in tqdm(timesteps, desc = 'sampling loop time step', total = len(timesteps), disable = not use_tqdm):
is_last_timestep = times_next == 0
for r in reversed(range(resample_times)):
is_last_resample_step = r == 0
if has_inpainting:
noised_inpaint_images, *_ = noise_scheduler.q_sample(inpaint_images, t = times)
img = img * ~inpaint_masks + noised_inpaint_images * inpaint_masks
self_cond = x_start if unet.self_cond else None
img, x_start = self.p_sample(
unet,
img,
times,
t_next = times_next,
text_embeds = text_embeds,
text_mask = text_mask,
cond_images = cond_images,
cond_scale = cond_scale,
self_cond = self_cond,
lowres_cond_img = lowres_cond_img,
lowres_noise_times = lowres_noise_times,
noise_scheduler = noise_scheduler,
pred_objective = pred_objective,
dynamic_threshold = dynamic_threshold,
**video_kwargs
)
if has_inpainting and not (is_last_resample_step or torch.all(is_last_timestep)):
renoised_img = noise_scheduler.q_sample_from_to(img, times_next, times)
img = torch.where(
self.right_pad_dims_to_datatype(is_last_timestep),
img,
renoised_img
)
img.clamp_(-1., 1.)
# final inpainting
if has_inpainting:
img = img * ~inpaint_masks + inpaint_images * inpaint_masks
unnormalize_img = self.unnormalize_img(img)
return unnormalize_img
@torch.no_grad()
@eval_decorator
@beartype
def sample(
self,
texts: List[str] = None,
text_masks = None,
text_embeds = None,
video_frames = None,
cond_images = None,
cond_video_frames = None,
post_cond_video_frames = None,
inpaint_videos = None,
inpaint_images = None,
inpaint_masks = None,
inpaint_resample_times = 5,
init_images = None,
skip_steps = None,
batch_size = 1,
cond_scale = 1.,
lowres_sample_noise_level = None,
start_at_unet_number = 1,
start_image_or_video = None,
stop_at_unet_number = None,
return_all_unet_outputs = False,
return_pil_images = False,
device = None,
use_tqdm = True,
use_one_unet_in_gpu = True
):
device = default(device, self.device)
self.reset_unets_all_one_device(device = device)
cond_images = maybe(cast_uint8_images_to_float)(cond_images)
if exists(texts) and not exists(text_embeds) and not self.unconditional:
assert all([*map(len, texts)]), 'text cannot be empty'
with autocast(enabled = False):
text_embeds, text_masks = self.encode_text(texts, return_attn_mask = True)
text_embeds, text_masks = map(lambda t: t.to(device), (text_embeds, text_masks))
if not self.unconditional:
assert exists(text_embeds), 'text must be passed in if the network was not trained without text `condition_on_text` must be set to `False` when training'
text_masks = default(text_masks, lambda: torch.any(text_embeds != 0., dim = -1))
batch_size = text_embeds.shape[0]
# inpainting
inpaint_images = default(inpaint_videos, inpaint_images)
if exists(inpaint_images):
if self.unconditional:
if batch_size == 1: # assume researcher wants to broadcast along inpainted images
batch_size = inpaint_images.shape[0]
assert inpaint_images.shape[0] == batch_size, 'number of inpainting images must be equal to the specified batch size on sample `sample(batch_size=<int>)``'
assert not (self.condition_on_text and inpaint_images.shape[0] != text_embeds.shape[0]), 'number of inpainting images must be equal to the number of text to be conditioned on'
assert not (self.condition_on_text and not exists(text_embeds)), 'text or text encodings must be passed into imagen if specified'
assert not (not self.condition_on_text and exists(text_embeds)), 'imagen specified not to be conditioned on text, yet it is presented'
assert not (exists(text_embeds) and text_embeds.shape[-1] != self.text_embed_dim), f'invalid text embedding dimension being passed in (should be {self.text_embed_dim})'
assert not (exists(inpaint_images) ^ exists(inpaint_masks)), 'inpaint images and masks must be both passed in to do inpainting'
outputs = []
is_cuda = next(self.parameters()).is_cuda
device = next(self.parameters()).device
lowres_sample_noise_level = default(lowres_sample_noise_level, self.lowres_sample_noise_level)
num_unets = len(self.unets)
# condition scaling
cond_scale = cast_tuple(cond_scale, num_unets)
# add frame dimension for video
if self.is_video and exists(inpaint_images):
video_frames = inpaint_images.shape[2]
if inpaint_masks.ndim == 3:
inpaint_masks = repeat(inpaint_masks, 'b h w -> b f h w', f = video_frames)
assert inpaint_masks.shape[1] == video_frames
assert not (self.is_video and not exists(video_frames)), 'video_frames must be passed in on sample time if training on video'
all_frame_dims = calc_all_frame_dims(self.temporal_downsample_factor, video_frames)
frames_to_resize_kwargs = lambda frames: dict(target_frames = frames) if exists(frames) else dict()
# for initial image and skipping steps
init_images = cast_tuple(init_images, num_unets)
init_images = [maybe(self.normalize_img)(init_image) for init_image in init_images]
skip_steps = cast_tuple(skip_steps, num_unets)
# handle starting at a unet greater than 1, for training only-upscaler training
if start_at_unet_number > 1:
assert start_at_unet_number <= num_unets, 'must start a unet that is less than the total number of unets'
assert not exists(stop_at_unet_number) or start_at_unet_number <= stop_at_unet_number
assert exists(start_image_or_video), 'starting image or video must be supplied if only doing upscaling'
prev_image_size = self.image_sizes[start_at_unet_number - 2]
prev_frame_size = all_frame_dims[start_at_unet_number - 2][0] if self.is_video else None
img = self.resize_to(start_image_or_video, prev_image_size, **frames_to_resize_kwargs(prev_frame_size))
# go through each unet in cascade
for unet_number, unet, channel, image_size, frame_dims, noise_scheduler, pred_objective, dynamic_threshold, unet_cond_scale, unet_init_images, unet_skip_steps in tqdm(zip(range(1, num_unets + 1), self.unets, self.sample_channels, self.image_sizes, all_frame_dims, self.noise_schedulers, self.pred_objectives, self.dynamic_thresholding, cond_scale, init_images, skip_steps), disable = not use_tqdm):
if unet_number < start_at_unet_number:
continue
assert not isinstance(unet, NullUnet), 'one cannot sample from null / placeholder unets'
context = self.one_unet_in_gpu(unet = unet) if is_cuda and use_one_unet_in_gpu else nullcontext()
with context:
# video kwargs
video_kwargs = dict()
if self.is_video:
video_kwargs = dict(
cond_video_frames = cond_video_frames,
post_cond_video_frames = post_cond_video_frames,
)
video_kwargs = compact(video_kwargs)
if self.is_video and self.resize_cond_video_frames:
downsample_scale = self.temporal_downsample_factor[unet_number - 1]
temporal_downsample_fn = partial(scale_video_time, downsample_scale = downsample_scale)
video_kwargs = maybe_transform_dict_key(video_kwargs, 'cond_video_frames', temporal_downsample_fn)
video_kwargs = maybe_transform_dict_key(video_kwargs, 'post_cond_video_frames', temporal_downsample_fn)
# low resolution conditioning
lowres_cond_img = lowres_noise_times = None
shape = (batch_size, channel, *frame_dims, image_size, image_size)
resize_kwargs = dict(target_frames = frame_dims[0]) if self.is_video else dict()
if unet.lowres_cond:
lowres_noise_times = self.lowres_noise_schedule.get_times(batch_size, lowres_sample_noise_level, device = device)
lowres_cond_img = self.resize_to(img, image_size, **resize_kwargs)
lowres_cond_img = self.normalize_img(lowres_cond_img)
lowres_cond_img, *_ = self.lowres_noise_schedule.q_sample(x_start = lowres_cond_img, t = lowres_noise_times, noise = torch.randn_like(lowres_cond_img))
# init images or video
if exists(unet_init_images):
unet_init_images = self.resize_to(unet_init_images, image_size, **resize_kwargs)
# shape of stage
shape = (batch_size, self.channels, *frame_dims, image_size, image_size)
img = self.p_sample_loop(
unet,
shape,
text_embeds = text_embeds,
text_mask = text_masks,
cond_images = cond_images,
inpaint_images = inpaint_images,
inpaint_masks = inpaint_masks,
inpaint_resample_times = inpaint_resample_times,
init_images = unet_init_images,
skip_steps = unet_skip_steps,
cond_scale = unet_cond_scale,
lowres_cond_img = lowres_cond_img,
lowres_noise_times = lowres_noise_times,
noise_scheduler = noise_scheduler,
pred_objective = pred_objective,
dynamic_threshold = dynamic_threshold,
use_tqdm = use_tqdm,
**video_kwargs
)
outputs.append(img)
if exists(stop_at_unet_number) and stop_at_unet_number == unet_number:
break
output_index = -1 if not return_all_unet_outputs else slice(None) # either return last unet output or all unet outputs
if not return_pil_images:
return outputs[output_index]
if not return_all_unet_outputs:
outputs = outputs[-1:]
assert not self.is_video, 'converting sampled video tensor to video file is not supported yet'
pil_images = list(map(lambda img: list(map(T.ToPILImage(), img.unbind(dim = 0))), outputs))
return pil_images[output_index] # now you have a bunch of pillow images you can just .save(/where/ever/you/want.png)
@beartype
def p_losses(
self,
unet: Union[Unet, Unet3D, NullUnet, DistributedDataParallel],
x_start,
times,
*,
noise_scheduler,
lowres_cond_img = None,
lowres_aug_times = None,
text_embeds = None,
text_mask = None,
cond_images = None,
noise = None,
times_next = None,
pred_objective = 'noise',
min_snr_gamma = None,
random_crop_size = None,
**kwargs
):
is_video = x_start.ndim == 5
noise = default(noise, lambda: torch.randn_like(x_start))
# normalize to [-1, 1]
x_start = self.normalize_img(x_start)
lowres_cond_img = maybe(self.normalize_img)(lowres_cond_img)
# random cropping during training
# for upsamplers
if exists(random_crop_size):
if is_video:
frames = x_start.shape[2]
x_start, lowres_cond_img, noise = map(lambda t: rearrange(t, 'b c f h w -> (b f) c h w'), (x_start, lowres_cond_img, noise))
aug = K.RandomCrop((random_crop_size, random_crop_size), p = 1.)
# make sure low res conditioner and image both get augmented the same way
# detailed https://kornia.readthedocs.io/en/latest/augmentation.module.html?highlight=randomcrop#kornia.augmentation.RandomCrop
x_start = aug(x_start)
lowres_cond_img = aug(lowres_cond_img, params = aug._params)
noise = aug(noise, params = aug._params)
if is_video:
x_start, lowres_cond_img, noise = map(lambda t: rearrange(t, '(b f) c h w -> b c f h w', f = frames), (x_start, lowres_cond_img, noise))
# get x_t
x_noisy, log_snr, alpha, sigma = noise_scheduler.q_sample(x_start = x_start, t = times, noise = noise)
# also noise the lowres conditioning image
# at sample time, they then fix the noise level of 0.1 - 0.3
lowres_cond_img_noisy = None
if exists(lowres_cond_img):
lowres_aug_times = default(lowres_aug_times, times)
lowres_cond_img_noisy, *_ = self.lowres_noise_schedule.q_sample(x_start = lowres_cond_img, t = lowres_aug_times, noise = torch.randn_like(lowres_cond_img))
# time condition
noise_cond = noise_scheduler.get_condition(times)
# unet kwargs
unet_kwargs = dict(
text_embeds = text_embeds,
text_mask = text_mask,
cond_images = cond_images,
lowres_noise_times = self.lowres_noise_schedule.get_condition(lowres_aug_times),
lowres_cond_img = lowres_cond_img_noisy,
cond_drop_prob = self.cond_drop_prob,
**kwargs
)
# self condition if needed
# Because 'unet' can be an instance of DistributedDataParallel coming from the
# ImagenTrainer.unet_being_trained when invoking ImagenTrainer.forward(), we need to
# access the member 'module' of the wrapped unet instance.
self_cond = unet.module.self_cond if isinstance(unet, DistributedDataParallel) else unet.self_cond
if self_cond and random() < 0.5:
with torch.no_grad():
pred = unet.forward(
x_noisy,
noise_cond,
**unet_kwargs
).detach()
x_start = noise_scheduler.predict_start_from_noise(x_noisy, t = times, noise = pred) if pred_objective == 'noise' else pred
unet_kwargs = {**unet_kwargs, 'self_cond': x_start}
# get prediction
pred = unet.forward(
x_noisy,
noise_cond,
**unet_kwargs
)
# prediction objective
if pred_objective == 'noise':
target = noise
elif pred_objective == 'x_start':
target = x_start
elif pred_objective == 'v':
# derivation detailed in Appendix D of Progressive Distillation paper
# https://arxiv.org/abs/2202.00512
# this makes distillation viable as well as solve an issue with color shifting in upresoluting unets, noted in imagen-video
target = alpha * noise - sigma * x_start
else:
raise ValueError(f'unknown objective {pred_objective}')
# losses
losses = self.loss_fn(pred, target, reduction = 'none')
losses = reduce(losses, 'b ... -> b', 'mean')
# min snr loss reweighting
snr = log_snr.exp()
maybe_clipped_snr = snr.clone()
if exists(min_snr_gamma):
maybe_clipped_snr.clamp_(max = min_snr_gamma)
if pred_objective == 'noise':
loss_weight = maybe_clipped_snr / snr
elif pred_objective == 'x_start':
loss_weight = maybe_clipped_snr
elif pred_objective == 'v':
loss_weight = maybe_clipped_snr / (snr + 1)
losses = losses * loss_weight
return losses.mean()
@beartype
def forward(
self,
images, # rename to images or video
unet: Union[Unet, Unet3D, NullUnet, DistributedDataParallel] = None,
texts: List[str] = None,
text_embeds = None,
text_masks = None,
unet_number = None,
cond_images = None,
**kwargs
):
if self.is_video and images.ndim == 4:
images = rearrange(images, 'b c h w -> b c 1 h w')
kwargs.update(ignore_time = True)
assert images.shape[-1] == images.shape[-2], f'the images you pass in must be a square, but received dimensions of {images.shape[2]}, {images.shape[-1]}'
assert not (len(self.unets) > 1 and not exists(unet_number)), f'you must specify which unet you want trained, from a range of 1 to {len(self.unets)}, if you are training cascading DDPM (multiple unets)'
unet_number = default(unet_number, 1)
assert not exists(self.only_train_unet_number) or self.only_train_unet_number == unet_number, 'you can only train on unet #{self.only_train_unet_number}'
images = cast_uint8_images_to_float(images)
cond_images = maybe(cast_uint8_images_to_float)(cond_images)
assert images.dtype == torch.float or images.dtype == torch.half, f'images tensor needs to be floats but {images.dtype} dtype found instead'
unet_index = unet_number - 1
unet = default(unet, lambda: self.get_unet(unet_number))
assert not isinstance(unet, NullUnet), 'null unet cannot and should not be trained'
noise_scheduler = self.noise_schedulers[unet_index]
min_snr_gamma = self.min_snr_gamma[unet_index]
pred_objective = self.pred_objectives[unet_index]
target_image_size = self.image_sizes[unet_index]
random_crop_size = self.random_crop_sizes[unet_index]
prev_image_size = self.image_sizes[unet_index - 1] if unet_index > 0 else None
b, c, *_, h, w, device, is_video = *images.shape, images.device, images.ndim == 5
assert images.shape[1] == self.channels
assert h >= target_image_size and w >= target_image_size
frames = images.shape[2] if is_video else None
all_frame_dims = tuple(safe_get_tuple_index(el, 0) for el in calc_all_frame_dims(self.temporal_downsample_factor, frames))
ignore_time = kwargs.get('ignore_time', False)
target_frame_size = all_frame_dims[unet_index] if is_video and not ignore_time else None
prev_frame_size = all_frame_dims[unet_index - 1] if is_video and not ignore_time and unet_index > 0 else None
frames_to_resize_kwargs = lambda frames: dict(target_frames = frames) if exists(frames) else dict()
times = noise_scheduler.sample_random_times(b, device = device)
if exists(texts) and not exists(text_embeds) and not self.unconditional:
assert all([*map(len, texts)]), 'text cannot be empty'
assert len(texts) == len(images), 'number of text captions does not match up with the number of images given'
with autocast(enabled = False):
text_embeds, text_masks = self.encode_text(texts, return_attn_mask = True)
text_embeds, text_masks = map(lambda t: t.to(images.device), (text_embeds, text_masks))
if not self.unconditional:
text_masks = default(text_masks, lambda: torch.any(text_embeds != 0., dim = -1))
assert not (self.condition_on_text and not exists(text_embeds)), 'text or text encodings must be passed into decoder if specified'
assert not (not self.condition_on_text and exists(text_embeds)), 'decoder specified not to be conditioned on text, yet it is presented'
assert not (exists(text_embeds) and text_embeds.shape[-1] != self.text_embed_dim), f'invalid text embedding dimension being passed in (should be {self.text_embed_dim})'
# handle video frame conditioning
if self.is_video and self.resize_cond_video_frames:
downsample_scale = self.temporal_downsample_factor[unet_index]
temporal_downsample_fn = partial(scale_video_time, downsample_scale = downsample_scale)
kwargs = maybe_transform_dict_key(kwargs, 'cond_video_frames', temporal_downsample_fn)
kwargs = maybe_transform_dict_key(kwargs, 'post_cond_video_frames', temporal_downsample_fn)
# handle low resolution conditioning
lowres_cond_img = lowres_aug_times = None
if exists(prev_image_size):
lowres_cond_img = self.resize_to(images, prev_image_size, **frames_to_resize_kwargs(prev_frame_size), clamp_range = self.input_image_range)
lowres_cond_img = self.resize_to(lowres_cond_img, target_image_size, **frames_to_resize_kwargs(target_frame_size), clamp_range = self.input_image_range)
if self.per_sample_random_aug_noise_level:
lowres_aug_times = self.lowres_noise_schedule.sample_random_times(b, device = device)
else:
lowres_aug_time = self.lowres_noise_schedule.sample_random_times(1, device = device)
lowres_aug_times = repeat(lowres_aug_time, '1 -> b', b = b)
images = self.resize_to(images, target_image_size, **frames_to_resize_kwargs(target_frame_size))
return self.p_losses(unet, images, times, text_embeds = text_embeds, text_mask = text_masks, cond_images = cond_images, noise_scheduler = noise_scheduler, lowres_cond_img = lowres_cond_img, lowres_aug_times = lowres_aug_times, pred_objective = pred_objective, min_snr_gamma = min_snr_gamma, random_crop_size = random_crop_size, **kwargs)
|