Model Card for Model ID
Instruction tuning The models have been fine-tuned.
Usage
!pip install vllm==0.6.4.post1 --force-reinstall
import time
import torch
import transformers
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
)
import vllm ### packaging==24.1にしないとエラーになる!! ###
print(vllm.__version__)
MAX_LENGTH = 1000
MODEL_NAME = "bay-llm/gemma-9b-SFT-1020-large-16bit" # コンペで提出したいモデルに適宜置換
llm = vllm.LLM(
model=MODEL_NAME,
tensor_parallel_size=1,
gpu_memory_utilization=0.95,
trust_remote_code=True,
max_model_len=1024,
)
tokenizer = llm.get_tokenizer()
# ELYZA-tasks-100-TVの読み込み。事前にファイルをアップロードしてください
# データセットの読み込み。
# omnicampusの開発環境では、左にタスクのjsonlをドラッグアンドドロップしてから実行。
import json
datasets = []
with open("../elyza-tasks-100-TV_0.jsonl", "r") as f:
item = ""
for line in f:
line = line.strip()
item += line
if item.endswith("}"):
datasets.append(json.loads(item))
item = ""
print(datasets[0])
messages_list = [
[{"role": "user", "content": datasets[i]["input"]}] for i in range(len(datasets))
]
prompts = [line[0]["content"] for line in messages_list]
prompt_token_ids = [tokenizer.apply_chat_template(messages, add_generation_prompt=True) for messages in messages_list]
sampling_params = vllm.SamplingParams(
temperature=0.5,
max_tokens=512,
)
outputs = llm.generate(prompt_token_ids=prompt_token_ids, sampling_params=sampling_params)
for prompt, response in zip(prompts, outputs):
print("prompt:", prompt)
print("output:", response.outputs[0].text.strip())
print("-"*80)
import json
data = [{
"task_id": i,
"input": prompts[i],
"output": outputs[i].outputs[0].text.strip()
} for i in range(len(datasets))]
file_path = 'submmit.jsonl'
with open(file_path, 'w', encoding='utf-8') as file:
for entry in data:
json.dump(entry, file, ensure_ascii=False)
file.write('\n')
Uploaded model
- Developed by: bay-llm
- License: gemma
- Finetuned from model : unsloth/gemma-2-9b-bnb-4bit
This gemma2 model was trained 2x faster with Unsloth and Huggingface's TRL library.
- Downloads last month
- 27
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social
visibility and check back later, or deploy to Inference Endpoints (dedicated)
instead.
Model tree for bay-llm/gemma-9b-SFT-1020-large-16bit
Base model
google/gemma-2-9b