{"cells":[{"cell_type":"markdown","metadata":{},"source":["# | Import Libraries [↑](#top) \n","\n","***\n","\n","Import all the required libraries for this notebook."]},{"cell_type":"code","execution_count":1,"metadata":{"execution":{"iopub.execute_input":"2024-09-19T17:05:23.520178Z","iopub.status.busy":"2024-09-19T17:05:23.519758Z","iopub.status.idle":"2024-09-19T17:05:27.482622Z","shell.execute_reply":"2024-09-19T17:05:27.481379Z","shell.execute_reply.started":"2024-09-19T17:05:23.520134Z"},"trusted":true},"outputs":[{"name":"stderr","output_type":"stream","text":["/home/ea301b/anaconda3/envs/binh_mamba/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n"," from .autonotebook import tqdm as notebook_tqdm\n"]},{"name":"stdout","output_type":"stream","text":["Current device is: cuda\n","mkdir: cannot create directory ‘output’: File exists\n"]}],"source":["import matplotlib.pyplot as plt\n","import pandas as pd\n","\n","import warnings\n","# import wandb\n","\n","\n","from sklearn.metrics import roc_auc_score\n","from sklearn.utils import shuffle\n","from sklearn.model_selection import StratifiedKFold, GroupKFold, KFold\n","import torch\n","import torch.nn as nn\n","from torch.nn import Parameter\n","import torch.nn.functional as F\n","from torch.optim import Adam, SGD, AdamW\n","from torch.optim.lr_scheduler import OneCycleLR\n","from torch.utils.data import DataLoader, Dataset\n","from tqdm.auto import tqdm\n","\n","# ======= OPTIONS =========\n","pd.set_option('display.max_rows', 500)\n","pd.set_option('display.max_columns', 500)\n","pd.set_option('display.width', 1000)\n","device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n","print(f\"Current device is: {device}\")\n","warnings.filterwarnings(\"ignore\")\n","!mkdir output"]},{"cell_type":"code","execution_count":2,"metadata":{"execution":{"iopub.execute_input":"2024-09-19T17:05:27.484828Z","iopub.status.busy":"2024-09-19T17:05:27.484129Z","iopub.status.idle":"2024-09-19T17:05:43.959428Z","shell.execute_reply":"2024-09-19T17:05:43.958637Z","shell.execute_reply.started":"2024-09-19T17:05:27.484792Z"},"trusted":true},"outputs":[],"source":["import random\n","import torch.nn as nn\n","from torch.nn import BCEWithLogitsLoss\n","from collections import namedtuple\n","from dataclasses import dataclass, field, asdict\n","from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel\n","from mamba_ssm.utils.hf import load_config_hf, load_state_dict_hf\n","# from huggingface_hub import HfApi\n","\n","# import evaluate\n","import numpy as np\n","# from datasets import load_dataset\n","# from transformers import Trainer\n","from transformers import DataCollatorWithPadding\n","from transformers import AutoTokenizer, TrainingArguments\n","import re"]},{"cell_type":"markdown","metadata":{"execution":{"iopub.execute_input":"2024-09-18T04:07:22.734063Z","iopub.status.busy":"2024-09-18T04:07:22.733462Z","iopub.status.idle":"2024-09-18T04:07:22.850537Z","shell.execute_reply":"2024-09-18T04:07:22.849644Z","shell.execute_reply.started":"2024-09-18T04:07:22.734029Z"},"trusted":true},"source":["import wandb\n","from huggingface_hub import login\n","\n","login(token=\"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\")\n","notes = \"Train Mamba With 400k row dataset\""]},{"cell_type":"markdown","metadata":{"papermill":{"duration":0.012589,"end_time":"2022-08-31T07:03:04.13341","exception":false,"start_time":"2022-08-31T07:03:04.120821","status":"completed"},"tags":[]},"source":["# | Load Data [↑](#top) \n","\n","***\n","\n","Load data."]},{"cell_type":"code","execution_count":3,"metadata":{"execution":{"iopub.execute_input":"2024-09-19T17:05:43.962860Z","iopub.status.busy":"2024-09-19T17:05:43.962010Z","iopub.status.idle":"2024-09-19T17:05:53.355460Z","shell.execute_reply":"2024-09-19T17:05:53.354541Z","shell.execute_reply.started":"2024-09-19T17:05:43.962813Z"},"trusted":true},"outputs":[{"name":"stderr","output_type":"stream","text":["Processing Text: 100%|██████████| 165767/165767 [00:04<00:00, 37690.26it/s]\n","Processing Text: 100%|██████████| 1679/1679 [00:00<00:00, 40374.47it/s]"]},{"name":"stdout","output_type":"stream","text":["Trainging DF Processing\n","\n","RangeIndex: 165767 entries, 0 to 165766\n","Data columns (total 4 columns):\n"," # Column Non-Null Count Dtype \n","--- ------ -------------- ----- \n"," 0 id 165767 non-null object\n"," 1 prompt_id 165767 non-null int64 \n"," 2 text 165767 non-null object\n"," 3 generated 165767 non-null int64 \n","dtypes: int64(2), object(2)\n","memory usage: 5.1+ MB\n","None\n","Testing DF Processing\n","\n","RangeIndex: 1679 entries, 0 to 1678\n","Data columns (total 4 columns):\n"," # Column Non-Null Count Dtype \n","--- ------ -------------- ----- \n"," 0 id 1679 non-null object\n"," 1 prompt_id 1679 non-null int64 \n"," 2 text 1679 non-null object\n"," 3 generated 1679 non-null int64 \n","dtypes: int64(2), object(2)\n","memory usage: 52.6+ KB\n","None\n"]},{"name":"stderr","output_type":"stream","text":["\n"]}],"source":["import pandas as pd\n","import re\n","import unicodedata\n","from tqdm import tqdm\n","\n","# Load DataFrame\n","train_df = pd.read_parquet('/home/HardDisk/binh230_intern/Mamba-AI-generated-text-detection/data/Mix-AI-Dataset/train_essays.parquet')\n","valid_df = pd.read_parquet('/home/HardDisk/binh230_intern/Mamba-AI-generated-text-detection/data/Mix-AI-Dataset/valid_essays.parquet')\n","\n","# Define characters to remove\n","char_to_remove = ['{', '£', '\\x97', '¹', 'å', '\\\\', '\\x85', '<', '\\x99', \n"," 'é', ']', '+', 'Ö', '\\xa0', '>', '|', '\\x80', '~', '©', \n"," '/', '\\x93', '$', 'Ó', '²', '^', ';', '`', 'á', '*', '(', \n"," '¶', '®', '[', '\\x94', '\\x91', '#', '-', 'ó', ')', '}', '=']\n","\n","# Define preprocessing function\n","def preprocess_text(text, strategy='light'): \n"," if strategy == \"none\":\n"," text = text\n"," elif strategy == \"light\":\n"," text = text.encode(\"ascii\", \"ignore\").decode('ascii') \n"," text = text.strip()\n"," text = text.strip(\"\\\"\")\n"," for c in char_to_remove:\n"," text = text.replace(c, \"\")\n"," if text and text[-1] != \".\":\n"," text = text.split(\".\")\n"," text = \".\".join(text[:-1])\n"," text += \".\"\n"," else:\n"," text = unicodedata.normalize('NFKD', text).encode('ascii', 'ignore').decode('ascii')\n"," text = text.lower()\n"," text = re.sub(r'[^a-z0-9\\s.,;?!:()\\'\\\"%-]', '', text)\n"," text = re.sub(r'\\s+', ' ', text).strip()\n"," \n"," return text\n","\n","# Apply preprocessing with progress bar\n","tqdm.pandas(desc=\"Processing Text\")\n","train_df['text'] = train_df['text'].progress_apply(lambda x: preprocess_text(x, strategy='light'))\n","valid_df['text'] = valid_df['text'].progress_apply(lambda x: preprocess_text(x, strategy='light'))\n","\n","# Display the first few rows to verify\n","print(\"Trainging DF Processing\")\n","print(train_df.info())\n","print(\"Testing DF Processing\")\n","print(valid_df.info())\n","\n"]},{"cell_type":"markdown","metadata":{"papermill":{"duration":0.008127,"end_time":"2022-08-31T07:03:11.985369","exception":false,"start_time":"2022-08-31T07:03:11.977242","status":"completed"},"tags":[]},"source":["# | Dataset [↑](#top) \n","\n","***\n","\n"," \n","We need to get the `max_len` from our `tokenizer`. We create a `tqdm` iterator and for each text we extract the tokenized length. Then we get the maximum value and we add 3 for the special tokens `CLS`, `SEP`, `SEP`.\n","\n","- [Hugging Face Padding and Truncation](https://huggingface.co/docs/transformers/pad_truncation): check truncation to `max_length` or `True` (batch max length)."]},{"cell_type":"markdown","metadata":{},"source":["One sample from the dataset should look as following:\n","```python\n","{\n","\t'inputs': {\n","\t\t'input_ids': tensor([1, 279, 883, ..., 0, 0]),\n","\t\t'token_type_ids': tensor([0, 0, 0, ..., 0, 0]),\n","\t\t'attention_mask': tensor([1, 1, 1, ..., 0, 0])\n","\t},\n","\t'label': tensor([0.0]),\n","\t'ids': '000e8c3c7ddb'\n","}\n","```\n","You can check it by running the cell below."]},{"cell_type":"markdown","metadata":{},"source":["import wandb\n","# Định nghĩa tên project để log thông tin quá trình huấn luyện trên wandb\n","os.environ[\"WANDB_PROJECT\"] = \"mamba_LLM_detect_binary_classification\"\n","os.environ[\"WANDB_API_KEY \"] = \"e7432690ce6d9bfdee410567f89d7e38844ed584\"\n","\n","\n","wandb.login()\n","# start a new wandb run to track this script\n","wandb.init(\n"," # set the wandb project where this run will be logged\n"," project=\"mamba_LLM_detect_binary_classification\",\n","\n"," # track hyperparameters and run metadata\n"," config={\n"," \"learning_rate\": 6e-5,\n"," \"architecture\": \"Mamba-130m-with-Linear-Head\",\n"," \"dataset\": \"Test\",\n"," \"epochs\": 1,\n"," \"lr_scheduler_type\": \"cosine\"\n"," }\n",")"]},{"cell_type":"markdown","metadata":{"papermill":{"duration":0.008073,"end_time":"2022-08-31T07:03:17.933189","exception":false,"start_time":"2022-08-31T07:03:17.925116","status":"completed"},"tags":[]},"source":["# | Model [↑](#top) \n","\n","***"]},{"cell_type":"code","execution_count":4,"metadata":{"execution":{"iopub.execute_input":"2024-09-19T17:05:53.357180Z","iopub.status.busy":"2024-09-19T17:05:53.356739Z","iopub.status.idle":"2024-09-19T17:05:53.373217Z","shell.execute_reply":"2024-09-19T17:05:53.372238Z","shell.execute_reply.started":"2024-09-19T17:05:53.357130Z"},"trusted":true},"outputs":[{"data":{"text/html":["
\n","\n","\n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n","
idprompt_idtextgenerated
0e_ddxvqx5i0In recent years, there has been a growing move...1
1e_hi0yzrcv0\\nWhy not cars in our life\\n\\nI have ever met ...1
2e_uesv4xha0A car is considered by many a nessecity for ev...1
3e_2tl5ylwy0H\\n\\nello fellow citezens , we are here to inf...0
4e_s6ci4vj00Have you ever known how if feels not being abl...1
\n","
"],"text/plain":[" id prompt_id text generated\n","0 e_ddxvqx5i 0 In recent years, there has been a growing move... 1\n","1 e_hi0yzrcv 0 \\nWhy not cars in our life\\n\\nI have ever met ... 1\n","2 e_uesv4xha 0 A car is considered by many a nessecity for ev... 1\n","3 e_2tl5ylwy 0 H\\n\\nello fellow citezens , we are here to inf... 0\n","4 e_s6ci4vj0 0 Have you ever known how if feels not being abl... 1"]},"execution_count":4,"metadata":{},"output_type":"execute_result"}],"source":["train_df.head()"]},{"cell_type":"code","execution_count":5,"metadata":{"execution":{"iopub.execute_input":"2024-09-19T17:05:53.374547Z","iopub.status.busy":"2024-09-19T17:05:53.374259Z","iopub.status.idle":"2024-09-19T17:05:56.661369Z","shell.execute_reply":"2024-09-19T17:05:56.660369Z","shell.execute_reply.started":"2024-09-19T17:05:53.374516Z"},"trusted":true},"outputs":[{"data":{"text/plain":["DatasetDict({\n"," train: Dataset({\n"," features: ['id', 'prompt_id', 'text', 'labels'],\n"," num_rows: 165767\n"," })\n"," test: Dataset({\n"," features: ['id', 'prompt_id', 'text', 'labels'],\n"," num_rows: 1679\n"," })\n","})"]},"execution_count":5,"metadata":{},"output_type":"execute_result"}],"source":["import pandas as pd\n","from datasets import Dataset, DatasetDict\n","from sklearn.model_selection import train_test_split\n","\n","# Assuming train_df is your DataFrame with a 'text' column\n","# Convert the 'id' column to a string to avoid ArrowTypeError\n","# df['id'] = df['id'].astype(str)\n","\n","# Rename the 'generated' column to 'labels'\n","train_df.rename(columns={'generated': 'labels'}, inplace=True)\n","valid_df.rename(columns={'generated': 'labels'}, inplace=True)\n","\n","# # Access the train and test datasets\n","# train_dataset, test_dataset = train_test_split(df, test_size=0.05)\n","\n","# Combine the splits into a DatasetDict\n","dataset_dict = DatasetDict({\n"," 'train': Dataset.from_pandas(train_df),\n"," 'test': Dataset.from_pandas(valid_df),\n","})\n","\n","# Display the first example from each dataset\n","dataset_dict"]},{"cell_type":"code","execution_count":6,"metadata":{"execution":{"iopub.execute_input":"2024-09-19T17:05:56.662755Z","iopub.status.busy":"2024-09-19T17:05:56.662437Z","iopub.status.idle":"2024-09-19T17:08:50.995345Z","shell.execute_reply":"2024-09-19T17:08:50.994344Z","shell.execute_reply.started":"2024-09-19T17:05:56.662708Z"},"trusted":true},"outputs":[{"name":"stderr","output_type":"stream","text":["Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n","Map: 100%|██████████| 165767/165767 [02:04<00:00, 1335.75 examples/s]\n","Map: 100%|██████████| 1679/1679 [00:01<00:00, 1349.49 examples/s]\n"]}],"source":["tokenizer = AutoTokenizer.from_pretrained('EleutherAI/gpt-neox-20b')\n","# Add eos tokens\n","# tokenizer.eos_token = \"<|endoftext|>\"\n","tokenizer.pad_token = tokenizer.eos_token\n","def preprocess_function(examples):\n"," # Tokenize the text with truncation\n"," samples = tokenizer(examples['text'], \n"," truncation=True, \n"," padding='max_length', \n"," max_length=1024, \n"," return_tensors=\"pt\")\n"," \n"," return samples\n","\n","# Apply preprocessing to the dataset\n","tokenized_dataset = dataset_dict.map(preprocess_function, batched=True)\n"]},{"cell_type":"code","execution_count":7,"metadata":{},"outputs":[{"data":{"text/plain":["DatasetDict({\n"," train: Dataset({\n"," features: ['id', 'prompt_id', 'text', 'labels', 'input_ids', 'attention_mask'],\n"," num_rows: 165767\n"," })\n"," test: Dataset({\n"," features: ['id', 'prompt_id', 'text', 'labels', 'input_ids', 'attention_mask'],\n"," num_rows: 1679\n"," })\n","})"]},"execution_count":7,"metadata":{},"output_type":"execute_result"}],"source":["tokenized_dataset"]},{"cell_type":"code","execution_count":8,"metadata":{"execution":{"iopub.execute_input":"2024-09-19T17:08:50.996892Z","iopub.status.busy":"2024-09-19T17:08:50.996548Z","iopub.status.idle":"2024-09-19T17:08:51.011804Z","shell.execute_reply":"2024-09-19T17:08:51.010782Z","shell.execute_reply.started":"2024-09-19T17:08:50.996858Z"},"trusted":true},"outputs":[],"source":["# Set seed cho hàm random\n","random.seed(42)\n","\n","# Tạo tập train và test\n","train_dataset = tokenized_dataset[\"train\"]\n","test_dataset = tokenized_dataset[\"test\"]\n","# Drop the 'prompt_id' feature from both datasets\n","train_dataset = train_dataset.remove_columns([\"text\"]).remove_columns([\"id\"])\n","test_dataset = test_dataset.remove_columns([\"text\"]).remove_columns([\"id\"])\n","\n","# Tạo tập evaluation để đánh giá trong lúc train\n","# Do số lượng tập test lớn nên chỉ lấy mẫu 1% tập dữ liệu test để đánh giá\n","# total_samples = len(test_dataset)\n","# eval_samples = int(0.5 * total_samples)\n","# eval_indices = random.sample(range(total_samples), eval_samples)\n","# eval_dataset = test_dataset.select(eval_indices)"]},{"cell_type":"code","execution_count":9,"metadata":{"execution":{"iopub.execute_input":"2024-09-19T17:08:51.013411Z","iopub.status.busy":"2024-09-19T17:08:51.013088Z","iopub.status.idle":"2024-09-19T17:10:28.712704Z","shell.execute_reply":"2024-09-19T17:10:28.711832Z","shell.execute_reply.started":"2024-09-19T17:08:51.013380Z"},"trusted":true},"outputs":[],"source":["import torch\n","import numpy as np\n","from transformers import Trainer, TrainingArguments, DataCollatorWithPadding\n","# Load the model\n","import torch\n","from xlstm import (\n"," xLSTMBlockStack,\n"," xLSTMBlockStackConfig,\n"," mLSTMBlockConfig,\n"," mLSTMLayerConfig,\n"," sLSTMBlockConfig,\n"," sLSTMLayerConfig,\n"," FeedForwardConfig,\n",")\n","\n","# Dataset and Tokenizer Setup\n","data_collator = DataCollatorWithPadding(tokenizer=tokenizer)\n"]},{"cell_type":"code","execution_count":10,"metadata":{},"outputs":[{"name":"stdout","output_type":"stream","text":["{'verbose': True, 'with_cuda': True, 'extra_ldflags': ['-L/home/ea301b/anaconda3/envs/binh_mamba/lib', '-lcublas'], 'extra_cflags': ['-DSLSTM_HIDDEN_SIZE=1024', '-DSLSTM_BATCH_SIZE=8', '-DSLSTM_NUM_HEADS=4', '-DSLSTM_NUM_STATES=4', '-DSLSTM_DTYPE_B=float', '-DSLSTM_DTYPE_R=__nv_bfloat16', '-DSLSTM_DTYPE_W=__nv_bfloat16', '-DSLSTM_DTYPE_G=__nv_bfloat16', '-DSLSTM_DTYPE_S=__nv_bfloat16', '-DSLSTM_DTYPE_A=float', '-DSLSTM_NUM_GATES=4', '-DSLSTM_SIMPLE_AGG=true', '-DSLSTM_GRADIENT_RECURRENT_CLIPVAL_VALID=false', '-DSLSTM_GRADIENT_RECURRENT_CLIPVAL=0.0', '-DSLSTM_FORWARD_CLIPVAL_VALID=false', '-DSLSTM_FORWARD_CLIPVAL=0.0', '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_BFLOAT16_OPERATORS__', '-U__CUDA_NO_BFLOAT16_CONVERSIONS__', '-U__CUDA_NO_BFLOAT162_OPERATORS__', '-U__CUDA_NO_BFLOAT162_CONVERSIONS__'], 'extra_cuda_cflags': ['-Xptxas=\"-v\"', '-gencode', 'arch=compute_80,code=compute_80', '-res-usage', '--use_fast_math', '-O3', '-Xptxas -O3', '--extra-device-vectorization', '-DSLSTM_HIDDEN_SIZE=1024', '-DSLSTM_BATCH_SIZE=8', '-DSLSTM_NUM_HEADS=4', '-DSLSTM_NUM_STATES=4', '-DSLSTM_DTYPE_B=float', '-DSLSTM_DTYPE_R=__nv_bfloat16', '-DSLSTM_DTYPE_W=__nv_bfloat16', '-DSLSTM_DTYPE_G=__nv_bfloat16', '-DSLSTM_DTYPE_S=__nv_bfloat16', '-DSLSTM_DTYPE_A=float', '-DSLSTM_NUM_GATES=4', '-DSLSTM_SIMPLE_AGG=true', '-DSLSTM_GRADIENT_RECURRENT_CLIPVAL_VALID=false', '-DSLSTM_GRADIENT_RECURRENT_CLIPVAL=0.0', '-DSLSTM_FORWARD_CLIPVAL_VALID=false', '-DSLSTM_FORWARD_CLIPVAL=0.0', '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_BFLOAT16_OPERATORS__', '-U__CUDA_NO_BFLOAT16_CONVERSIONS__', '-U__CUDA_NO_BFLOAT162_OPERATORS__', '-U__CUDA_NO_BFLOAT162_CONVERSIONS__']}\n"]},{"name":"stderr","output_type":"stream","text":["Using /home/ea301b/.cache/torch_extensions/py312_cu124 as PyTorch extensions root...\n","Detected CUDA files, patching ldflags\n","Emitting ninja build file /home/ea301b/.cache/torch_extensions/py312_cu124/slstm_HS1024BS8NH4NS4DBfDRbDWbDGbDSbDAfNG4SA1GRCV0GRC0d0FCV0FC0d0/build.ninja...\n","Building extension module slstm_HS1024BS8NH4NS4DBfDRbDWbDGbDSbDAfNG4SA1GRCV0GRC0d0FCV0FC0d0...\n","Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)\n","Loading extension module slstm_HS1024BS8NH4NS4DBfDRbDWbDGbDSbDAfNG4SA1GRCV0GRC0d0FCV0FC0d0...\n"]},{"name":"stdout","output_type":"stream","text":["ninja: no work to do.\n"]},{"data":{"text/plain":["True"]},"execution_count":10,"metadata":{},"output_type":"execute_result"}],"source":["from omegaconf import OmegaConf\n","from dacite import from_dict\n","from dacite import Config as DaciteConfig\n","from xlstm import xLSTMLMModel, xLSTMLMModelConfig\n","\n","xlstm_cfg = \"\"\" \n","vocab_size: 50304\n","mlstm_block:\n"," mlstm:\n"," conv1d_kernel_size: 4\n"," qkv_proj_blocksize: 4\n"," num_heads: 4\n","slstm_block:\n"," slstm:\n"," backend: cuda\n"," num_heads: 4\n"," conv1d_kernel_size: 4\n"," bias_init: powerlaw_blockdependent\n"," feedforward:\n"," proj_factor: 1.3\n"," act_fn: gelu\n","context_length: 1024\n","num_blocks: 16\n","embedding_dim: 1024\n","slstm_at: [1]\n","\"\"\"\n","cfg = OmegaConf.create(xlstm_cfg)\n","cfg = from_dict(data_class=xLSTMLMModelConfig, data=OmegaConf.to_container(cfg), config=DaciteConfig(strict=True))\n","xlstm_stack = xLSTMLMModel(cfg)\n","\n","x = torch.randint(0, 50304, size=(4, 256)).to(\"cuda\")\n","xlstm_stack = xlstm_stack.to(\"cuda\")\n","y = xlstm_stack(x)\n","y.shape[1:] == (256, 50304)\n","# model = xlstm_stack.lm_head"]},{"cell_type":"code","execution_count":11,"metadata":{},"outputs":[{"data":{"text/plain":["xLSTMLMModel(\n"," (xlstm_block_stack): xLSTMBlockStack(\n"," (blocks): ModuleList(\n"," (0): mLSTMBlock(\n"," (xlstm_norm): LayerNorm()\n"," (xlstm): mLSTMLayer(\n"," (proj_up): Linear(in_features=1024, out_features=4096, bias=False)\n"," (q_proj): LinearHeadwiseExpand(in_features=2048, num_heads=512, expand_factor_up=1, bias=False, trainable_weight=True, trainable_bias=True, )\n"," (k_proj): LinearHeadwiseExpand(in_features=2048, num_heads=512, expand_factor_up=1, bias=False, trainable_weight=True, trainable_bias=True, )\n"," (v_proj): LinearHeadwiseExpand(in_features=2048, num_heads=512, expand_factor_up=1, bias=False, trainable_weight=True, trainable_bias=True, )\n"," (conv1d): CausalConv1d(\n"," (conv): Conv1d(2048, 2048, kernel_size=(4,), stride=(1,), padding=(3,), groups=2048)\n"," )\n"," (conv_act_fn): SiLU()\n"," (mlstm_cell): mLSTMCell(\n"," (igate): Linear(in_features=6144, out_features=4, bias=True)\n"," (fgate): Linear(in_features=6144, out_features=4, bias=True)\n"," (outnorm): MultiHeadLayerNorm()\n"," )\n"," (ogate_act_fn): SiLU()\n"," (proj_down): Linear(in_features=2048, out_features=1024, bias=False)\n"," (dropout): Dropout(p=0.0, inplace=False)\n"," )\n"," )\n"," (1): sLSTMBlock(\n"," (xlstm_norm): LayerNorm()\n"," (xlstm): sLSTMLayer(\n"," (conv1d): CausalConv1d(\n"," (conv): Conv1d(1024, 1024, kernel_size=(4,), stride=(1,), padding=(3,), groups=1024)\n"," )\n"," (conv_act_fn): SiLU()\n"," (fgate): LinearHeadwiseExpand(in_features=1024, num_heads=4, expand_factor_up=1, bias=False, trainable_weight=True, trainable_bias=True, )\n"," (igate): LinearHeadwiseExpand(in_features=1024, num_heads=4, expand_factor_up=1, bias=False, trainable_weight=True, trainable_bias=True, )\n"," (zgate): LinearHeadwiseExpand(in_features=1024, num_heads=4, expand_factor_up=1, bias=False, trainable_weight=True, trainable_bias=True, )\n"," (ogate): LinearHeadwiseExpand(in_features=1024, num_heads=4, expand_factor_up=1, bias=False, trainable_weight=True, trainable_bias=True, )\n"," (slstm_cell): sLSTMCell_cuda(function=slstm, hidden_size=1024, num_heads=4)\n"," (group_norm): MultiHeadLayerNorm()\n"," (dropout): Dropout(p=0.0, inplace=False)\n"," )\n"," (ffn_norm): LayerNorm()\n"," (ffn): GatedFeedForward(\n"," (proj_up): Linear(in_features=1024, out_features=2688, bias=False)\n"," (proj_down): Linear(in_features=1344, out_features=1024, bias=False)\n"," (dropout): Dropout(p=0.0, inplace=False)\n"," )\n"," )\n"," (2-15): 14 x mLSTMBlock(\n"," (xlstm_norm): LayerNorm()\n"," (xlstm): mLSTMLayer(\n"," (proj_up): Linear(in_features=1024, out_features=4096, bias=False)\n"," (q_proj): LinearHeadwiseExpand(in_features=2048, num_heads=512, expand_factor_up=1, bias=False, trainable_weight=True, trainable_bias=True, )\n"," (k_proj): LinearHeadwiseExpand(in_features=2048, num_heads=512, expand_factor_up=1, bias=False, trainable_weight=True, trainable_bias=True, )\n"," (v_proj): LinearHeadwiseExpand(in_features=2048, num_heads=512, expand_factor_up=1, bias=False, trainable_weight=True, trainable_bias=True, )\n"," (conv1d): CausalConv1d(\n"," (conv): Conv1d(2048, 2048, kernel_size=(4,), stride=(1,), padding=(3,), groups=2048)\n"," )\n"," (conv_act_fn): SiLU()\n"," (mlstm_cell): mLSTMCell(\n"," (igate): Linear(in_features=6144, out_features=4, bias=True)\n"," (fgate): Linear(in_features=6144, out_features=4, bias=True)\n"," (outnorm): MultiHeadLayerNorm()\n"," )\n"," (ogate_act_fn): SiLU()\n"," (proj_down): Linear(in_features=2048, out_features=1024, bias=False)\n"," (dropout): Dropout(p=0.0, inplace=False)\n"," )\n"," )\n"," )\n"," (post_blocks_norm): LayerNorm()\n"," )\n"," (token_embedding): Embedding(50304, 1024)\n"," (emb_dropout): Identity()\n"," (lm_head): Linear(in_features=1024, out_features=2, bias=True)\n",")"]},"execution_count":11,"metadata":{},"output_type":"execute_result"}],"source":["xlstm_stack.lm_head = nn.Linear(xlstm_stack.lm_head.in_features, 2)\n","model = xlstm_stack\n","model.cuda()"]},{"cell_type":"code","execution_count":12,"metadata":{},"outputs":[],"source":["from transformers import DataCollatorWithPadding, AutoConfig\n","\n","# Dataset and Tokenizer Setup\n","data_collator = DataCollatorWithPadding(tokenizer=tokenizer)"]},{"cell_type":"code","execution_count":13,"metadata":{},"outputs":[],"source":["# import torch\n","# import numpy as np\n","# import wandb # Weights & Biases integration\n","# from torch import nn\n","# import torch\n","# from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, roc_auc_score\n","# from typing import Dict, Union\n","# import torch\n","# from transformers import (\n","# DataCollatorWithPadding, \n","# AdamW, \n","# Trainer, \n","# TrainingArguments,\n","# get_cosine_schedule_with_warmup,\n","# TrainerCallback\n","# )\n","# from torch.utils.data import DataLoader\n","# from huggingface_hub import login # For pushing to the Hugging Face Hub\n","\n","# # Authenticate Hugging Face API token\n","# # Make sure you've logged in before running the script\n","# login(token=\"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\")\n","# # Initialize wandb run\n","# wandb.init(project=\"Detect AI Generated Text\", \n","# name=\"xLSTM-base\",\n","# config={\n","# \"learning_rate\": 1e-4,\n","# \"label_smoothing\": 0.03,\n","# \"batch_size\": 8,\n","# \"num_epochs\": 1,\n","# \"optimizer\": \"AdamW\",\n","# \"model\": 'xLSTM',\n","# \"model_params\": sum(p.numel() for p in xlstm_stack.parameters() if p.requires_grad)\n","# })\n","\n"," \n","# # Access the configuration\n","# config = wandb.config\n","\n","# # Now you can call config values like this\n","# learning_rate = config.learning_rate\n","# label_smoothing = config.label_smoothing\n","# batch_size = config.batch_size\n","\n","# # Data Collator Setup\n","# data_collator = DataCollatorWithPadding(tokenizer=tokenizer)\n","\n","# # Dataloader Setup\n","# train_data_loader = DataLoader(\n","# train_dataset, \n","# batch_size=batch_size, # Increased batch size since it will be split across GPUs\n","# num_workers=4, \n","# shuffle=True, \n","# pin_memory=True, \n","# collate_fn=data_collator\n","# )\n","\n","# test_data_loader = DataLoader(\n","# test_dataset, \n","# batch_size=2, # Increased batch size\n","# num_workers=4, \n","# shuffle=False, \n","# pin_memory=True, \n","# collate_fn=data_collator\n","# )\n","\n","# # Optimizer Setup\n","# optimizer = AdamW(\n","# xlstm_stack.parameters(),\n","# lr=learning_rate, # Define your learning_rate\n","# weight_decay=0.1\n","# )\n","\n","# # Scheduler Setup (Cosine Annealing)\n","# total_train_steps = len(train_data_loader) * 1 # Adjust based on your epochs\n","# lr_scheduler = get_cosine_schedule_with_warmup(\n","# optimizer,\n","# num_warmup_steps=500, # Can adjust based on needs\n","# num_training_steps=total_train_steps\n","# )\n","\n","\n","# def compute_metrics(eval_pred):\n","# \"\"\"\n","# Compute metrics for Hugging Face Trainer, including AUROC.\n","\n","# Args:\n","# eval_pred: tuple of (predictions, labels) where predictions are logits.\n","\n","# Returns:\n","# dictionary containing the computed metrics, including AUROC.\n","# \"\"\"\n","# # Unpack predictions and labels\n","# logits, labels = eval_pred\n","# preds = logits.argmax(-1) # Get the predicted class\n","\n","# # Calculate accuracy\n","# accuracy = accuracy_score(labels, preds)\n","\n","# # Calculate precision, recall, and F1-score\n","# precision = precision_score(labels, preds, average='weighted')\n","# recall = recall_score(labels, preds, average='weighted')\n","# f1 = f1_score(labels, preds, average='weighted')\n","\n","# # Calculate probabilities using softmax on logits (not on preds)\n","# probs = torch.softmax(torch.tensor(logits), dim=-1).numpy()\n","# # For binary classification, take the probability of the positive class (class 1)\n","# auroc = roc_auc_score(labels, probs[:, 1])\n","\n","\n","# return {\n","# 'accuracy': accuracy,\n","# 'precision': precision,\n","# 'recall': recall,\n","# 'f1': f1,\n","# 'auroc': auroc\n","# }\n","\n","\n","\n","# # Training Arguments Setup\n","# training_args = TrainingArguments(\n","# output_dir=\"./results\", # Directory to save model checkpoints\n","# evaluation_strategy=\"steps\", # Evaluate every few steps\n","# # eval_steps=, # Evaluate every 1000 steps\n","# per_device_train_batch_size=batch_size, # Adjust batch size per device (GPU)\n","# per_device_eval_batch_size=4, # Same for evaluation\n","# num_train_epochs=1, # Define total number of epochs\n","# weight_decay=0.1, # L2 regularization\n","# logging_dir=\"./logs\", # Log directory\n","# fp16=False, # Use mixed precision training\n","# save_steps=2000, # Save model every 20000 steps\n","# label_smoothing_factor=0.03,\n","# hub_model_id=\"xLSTM-4-1\", # Set model name for HF Hub\n","# push_to_hub=True, # Push to Hugging Face Hub\n","# save_total_limit=2, # Only keep the last 2 checkpoints,\n","# metric_for_best_model=\"eval_auroc\", # Use AUROC to determine best model\n","# greater_is_better=True, # Higher AUROC is better\n","# max_grad_norm=1,\n","# report_to=\"wandb\", # Report metrics to Weights & Biases\n","# )\n","\n","# # Initialize the Trainer\n","# trainer = Trainer(\n","# model=model,\n","# args=training_args,\n","# train_dataset=train_dataset, # Replace with your actual training dataset\n","# eval_dataset=test_dataset, # Replace with your actual evaluation dataset\n","# tokenizer=tokenizer,\n","# data_collator=data_collator,\n","# optimizers=(optimizer, lr_scheduler), # Pass the optimizer and scheduler\n","# compute_metrics=compute_metrics # Optional custom metric computation\n","# )\n","\n","# # Training and evaluation\n","# trainer.train()\n","# trainer.evaluate()\n","\n","# # Push to Hub\n","# trainer.push_to_hub()\n","\n","# # Finish wandb logging\n","# wandb.finish()\n"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":[]},{"cell_type":"code","execution_count":14,"metadata":{"execution":{"iopub.execute_input":"2024-09-19T17:10:28.716075Z","iopub.status.busy":"2024-09-19T17:10:28.715683Z","iopub.status.idle":"2024-09-19T17:10:28.726787Z","shell.execute_reply":"2024-09-19T17:10:28.726000Z","shell.execute_reply.started":"2024-09-19T17:10:28.716039Z"},"trusted":true},"outputs":[],"source":["import torch\n","import numpy as np\n","from tqdm import tqdm\n","from sklearn.metrics import roc_auc_score, accuracy_score\n","\n","# Accuracy Calculation\n","def compute_accuracy(labels, predictions):\n"," preds = torch.argmax(predictions, dim=1)\n"," correct = torch.sum(preds == labels)\n"," return correct.item() / len(labels)\n","\n","def TestModel(test_data_loader, model, criterion):\n"," test_losses = []\n"," all_predictions = []\n"," all_actual_values = []\n"," \n"," with torch.no_grad():\n"," for batch in tqdm(test_data_loader):\n"," if len(batch.input_ids) == 0:\n"," # Safeguard against empty sequences.\n"," continue\n","\n"," # Have shape (batch size, token count)\n"," token_sequences = batch.input_ids.cuda()\n"," attention_masks = batch.attention_mask.cuda()\n"," # Has shape (batch size)\n"," labels = batch.labels.cuda()\n","\n"," with torch.cuda.amp.autocast():\n"," output = model(token_sequences)\n"," raw_predictions = output[:, -1, :]\n","\n"," loss = criterion(raw_predictions.view(-1, 2), labels.view(-1))\n"," test_losses.append(loss.detach().cpu())\n","\n"," scaled_predictions = raw_predictions.softmax(dim=1)[:, 1]\n"," all_predictions.extend(scaled_predictions.cpu().numpy())\n"," all_actual_values.extend(labels.cpu().numpy())\n","\n"," all_predictions, all_actual_values = np.array(all_predictions), np.array(all_actual_values)\n","\n"," auroc = roc_auc_score(all_actual_values, all_predictions)\n"," \n"," # Binarize predictions and compute accuracy\n"," binary_predictions = (all_predictions > 0.7).astype(int)\n"," accuracy = accuracy_score(all_actual_values, binary_predictions)\n"," \n"," return accuracy, auroc, np.mean(test_losses)\n"]},{"cell_type":"code","execution_count":15,"metadata":{"collapsed":true,"execution":{"iopub.execute_input":"2024-09-19T17:10:28.728787Z","iopub.status.busy":"2024-09-19T17:10:28.728385Z","iopub.status.idle":"2024-09-20T03:37:05.179864Z","shell.execute_reply":"2024-09-20T03:37:05.178362Z","shell.execute_reply.started":"2024-09-19T17:10:28.728728Z"},"jupyter":{"outputs_hidden":true},"trusted":true},"outputs":[{"name":"stderr","output_type":"stream","text":["\u001b[34m\u001b[1mwandb\u001b[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.\n","\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mtruonggiabjnh2003\u001b[0m (\u001b[33mtruonggiabjnh2003-fpt-university\u001b[0m). Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n"]},{"data":{"text/html":["Tracking run with wandb version 0.18.5"],"text/plain":[""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["Run data is saved locally in /home/HardDisk/binh230_intern/Mamba-AI-generated-text-detection/train/finetune_model/wandb/run-20241026_104913-w6ldiog1"],"text/plain":[""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["Syncing run experiment_3090_1 to Weights & Biases (docs)
"],"text/plain":[""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":[" View project at https://wandb.ai/truonggiabjnh2003-fpt-university/your_project_name"],"text/plain":[""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":[" View run at https://wandb.ai/truonggiabjnh2003-fpt-university/your_project_name/runs/w6ldiog1"],"text/plain":[""]},"metadata":{},"output_type":"display_data"},{"name":"stdout","output_type":"stream","text":["lr = 0.0003, label_smoothing = 0.03, output_subdir = 3090_1\n"]},{"name":"stderr","output_type":"stream","text":[" 0%| | 0/20721 [00:00 98\u001b[0m \u001b[43mscaler\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mscale\u001b[49m\u001b[43m(\u001b[49m\u001b[43mloss\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbackward\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 99\u001b[0m scaler\u001b[38;5;241m.\u001b[39munscale_(optimizer)\n\u001b[1;32m 100\u001b[0m torch\u001b[38;5;241m.\u001b[39mnn\u001b[38;5;241m.\u001b[39mutils\u001b[38;5;241m.\u001b[39mclip_grad_norm_(model\u001b[38;5;241m.\u001b[39mparameters(), \u001b[38;5;241m1.0\u001b[39m)\n","File \u001b[0;32m~/anaconda3/envs/binh_mamba/lib/python3.12/site-packages/torch/_tensor.py:521\u001b[0m, in \u001b[0;36mTensor.backward\u001b[0;34m(self, gradient, retain_graph, create_graph, inputs)\u001b[0m\n\u001b[1;32m 511\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m has_torch_function_unary(\u001b[38;5;28mself\u001b[39m):\n\u001b[1;32m 512\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m handle_torch_function(\n\u001b[1;32m 513\u001b[0m Tensor\u001b[38;5;241m.\u001b[39mbackward,\n\u001b[1;32m 514\u001b[0m (\u001b[38;5;28mself\u001b[39m,),\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 519\u001b[0m inputs\u001b[38;5;241m=\u001b[39minputs,\n\u001b[1;32m 520\u001b[0m )\n\u001b[0;32m--> 521\u001b[0m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mautograd\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbackward\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 522\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mgradient\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mretain_graph\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcreate_graph\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minputs\u001b[49m\n\u001b[1;32m 523\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n","File \u001b[0;32m~/anaconda3/envs/binh_mamba/lib/python3.12/site-packages/torch/autograd/__init__.py:289\u001b[0m, in \u001b[0;36mbackward\u001b[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)\u001b[0m\n\u001b[1;32m 284\u001b[0m retain_graph \u001b[38;5;241m=\u001b[39m create_graph\n\u001b[1;32m 286\u001b[0m \u001b[38;5;66;03m# The reason we repeat the same comment below is that\u001b[39;00m\n\u001b[1;32m 287\u001b[0m \u001b[38;5;66;03m# some Python versions print out the first line of a multi-line function\u001b[39;00m\n\u001b[1;32m 288\u001b[0m \u001b[38;5;66;03m# calls in the traceback and some print out the last line\u001b[39;00m\n\u001b[0;32m--> 289\u001b[0m \u001b[43m_engine_run_backward\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 290\u001b[0m \u001b[43m \u001b[49m\u001b[43mtensors\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 291\u001b[0m \u001b[43m \u001b[49m\u001b[43mgrad_tensors_\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 292\u001b[0m \u001b[43m \u001b[49m\u001b[43mretain_graph\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 293\u001b[0m \u001b[43m \u001b[49m\u001b[43mcreate_graph\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 294\u001b[0m \u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 295\u001b[0m \u001b[43m \u001b[49m\u001b[43mallow_unreachable\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 296\u001b[0m \u001b[43m \u001b[49m\u001b[43maccumulate_grad\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 297\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n","File \u001b[0;32m~/anaconda3/envs/binh_mamba/lib/python3.12/site-packages/torch/autograd/graph.py:768\u001b[0m, in \u001b[0;36m_engine_run_backward\u001b[0;34m(t_outputs, *args, **kwargs)\u001b[0m\n\u001b[1;32m 766\u001b[0m unregister_hooks \u001b[38;5;241m=\u001b[39m _register_logging_hooks_on_whole_graph(t_outputs)\n\u001b[1;32m 767\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 768\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mVariable\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_execution_engine\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun_backward\u001b[49m\u001b[43m(\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;66;43;03m# Calls into the C++ engine to run the backward pass\u001b[39;49;00m\n\u001b[1;32m 769\u001b[0m \u001b[43m \u001b[49m\u001b[43mt_outputs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\n\u001b[1;32m 770\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;66;03m# Calls into the C++ engine to run the backward pass\u001b[39;00m\n\u001b[1;32m 771\u001b[0m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[1;32m 772\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m attach_logging_hooks:\n","\u001b[0;31mKeyboardInterrupt\u001b[0m: "]}],"source":["import wandb\n","import torch\n","import torch.nn as nn\n","from torch.utils.data import DataLoader\n","from tqdm import tqdm\n","from transformers import Adafactor, AdamW\n","\n","\n","\n","# Variables for the experiment\n","label_smoothing = 0.03\n","output_subdir = '3090_1'\n","max_learning_rates = [3e-4]\n","\n","\n","# Initialize Weights & Biases\n","wandb.init(project=\"your_project_name\", name=\"experiment_3090_1\", config={\n"," \"label_smoothing\": label_smoothing,\n"," \"output_subdir\": output_subdir,\n"," \"max_learning_rate\": max_learning_rates[0],\n"," \"batch_size\": 8\n","})\n","\n","\n","\n","# Run experiment\n","for max_learning_rate in max_learning_rates:\n"," print(f'lr = {max_learning_rate}, label_smoothing = {label_smoothing}, output_subdir = {output_subdir}')\n"," \n"," # Dataloader Setup\n"," data_collator = DataCollatorWithPadding(tokenizer=tokenizer)\n"," train_data_loader = DataLoader(\n"," train_dataset, \n"," batch_size=8,\n"," num_workers=4, \n"," shuffle=True, \n"," pin_memory=True, \n"," collate_fn=data_collator\n"," )\n"," test_data_loader = DataLoader(\n"," test_dataset, \n"," batch_size=8,\n"," num_workers=4, \n"," shuffle=False, \n"," pin_memory=True, \n"," collate_fn=data_collator\n"," )\n","\n"," # Optimizer, Criterion, and Scaler Setup\n"," optimizer = AdamW(\n"," model.parameters(),\n"," lr=max_learning_rate,\n"," weight_decay=0.1\n"," )\n"," criterion = nn.CrossEntropyLoss(label_smoothing=label_smoothing)\n"," scaler = torch.cuda.amp.GradScaler(enabled=True)\n","\n"," total_step_count = len(train_data_loader)\n"," lr_schedule = torch.optim.lr_scheduler.OneCycleLR(\n"," optimizer=optimizer,\n"," max_lr=max_learning_rate,\n"," total_steps=total_step_count,\n"," pct_start=0.1,\n"," anneal_strategy='linear',\n"," cycle_momentum=False\n"," )\n","\n"," best_auroc = -99999999\n"," train_losses = []\n"," model.train()\n","\n"," # Tracking the number of rows processed\n"," total_rows_processed = 0\n"," row_threshold = 50000\n","\n"," print_steps = 500\n","\n"," for batch_index, train_batch in enumerate(tqdm(train_data_loader)):\n"," if len(train_batch.input_ids) == 0:\n"," continue\n","\n"," # Send data to GPU(s)\n"," token_sequences = train_batch.input_ids.to(\"cuda\")\n"," attention_masks = train_batch.attention_mask.to(\"cuda\")\n"," labels = train_batch.labels.to(\"cuda\")\n","\n"," optimizer.zero_grad()\n","\n"," with torch.cuda.amp.autocast():\n"," output = model(token_sequences)\n"," raw_predictions = output[:, -1, :]\n","\n"," loss = criterion(raw_predictions.view(-1, 2), labels.view(-1)) \n","\n"," # Training accuracy\n"," accuracy = compute_accuracy(labels.view(-1), raw_predictions.view(-1, 2))\n","\n"," scaler.scale(loss).backward()\n"," scaler.unscale_(optimizer)\n"," torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)\n"," scaler.step(optimizer)\n"," scaler.update()\n"," lr_schedule.step()\n","\n"," train_losses.append(loss.detach().cpu())\n","\n"," # Log training accuracy and loss every 500 steps\n"," if (batch_index + 1) % print_steps == 0:\n"," avg_train_loss = sum(train_losses) / len(train_losses)\n"," print(f\"Step {batch_index+1}/{total_step_count}: Avg Train Loss = {avg_train_loss:.4f}, Train Accuracy = {accuracy*100:.2f}%\")\n","\n"," # Log to Weights & Biases\n"," wandb.log({\"train_loss\": avg_train_loss, \"train_accuracy\": accuracy * 100})\n","\n"," train_losses = [] # Reset train loss tracking for the next 500 steps\n","\n"," # Increment the number of rows processed\n"," total_rows_processed += len(train_batch.input_ids)\n","\n"," # Evaluate the model every 50,000 rows\n"," if total_rows_processed >= row_threshold:\n"," model.eval()\n"," val_accuracy, auroc, test_loss = TestModel(test_data_loader, model, criterion)\n"," model.train()\n"," \n"," print(f'Validation Loss: {test_loss:.4f}, Validation Accuracy: {val_accuracy*100:.2f}%, AuROC Score: {auroc*100:.2f}%')\n"," \n"," # Log validation metrics to Weights & Biases\n"," wandb.log({\"val_loss\": test_loss, \"val_accuracy\": val_accuracy * 100, \"auroc\": auroc * 100})\n","\n"," total_rows_processed = 0 # Reset after each evaluation\n","\n"," # Save model if improved\n"," if auroc > best_auroc:\n"," best_auroc = auroc\n"," torch.save(model.state_dict(), f'./weight_models/xLSTM-Base-Val_Accuracy-{val_accuracy*100}%-AuROC_Score-{auroc*100}-Loss-{int(test_loss*1000)}.pth')\n","\n"," print(f'Training Finish !!!')\n","\n","# Finish the W&B run\n","# wandb.finish()\n"]},{"cell_type":"code","execution_count":14,"metadata":{"execution":{"iopub.execute_input":"2024-09-20T04:13:25.274416Z","iopub.status.busy":"2024-09-20T04:13:25.274045Z","iopub.status.idle":"2024-09-20T04:13:30.882150Z","shell.execute_reply":"2024-09-20T04:13:30.880422Z","shell.execute_reply.started":"2024-09-20T04:13:25.274380Z"},"trusted":true},"outputs":[],"source":["torch.save(model.state_dict(), f'./Models/Mamba-780m-Step-{batch_index+1}-Loss-{int(test_loss*1000)}.pth')\n"]},{"cell_type":"code","execution_count":null,"metadata":{"execution":{"iopub.status.busy":"2024-09-20T03:37:05.182217Z","iopub.status.idle":"2024-09-20T03:37:05.182548Z","shell.execute_reply":"2024-09-20T03:37:05.182399Z","shell.execute_reply.started":"2024-09-20T03:37:05.182382Z"},"trusted":true},"outputs":[],"source":["auroc_scores_by_dataset, test_loss"]},{"cell_type":"code","execution_count":null,"metadata":{"execution":{"iopub.status.busy":"2024-09-20T03:37:05.183718Z","iopub.status.idle":"2024-09-20T03:37:05.184069Z","shell.execute_reply":"2024-09-20T03:37:05.183920Z","shell.execute_reply.started":"2024-09-20T03:37:05.183903Z"},"trusted":true},"outputs":[],"source":["import torch\n","import numpy as np\n","from tqdm import tqdm\n","from sklearn.metrics import roc_auc_score\n","\n","model.eval()\n","auroc_scores_by_dataset, test_loss = TestModel(test_data_loader, model, criterion)\n","model.train()\n","\n","# average_auroc = np.average(auroc_scores_by_dataset, weights=[1, 1])\n","# if (average_auroc > best_auroc) or (max(auroc_scores_by_dataset) > 0.993):\n","# best_auroc = average_auroc\n","# if output_subdir is not None:\n","# torch.save(model.state_dict(), f'Models/Mamba/{output_subdir}/S{step_number}_CTX1024.pth')\n","\n","# train_losses = []"]},{"cell_type":"markdown","metadata":{},"source":["### Confusion Matrix\n"]},{"cell_type":"code","execution_count":null,"metadata":{"execution":{"iopub.status.busy":"2024-09-20T03:37:05.185888Z","iopub.status.idle":"2024-09-20T03:37:05.186360Z","shell.execute_reply":"2024-09-20T03:37:05.186134Z","shell.execute_reply.started":"2024-09-20T03:37:05.186109Z"},"trusted":true},"outputs":[],"source":["import matplotlib.pyplot as plt\n","import pandas as pd\n","import seaborn as sns\n","\n","from sklearn.metrics import confusion_matrix\n","\n","def binarize(x, threshold):\n"," if x > threshold:\n"," x = 1\n"," else:\n"," x = 0\n"," return x\n","\n","# Assuming df is your pandas DataFrame\n","oof_df[\"binary\"] = oof_df[\"preds\"].apply(lambda x: binarize(x, 0.5))\n","true_labels = oof_df[\"generated\"].values\n","predicted_labels = oof_df[\"binary\"].values\n","\n","# Get the unique classes from both true and predicted labels\n","classes = np.unique(np.concatenate((true_labels, predicted_labels)))\n","\n","# Compute the confusion matrix\n","cm = confusion_matrix(true_labels, predicted_labels, labels=classes)\n","plt.figure(figsize=(8, 6))\n","sns.heatmap(cm, annot=True, fmt=\"d\", cmap=\"Blues\", xticklabels=classes, yticklabels=classes)\n","plt.xlabel(\"Predicted Labels\")\n","plt.ylabel(\"True Labels\")\n","plt.title(\"Confusion Matrix\")"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":[]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":[]}],"metadata":{"kaggle":{"accelerator":"nvidiaTeslaT4","dataSources":[{"databundleVersionId":7516023,"sourceId":61542,"sourceType":"competition"},{"datasetId":3936750,"sourceId":6847931,"sourceType":"datasetVersion"},{"datasetId":4325258,"sourceId":7432540,"sourceType":"datasetVersion"},{"datasetId":4336615,"sourceId":7452416,"sourceType":"datasetVersion"}],"dockerImageVersionId":30762,"isGpuEnabled":true,"isInternetEnabled":true,"language":"python","sourceType":"notebook"},"kernelspec":{"display_name":"binh_mamba","language":"python","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.12.5"}},"nbformat":4,"nbformat_minor":4}