llama2-0b-unit-test / README.md
MaxJeblick's picture
Update README.md
7581874 verified
|
raw
history blame
2.13 kB
metadata
{}

Small dummy LLama2-type Model useable for Unit/Integration tests. Suitable for CPU only machines, see H2O LLM Studio for an example integration test.

Model was created as follows:

from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM

repo_name = "MaxJeblick/llama2-0b-unit-test"
model_name = "h2oai/h2ogpt-4096-llama2-7b-chat"
config = AutoConfig.from_pretrained(model_name)
config.hidden_size = 12
config.max_position_embeddings = 1024
config.intermediate_size = 24
config.num_attention_heads = 2
config.num_hidden_layers = 2
config.num_key_value_heads = 2

tokenizer = AutoTokenizer.from_pretrained(model_name)

model = AutoModelForCausalLM.from_config(config)
print(model.num_parameters())  # 770_940

model.push_to_hub(repo_name, private=False)
tokenizer.push_to_hub(repo_name, private=False)
config.push_to_hub(repo_name, private=False)

Below is a small example that will run in ~ 1 second.

import torch
from transformers import AutoModelForCausalLM


def test_manual_greedy_generate():
    max_new_tokens = 10

    # note this is on CPU!
    model = AutoModelForCausalLM.from_pretrained("MaxJeblick/llama2-0b-unit-test").eval()
    input_ids = model.dummy_inputs["input_ids"]

    y = model.generate(input_ids, max_new_tokens=max_new_tokens)

    assert y.shape == (3, input_ids.shape[1] + max_new_tokens)

    for _ in range(max_new_tokens):
        with torch.no_grad():
            outputs = model(input_ids)

        next_token_logits = outputs.logits[:, -1, :]
        next_token_id = torch.argmax(next_token_logits, dim=-1).unsqueeze(-1)

        input_ids = torch.cat([input_ids, next_token_id], dim=-1)

    assert torch.allclose(y, input_ids)

Tipp:

Use fixtures with session scope to load the model only once. This will decrease test runtime further.

import pytest
from transformers import AutoModelForCausalLM
@pytest.fixture(scope="session")
def model():
    return AutoModelForCausalLM.from_pretrained("MaxJeblick/llama2-0b-unit-test").eval()