File size: 3,092 Bytes
07423df |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 |
import torch
from llm_studio.python_configs.text_causal_language_modeling_config import (
ConfigProblemBase,
)
from llm_studio.src.models.text_causal_language_modeling_model import Model
from llm_studio.src.utils.modeling_utils import TokenStoppingCriteria, activate_neftune
def test_token_stopping_criteria():
token_stopping_criteria = TokenStoppingCriteria(
stop_word_ids=torch.tensor([0, 1, 2, 8]), prompt_input_ids_len=4
)
input_ids = torch.tensor(
[
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
[2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
[3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
[4, 5, 6, 7, 8, 9, 10, 11, 12, 13],
[5, 6, 7, 8, 9, 10, 11, 12, 13, 14],
]
).long()
# prompt input len is 4, so generated ids of last sample of the batch are
# [9, 10, 11, 12, 13, 14], do not trigger stopping criteria
assert not token_stopping_criteria(input_ids=input_ids, scores=None)
token_stopping_criteria = TokenStoppingCriteria(
stop_word_ids=torch.tensor([6]), prompt_input_ids_len=0
)
# first item reads [ 0, 1, 2, 3, 4, 5], so do not trigger stopping criteria
assert not token_stopping_criteria(input_ids=input_ids[:, :6], scores=None)
assert token_stopping_criteria(input_ids=input_ids[:, :7], scores=None)
# Test stopping criteria with compound tokens
token_stopping_criteria = TokenStoppingCriteria(
stop_word_ids=torch.tensor([[6, 7]]), prompt_input_ids_len=0
)
assert not token_stopping_criteria(input_ids=input_ids[:, :6], scores=None)
assert not token_stopping_criteria(input_ids=input_ids[:, :7], scores=None)
assert token_stopping_criteria(input_ids=input_ids[:, :8], scores=None)
# Test stopping criteria with stop word ids being longer than generated text
token_stopping_criteria = TokenStoppingCriteria(
stop_word_ids=torch.tensor([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]]),
prompt_input_ids_len=0,
)
assert not token_stopping_criteria(input_ids=input_ids, scores=None)
def test_neftune_is_disabled_in_inference():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
cfg = ConfigProblemBase(llm_backbone="h2oai/llama2-0b-unit-test")
cfg.architecture.backbone_dtype = "float32"
model = Model(cfg).eval().to(device)
input_batch = {
"input_ids": torch.randint(
0,
1000,
(1, 10),
).to(device),
"attention_mask": torch.ones(1, 10).to(device),
}
with torch.no_grad():
outputs = model.backbone(**input_batch)
activate_neftune(model, neftune_noise_alpha=10)
assert model.backbone.get_input_embeddings().neftune_noise_alpha == 10
with torch.no_grad():
outputs_after_neftune = model.backbone(**input_batch)
assert torch.allclose(outputs["logits"], outputs_after_neftune["logits"])
# state dict does not contain neftune noise
assert [key for key in model.state_dict() if "neftune" in key] == []
|