{ "cells": [ { "cell_type": "markdown", "id": "1092f43b", "metadata": {}, "source": [ "# Convert CLIP models to CoreML" ] }, { "cell_type": "code", "execution_count": null, "id": "e5f63e7a", "metadata": {}, "outputs": [], "source": [ "!pip install torch transformers coremltools" ] }, { "cell_type": "code", "execution_count": null, "id": "a7f0ab67", "metadata": {}, "outputs": [], "source": [ "from transformers import CLIPProcessor, CLIPModel\n", "\n", "model_version = \"laion/CLIP-ViT-H-14-laion2B-s32B-b79K\"\n", "\n", "processor = CLIPProcessor.from_pretrained(model_version)" ] }, { "cell_type": "markdown", "id": "4bd0aa05", "metadata": {}, "source": [ "# Text model" ] }, { "cell_type": "code", "execution_count": null, "id": "19851197", "metadata": {}, "outputs": [], "source": [ "# wrapped CLIPModel so that forward() function returns get_text_features()\n", "class WrappedCLIPModel_Text(CLIPModel): \n", " def forward(self, *args, **kwargs):\n", " return self.get_text_features(*args, **kwargs)\n", "\n", "model_pt_text = WrappedCLIPModel_Text.from_pretrained(model_version)\n", "model_pt_text.eval()" ] }, { "cell_type": "code", "execution_count": null, "id": "c8b3a1ca", "metadata": {}, "outputs": [], "source": [ "import torch\n", "\n", "with torch.no_grad():\n", " text = \"the \" + \" \".join([\"example text\"]*37) # 77 tokens\n", " processed_text = processor(text=text, images=None, return_tensors=\"pt\", padding=True)\n", " print(len(processed_text.input_ids[0]), processed_text.input_ids)\n", " model_traced = torch.jit.trace(model_pt_text, processed_text.input_ids, strict=True)" ] }, { "cell_type": "code", "execution_count": null, "id": "5066eb03", "metadata": { "scrolled": true }, "outputs": [], "source": [ "import coremltools as ct\n", "import numpy as np\n", "\n", "# Convert traced model to CoreML\n", "text_input_shape = ct.Shape(shape=(1, 77))\n", "\n", "model_coreml = ct.convert(\n", " model_traced,\n", " inputs=[ct.TensorType(name=\"input_text_token_ids\", shape=text_input_shape, dtype=np.float32)],\n", " outputs=[ct.TensorType(name=\"output_embedding\", dtype=np.float16)],\n", " minimum_deployment_target=ct.target.macOS13,\n", " convert_to='mlprogram'\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "a323b1b8", "metadata": {}, "outputs": [], "source": [ "model_coreml.get_spec().description" ] }, { "cell_type": "code", "execution_count": null, "id": "04773702", "metadata": {}, "outputs": [], "source": [ "model_coreml.save(\"CLIP-ViT-H-14-laion2B-s32B-b79K.text-encoder.mlpackage\")" ] }, { "cell_type": "markdown", "id": "346ade90", "metadata": {}, "source": [ "## Check correctness\n", "Should see a mean difference on the order of 1e-5 " ] }, { "cell_type": "code", "execution_count": null, "id": "9fcaef03", "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import torch\n", "with torch.no_grad():\n", " processed_text = processor(text=\"hello there\", images=None, return_tensors=\"pt\", padding=True)\n", " input_ids = processed_text.input_ids\n", " input_ids = torch.cat([input_ids, torch.tensor([[49407] * (77-input_ids.shape[1])])], dim=1)\n", " print(\"input shape:\", input_ids.shape)\n", "\n", " res_pt = model_pt_text(**processed_text)\n", " print(f\"original output: shape {res_pt.shape}, {res_pt}\")\n", " \n", " coreml_out = model_coreml.predict({'input_text_token_ids': input_ids.float()})\n", " res_coreml = torch.tensor(coreml_out['output_embedding'])\n", " print(f\"coreml output: shape {res_coreml.shape}, {res_coreml}, type {type(res_coreml)}\")\n", " \n", " difference = res_pt - res_coreml\n", " print(f\"mean difference: {torch.sum(difference)/difference.shape[1]}, max: {torch.max(difference)}\")\n", "\n" ] }, { "cell_type": "markdown", "id": "ec415cc5", "metadata": {}, "source": [ "# Image encoder" ] }, { "cell_type": "code", "execution_count": null, "id": "9228b9dc", "metadata": {}, "outputs": [], "source": [ "# wrap CLIPModel so that forward() function returns get_image_features()\n", "class WrappedCLIPModel_Image(CLIPModel): \n", " def forward(self, *args, **kwargs):\n", " return self.get_image_features(*args, **kwargs)\n", "\n", "model_pt_image = WrappedCLIPModel_Image.from_pretrained(model_version)\n", "model_pt_image.eval()" ] }, { "cell_type": "code", "execution_count": null, "id": "e9560396", "metadata": {}, "outputs": [], "source": [ "from PIL import Image\n", "import torch\n", "\n", "with torch.no_grad():\n", " image = Image.open(\"example.jpg\") \n", " processed_image = processor(text=None, images=image, return_tensors=\"pt\", padding=True)\n", " trace_input = torch.rand_like(processed_image.pixel_values)\n", " model_traced = torch.jit.trace(model_pt_image, trace_input, strict=True)" ] }, { "cell_type": "code", "execution_count": null, "id": "37adb85f", "metadata": { "scrolled": true }, "outputs": [], "source": [ "import coremltools as ct\n", "import numpy as np\n", "\n", "# Convert traced model to CoreML\n", "image_input_shape = ct.Shape(shape=trace_input.shape)\n", "\n", "model_coreml = ct.convert(\n", " model_traced,\n", " inputs=[ct.TensorType(name=\"input_image_preproessed\", shape=image_input_shape, dtype=np.float16)],\n", " outputs=[ct.TensorType(name=\"output_embedding\", dtype=np.float16)],\n", " minimum_deployment_target=ct.target.macOS13,\n", " convert_to='mlprogram'\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "9cb1b830", "metadata": {}, "outputs": [], "source": [ "model_coreml.get_spec().description" ] }, { "cell_type": "code", "execution_count": null, "id": "281451f8", "metadata": {}, "outputs": [], "source": [ "model_coreml.save(\"CLIP-ViT-H-14-laion2B-s32B-b79K.image-encoder.mlpackage\")" ] }, { "cell_type": "markdown", "id": "9f2e43c3", "metadata": {}, "source": [ "## Check correctness\n", "Should see a mean difference on the order of 1e-5 " ] }, { "cell_type": "code", "execution_count": null, "id": "7cfe24af", "metadata": {}, "outputs": [], "source": [ "\n", "with torch.no_grad():\n", " image = Image.open(\"example.jpg\")\n", "\n", " processed_image = processor(text=None, images=image, return_tensors=\"pt\", padding=True)\n", " print(\"input shape:\", processed_image.pixel_values.shape)\n", "\n", " res_pt = model_pt_image.get_image_features(**processed_image)\n", " print(f\"original output: shape {res_pt.shape}, {res_pt}\")\n", "\n", " coreml_out = model_coreml.predict({'input_image_preproessed': processed_image.pixel_values})\n", " res_coreml = torch.tensor(coreml_out['output_embedding'])\n", " print(f\"coreml output: shape {res_coreml.shape}, {res_coreml}, type {type(res_coreml)}\")\n", "\n", " difference = res_pt - res_coreml\n", " print(f\"mean difference: {torch.sum(difference)/difference.shape[1]}, cosine: {torch.nn.functional.cosine_similarity(res_pt, res_coreml)}, max: {torch.max(difference)}\")" ] }, { "cell_type": "markdown", "id": "154fffa4", "metadata": {}, "source": [ "# Check performance" ] }, { "cell_type": "code", "execution_count": null, "id": "55260e23", "metadata": {}, "outputs": [], "source": [ "import time\n", "from tqdm.auto import tqdm\n", "\n", "model_pt_image = model_pt_image.to('mps', dtype=torch.float16)\n", "\n", "start = time.perf_counter()\n", "for i in tqdm(range(100)):\n", " model_pt_image(pixel_values = torch.rand_like(processed_image.pixel_values, device=model_pt_image.device, dtype=torch.float16))\n", "end = time.perf_counter()\n", "print(\"original (GPU): \", (end-start)/100)\n", "\n", "start = time.perf_counter()\n", "for i in tqdm(range(100)):\n", " model_coreml.predict({'input_image_preproessed': torch.rand_like(processed_image.pixel_values)})\n", "end = time.perf_counter()\n", "print(\"coreml: \", (end-start)/100)\n" ] }, { "cell_type": "code", "execution_count": null, "id": "41449a3a", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.8" } }, "nbformat": 4, "nbformat_minor": 5 }