H2OTest / tests /src /models /test_text_causal_language_modeling_model.py
elineve's picture
Upload 301 files
07423df
raw
history blame
3.09 kB
import torch
from llm_studio.python_configs.text_causal_language_modeling_config import (
ConfigProblemBase,
)
from llm_studio.src.models.text_causal_language_modeling_model import Model
from llm_studio.src.utils.modeling_utils import TokenStoppingCriteria, activate_neftune
def test_token_stopping_criteria():
token_stopping_criteria = TokenStoppingCriteria(
stop_word_ids=torch.tensor([0, 1, 2, 8]), prompt_input_ids_len=4
)
input_ids = torch.tensor(
[
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
[2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
[3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
[4, 5, 6, 7, 8, 9, 10, 11, 12, 13],
[5, 6, 7, 8, 9, 10, 11, 12, 13, 14],
]
).long()
# prompt input len is 4, so generated ids of last sample of the batch are
# [9, 10, 11, 12, 13, 14], do not trigger stopping criteria
assert not token_stopping_criteria(input_ids=input_ids, scores=None)
token_stopping_criteria = TokenStoppingCriteria(
stop_word_ids=torch.tensor([6]), prompt_input_ids_len=0
)
# first item reads [ 0, 1, 2, 3, 4, 5], so do not trigger stopping criteria
assert not token_stopping_criteria(input_ids=input_ids[:, :6], scores=None)
assert token_stopping_criteria(input_ids=input_ids[:, :7], scores=None)
# Test stopping criteria with compound tokens
token_stopping_criteria = TokenStoppingCriteria(
stop_word_ids=torch.tensor([[6, 7]]), prompt_input_ids_len=0
)
assert not token_stopping_criteria(input_ids=input_ids[:, :6], scores=None)
assert not token_stopping_criteria(input_ids=input_ids[:, :7], scores=None)
assert token_stopping_criteria(input_ids=input_ids[:, :8], scores=None)
# Test stopping criteria with stop word ids being longer than generated text
token_stopping_criteria = TokenStoppingCriteria(
stop_word_ids=torch.tensor([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]]),
prompt_input_ids_len=0,
)
assert not token_stopping_criteria(input_ids=input_ids, scores=None)
def test_neftune_is_disabled_in_inference():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
cfg = ConfigProblemBase(llm_backbone="h2oai/llama2-0b-unit-test")
cfg.architecture.backbone_dtype = "float32"
model = Model(cfg).eval().to(device)
input_batch = {
"input_ids": torch.randint(
0,
1000,
(1, 10),
).to(device),
"attention_mask": torch.ones(1, 10).to(device),
}
with torch.no_grad():
outputs = model.backbone(**input_batch)
activate_neftune(model, neftune_noise_alpha=10)
assert model.backbone.get_input_embeddings().neftune_noise_alpha == 10
with torch.no_grad():
outputs_after_neftune = model.backbone(**input_batch)
assert torch.allclose(outputs["logits"], outputs_after_neftune["logits"])
# state dict does not contain neftune noise
assert [key for key in model.state_dict() if "neftune" in key] == []