--- base_model: tokyotech-llm/Swallow-7b-hf library_name: peft --- # Model Info This is a model that applies LLM2Vec to Swallow. Only the PEFT Adapter is distributed. LLM2Vec is fine-tuned on two tasks: MNTP and SimCSE, and this repository contains the results of applying SimCSE after MNTP. For the MNTP Adapter, please refer to [this link](https://huggingface.co/uzabase/LLM2Vec-Llama-2-7b-hf-wikipedia-jp-mntp). ## Model Details ### Model Description - **Model type:** PEFT - **Language(s) (NLP):** Japanese - **License:** Apache2.0 - **Finetuned from model:** [Swallow-7b-hf](https://huggingface.co/tokyotech-llm/Swallow-7b-hf) ### Model Sources - **Repository:** https://github.com/McGill-NLP/llm2vec - **Paper:** https://arxiv.org/abs/2404.05961 # Usage - Please see [original LLM2Vec repo](https://huggingface.co/McGill-NLP/LLM2Vec-Llama-2-7b-chat-hf-mntp-unsup-simcse#usage) # Benchmark - Followings are summaries. Details are [here](https://tech.uzabase.com/entry/2024/09/30/114245) ## MTEB(Japanese) | | Classification | Clustering | PairClassification | Reranking | BitextMining | Retrieval | Sts | 平均 | | --- | ---: | ---: | ---: | ---: | ---: | ---: | ---: | ---: | | Llama2-Llm2vec-eng | 0.527 | 0.258 | 0.501 | 0.217 | 0.275 | 0.296 | 0.765 | 0.408 | | Llama2-Llm2vec-jpn | 0.570 | 0.365 | 0.510 | 0.349 | 0.470 | 0.417 | 0.795 | 0.498 | | **Swallow-Llm2vec-jpn (This repo)** | 0.621 | 0.391 | 0.510 | 0.475 | 0.475 | 0.491 | 0.832 | 0.523 | ## MTEB(English) | | Classification | Clustering | Pair_Classification| Reranking | Retrieval | STS | 平均 | | --- | ---: | ---: | ---: | ---: | ---: | ---: | ---: | | Llama2-Llm2vec-eng | 0.709 | 0.386 | 0.780 | 0.588 | 0.329| 0.723 | 0.586 | | Llama2-Llm2vec-jpn | 0.722 | 0.428 | 0.785 | 0.594 | 0.371 | 0.717 | 0.603 | | **Swallow-Llm2vec-jpn (This repo)** | 0.695 | 0.385 | 0.751 | 0.576 | 0.318 | 0.710 | 0.572 | # Training Details ## Training Data - Make Corpus from SimCSE from [Wikipedia](https://huggingface.co/datasets/wikimedia/wikipedia) - Script for making SimCSE Corpus ``` import argparse import random import re from pathlib import Path from datasets import load_dataset from tqdm import tqdm def main(args): random.seed(args.seed) wiki_ds = load_dataset("wikimedia/wikipedia", "20231101.ja") sampled_index = random.sample(range(len(wiki_ds["train"])), args.N) sample_wiki = wiki_ds["train"][sampled_index] output_texts = [] for title, text in tqdm(zip(sample_wiki["title"], sample_wiki["text"])): output_texts.append(title) sentences = re.split("[\n。]", text) for sentence in sentences: if len(sentence) > args.min_sentence_len: output_texts.append(sentence.strip()+"。") with args.output_path.open(mode="w") as f: for line in output_texts: f.write(line) f.write("\n") if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--N", default=200000, type=int) parser.add_argument("--seed", default=42, type=int) parser.add_argument("-o", "--output_path", type=Path) parser.add_argument("--min_sentence_len", default=50, type=int) args = parser.parse_args() main(args) ``` ## Training Hyperparameter - simcse_dropout: 0.3 - bidirectional: true - pooling_mode: "mean" - remove_unused_columns: false - learning_rate: 3e-5 - loss_scale: 20 - batch_size: 256 - gradient_accumulation_steps: 1 - max_seq_length: 128 - lora_r: 16 - torch_dtype: "bfloat16" - attn_implementation: "flash_attention_2" - seed: 42 - bf16: true - gradient_checkpointing: true ## Accelerator Settings - deepspeed_config: - gradient_accumulation_steps: 1 - gradient_clipping: 1.0 - offload_optimizer_device: nvme - offload_optimizer_nvme_path: /nvme - zero3_save_16bit_model: true - zero_stage: 2 - distributed_type: DEEPSPEED - downcast_bf16: 'no' - dynamo_config: - dynamo_backend: INDUCTOR - dynamo_mode: default - dynamo_use_dynamic: true - dynamo_use_fullgraph: true - enable_cpu_affinity: false - machine_rank: 0 - main_training_function: main - mixed_precision: bf16 - num_machines: 1 - num_processes: 2 - rdzv_backend: static - same_network: true - quse_cpu: false ## Framework versions - Python: 3.12.3 - PEFT 0.11.1 - Sentence Transformers: 3.0.1 - Transformers: 4.41.0 - PyTorch: 2.3.0 - Accelerate: 0.30.1 - Datasets: 2.20.0 - Tokenizers: 0.19.1 - MTEB: 1.13.0