Spaces:
Runtime error
Runtime error
from fastapi import FastAPI | |
from fastapi.responses import HTMLResponse, FileResponse | |
from fastapi.middleware.cors import CORSMiddleware | |
from fastapi.staticfiles import StaticFiles | |
from typing import Dict, List, Any, Tuple | |
import pickle | |
import math | |
import re | |
import gc | |
from utils import split | |
import torch | |
from build_vocab import WordVocab | |
from pretrain_trfm import TrfmSeq2seq | |
from transformers import T5EncoderModel, T5Tokenizer | |
import numpy as np | |
import pydantic | |
app = FastAPI() | |
app.mount("/static", StaticFiles(directory="static"), name="static") | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
async def read_root(): | |
return FileResponse("static/index.html") | |
class PredictData(pydantic.BaseModel): | |
sequence: str | |
smiles: str | |
async def predict(data: PredictData): | |
endpointHandler = EndpointHandler() | |
result = endpointHandler.predict({ | |
"inputs": { | |
"sequence": data.sequence, | |
"smiles": data.smiles | |
} | |
}) | |
return result | |
tokenizer = T5Tokenizer.from_pretrained( | |
"Rostlab/prot_t5_xl_half_uniref50-enc", do_lower_case=False, torch_dtype=torch.float16) | |
model = T5EncoderModel.from_pretrained( | |
"Rostlab/prot_t5_xl_half_uniref50-enc") | |
class EndpointHandler(): | |
def __init__(self, path=""): | |
self.tokenizer = tokenizer | |
self.model = model | |
# path to the vocab_content and trfm model | |
vocab_content_path = "vocab_content.txt" | |
trfm_path = "trfm_12_23000.pkl" | |
# load the vocab_content instead of the pickle file | |
with open(vocab_content_path, "r", encoding="utf-8") as f: | |
vocab_content = f.read().strip().split("\n") | |
# load the vocab and trfm model | |
self.vocab = WordVocab(vocab_content) | |
self.trfm = TrfmSeq2seq(len(self.vocab), 256, len(self.vocab), 4) | |
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') | |
self.trfm.load_state_dict(torch.load(trfm_path, map_location=device)) | |
self.trfm.eval() | |
# path to the pretrained models | |
self.Km_model_path = "Km.pkl" | |
self.Kcat_model_path = "Kcat.pkl" | |
self.Kcat_over_Km_model_path = "Kcat_over_Km.pkl" | |
# vocab indices | |
self.pad_index = 0 | |
self.unk_index = 1 | |
self.eos_index = 2 | |
self.sos_index = 3 | |
def predict(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: | |
""" | |
Function where the endpoint logic is implemented. | |
Args: | |
data (Dict[str, Any]): The input data for the endpoint. It only contain a single key "inputs" which is a list of dictionaries. The dictionary contains the following keys: | |
- sequence (str): Amino acid sequence. | |
- smiles (str): SMILES representation of the molecule. | |
Returns: | |
Dict[str, Any]: The output data for the endpoint. The dictionary contains the following keys: | |
- Km (float): float of predicted Km value. | |
- Kcat (float): float of predicted Kcat value. | |
- Vmax (float): float of predicted Vmax value. | |
""" | |
sequence = data["inputs"]["sequence"] | |
smiles = data["inputs"]["smiles"] | |
seq_vec = self.Seq_to_vec(sequence) | |
smiles_vec = self.smiles_to_vec(smiles) | |
fused_vector = np.concatenate((smiles_vec, seq_vec), axis=1) | |
pred_Km = self.predict_feature_using_model( | |
fused_vector, self.Km_model_path) | |
pred_Kcat = self.predict_feature_using_model( | |
fused_vector, self.Kcat_model_path) | |
pred_Vmax = self.predict_feature_using_model( | |
fused_vector, self.Kcat_over_Km_model_path) | |
result = { | |
"Km": pred_Km, | |
"Kcat": pred_Kcat, | |
"Vmax": pred_Vmax, | |
} | |
return result | |
def predict_feature_using_model(self, X: np.array, model_path: str) -> float: | |
""" | |
Function to predict the feature using the pretrained model. | |
""" | |
with open(model_path, "rb") as f: | |
model = pickle.load(f) | |
pred_feature = model.predict(X) | |
pred_feature_pow = math.pow(10, pred_feature) | |
return pred_feature_pow | |
def smiles_to_vec(self, Smiles: str) -> np.array: | |
""" | |
Function to convert the smiles to a vector using the pretrained model. | |
""" | |
Smiles = [Smiles] | |
x_split = [split(sm) for sm in Smiles] | |
xid, xseg = self.get_array(x_split, self.vocab) | |
X = self.trfm.encode(torch.t(xid)) | |
return X | |
def get_inputs(self, sm: str, vocab: WordVocab) -> Tuple[List[int], List[int]]: | |
""" | |
Convert smiles to tensor | |
""" | |
seq_len = len(sm) | |
sm = sm.split() | |
ids = [vocab.stoi.get(token, self.unk_index) for token in sm] | |
ids = [self.sos_index] + ids + [self.eos_index] | |
seg = [1]*len(ids) | |
padding = [self.pad_index]*(seq_len - len(ids)) | |
ids.extend(padding), seg.extend(padding) | |
return ids, seg | |
def get_array(self, smiles: list[str], vocab: WordVocab) -> Tuple[torch.tensor, torch.tensor]: | |
""" | |
Convert smiles to tensor | |
""" | |
x_id, x_seg = [], [] | |
for sm in smiles: | |
a,b = self.get_inputs(sm, vocab) | |
x_id.append(a) | |
x_seg.append(b) | |
return torch.tensor(x_id), torch.tensor(x_seg) | |
def Seq_to_vec(self, Sequence: str) -> np.array: | |
""" | |
Function to convert the sequence to a vector using the pretrained model. | |
""" | |
Sequence = [Sequence] | |
sequences_Example = [] | |
for i in range(len(Sequence)): | |
zj = '' | |
for j in range(len(Sequence[i]) - 1): | |
zj += Sequence[i][j] + ' ' | |
zj += Sequence[i][-1] | |
sequences_Example.append(zj) | |
gc.collect() | |
print(torch.cuda.is_available()) | |
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') | |
self.model = self.model.to(device) | |
self.model = self.model.eval() | |
features = [] | |
for i in range(len(sequences_Example)): | |
sequences_Example_i = sequences_Example[i] | |
sequences_Example_i = [re.sub(r"[UZOB]", "X", sequences_Example_i)] | |
ids = self.tokenizer.batch_encode_plus(sequences_Example_i, add_special_tokens=True, padding=True) | |
input_ids = torch.tensor(ids['input_ids']).to(device) | |
attention_mask = torch.tensor(ids['attention_mask']).to(device) | |
with torch.no_grad(): | |
embedding = self.model(input_ids=input_ids, attention_mask=attention_mask) | |
embedding = embedding.last_hidden_state.cpu().numpy() | |
for seq_num in range(len(embedding)): | |
seq_len = (attention_mask[seq_num] == 1).sum() | |
seq_emd = embedding[seq_num][:seq_len - 1] | |
features.append(seq_emd) | |
features_normalize = np.zeros([len(features), len(features[0][0])], dtype=float) | |
for i in range(len(features)): | |
for k in range(len(features[0][0])): | |
for j in range(len(features[i])): | |
features_normalize[i][k] += features[i][j][k] | |
features_normalize[i][k] /= len(features[i]) | |
return features_normalize | |