import gradio as gr import sys import random import os import pandas as pd import torch import itertools from torch.utils.data import DataLoader from transformers import AutoTokenizer sys.path.append("scripts/") from foldseek_util import get_struc_seq from utils import seed_everything from models import PLTNUM_PreTrainedModel from datasets_ import PLTNUMDataset class Config: def __init__(self): self.batch_size = 2 self.use_amp = False self.num_workers = 1 self.max_length = 512 self.used_sequence = "left" self.padding_side = "right" self.task = "classification" self.sequence_col = "sequence" self.seed = 42 def predict_stability_with_pdb(model_choice, organism_choice, pdb_files, cfg=Config()): results = [] for pdb_file in pdb_files: try: pdb_path = pdb_file.name os.system("chmod 777 bin/foldseek") sequences = get_foldseek_seq(pdb_path) if not sequences: results.append({"file_name": pdb_path, "raw prediction value": None, "binary prediction value": None }) continue sequence = sequences[2] if model_choice == "SaProt" else sequences[0] output = predict_stability_core(model_choice, organism_choice, sequence, cfg) results.append({"file_name": pdb_path, "raw prediction value": output["raw prediction values"][0], "binary prediction value": output["binary prediction values"][0] }) except Exception as e: results.append({"file_name": pdb_file.name, "raw prediction value": None, "binary prediction value": None }) df = pd.DataFrame(results) output_csv = "/tmp/predictions.csv" df.to_csv(output_csv, index=False) return output_csv def predict_stability_with_sequence(model_choice, organism_choice, sequence, cfg=Config()): try: if not sequence: return "No valid sequence provided." return predict_stability_core(model_choice, organism_choice, sequence, cfg) except Exception as e: return f"An error occurred: {str(e)}" def predict_stability_core(model_choice, organism_choice, sequence, cfg=Config()): cell_line = "HeLa" if organism_choice == "Human" else "NIH3T3" cfg.model = f"sagawa/PLTNUM-{model_choice}-{cell_line}" cfg.architecture = model_choice cfg.model_path = f"sagawa/PLTNUM-{model_choice}-{cell_line}" output = predict(cfg, sequence) return output def get_foldseek_seq(pdb_path): parsed_seqs = get_struc_seq( "bin/foldseek", pdb_path, ["A"], process_id=random.randint(0, 10000000), )["A"] return parsed_seqs def predict(cfg, sequence): cfg.token_length = 2 if cfg.architecture == "SaProt" else 1 cfg.device = "cuda" if torch.cuda.is_available() else "cpu" if cfg.used_sequence == "both": cfg.max_length += 1 seed_everything(cfg.seed) df = pd.DataFrame({cfg.sequence_col: [sequence]}) tokenizer = AutoTokenizer.from_pretrained( cfg.model_path, padding_side=cfg.padding_side ) cfg.tokenizer = tokenizer dataset = PLTNUMDataset(cfg, df, train=False) dataloader = DataLoader( dataset, batch_size=cfg.batch_size, shuffle=False, num_workers=cfg.num_workers, pin_memory=True, drop_last=False, ) model = PLTNUM_PreTrainedModel.from_pretrained(cfg.model_path, cfg=cfg) model.to(cfg.device) model.eval() predictions = [] for inputs, _ in dataloader: inputs = inputs.to(cfg.device) with torch.no_grad(): with torch.amp.autocast(cfg.device, enabled=cfg.use_amp): preds = ( torch.sigmoid(model(inputs)) if cfg.task == "classification" else model(inputs) ) predictions += preds.cpu().tolist() predictions = list(itertools.chain.from_iterable(predictions)) outputs = { "raw prediction values": predictions, "binary prediction values": [1 if x > 0.5 else 0 for x in predictions] } html_output = f"""
Raw prediction value: {outputs['raw prediction values'][0]}
Binary prediction values: {outputs['binary prediction values'][0]}