metadata
tags:
- text-generation
license: other
Usage
from peft import PeftModel, PeftConfig
from transformers import AutoModelForCausalLM
from transformers import AutoTokenizer
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
config = PeftConfig.from_pretrained("Ashishkr/llama2-qrecc-context-resolution")
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
model = PeftModel.from_pretrained(model, "Ashishkr/llama2-qrecc-context-resolution").to(device)
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
def response_generate(
model: AutoModelForCausalLM,
tokenizer: AutoTokenizer,
prompt: str,
max_new_tokens: int = 128,
temperature: float = 0.7,
):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
inputs = tokenizer(
[prompt],
return_tensors="pt",
return_token_type_ids=False,
).to(
device
)
with torch.autocast("cuda", dtype=torch.bfloat16):
response = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
temperature=temperature,
return_dict_in_generate=True,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
)
decoded_output = tokenizer.decode(
response["sequences"][0],
skip_special_tokens=True,
)
return decoded_output
prompt = """ Strictly use the context provided, to generate the repsonse. No additional information to be added. Re-write the user query using the context .
>>CONTEXT<<Where did jessica go to school? Where did she work at?>>USER<<What did she do next for work?>>REWRITE<<"""
response = response_generate(
model,
tokenizer,
prompt,
max_new_tokens=20,
temperature=0.1,
)
print(response)