Small dummy LLama2-type Model useable for Unit/Integration tests. Suitable for CPU only machines, see H2O LLM Studio for an example integration test.
Model was created as follows:
from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM
repo_name = "MaxJeblick/llama2-0b-unit-test"
model_name = "h2oai/h2ogpt-4096-llama2-7b-chat"
config = AutoConfig.from_pretrained(model_name)
config.hidden_size = 12
config.max_position_embeddings = 1024
config.intermediate_size = 24
config.num_attention_heads = 2
config.num_hidden_layers = 2
config.num_key_value_heads = 2
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_config(config)
print(model.num_parameters()) # 770_940
model.push_to_hub(repo_name, private=False)
tokenizer.push_to_hub(repo_name, private=False)
config.push_to_hub(repo_name, private=False)
Below is a small example that will run in ~ 1 second.
import torch
from transformers import AutoModelForCausalLM
def test_manual_greedy_generate():
max_new_tokens = 10
# note this is on CPU!
model = AutoModelForCausalLM.from_pretrained("MaxJeblick/llama2-0b-unit-test").eval()
input_ids = model.dummy_inputs["input_ids"]
y = model.generate(input_ids, max_new_tokens=max_new_tokens)
assert y.shape == (3, input_ids.shape[1] + max_new_tokens)
for _ in range(max_new_tokens):
with torch.no_grad():
outputs = model(input_ids)
next_token_logits = outputs.logits[:, -1, :]
next_token_id = torch.argmax(next_token_logits, dim=-1).unsqueeze(-1)
input_ids = torch.cat([input_ids, next_token_id], dim=-1)
assert torch.allclose(y, input_ids)
Tipp:
Use fixtures with session scope to load the model only once. This will decrease test runtime further.
import pytest
from transformers import AutoModelForCausalLM
@pytest.fixture(scope="session")
def model():
return AutoModelForCausalLM.from_pretrained("MaxJeblick/llama2-0b-unit-test").eval()
- Downloads last month
- 11,089
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social
visibility and check back later, or deploy to Inference Endpoints (dedicated)
instead.