|
from typing import Dict, List, Any, Tuple, Optional |
|
import os |
|
import torch |
|
from transformers import AutoTokenizer, PreTrainedTokenizerFast |
|
import pandas as pd |
|
import time |
|
import numpy as np |
|
from precious3_gpt_multi_modal import Precious3MPTForCausalLM |
|
|
|
|
|
class EndpointHandler: |
|
def __init__(self, path: str = ""): |
|
""" |
|
Initializes the EndpointHandler with the specified model type and device. |
|
|
|
Args: |
|
path (str): Path to the pretrained model directory. |
|
|
|
""" |
|
self.device = 'cuda' |
|
self.path = path |
|
|
|
|
|
self.model = self._load_model(path) |
|
print('Model loaded') |
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained("insilicomedicine/precious3-gpt-multi-modal", trust_remote_code=True) |
|
print('Tokenizer loaded') |
|
|
|
|
|
self._set_model_token_ids() |
|
|
|
|
|
self.unique_compounds_p3, self.unique_genes_p3 = self._load_unique_entities() |
|
self.emb_gpt_genes, self.emb_hgt_genes = self._load_embeddings() |
|
print('Embeddings loaded') |
|
|
|
def _load_model(self, path: str) -> Precious3MPTForCausalLM: |
|
""" Load model based on specified model type. """ |
|
return Precious3MPTForCausalLM.from_pretrained(path, torch_dtype=torch.bfloat16).to(self.device) |
|
|
|
def _set_model_token_ids(self): |
|
""" Set predefined token IDs in the model config. """ |
|
self.model.config.pad_token_id = self.tokenizer.pad_token_id |
|
self.model.config.bos_token_id = self.tokenizer.bos_token_id |
|
self.model.config.eos_token_id = self.tokenizer.eos_token_id |
|
|
|
def _load_unique_entities(self) -> Tuple[List[str], List[str]]: |
|
""" Load unique entities from online CSV and return lists of compounds and genes. """ |
|
unique_entities_p3 = pd.read_csv('https://huggingface.co/insilicomedicine/precious3-gpt/raw/main/all_entities_with_type.csv') |
|
unique_compounds = [i.strip() for i in unique_entities_p3[unique_entities_p3.type == 'compound'].entity.to_list()] |
|
unique_genes = [i.strip() for i in unique_entities_p3[unique_entities_p3.type == 'gene'].entity.to_list()] |
|
return unique_compounds, unique_genes |
|
|
|
def _load_embeddings(self) -> Tuple[Dict[str, Any], Dict[str, Any]]: |
|
""" Load gene embeddings and return as dictionaries. """ |
|
emb_gpt_genes = pd.read_pickle('https://huggingface.co/insilicomedicine/precious3-gpt-multi-modal/resolve/main/multi-modal-data/emb_gpt_genes.pickle') |
|
emb_hgt_genes = pd.read_pickle('https://huggingface.co/insilicomedicine/precious3-gpt-multi-modal/resolve/main/multi-modal-data/emb_hgt_genes.pickle') |
|
return (dict(zip(emb_gpt_genes.gene_symbol.tolist(), emb_gpt_genes.embs.tolist())), |
|
dict(zip(emb_hgt_genes.gene_symbol.tolist(), emb_hgt_genes.embs.tolist()))) |
|
|
|
def create_prompt(self, prompt_config: Dict[str, Any]) -> str: |
|
""" |
|
Create a prompt string based on the provided configuration. |
|
|
|
Args: |
|
prompt_config (Dict[str, Any]): Configuration dict containing prompt variables. |
|
|
|
Returns: |
|
str: The formatted prompt string. |
|
""" |
|
prompt = "[BOS]" |
|
multi_modal_prefix = '<modality0><modality1><modality2><modality3>' * 3 |
|
|
|
for k, v in prompt_config.items(): |
|
if k == 'instruction': |
|
prompt += f'<{v}>' if isinstance(v, str) else "".join([f'<{v_i}>' for v_i in v]) |
|
elif k == 'up': |
|
if v: |
|
prompt += f'{multi_modal_prefix}<{k}>{v} </{k}>' if isinstance(v, str) else f'{multi_modal_prefix}<{k}>{" ".join(v)} </{k}>' |
|
elif k == 'down': |
|
if v: |
|
prompt += f'{multi_modal_prefix}<{k}>{v} </{k}>' if isinstance(v, str) else f'{multi_modal_prefix}<{k}>{" ".join(v)} </{k}>' |
|
elif k == 'age': |
|
if isinstance(v, int): |
|
prompt += f'<{k}_individ>{v} </{k}_individ>' if prompt_config['species'].strip() == 'human' else f'<{k}_individ>Macaca-{int(v/20)} </{k}_individ>' |
|
else: |
|
if v: |
|
prompt += f'<{k}>{v.strip()} </{k}>' if isinstance(v, str) else f'<{k}>{" ".join(v)} </{k}>' |
|
else: |
|
prompt += f'<{k}></{k}>' |
|
|
|
print('Generated prompt:', prompt) |
|
return prompt |
|
|
|
def custom_generate(self, |
|
input_ids: torch.Tensor, |
|
acc_embs_up_kg_mean: Optional[np.ndarray], |
|
acc_embs_down_kg_mean: Optional[np.ndarray], |
|
acc_embs_up_txt_mean: Optional[np.ndarray], |
|
acc_embs_down_txt_mean: Optional[np.ndarray], |
|
device: str, |
|
max_new_tokens: int, |
|
mode: str, |
|
temperature: float = 0.8, |
|
top_p: float = 0.2, |
|
top_k: int = 3550, |
|
n_next_tokens: int = 50, |
|
num_return_sequences: int = 1, |
|
random_seed: int = 137) -> Tuple[Dict[str, List], List[List], int]: |
|
""" |
|
Generate sequences based on input ids and accumulated embeddings. |
|
|
|
Args: |
|
input_ids (torch.Tensor): Input token IDs for generation. |
|
acc_embs_up_kg_mean (Optional[np.ndarray]): Accumulated embeddings for UP genes (KG mean). |
|
acc_embs_down_kg_mean (Optional[np.ndarray]): Accumulated embeddings for DOWN genes (KG mean). |
|
acc_embs_up_txt_mean (Optional[np.ndarray]): Accumulated embeddings for UP genes (Text mean). |
|
acc_embs_down_txt_mean (Optional[np.ndarray]): Accumulated embeddings for DOWN genes (Text mean). |
|
device (str): The device to perform computation on. |
|
max_new_tokens (int): Maximum number of new tokens to generate. |
|
mode (str): Mode of generation to determine behavior. |
|
temperature (float): Temperature for randomness in sampling. |
|
top_p (float): Top-p (nucleus) sampling threshold. |
|
top_k (int): Top-k sampling threshold. |
|
n_next_tokens (int): Number of tokens to consider for predicting compounds. |
|
num_return_sequences (int): Number of sequences to return. |
|
random_seed (int): Random seed for reproducibility. |
|
|
|
Returns: |
|
Tuple[Dict[str, List], List[List], int]: Processed outputs, predicted compounds, and the random seed. |
|
""" |
|
torch.manual_seed(random_seed) |
|
|
|
|
|
modality0_emb = torch.unsqueeze(torch.from_numpy(acc_embs_up_kg_mean), 0).to(device) if isinstance(acc_embs_up_kg_mean, np.ndarray) else None |
|
modality1_emb = torch.unsqueeze(torch.from_numpy(acc_embs_down_kg_mean), 0).to(device) if isinstance(acc_embs_down_kg_mean, np.ndarray) else None |
|
modality2_emb = torch.unsqueeze(torch.from_numpy(acc_embs_up_txt_mean), 0).to(device) if isinstance(acc_embs_up_txt_mean, np.ndarray) else None |
|
modality3_emb = torch.unsqueeze(torch.from_numpy(acc_embs_down_txt_mean), 0).to(device) if isinstance(acc_embs_down_txt_mean, np.ndarray) else None |
|
|
|
|
|
outputs = [] |
|
next_token_compounds = [] |
|
next_token_up_genes = [] |
|
next_token_down_genes = [] |
|
|
|
|
|
for _ in range(num_return_sequences): |
|
start_time = time.time() |
|
generated_sequence = [] |
|
current_token = input_ids.clone() |
|
next_token = current_token[0][-1] |
|
generated_tokens_counter = 0 |
|
|
|
while generated_tokens_counter < max_new_tokens - 1: |
|
|
|
if next_token == self.tokenizer.eos_token_id: |
|
generated_sequence.append(current_token) |
|
break |
|
|
|
|
|
logits = self.model.forward( |
|
input_ids=current_token, |
|
modality0_emb=modality0_emb, |
|
modality0_token_id=self.tokenizer.encode('<modality0>')[0], |
|
modality1_emb=modality1_emb, |
|
modality1_token_id=self.tokenizer.encode('<modality1>')[0], |
|
modality2_emb=modality2_emb, |
|
modality2_token_id=self.tokenizer.encode('<modality2>')[0], |
|
modality3_emb=modality3_emb, |
|
modality3_token_id=self.tokenizer.encode('<modality3>')[0], |
|
)[0] |
|
|
|
|
|
if temperature != 1.0: |
|
logits = logits / temperature |
|
|
|
|
|
sorted_logits, sorted_indices = torch.sort(logits, descending=True) |
|
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1) |
|
sorted_indices_to_remove = cumulative_probs > top_p |
|
|
|
if top_k > 0: |
|
sorted_indices_to_remove[..., top_k:] = 1 |
|
|
|
inf_tensor = torch.tensor(float("-inf")).type(torch.bfloat16).to(logits.device) |
|
logits = logits.where(sorted_indices_to_remove, inf_tensor) |
|
|
|
|
|
if current_token[0][-1] == self.tokenizer.encode('<drug>')[0] and len(next_token_compounds) == 0: |
|
next_token_compounds.append(torch.topk(torch.softmax(logits, dim=-1)[0][-1, :].flatten(), n_next_tokens).indices) |
|
|
|
if current_token[0][-1] == self.tokenizer.encode('<up>')[0] and len(next_token_up_genes) == 0: |
|
|
|
n_next_tokens_4_genes = 250 |
|
top_k_up_genes = torch.topk(torch.softmax(logits, dim=-1)[0][-1, :].flatten(), n_next_tokens_4_genes).indices |
|
next_token_up_genes.append(top_k_up_genes) |
|
generated_tokens_counter += len(top_k_up_genes) |
|
current_token = torch.cat((current_token, top_k_up_genes.unsqueeze(0), |
|
torch.tensor([self.tokenizer.encode('</up>')[0]]).unsqueeze(0).to(device)), dim=-1) |
|
continue |
|
|
|
if current_token[0][-1] == self.tokenizer.encode('<down>')[0] and len(next_token_down_genes) == 0: |
|
|
|
n_next_tokens_4_genes = 250 |
|
top_k_down_genes = torch.topk(torch.softmax(logits, dim=-1)[0][-1, :].flatten(), n_next_tokens_4_genes).indices |
|
next_token_down_genes.append(top_k_down_genes) |
|
generated_tokens_counter += len(top_k_down_genes) |
|
current_token = torch.cat((current_token, top_k_down_genes.unsqueeze(0), |
|
torch.tensor([self.tokenizer.encode('</down>')[0]]).unsqueeze(0).to(device)), dim=-1) |
|
continue |
|
|
|
|
|
next_token = torch.multinomial(torch.softmax(logits, dim=-1)[0], num_samples=1)[-1, :].unsqueeze(0) |
|
current_token = torch.cat((current_token, next_token), dim=-1) |
|
generated_tokens_counter += 1 |
|
|
|
print("Generation time:", time.time() - start_time) |
|
outputs.append(generated_sequence) |
|
|
|
|
|
processed_outputs = self.process_generated_outputs(next_token_up_genes, next_token_down_genes, mode) |
|
|
|
predicted_compounds_ids = [self.tokenizer.convert_ids_to_tokens(j) for j in next_token_compounds] |
|
predicted_compounds = [[i.strip() for i in j] for j in predicted_compounds_ids] |
|
|
|
return processed_outputs, predicted_compounds, random_seed |
|
|
|
def process_generated_outputs(self, next_token_up_genes: List[List], next_token_down_genes: List[List], mode: str) -> Dict[str, List]: |
|
""" |
|
Process generated outputs for UP and DOWN genes based on the mode. |
|
|
|
Args: |
|
next_token_up_genes (List[List]): List of tokens generated for UP genes. |
|
next_token_down_genes (List[List]): List of tokens generated for DOWN genes. |
|
mode (str): Generation mode. |
|
|
|
Returns: |
|
Dict[str, List]: Processed outputs based on the model mode. |
|
""" |
|
processed_outputs = {"up": [], "down": []} |
|
if mode in ['meta2diff', 'meta2diff2compound']: |
|
processed_outputs['up'] = self._get_unique_genes(next_token_up_genes) |
|
processed_outputs['down'] = self._get_unique_genes(next_token_down_genes) |
|
else: |
|
processed_outputs = {"generated_sequences": []} |
|
|
|
return processed_outputs |
|
|
|
def _get_unique_genes(self, tokens: List[List]) -> List[List[str]]: |
|
""" |
|
Get unique gene symbols from generated tokens. |
|
|
|
Args: |
|
tokens (List[List]): List of token IDs. |
|
|
|
Returns: |
|
List[List[str]]: List of unique gene symbols for each token sequence. |
|
""" |
|
predicted_genes = [] |
|
predicted_genes_tokens = [self.tokenizer.convert_ids_to_tokens(j) for j in tokens] |
|
for j in predicted_genes_tokens: |
|
generated_sample = [i.strip() for i in j] |
|
|
|
predicted_genes.append(sorted(set(generated_sample) & set(self.unique_genes_p3), key=generated_sample.index)) |
|
return predicted_genes |
|
|
|
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: |
|
""" |
|
Handles incoming requests to the endpoint, processing data and generating responses. |
|
|
|
Args: |
|
data (Dict[str, Any]): The payload with the text prompt and generation parameters. |
|
|
|
Returns: |
|
Dict[str, Any]: The resulting output dictionary for the request. |
|
""" |
|
data = data.copy() |
|
parameters = data.pop("parameters", None) |
|
config_data = data.pop("inputs", None) |
|
mode = data.pop('mode', 'Not specified') |
|
|
|
config_data_copy = config_data.copy() |
|
|
|
prompt = self.create_prompt(config_data_copy) |
|
if mode != "diff2compound": |
|
prompt += "<up>" |
|
|
|
inputs = self.tokenizer(prompt, return_tensors="pt") |
|
|
|
if 3 in inputs['input_ids'][0]: |
|
decoded_tokens = self.tokenizer.convert_ids_to_tokens(inputs['input_ids'][0]) |
|
print(f"\n>>> Warning! There are unknown tokens in prompt: {''.join(decoded_tokens)} \n") |
|
|
|
input_ids = inputs["input_ids"].to(self.device) |
|
|
|
max_new_tokens = self.model.config.max_seq_len - len(input_ids[0]) |
|
|
|
acc_embs_up1_mean, acc_embs_up2_mean, acc_embs_down1_mean, acc_embs_down2_mean = self._get_accumulated_embeddings(config_data) |
|
|
|
generated_sequence, raw_next_token_generation, out_seed = self.custom_generate( |
|
input_ids=input_ids, |
|
acc_embs_up_kg_mean=acc_embs_up1_mean, |
|
acc_embs_down_kg_mean=acc_embs_down1_mean, |
|
acc_embs_up_txt_mean=acc_embs_up2_mean, |
|
acc_embs_down_txt_mean=acc_embs_down2_mean, |
|
max_new_tokens=max_new_tokens, mode=mode, |
|
device=self.device, **parameters |
|
) |
|
|
|
next_token_generation = [sorted(set(i) & set(self.unique_compounds_p3), key=i.index) for i in raw_next_token_generation] |
|
|
|
out = self._prepare_output(generated_sequence, next_token_generation, mode, prompt, out_seed) |
|
|
|
return out |
|
|
|
def _get_accumulated_embeddings(self, config_data: Dict[str, List[str]]) -> Tuple[Optional[np.ndarray], Optional[np.ndarray], Optional[np.ndarray], Optional[np.ndarray]]: |
|
""" |
|
Retrieve accumulated embeddings for UP and DOWN genes. |
|
|
|
Args: |
|
config_data (Dict[str, List[str]]): Configuration dictionary with gene information. |
|
|
|
Returns: |
|
Tuple[Optional[np.ndarray], ...]: Mean accumulated embeddings for UP and DOWN genes. |
|
""" |
|
acc_embs_up1 = [] |
|
acc_embs_up2 = [] |
|
if 'up' in config_data: |
|
for gs in config_data['up']: |
|
try: |
|
acc_embs_up1.append(self.emb_hgt_genes[gs]) |
|
acc_embs_up2.append(self.emb_gpt_genes[gs]) |
|
except Exception as e: |
|
pass |
|
|
|
|
|
acc_embs_up1_mean = np.array(acc_embs_up1).mean(0) if acc_embs_up1 else None |
|
acc_embs_up2_mean = np.array(acc_embs_up2).mean(0) if acc_embs_up2 else None |
|
|
|
acc_embs_down1 = [] |
|
acc_embs_down2 = [] |
|
if 'down' in config_data: |
|
for gs in config_data['down']: |
|
try: |
|
acc_embs_down1.append(self.emb_hgt_genes[gs]) |
|
acc_embs_down2.append(self.emb_gpt_genes[gs]) |
|
except Exception as e: |
|
pass |
|
|
|
|
|
|
|
|
|
acc_embs_down1_mean = np.array(acc_embs_down1).mean(0) if acc_embs_down1 else None |
|
acc_embs_down2_mean = np.array(acc_embs_down2).mean(0) if acc_embs_down2 else None |
|
|
|
return acc_embs_up1_mean, acc_embs_up2_mean, acc_embs_down1_mean, acc_embs_down2_mean |
|
|
|
def _prepare_output(self, generated_sequence: Any, next_token_generation: List[List], mode: str, prompt: str, out_seed: int) -> Dict[str, Any]: |
|
""" |
|
Prepare the output dictionary based on the mode of operation. |
|
|
|
Args: |
|
generated_sequence (Any): The generated sequences from the model. |
|
next_token_generation (List[List]): The next tokens generated. |
|
mode (str): Mode of operation. |
|
prompt (str): The input prompt that was used. |
|
out_seed (int): Random seed used in generation. |
|
|
|
Returns: |
|
Dict[str, Any]: Output dictionary with structured results. |
|
""" |
|
try: |
|
outputs = {} |
|
if mode == "meta2diff": |
|
outputs = {"up": generated_sequence['up'], "down": generated_sequence['down']} |
|
out = {"output": outputs, "mode": mode, "message": "Done!", "input": prompt, 'random_seed': out_seed} |
|
elif mode == "meta2diff2compound": |
|
outputs = {"up": generated_sequence['up'], "down": generated_sequence['down']} |
|
out = { |
|
"output": outputs, "compounds": next_token_generation, "mode": mode, |
|
"message": "Done!", "input": prompt, 'random_seed': out_seed} |
|
elif mode == "diff2compound": |
|
outputs = generated_sequence |
|
out = { |
|
"output": outputs, "compounds": next_token_generation, "mode": mode, |
|
"message": "Done!", "input": prompt, 'random_seed': out_seed} |
|
else: |
|
out = {"message": f"Specify one of the following modes: meta2diff, meta2diff2compound, diff2compound. Your mode is: {mode}"} |
|
|
|
except Exception as e: |
|
print(e) |
|
outputs, next_token_generation = [None], [None] |
|
out = {"output": outputs, "mode": mode, 'message': f"{e}", "input": prompt, 'random_seed': 137} |
|
|
|
return out |