{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "d5ac353e", "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "import argparse\n", "import os\n", "import shutil\n", "import random\n", "from PIL import Image\n", "\n", "import numpy as np\n", "import torch\n", "import torch.backends.cudnn as cudnn\n", "from transformers import StoppingCriteria, StoppingCriteriaList\n", "\n", "import lavis.tasks as tasks\n", "from lavis.common.config import Config\n", "from lavis.common.dist_utils import get_rank, init_distributed_mode\n", "from lavis.common.logger import setup_logger\n", "from lavis.common.optims import (\n", " LinearWarmupCosineLRScheduler,\n", " LinearWarmupStepLRScheduler,\n", ")\n", "from lavis.common.registry import registry\n", "from lavis.common.utils import now\n", "\n", "# imports modules for registration\n", "from lavis.datasets.builders import *\n", "from lavis.models import *\n", "from lavis.processors import *\n", "from lavis.runners import *\n", "from lavis.tasks import *" ] }, { "cell_type": "code", "execution_count": null, "id": "4fdef7a6", "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "shutil.copytree('/ibex/project/c2133/vicuna', '/tmp/vicuna')" ] }, { "cell_type": "code", "execution_count": 2, "id": "661f9e80", "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "class StoppingCriteriaSub(StoppingCriteria):\n", "\n", " def __init__(self, stops = [], encounters=1):\n", " super().__init__()\n", " self.stops = stops\n", "\n", " def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):\n", " for stop in self.stops:\n", " if torch.all((stop == input_ids[0][-len(stop):])).item():\n", " return True\n", "\n", " return False\n", "\n", "\n", "stop_words_ids = [torch.tensor([835]).to('cuda:0'), \n", " torch.tensor([2277, 29937]).to('cuda:0')] # '###' can be encoded in different ways.\n", "stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])" ] }, { "cell_type": "code", "execution_count": 6, "id": "1822a77a", "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "parser = argparse.ArgumentParser(description=\"Training\")\n", "\n", "parser.add_argument(\"--cfg-path\", required=True, help=\"path to configuration file.\")\n", "parser.add_argument(\n", " \"--options\",\n", " nargs=\"+\",\n", " help=\"override some settings in the used config, the key-value pair \"\n", " \"in xxx=yyy format will be merged into config file (deprecate), \"\n", " \"change to --cfg-options instead.\",\n", ")\n", "\n", "args = parser.parse_args([\"--cfg-path\", \"lavis/projects/blip2/train/vicuna_pretrain_stage2_cc.yaml\"])\n", "\n", "cfg = Config(args)\n", "device = 'cuda:0'" ] }, { "cell_type": "code", "execution_count": 4, "id": "57e90f19", "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "vis_processor_cfg = cfg.datasets_cfg.cc_combine.vis_processor.train\n", "vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg)" ] }, { "cell_type": "code", "execution_count": 7, "id": "4cc521da", "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Loading LLAMA\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "abeac6970d914446adc1fb73f7e5b5f9", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Loading checkpoint shards: 0%| | 0/3 [00:00╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮\n", " in <module>:2 \n", " \n", " 1 task = tasks.setup_task(cfg) \n", " 2 model = task.build_model(cfg) \n", " 3 \n", " \n", " /home/zhud/project/blip2/lavis/tasks/base_task.py:33 in build_model \n", " \n", " 30 │ │ model_config = cfg.model_cfg \n", " 31 │ │ \n", " 32 │ │ model_cls = registry.get_model_class(model_config.arch) \n", " 33 │ │ return model_cls.from_config(model_config) \n", " 34 │ \n", " 35 │ def build_datasets(self, cfg): \n", " 36 │ │ \"\"\" \n", " \n", " /home/zhud/project/blip2/lavis/models/blip2_models/blip2_llama.py:315 in from_config \n", " \n", " 312 │ │ ckpt_path = cfg.get(\"ckpt\", \"\") \n", " 313 │ │ if ckpt_path: \n", " 314 │ │ │ print(\"Load BLIP2-LLM Checkpoint: {}\".format(ckpt_path)) \n", " 315 │ │ │ ckpt = torch.load(ckpt_path, map_location=\"cpu\") \n", " 316 │ │ │ msg = model.load_state_dict(ckpt['model'], strict=False) \n", " 317 │ │ \n", " 318 │ │ return model \n", " \n", " /home/zhud/anaconda3/envs/eye/lib/python3.9/site-packages/torch/serialization.py:791 in load \n", " \n", " 788 │ if 'encoding' not in pickle_load_args.keys(): \n", " 789 │ │ pickle_load_args['encoding'] = 'utf-8' \n", " 790 │ \n", " 791 with _open_file_like(f, 'rb') as opened_file: \n", " 792 │ │ if _is_zipfile(opened_file): \n", " 793 │ │ │ # The zipfile reader is going to advance the current file position. \n", " 794 │ │ │ # If we want to actually tail call to torch.jit.load, we need to \n", " \n", " /home/zhud/anaconda3/envs/eye/lib/python3.9/site-packages/torch/serialization.py:271 in \n", " _open_file_like \n", " \n", " 268 \n", " 269 def _open_file_like(name_or_buffer, mode): \n", " 270 │ if _is_path(name_or_buffer): \n", " 271 │ │ return _open_file(name_or_buffer, mode) \n", " 272 │ else: \n", " 273 │ │ if 'w' in mode: \n", " 274 │ │ │ return _open_buffer_writer(name_or_buffer) \n", " \n", " /home/zhud/anaconda3/envs/eye/lib/python3.9/site-packages/torch/serialization.py:252 in __init__ \n", " \n", " 249 \n", " 250 class _open_file(_opener): \n", " 251 │ def __init__(self, name, mode): \n", " 252 │ │ super().__init__(open(name, mode)) \n", " 253 │ \n", " 254 │ def __exit__(self, *args): \n", " 255 │ │ self.file_like.close() \n", "╰──────────────────────────────────────────────────────────────────────────────────────────────────╯\n", "FileNotFoundError: [Errno 2] No such file or directory: \n", "'/home/zhud/project/blip2/lavis/output/BLIP2/Vicuna_pretrain_stage2_cc/20230405233/checkpoint_3.pth'\n", "\n" ], "text/plain": [ "\u001B[31m╭─\u001B[0m\u001B[31m──────────────────────────────\u001B[0m\u001B[31m \u001B[0m\u001B[1;31mTraceback \u001B[0m\u001B[1;2;31m(most recent call last)\u001B[0m\u001B[31m \u001B[0m\u001B[31m───────────────────────────────\u001B[0m\u001B[31m─╮\u001B[0m\n", "\u001B[31m│\u001B[0m in \u001B[92m\u001B[0m:\u001B[94m2\u001B[0m \u001B[31m│\u001B[0m\n", "\u001B[31m│\u001B[0m \u001B[31m│\u001B[0m\n", "\u001B[31m│\u001B[0m \u001B[2m1 \u001B[0mtask = tasks.setup_task(cfg) \u001B[31m│\u001B[0m\n", "\u001B[31m│\u001B[0m \u001B[31m❱ \u001B[0m2 model = task.build_model(cfg) \u001B[31m│\u001B[0m\n", "\u001B[31m│\u001B[0m \u001B[2m3 \u001B[0m \u001B[31m│\u001B[0m\n", "\u001B[31m│\u001B[0m \u001B[31m│\u001B[0m\n", "\u001B[31m│\u001B[0m \u001B[2;33m/home/zhud/project/blip2/lavis/tasks/\u001B[0m\u001B[1;33mbase_task.py\u001B[0m:\u001B[94m33\u001B[0m in \u001B[92mbuild_model\u001B[0m \u001B[31m│\u001B[0m\n", "\u001B[31m│\u001B[0m \u001B[31m│\u001B[0m\n", "\u001B[31m│\u001B[0m \u001B[2m 30 \u001B[0m\u001B[2m│ │ \u001B[0mmodel_config = cfg.model_cfg \u001B[31m│\u001B[0m\n", "\u001B[31m│\u001B[0m \u001B[2m 31 \u001B[0m\u001B[2m│ │ \u001B[0m \u001B[31m│\u001B[0m\n", "\u001B[31m│\u001B[0m \u001B[2m 32 \u001B[0m\u001B[2m│ │ \u001B[0mmodel_cls = registry.get_model_class(model_config.arch) \u001B[31m│\u001B[0m\n", "\u001B[31m│\u001B[0m \u001B[31m❱ \u001B[0m 33 \u001B[2m│ │ \u001B[0m\u001B[94mreturn\u001B[0m model_cls.from_config(model_config) \u001B[31m│\u001B[0m\n", "\u001B[31m│\u001B[0m \u001B[2m 34 \u001B[0m\u001B[2m│ \u001B[0m \u001B[31m│\u001B[0m\n", "\u001B[31m│\u001B[0m \u001B[2m 35 \u001B[0m\u001B[2m│ \u001B[0m\u001B[94mdef\u001B[0m \u001B[92mbuild_datasets\u001B[0m(\u001B[96mself\u001B[0m, cfg): \u001B[31m│\u001B[0m\n", "\u001B[31m│\u001B[0m \u001B[2m 36 \u001B[0m\u001B[2;90m│ │ \u001B[0m\u001B[33m\"\"\"\u001B[0m \u001B[31m│\u001B[0m\n", "\u001B[31m│\u001B[0m \u001B[31m│\u001B[0m\n", "\u001B[31m│\u001B[0m \u001B[2;33m/home/zhud/project/blip2/lavis/models/blip2_models/\u001B[0m\u001B[1;33mblip2_llama.py\u001B[0m:\u001B[94m315\u001B[0m in \u001B[92mfrom_config\u001B[0m \u001B[31m│\u001B[0m\n", "\u001B[31m│\u001B[0m \u001B[31m│\u001B[0m\n", "\u001B[31m│\u001B[0m \u001B[2m312 \u001B[0m\u001B[2m│ │ \u001B[0mckpt_path = cfg.get(\u001B[33m\"\u001B[0m\u001B[33mckpt\u001B[0m\u001B[33m\"\u001B[0m, \u001B[33m\"\u001B[0m\u001B[33m\"\u001B[0m) \u001B[31m│\u001B[0m\n", "\u001B[31m│\u001B[0m \u001B[2m313 \u001B[0m\u001B[2m│ │ \u001B[0m\u001B[94mif\u001B[0m ckpt_path: \u001B[31m│\u001B[0m\n", "\u001B[31m│\u001B[0m \u001B[2m314 \u001B[0m\u001B[2m│ │ │ \u001B[0m\u001B[96mprint\u001B[0m(\u001B[33m\"\u001B[0m\u001B[33mLoad BLIP2-LLM Checkpoint: \u001B[0m\u001B[33m{}\u001B[0m\u001B[33m\"\u001B[0m.format(ckpt_path)) \u001B[31m│\u001B[0m\n", "\u001B[31m│\u001B[0m \u001B[31m❱ \u001B[0m315 \u001B[2m│ │ │ \u001B[0mckpt = torch.load(ckpt_path, map_location=\u001B[33m\"\u001B[0m\u001B[33mcpu\u001B[0m\u001B[33m\"\u001B[0m) \u001B[31m│\u001B[0m\n", "\u001B[31m│\u001B[0m \u001B[2m316 \u001B[0m\u001B[2m│ │ │ \u001B[0mmsg = model.load_state_dict(ckpt[\u001B[33m'\u001B[0m\u001B[33mmodel\u001B[0m\u001B[33m'\u001B[0m], strict=\u001B[94mFalse\u001B[0m) \u001B[31m│\u001B[0m\n", "\u001B[31m│\u001B[0m \u001B[2m317 \u001B[0m\u001B[2m│ │ \u001B[0m \u001B[31m│\u001B[0m\n", "\u001B[31m│\u001B[0m \u001B[2m318 \u001B[0m\u001B[2m│ │ \u001B[0m\u001B[94mreturn\u001B[0m model \u001B[31m│\u001B[0m\n", "\u001B[31m│\u001B[0m \u001B[31m│\u001B[0m\n", "\u001B[31m│\u001B[0m \u001B[2;33m/home/zhud/anaconda3/envs/eye/lib/python3.9/site-packages/torch/\u001B[0m\u001B[1;33mserialization.py\u001B[0m:\u001B[94m791\u001B[0m in \u001B[92mload\u001B[0m \u001B[31m│\u001B[0m\n", "\u001B[31m│\u001B[0m \u001B[31m│\u001B[0m\n", "\u001B[31m│\u001B[0m \u001B[2m 788 \u001B[0m\u001B[2m│ \u001B[0m\u001B[94mif\u001B[0m \u001B[33m'\u001B[0m\u001B[33mencoding\u001B[0m\u001B[33m'\u001B[0m \u001B[95mnot\u001B[0m \u001B[95min\u001B[0m pickle_load_args.keys(): \u001B[31m│\u001B[0m\n", "\u001B[31m│\u001B[0m \u001B[2m 789 \u001B[0m\u001B[2m│ │ \u001B[0mpickle_load_args[\u001B[33m'\u001B[0m\u001B[33mencoding\u001B[0m\u001B[33m'\u001B[0m] = \u001B[33m'\u001B[0m\u001B[33mutf-8\u001B[0m\u001B[33m'\u001B[0m \u001B[31m│\u001B[0m\n", "\u001B[31m│\u001B[0m \u001B[2m 790 \u001B[0m\u001B[2m│ \u001B[0m \u001B[31m│\u001B[0m\n", "\u001B[31m│\u001B[0m \u001B[31m❱ \u001B[0m 791 \u001B[2m│ \u001B[0m\u001B[94mwith\u001B[0m _open_file_like(f, \u001B[33m'\u001B[0m\u001B[33mrb\u001B[0m\u001B[33m'\u001B[0m) \u001B[94mas\u001B[0m opened_file: \u001B[31m│\u001B[0m\n", "\u001B[31m│\u001B[0m \u001B[2m 792 \u001B[0m\u001B[2m│ │ \u001B[0m\u001B[94mif\u001B[0m _is_zipfile(opened_file): \u001B[31m│\u001B[0m\n", "\u001B[31m│\u001B[0m \u001B[2m 793 \u001B[0m\u001B[2m│ │ │ \u001B[0m\u001B[2m# The zipfile reader is going to advance the current file position.\u001B[0m \u001B[31m│\u001B[0m\n", "\u001B[31m│\u001B[0m \u001B[2m 794 \u001B[0m\u001B[2m│ │ │ \u001B[0m\u001B[2m# If we want to actually tail call to torch.jit.load, we need to\u001B[0m \u001B[31m│\u001B[0m\n", "\u001B[31m│\u001B[0m \u001B[31m│\u001B[0m\n", "\u001B[31m│\u001B[0m \u001B[2;33m/home/zhud/anaconda3/envs/eye/lib/python3.9/site-packages/torch/\u001B[0m\u001B[1;33mserialization.py\u001B[0m:\u001B[94m271\u001B[0m in \u001B[31m│\u001B[0m\n", "\u001B[31m│\u001B[0m \u001B[92m_open_file_like\u001B[0m \u001B[31m│\u001B[0m\n", "\u001B[31m│\u001B[0m \u001B[31m│\u001B[0m\n", "\u001B[31m│\u001B[0m \u001B[2m 268 \u001B[0m \u001B[31m│\u001B[0m\n", "\u001B[31m│\u001B[0m \u001B[2m 269 \u001B[0m\u001B[94mdef\u001B[0m \u001B[92m_open_file_like\u001B[0m(name_or_buffer, mode): \u001B[31m│\u001B[0m\n", "\u001B[31m│\u001B[0m \u001B[2m 270 \u001B[0m\u001B[2m│ \u001B[0m\u001B[94mif\u001B[0m _is_path(name_or_buffer): \u001B[31m│\u001B[0m\n", "\u001B[31m│\u001B[0m \u001B[31m❱ \u001B[0m 271 \u001B[2m│ │ \u001B[0m\u001B[94mreturn\u001B[0m _open_file(name_or_buffer, mode) \u001B[31m│\u001B[0m\n", "\u001B[31m│\u001B[0m \u001B[2m 272 \u001B[0m\u001B[2m│ \u001B[0m\u001B[94melse\u001B[0m: \u001B[31m│\u001B[0m\n", "\u001B[31m│\u001B[0m \u001B[2m 273 \u001B[0m\u001B[2m│ │ \u001B[0m\u001B[94mif\u001B[0m \u001B[33m'\u001B[0m\u001B[33mw\u001B[0m\u001B[33m'\u001B[0m \u001B[95min\u001B[0m mode: \u001B[31m│\u001B[0m\n", "\u001B[31m│\u001B[0m \u001B[2m 274 \u001B[0m\u001B[2m│ │ │ \u001B[0m\u001B[94mreturn\u001B[0m _open_buffer_writer(name_or_buffer) \u001B[31m│\u001B[0m\n", "\u001B[31m│\u001B[0m \u001B[31m│\u001B[0m\n", "\u001B[31m│\u001B[0m \u001B[2;33m/home/zhud/anaconda3/envs/eye/lib/python3.9/site-packages/torch/\u001B[0m\u001B[1;33mserialization.py\u001B[0m:\u001B[94m252\u001B[0m in \u001B[92m__init__\u001B[0m \u001B[31m│\u001B[0m\n", "\u001B[31m│\u001B[0m \u001B[31m│\u001B[0m\n", "\u001B[31m│\u001B[0m \u001B[2m 249 \u001B[0m \u001B[31m│\u001B[0m\n", "\u001B[31m│\u001B[0m \u001B[2m 250 \u001B[0m\u001B[94mclass\u001B[0m \u001B[4;92m_open_file\u001B[0m(_opener): \u001B[31m│\u001B[0m\n", "\u001B[31m│\u001B[0m \u001B[2m 251 \u001B[0m\u001B[2m│ \u001B[0m\u001B[94mdef\u001B[0m \u001B[92m__init__\u001B[0m(\u001B[96mself\u001B[0m, name, mode): \u001B[31m│\u001B[0m\n", "\u001B[31m│\u001B[0m \u001B[31m❱ \u001B[0m 252 \u001B[2m│ │ \u001B[0m\u001B[96msuper\u001B[0m().\u001B[92m__init__\u001B[0m(\u001B[96mopen\u001B[0m(name, mode)) \u001B[31m│\u001B[0m\n", "\u001B[31m│\u001B[0m \u001B[2m 253 \u001B[0m\u001B[2m│ \u001B[0m \u001B[31m│\u001B[0m\n", "\u001B[31m│\u001B[0m \u001B[2m 254 \u001B[0m\u001B[2m│ \u001B[0m\u001B[94mdef\u001B[0m \u001B[92m__exit__\u001B[0m(\u001B[96mself\u001B[0m, *args): \u001B[31m│\u001B[0m\n", "\u001B[31m│\u001B[0m \u001B[2m 255 \u001B[0m\u001B[2m│ │ \u001B[0m\u001B[96mself\u001B[0m.file_like.close() \u001B[31m│\u001B[0m\n", "\u001B[31m╰──────────────────────────────────────────────────────────────────────────────────────────────────╯\u001B[0m\n", "\u001B[1;91mFileNotFoundError: \u001B[0m\u001B[1m[\u001B[0mErrno \u001B[1;36m2\u001B[0m\u001B[1m]\u001B[0m No such file or directory: \n", "\u001B[32m'/home/zhud/project/blip2/lavis/output/BLIP2/Vicuna_pretrain_stage2_cc/20230405233/checkpoint_3.pth'\u001B[0m\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "task = tasks.setup_task(cfg)\n", "model = task.build_model(cfg)" ] }, { "cell_type": "code", "execution_count": 9, "id": "ba874036", "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [ { "data": { "text/plain": [ "'/ibex/project/c2133/vicuna'" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [] }, { "cell_type": "markdown", "id": "bf1c4e1c", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "### Load Checkpoint" ] }, { "cell_type": "code", "execution_count": null, "id": "a2a7f2bd", "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "ckpt_path = '/ibex/project/c2133/vicuna_ckpt_test/Vicuna_prompt_stage2_laion/20230410145/checkpoint_4.pth'\n", "ckpt = torch.load(ckpt_path, map_location=\"cpu\")\n", "msg = model.load_state_dict(ckpt['model'], strict=False)\n", "model = model.to(device)" ] }, { "cell_type": "markdown", "id": "035a495f", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "### Example of Tokenizer" ] }, { "cell_type": "code", "execution_count": 35, "id": "3426ae10", "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "texts = [\"A chat\", \"The assistant gives helpful\"]\n", "\n", "llama_tokens = model.llama_tokenizer(\n", " texts, \n", " return_tensors=\"pt\", \n", " padding=\"longest\",\n", " truncation=True,\n", " max_length=10).to(device)" ] }, { "cell_type": "code", "execution_count": 13, "id": "376400a4", "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "texts = \"The assistant gives helpful\"\n", "\n", "llama_tokens = model.llama_tokenizer(\n", " texts, \n", " return_tensors=\"pt\", \n", " padding=\"longest\",\n", " truncation=True,\n", " max_length=10).to(device)" ] }, { "cell_type": "code", "execution_count": 14, "id": "6988ee66", "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [ { "data": { "text/plain": [ "torch.Size([1, 5])" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "llama_tokens.attention_mask.shape" ] }, { "cell_type": "code", "execution_count": 9, "id": "dc9e376d", "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "targets = llama_tokens.input_ids.masked_fill(\n", " llama_tokens.input_ids == model.llama_tokenizer.pad_token_id, -100\n", " )" ] }, { "cell_type": "code", "execution_count": 10, "id": "e458fa52", "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [ { "data": { "text/plain": [ "torch.Size([2, 3])" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "torch.ones([targets.shape[0], targets.shape[0]+1]).shape" ] }, { "cell_type": "code", "execution_count": null, "id": "24607f7a", "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "text = \\\n", "\"### Human: What's your name?\" \\\n", "\"### Assistant: \"\n", "\n", "\n", "llama_tokens = model.llama_tokenizer(\n", " text, \n", " return_tensors=\"pt\", \n", " ).to(device)" ] }, { "cell_type": "markdown", "id": "5e69d3e1", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "### Example of Emb Input" ] }, { "cell_type": "code", "execution_count": 188, "id": "205b092f", "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "​\n", "\n", "I'm sorry, I am an AI language model and do not have a physical form or a name. My purpose is to assist you with any questions or tasks you may have to the best of my ability. Is there anything specific you would like help with?\n", "###\n" ] } ], "source": [ "inputs_embeds = model.llama_model.model.embed_tokens(llama_tokens.input_ids)\n", "outputs = model.llama_model.generate(\n", " inputs_embeds=inputs_embeds,\n", " query_embeds=None,\n", " attention_mask=llama_tokens.attention_mask,\n", " max_new_tokens=500,\n", " stopping_criteria=stopping_criteria,\n", " )\n", "output_text = model.llama_tokenizer.decode(outputs[0])\n", "print(output_text)" ] }, { "cell_type": "code", "execution_count": 189, "id": "561b42f5", "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [ { "data": { "text/plain": [ "torch.Size([1, 16, 5120])" ] }, "execution_count": 189, "metadata": {}, "output_type": "execute_result" } ], "source": [ "inputs_embeds.shape" ] }, { "cell_type": "markdown", "id": "a1694ad6", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "### Example of ID Input" ] }, { "cell_type": "code", "execution_count": null, "id": "c1dc7841", "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "outputs = model.llama_model.generate(\n", " input_ids=llama_tokens.input_ids,\n", " query_embeds=None,\n", " attention_mask=llama_tokens.attention_mask,\n", " max_new_tokens=500,\n", " stopping_criteria=stopping_criteria,\n", " )\n", "output_text = model.llama_tokenizer.decode(outputs[0])\n", "print(output_text)" ] }, { "cell_type": "markdown", "id": "19dd1f9d", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [] }, { "cell_type": "markdown", "id": "468ac97e", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "### Example of Mixed Input" ] }, { "cell_type": "code", "execution_count": 47, "id": "4af3a9bf", "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "ckpt_path = '/home/zhud/project/blip2/lavis/output/BLIP2/Vicuna_pretrain_stage2_cc/20230408015/checkpoint_2.pth'\n", "ckpt = torch.load(ckpt_path, map_location=\"cpu\")\n", "msg = model.load_state_dict(ckpt['model'], strict=False)\n", "model = model.to(device)" ] }, { "cell_type": "code", "execution_count": 48, "id": "c3148611", "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "# Load the image using PIL\n", "image = Image.open('test_img5.jpg').convert('RGB')\n", "image = vis_processor(image).unsqueeze(0).to(device)\n", "inputs_llama, atts_llama = model.encode_img(image)" ] }, { "cell_type": "code", "execution_count": 53, "id": "07b82707", "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "text = \\\n", "\"A chat between a curious human and an artificial intelligence assistant. \" \\\n", "\"The assistant gives helpful, detailed, and polite answers to the human's questions. \"\\\n", "\"Human may ask questions related to a given image. \" \\\n", "\"The image will be wrapped as IMAGE_CONTENT \" \\\n", "\"### Human: To_Split \" \\\n", "\"### Assistant: Received the image. \" \\\n", "\"### Human: Describe the image in detail. Say everthing you see. Describe all the things.\" \\\n", "\"### Assistant: \"\n", "\n", "\n", "text = \\\n", "\"A chat between a curious human and an artificial intelligence assistant. \" \\\n", "\"The assistant gives helpful, detailed, and polite answers to the human's questions. \"\\\n", "\"Human may ask questions related to a given image. \" \\\n", "\"The image will be wrapped as IMAGE_CONTENT \" \\\n", "\"### Human: Describe the image in detail. Say everthing you see. To_Split \" \\\n", "\"### Assistant: \"\n", "\n", "text = \\\n", "\"### Human: Describe the image in detail. Say everthing you see. To_Split \" \\\n", "\"### Assistant: \"\n", "\n", "\n", "\n", "# text = \\\n", "# \"A chat between a curious human and an artificial intelligence assistant. \" \\\n", "# \"The assistant gives helpful, detailed, and polite answers to the human's questions. \"\\\n", "# \"Human may ask questions related to a given image. \" \\\n", "# \"The image will be wrapped as IMAGE_CONTENT \" \\\n", "# \"### Human: To_Split \" \\\n", "# \"### Assistant: Received the image. \" \\\n", "# \"### Human: This is a draft of a website. Give me the html code to write this website. \" \\\n", "# \"Btw, you need to come up with some jokes in the website to fill the placeholders. \" \\\n", "# \"Also, make the website colorful and vivid. \" \\\n", "# \"### Assistant: \"\n", "\n", "\n", "# text = \\\n", "# \"Return what the human says. \" \\\n", "# \"### Human: There is a big elephant in the sky. \" \\\n", "# \"### Assistant: There is a big elephant in the sky. \" \\\n", "# \"### Human: fdjlks klcznv_l1 \" \\\n", "# \"### Assistant: fdjlks klcznv_l1 \" \\\n", "# \"### Human: To_Split \" \\\n", "# \"### Assistant: \"\n", "\n", "\n", "text_1, text_2 = text.split('To_Split')\n", "\n", "text_1_tokens = model.llama_tokenizer(text_1, return_tensors=\"pt\").to(device)\n", "text_2_tokens = model.llama_tokenizer(text_2, return_tensors=\"pt\", add_special_tokens=False).to(device)\n", "text_1_emb = model.llama_model.model.embed_tokens(text_1_tokens.input_ids)\n", "text_2_emb = model.llama_model.model.embed_tokens(text_2_tokens.input_ids)" ] }, { "cell_type": "code", "execution_count": 54, "id": "136b9e97", "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "\n", "The image shows a small bird perched on a tree stump, with a camera lens in the background\n", "\n", "The bird is a small bird, with a bright yellow beak and black feathers. It is perched on a tree stump, with its wings spread out and its beak open. The bird is looking to the left, as if it is about to take off.\n", "\n", "The camera lens in the background is a large, black lens with a silver ring around the front. The lens is attached to a camera, which is not visible in the image. The lens is pointed at the bird, with the camera's viewfinder showing the bird in the center of the frame.\n", "\n", "The background of the image is a forest, with trees and foliage visible in the distance. The trees are covered in leaves, and there is a thick layer of mist or fog in the air, which gives the image a dreamy, ethereal quality.\n", "\n", "The lighting in the image is soft and diffused, with the sun shining through the trees and casting a warm, golden light on the bird and the tree stump. The lighting creates deep shadows in the forest, which add to the sense of mystery and wonder in the image.\n", "\n", "The overall effect of the image is one of peacefulness and tranquility, with the bird and the forest creating a sense of calm and serenity. The image is beautifully composed, with the bird and the camera lens creating a visual balance that draws the viewer's eye to the center of the frame.\n", "###\n" ] } ], "source": [ "outputs = model.llama_model.generate(\n", " inputs_embeds=torch.concat([text_1_emb, inputs_llama, text_2_emb], dim=1),\n", " query_embeds=None,\n", " attention_mask=torch.concat([text_1_tokens.attention_mask, atts_llama, text_2_tokens.attention_mask], dim=1),\n", " max_new_tokens=600,\n", " stopping_criteria=stopping_criteria,\n", " )\n", "output_text = model.llama_tokenizer.decode(outputs[0])\n", "print(output_text)" ] }, { "cell_type": "code", "execution_count": 83, "id": "54cc3d4a", "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "with open('lavis/prompts/image_caption.txt', 'r') as f:\n", " prompts = f.read().splitlines()" ] }, { "cell_type": "code", "execution_count": 92, "id": "f52cd85c", "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "prompt_token = model.llama_tokenizer(prompts, return_tensors=\"pt\", padding=\"longest\",)" ] }, { "cell_type": "code", "execution_count": 103, "id": "4b0cf1d0", "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[(15, 6), (16, 11), (17, 17), (18, 17), (19, 27), (20, 18), (21, 21), (22, 4), (23, 6), (24, 2)]\n" ] } ], "source": [ "\n", "\n", "my_list = prompt_token.attention_mask.sum(1).numpy()\n", "counts = {}\n", "\n", "for element in my_list:\n", " if element in counts:\n", " counts[element] += 1\n", " else:\n", " counts[element] = 1\n", "\n", "print(sorted(counts.items(), key=lambda item: item[0]))" ] }, { "cell_type": "code", "execution_count": 58, "id": "f7919e93", "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[1, 2, 1, 2, 1, 2]\n" ] } ], "source": [ "a,b = [1,1,1], [2,2,2]\n", "c = [i for pair in zip(a,b) for i in pair]\n", "print(c)" ] }, { "cell_type": "markdown", "id": "3c64a037", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "### Example of Image Input" ] }, { "cell_type": "code", "execution_count": 67, "id": "87164578", "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "a bird eating from a bird feeder\n", "\n", "bird feeder, bird feeder, bird feeder, bird feeder, bird feeder, bird feeder, bird\n", "bird feeder, bird feeder, bird feeder, bird feeder, bird feeder, bird feeder, bird\n", "bird feeder, bird feeder, bird feeder, bird feeder, bird feeder, bird feeder, bird\n", "bird feeder, bird feeder, bird feeder\n" ] } ], "source": [ "inputs_embeds = model.llama_model.model.embed_tokens(llama_tokens.input_ids)\n", "bos_embeds = model.llama_model.model.embed_tokens(torch.tensor(model.llama_tokenizer.bos_token_id, device=device))[None, None]\n", "outputs = model.llama_model.generate(\n", " inputs_embeds=torch.concat([bos_embeds, inputs_llama], dim=1),\n", " query_embeds=None,\n", " attention_mask=torch.concat([atts_llama[:, :1], atts_llama], dim=1),\n", " max_new_tokens=100,\n", " stopping_criteria=stopping_criteria,\n", " )\n", "output_text = model.llama_tokenizer.decode(outputs[0])\n", "print(output_text)" ] } ], "metadata": { "kernelspec": { "display_name": "eye", "language": "python", "name": "eye" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.16" } }, "nbformat": 4, "nbformat_minor": 5 }