In [None]:
pip install coremltools==8.0b1 torch==2.3.0 torchvision torchaudio scikit-learn==1.1.2 

In [38]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
import coremltools as ct
import coremltools.optimize as cto
from PIL import Image
import numpy as np
import requests
import os


class BasicBlock(nn.Module):
 expansion = 1

 def __init__(self, in_planes, planes, stride=1):
 super(BasicBlock, self).__init__()
 self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
 self.bn1 = nn.BatchNorm2d(planes)
 self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
 self.bn2 = nn.BatchNorm2d(planes)

 self.shortcut = nn.Sequential()
 if stride != 1 or in_planes != self.expansion*planes:
 self.shortcut = nn.Sequential(
 nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
 nn.BatchNorm2d(self.expansion*planes)
 )

 def forward(self, x):
 out = F.relu(self.bn1(self.conv1(x)))
 out = self.bn2(self.conv2(out))
 out += self.shortcut(x)
 out = F.relu(out)
 return out

class Bottleneck(nn.Module):
 expansion = 4

 def __init__(self, in_planes, planes, stride=1):
 super(Bottleneck, self).__init__()
 self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
 self.bn1 = nn.BatchNorm2d(planes)
 self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
 self.bn2 = nn.BatchNorm2d(planes)
 self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)
 self.bn3 = nn.BatchNorm2d(self.expansion*planes)

 self.shortcut = nn.Sequential()
 if stride != 1 or in_planes != self.expansion*planes:
 self.shortcut = nn.Sequential(
 nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
 nn.BatchNorm2d(self.expansion*planes)
 )

 def forward(self, x):
 out = F.relu(self.bn1(self.conv1(x)))
 out = F.relu(self.bn2(self.conv2(out)))
 out = self.bn3(self.conv3(out))
 out += self.shortcut(x)
 out = F.relu(out)
 return out

class ResNet(nn.Module):
 def __init__(self, block, num_blocks, num_classes=1000):
 super(ResNet, self).__init__()
 self.in_planes = 64

 self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
 self.bn1 = nn.BatchNorm2d(64)
 self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
 self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
 self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
 self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
 self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
 self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
 self.fc = nn.Linear(512*block.expansion, num_classes)

 def _make_layer(self, block, planes, num_blocks, stride):
 strides = [stride] + [1]*(num_blocks-1)
 layers = []
 for stride in strides:
 layers.append(block(self.in_planes, planes, stride))
 self.in_planes = planes * block.expansion
 return nn.Sequential(*layers)

 def forward(self, x):
 x = F.relu(self.bn1(self.conv1(x)))
 x = self.maxpool(x)
 x = self.layer1(x)
 x = self.layer2(x)
 x = self.layer3(x)
 x = self.layer4(x)
 x = self.avgpool(x)
 x = torch.flatten(x, 1)
 x = self.fc(x)
 return x

def ResNet50():
 return ResNet(Bottleneck, [3, 4, 6, 3])

# Initialize the model
model = ResNet50()
model.eval() # Switch to inference mode

# Custom batch size and image size
batch_size = 1
image_size = 224 #1024 #224 # You can change this value to any desired input size

# Example input tensor with custom batch size and image size
input_tensor = torch.randn(batch_size, 3, image_size, image_size)

# Perform forward pass and trace the model
traced_model = torch.jit.trace(model, input_tensor)
#print(output)

# Exporting for iOS18
coreml_model_iOS18 = ct.convert(
 traced_model,
 inputs=[ct.TensorType(name="input", shape=input_tensor.shape, dtype=np.float16)],
 #classifier_config=ct.ClassifierConfig(class_labels=class_labels),
 minimum_deployment_target=ct.target.iOS18
)
a = f"resnet-from-scratch-b{batch_size}.mlpackage"
coreml_model_iOS18.save(a)

# -------------------- quantization LUT only ----------------------------
print("OptimizationConfig LUT")

config = cto.coreml.OptimizationConfig(
 global_config=cto.coreml.OpPalettizerConfig(mode="uniform", nbits=4)
)
compressed_model = cto.coreml.palettize_weights(coreml_model_iOS18, config)
a = f"rnfs-4bit-b{batch_size}.mlpackage"
compressed_model.save(a)


# -------------------- OptimizationConfig LINEAR ----------------------------
print("OptimizationConfig LINEAR")

dt = ct.converters.mil.mil.types.int4 
print("-------- (W4) -------- ")

weight_config = cto.coreml.OptimizationConfig(
 global_config=cto.coreml.OpLinearQuantizerConfig(
 mode="linear_symmetric", dtype=dt
 )
)

compressed_model2 = cto.coreml.linear_quantize_weights(coreml_model_iOS18, weight_config) 
print("-------- W8 selected! ---------- ")

activation_config = cto.coreml.OptimizationConfig(
 global_config=cto.coreml.experimental.OpActivationLinearQuantizerConfig(
 mode="linear_symmetric"
 )
)
print("-------- Activation A8 quant! ---------- ")
compressed_model_a8 = cto.coreml.experimental.linear_quantize_activations(
 compressed_model2, 
 activation_config, [{"input": torch.randn_like(input_tensor)+i} for i in range(10)]
)
a = f"rnfs-A4W8-b{batch_size}.mlpackage"
compressed_model_a8.save(a)


# -------------------- OptimizationConfig LUT(LINEAR)" ----------------------------
print("OptimizationConfig LUT(LINEAR)")

