|
import sys |
|
import json |
|
import torch |
|
from peft import PeftModel |
|
from transformers import GenerationConfig, LlamaForCausalLM, LlamaTokenizer, AutoModelForCausalLM, AutoTokenizer |
|
from .handler import DataHandler |
|
|
|
assert torch.cuda.is_available(), "No cuda device detected" |
|
|
|
|
|
class Inferer: |
|
""" |
|
A basic inference class for accessing medAlpaca models programmatically. |
|
|
|
This class provides methods for loading supported medAlpaca models, tokenizing inputs, |
|
and generating outputs based on the specified model and configurations. |
|
|
|
Attributes: |
|
available_models (dict): A dictionary containing the supported models and their configurations. |
|
|
|
Args: |
|
model_name (str): The name of the medAlpaca model to use for inference. |
|
prompt_template (str): The path to the JSON file containing the prompt template. |
|
base_model (str, optional): If LoRA is used, this should point to the bases model weigts |
|
model_max_length: (int, optional): Number of input tokens to the model. Default is 512. |
|
load_in_8bit (bool, optional): Wether a quantized model should be loaded. Default is False |
|
torch_dtype (torch.dtype, optional): The torch datatype to load the base model. Default is float16 |
|
peft (bool, optional): If the model was trainied in 8bit or with LoRA, PEFT library should be used |
|
to load the model. Default is False. |
|
|
|
Example: |
|
|
|
medalpaca = medAlapaca("medalpaca/medalapca-7b", "prompts/alpaca.json") |
|
response = medalpaca(input="What is Amoxicillin?") |
|
""" |
|
|
|
def __init__( |
|
self, |
|
model_name: str, |
|
prompt_template: str, |
|
base_model: str = None, |
|
model_max_length: int = 512, |
|
load_in_8bit: bool = False, |
|
torch_dtype: torch.dtype = torch.float16, |
|
peft: bool = False |
|
) -> None: |
|
|
|
if base_model and not peft: |
|
raise ValueError( |
|
"You have specified a base model, but `peft` is false", |
|
"This would load the base model only" |
|
) |
|
|
|
self.model = self._load_model( |
|
model_name = model_name, |
|
base_model = base_model or model_name, |
|
load_in_8bit = load_in_8bit, |
|
torch_dtype = torch_dtype, |
|
peft = peft |
|
) |
|
|
|
tokenizer = self._load_tokenizer(base_model or model_name) |
|
|
|
self.data_handler = DataHandler( |
|
tokenizer, |
|
prompt_template = prompt_template, |
|
model_max_length = model_max_length, |
|
train_on_inputs = False, |
|
) |
|
|
|
|
|
def _load_model( |
|
self, |
|
model_name: str, |
|
base_model: str, |
|
load_in_8bit: bool, |
|
torch_dtype: torch.dtype, |
|
peft: bool |
|
) -> torch.nn.Module: |
|
|
|
if "llama" in base_model.lower(): |
|
load_model = LlamaForCausalLM |
|
else: |
|
load_model = AutoModelForCausalLM |
|
|
|
model = load_model.from_pretrained( |
|
base_model, |
|
load_in_8bit=load_in_8bit, |
|
torch_dtype=torch_dtype, |
|
device_map={"": 0}, |
|
) |
|
|
|
if peft: |
|
model = PeftModel.from_pretrained( |
|
model, |
|
model_id=model_name, |
|
torch_dtype=torch_dtype, |
|
device_map={"": 0}, |
|
) |
|
|
|
if not load_in_8bit: |
|
model.half() |
|
|
|
model.eval() |
|
|
|
if torch.__version__ >= "2" and sys.platform != "win32": |
|
model = torch.compile(model) |
|
|
|
return model |
|
|
|
def _load_tokenizer(self, model_name: str): |
|
if "llama" in model_name.lower(): |
|
tokenizer = LlamaTokenizer.from_pretrained(model_name) |
|
else: |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
tokenizer.pad_token_id = 0 |
|
tokenizer.padding_side = "left" |
|
return tokenizer |
|
|
|
def __call__( |
|
self, |
|
input: str, |
|
instruction: str = None, |
|
output: str = None, |
|
max_new_tokens: int = 128, |
|
verbose: bool = False, |
|
**generation_kwargs, |
|
) -> str: |
|
""" |
|
Generate a response from the medAlpaca model using the given input and instruction. |
|
|
|
Args: |
|
input (str): |
|
The input text to provide to the model. |
|
instruction (str, optional): |
|
An optional instruction to guide the model's response. |
|
output (str, optional): |
|
Prepended to the models output, e.g. for 1-shot prompting |
|
max_new_tokens (int, optional): |
|
How many new tokens the model can generate |
|
verbose (bool, optional): |
|
If True, print the prompt before generating a response. |
|
**generation_kwargs: |
|
Keyword arguments to passed to the `GenerationConfig`. |
|
See here for possible arguments: https://huggingface.co/docs/transformers/v4.20.1/en/main_classes/text_generation |
|
|
|
Returns: |
|
str: The generated response from the medAlpaca model. |
|
""" |
|
|
|
prompt = self.data_handler.generate_prompt(instruction = instruction, input = input, output = output) |
|
if verbose: |
|
print(prompt) |
|
|
|
input_tokens = self.data_handler.tokenizer(prompt, return_tensors="pt") |
|
input_token_ids = input_tokens["input_ids"].to("cuda") |
|
|
|
generation_config = GenerationConfig(**generation_kwargs) |
|
|
|
with torch.no_grad(): |
|
generation_output = self.model.generate( |
|
input_ids=input_token_ids, |
|
generation_config=generation_config, |
|
return_dict_in_generate=True, |
|
output_scores=True, |
|
max_new_tokens=max_new_tokens, |
|
) |
|
generation_output_decoded = self.data_handler.tokenizer.decode(generation_output.sequences[0]) |
|
split = f'{self.data_handler.prompt_template["output"]}{output or ""}' |
|
response = generation_output_decoded.split(split)[-1].strip() |
|
return response |
|
|