Small dummy LLama2-type Model useable for Unit/Integration tests. Suitable for CPU only machines, see H2O LLM Studio for an example integration test.

Model was created as follows:

from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM

repo_name = "MaxJeblick/llama2-0b-unit-test"
model_name = "h2oai/h2ogpt-4096-llama2-7b-chat"
config = AutoConfig.from_pretrained(model_name)
config.hidden_size = 12
config.max_position_embeddings = 1024
config.intermediate_size = 24
config.num_attention_heads = 2
config.num_hidden_layers = 2
config.num_key_value_heads = 2

tokenizer = AutoTokenizer.from_pretrained(model_name)

model = AutoModelForCausalLM.from_config(config)
print(model.num_parameters())  # 770_940

model.push_to_hub(repo_name, private=False)
tokenizer.push_to_hub(repo_name, private=False)
config.push_to_hub(repo_name, private=False)

Below is a small example that will run in ~ 1 second.

import torch
from transformers import AutoModelForCausalLM


def test_manual_greedy_generate():
    max_new_tokens = 10

    # note this is on CPU!
    model = AutoModelForCausalLM.from_pretrained("MaxJeblick/llama2-0b-unit-test").eval()
    input_ids = model.dummy_inputs["input_ids"]

    y = model.generate(input_ids, max_new_tokens=max_new_tokens)

    assert y.shape == (3, input_ids.shape[1] + max_new_tokens)

    for _ in range(max_new_tokens):
        with torch.no_grad():
            outputs = model(input_ids)

        next_token_logits = outputs.logits[:, -1, :]
        next_token_id = torch.argmax(next_token_logits, dim=-1).unsqueeze(-1)

        input_ids = torch.cat([input_ids, next_token_id], dim=-1)

    assert torch.allclose(y, input_ids)

Tipp:

Use fixtures with session scope to load the model only once. This will decrease test runtime further.

import pytest
from transformers import AutoModelForCausalLM
@pytest.fixture(scope="session")
def model():
    return AutoModelForCausalLM.from_pretrained("MaxJeblick/llama2-0b-unit-test").eval()
Downloads last month
11,089
Safetensors
Model size
771k params
Tensor type
F32
ยท
Inference Examples
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social visibility and check back later, or deploy to Inference Endpoints (dedicated) instead.

Space using MaxJeblick/llama2-0b-unit-test 1