from typing import Dict, List, Any import os import torch from transformers import AutoTokenizer, AutoModel import pandas as pd import time import numpy as np class EndpointHandler: def __init__(self, path="insilicomedicine/precious3-gpt"): self.model = AutoModel.from_pretrained(path, trust_remote_code=True).to('cuda') self.tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True) self.model.config.pad_token_id = self.tokenizer.pad_token_id self.model.config.bos_token_id = self.tokenizer.bos_token_id self.model.config.eos_token_id = self.tokenizer.eos_token_id unique_entities_p3 = pd.read_csv('https://huggingface.co/insilicomedicine/precious3-gpt/raw/main/all_entities_with_type.csv') self.unique_compounds_p3 = [i.strip() for i in unique_entities_p3[unique_entities_p3.type=='compound'].entity.to_list()] self.unique_genes_p3 = [i.strip() for i in unique_entities_p3[unique_entities_p3.type=='gene'].entity.to_list()] def create_prompt(self, prompt_config): prompt = "[BOS]" multi_modal_prefix = '' for k, v in prompt_config.items(): if k=='instruction': prompt+=f'<{v}>' if isinstance(v, str) else "".join([f'<{v_i}>' for v_i in v]) elif k=='up': if v: prompt+=f'{multi_modal_prefix}<{k}>{v} ' if isinstance(v, str) else f'{multi_modal_prefix}<{k}>{" ".join(v)} ' elif k=='down': if v: prompt+=f'{multi_modal_prefix}<{k}>{v} ' if isinstance(v, str) else f'{multi_modal_prefix}<{k}>{" ".join(v)} ' elif k=='age': if isinstance(v, int): if prompt_config['species'].strip() == 'human': prompt+=f'<{k}_individ>{v} ' elif prompt_config['species'].strip() == 'macaque': prompt+=f'<{k}_individ>Macaca-{int(v/20)} ' else: if v: prompt+=f'<{k}>{v.strip()} ' if isinstance(v, str) else f'<{k}>{" ".join(v)} ' else: prompt+=f'<{k}>' return prompt def custom_generate(self, input_ids, device, max_new_tokens, mode, temperature=0.8, top_p=0.2, top_k=3550, n_next_tokens=50, num_return_sequences=1, random_seed=137): torch.manual_seed(random_seed) # Set parameters # temperature - Higher value for more randomness, lower for more control # top_p - Probability threshold for nucleus sampling (aka top-p sampling) # top_k - Ignore logits below the top-k value to reduce randomness (if non-zero) # n_next_tokens - Number of top next tokens when predicting compounds # Generate sequences outputs = [] next_token_compounds = [] for _ in range(num_return_sequences): start_time = time.time() generated_sequence = [] current_token = input_ids.clone() for _ in range(max_new_tokens): # Maximum length of generated sequence # Forward pass through the model logits = self.model.forward( input_ids=current_token )[0] # Apply temperature to logits if temperature != 1.0: logits = logits / temperature # Apply top-p sampling (nucleus sampling) sorted_logits, sorted_indices = torch.sort(logits, descending=True) cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1) sorted_indices_to_remove = cumulative_probs > top_p if top_k > 0: sorted_indices_to_remove[..., top_k:] = 1 # Set the logit values of the removed indices to a very small negative value inf_tensor = torch.tensor(float("-inf")).type(torch.bfloat16).to(logits.device) logits = logits.where(sorted_indices_to_remove, inf_tensor) # Sample the next token if current_token[0][-1] == self.tokenizer.encode('')[0] and len(next_token_compounds)==0: next_token_compounds.append(torch.topk(torch.softmax(logits, dim=-1)[0][len(current_token[0])-1, :].flatten(), n_next_tokens).indices) next_token = torch.multinomial(torch.softmax(logits, dim=-1)[0], num_samples=1)[len(current_token[0])-1, :].unsqueeze(0) # Append the sampled token to the generated sequence generated_sequence.append(next_token.item()) # Stop generation if an end token is generated if next_token == self.tokenizer.eos_token_id: break # Prepare input for the next iteration current_token = torch.cat((current_token, next_token), dim=-1) print(time.time()-start_time) outputs.append(generated_sequence) # Process generated up/down lists processed_outputs = {"up": [], "down": []} if mode in ['meta2diff', 'meta2diff2compound']: for output in outputs: up_split_index = output.index(self.tokenizer.convert_tokens_to_ids('')) generated_up_raw = [i.strip() for i in self.tokenizer.convert_ids_to_tokens(output[:up_split_index])] generated_up = sorted(set(generated_up_raw) & set(self.unique_genes_p3), key = generated_up_raw.index) processed_outputs['up'].append(generated_up) down_split_index = output.index(self.tokenizer.convert_tokens_to_ids('')) generated_down_raw = [i.strip() for i in self.tokenizer.convert_ids_to_tokens(output[up_split_index:down_split_index+1])] generated_down = sorted(set(generated_down_raw) & set(self.unique_genes_p3), key = generated_down_raw.index) processed_outputs['down'].append(generated_down) else: processed_outputs = outputs predicted_compounds_ids = [self.tokenizer.convert_ids_to_tokens(j) for j in next_token_compounds] predicted_compounds = [] for j in predicted_compounds_ids: predicted_compounds.append([i.strip() for i in j]) return processed_outputs, predicted_compounds, random_seed def __call__(self, data: Dict[str, Any]) -> Dict[str, str]: """ Args: data (:dict:): The payload with the text prompt and generation parameters. """ device = "cuda" parameters = data.pop("parameters", None) config_data = data.pop("inputs", None) mode = data.pop('mode', 'Not specified') prompt = self.create_prompt(config_data) inputs = self.tokenizer(prompt, return_tensors="pt") input_ids = inputs["input_ids"].to(device) max_new_tokens = self.model.config.max_seq_len - len(input_ids[0]) try: generated_sequence, raw_next_token_generation, out_seed = self.custom_generate(input_ids = input_ids, max_new_tokens=max_new_tokens, mode=mode, device=device, **parameters) next_token_generation = [sorted(set(i) & set(self.unique_compounds_p3), key = i.index) for i in raw_next_token_generation] if mode == "meta2diff": outputs = {"up": generated_sequence['up'], "down": generated_sequence['down']} out = {"output": outputs, "mode": mode, "message": "Done!", "input": prompt, 'random_seed': out_seed} elif mode == "meta2diff2compound": outputs = {"up": generated_sequence['up'], "down": generated_sequence['down']} out = { "output": outputs, "compounds": next_token_generation, "raw_output": raw_next_token_generation, "mode": mode, "message": "Done!", "input": prompt, 'random_seed': out_seed} elif mode == "diff2compound": outputs = generated_sequence out = { "output": outputs, "compounds": next_token_generation, "raw_output": raw_next_token_generation, "mode": mode, "message": "Done!", "input": prompt, 'random_seed': out_seed} else: out = {"message": f"Specify one of the following modes: meta2diff, meta2diff2compound, diff2compound. Your mode is: {mode}"} except Exception as e: print(e) outputs, next_token_generation = [None], [None] out = {"output": outputs, "mode": mode, 'message': f"{e}", "input": prompt, 'random_seed': 137} return out