dt = ct.converters.mil.mil.types.int8 # lut is 4 bit already
print("-------- LUT(W8) -------- ")
weight_config = cto.coreml.OptimizationConfig(
 global_config=cto.coreml.OpLinearQuantizerConfig(
 mode="linear_symmetric", dtype=dt
 )
)

compressed_model1 = cto.coreml.linear_quantize_weights(coreml_model_iOS18, weight_config) 
compressed_model2 = cto.coreml.palettize_weights(compressed_model1, config, joint_compression=True)
print("-------- LUT4+W8 selected! ---------- ")

activation_config = cto.coreml.OptimizationConfig(
 global_config=cto.coreml.experimental.OpActivationLinearQuantizerConfig(
 mode="linear_symmetric"
 )
)
print("-------- Activation A8 quant! ---------- ")
compressed_model_a8 = cto.coreml.experimental.linear_quantize_activations(
 compressed_model2, 
 activation_config, [{"input": torch.randn_like(input_tensor)+i} for i in range(10)]
)

a = f"rnfs-A8W8-LUT4-b{batch_size}.mlpackage"
compressed_model.save(a)

print(a)
print("Done!")


Converting PyTorch Frontend ==> MIL Ops: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋| 440/441 [00:00<00:00, 6548.48 ops/s]
Running MIL frontend_pytorch pipeline: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 139.19 passes/s]
Running MIL default pipeline: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 79/79 [00:01<00:00, 57.60 passes/s]
Running MIL backend_mlprogram pipeline: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 12/12 [00:00<00:00, 233.95 passes/s]


OptimizationConfig LUT



Running compression pass palettize_weights: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [00:00<00:00, 99.79 ops/s]
Running MIL frontend_milinternal pipeline: 0 passes [00:00, ? passes/s]
Running MIL default pipeline: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 77/77 [00:00<00:00, 176.72 passes/s]
Running MIL backend_mlprogram pipeline: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 12/12 [00:00<00:00, 180.92 passes/s]


OptimizationConfig LINEAR
-------- (W4) -------- 



Running compression pass linear_quantize_weights: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [00:00<00:00, 92.51 ops/s]
Running MIL frontend_milinternal pipeline: 0 passes [00:00, ? passes/s]
Running MIL default pipeline: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 77/77 [00:00<00:00, 167.87 passes/s]
Running MIL backend_mlprogram pipeline: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 12/12 [00:00<00:00, 209.76 passes/s]


-------- W8 selected! ---------- 
-------- Activation A8 quant! ---------- 



Running activation compression pass insert_prefix_quantize_dequantize_pair: 100%|██████████████████████████████████████████████████████████████████████████████████| 522/522 [00:00<00:00, 7993.67 ops/s]
Running compression pass linear_quantize_activations: start calibrating 10 samples
Running compression pass linear_quantize_activations: calibration may take a while ...
Running compression pass linear_quantize_activations: calibrating sample 1/10 succeeds.
Running compression pass linear_quantize_activations: calibrating sample 2/10 succeeds.
Running compression pass linear_quantize_activations: calibrating sample 3/10 succeeds.
Running compression pass linear_quantize_activations: calibrating sample 4/10 succeeds.
Running compression pass linear_quantize_activations: calibrating sample 5/10 succeeds.
Running compression pass linear_quantize_activations: calibrating sample 6/10 succeeds.
Running compression pass linear_quantize_activations: calibrating sample 7/10 succeeds.
Running comp

OptimizationConfig LUT(LINEAR)
-------- LUT(W8) -------- 



Running compression pass linear_quantize_weights: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [00:00<00:00, 107.97 ops/s]
Running MIL frontend_milinternal pipeline: 0 passes [00:00, ? passes/s]
Running MIL default pipeline: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 77/77 [00:00<00:00, 176.48 passes/s]
Running MIL backend_mlprogram pipeline: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 12/12 [00:00<00:00, 215.97 passes/s]





Running compression pass palettize_weights: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 121/121 [00:00<00:00, 116588.74 ops/s]
Running MIL frontend_milinternal pipeline: 0 passes [00:00, ? passes/s]
Running MIL default pipeline: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 77/77 [00:00<00:00, 180.58 passes/s]
Running MIL backend_mlprogram pipeline: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 12/12 [00:00<00:00, 198.24 passes/s]


-------- LUT4+W8 selected! ---------- 
-------- Activation A8 quant! ---------- 



Running activation compression pass insert_prefix_quantize_dequantize_pair: 100%|██████████████████████████████████████████████████████████████████████████████████| 522/522 [00:00<00:00, 6895.20 ops/s]
Running compression pass linear_quantize_activations: start calibrating 10 samples
Running compression pass linear_quantize_activations: calibration may take a while ...
Running compression pass linear_quantize_activations: calibrating sample 1/10 succeeds.
Running compression pass linear_quantize_activations: calibrating sample 2/10 succeeds.
Running compression pass linear_quantize_activations: calibrating sample 3/10 succeeds.
Running compression pass linear_quantize_activations: calibrating sample 4/10 succeeds.
Running compression pass linear_quantize_activations: calibrating sample 5/10 succeeds.
Running compression pass linear_quantize_activations: calibrating sample 6/10 succeeds.
Running compression pass linear_quantize_activations: calibrating sample 7/10 succeeds.
Running comp

rnfs-A8W8-LUT4-b1.mlpackage
Done!
