import gradio as gr from transformers import AutoProcessor, AutoModelForCausalLM import spaces import torch.nn.functional as F import requests import copy import torch from PIL import Image, ImageDraw, ImageFont import io import matplotlib.pyplot as plt import matplotlib.patches as patches import random import numpy as np from esm import pretrained, FastaBatchedDataset models = { 'facebook/esm2_t36_3B_UR50D': pretrained.load_model_and_alphabet('esm2_t36_3B_UR50D').to("cuda").eval(), } processors = { 'microsoft/Florence-2-large-ft': AutoProcessor.from_pretrained('microsoft/Florence-2-large-ft', trust_remote_code=True), 'microsoft/Florence-2-large': AutoProcessor.from_pretrained('microsoft/Florence-2-large', trust_remote_code=True), 'microsoft/Florence-2-base-ft': AutoProcessor.from_pretrained('microsoft/Florence-2-base-ft', trust_remote_code=True), 'microsoft/Florence-2-base': AutoProcessor.from_pretrained('microsoft/Florence-2-base', trust_remote_code=True), } DESCRIPTION = "Esm2 embedding" colormap = ['blue','orange','green','purple','brown','pink','gray','olive','cyan','red', 'lime','indigo','violet','aqua','magenta','coral','gold','tan','skyblue'] @spaces.GPU def run_example(protein, model_id='facebook/esm2_t36_3B_UR50D'): model_esm, alphabet = models[model_id] protein_name = 'protein_name' protein_seq = protein include = 'per_tok' repr_layers = [36] truncation_seq_length = 1024 toks_per_batch = 4096 print("start") dataset = FastaBatchedDataset([protein_name], [protein_seq]) print("dataset prepared") batches = dataset.get_batch_indices(toks_per_batch, extra_toks_per_seq=1) print("batches prepared") data_loader = torch.utils.data.DataLoader( dataset, collate_fn=alphabet.get_batch_converter(truncation_seq_length), batch_sampler=batches ) print(f"Read sequences") return_contacts = "contacts" in include assert all(-(model_esm.num_layers + 1) <= i <= model_esm.num_layers for i in repr_layers) repr_layers = [(i + model_esm.num_layers + 1) % (model_esm.num_layers + 1) for i in repr_layers] with torch.no_grad(): for batch_idx, (labels, strs, toks) in enumerate(data_loader): print( f"Processing {batch_idx + 1} of {len(batches)} batches ({toks.size(0)} sequences)" ) if torch.cuda.is_available(): toks = toks.to(device="cuda", non_blocking=True) out = model_esm(toks, repr_layers=repr_layers, return_contacts=return_contacts) representations = { layer: t.to(device="cpu") for layer, t in out["representations"].items() } if return_contacts: contacts = out["contacts"].to(device="cpu") for i, label in enumerate(labels): result = {"label": label} truncate_len = min(truncation_seq_length, len(strs[i])) # Call clone on tensors to ensure tensors are not views into a larger representation # See https://github.com/pytorch/pytorch/issues/1995 if "per_tok" in include: result["representations"] = { layer: t[i, 1: truncate_len + 1].clone() for layer, t in representations.items() } if "mean" in include: result["mean_representations"] = { layer: t[i, 1: truncate_len + 1].mean(0).clone() for layer, t in representations.items() } if "bos" in include: result["bos_representations"] = { layer: t[i, 0].clone() for layer, t in representations.items() } if return_contacts: result["contacts"] = contacts[i, : truncate_len, : truncate_len].clone() esm_emb = result['representations'][36] ''' inputs = tokenizer([protein], return_tensors="pt", padding=True, truncation=True).to('cuda') with torch.no_grad(): outputs = model_esm(**inputs) esm_emb = outputs.last_hidden_state.detach()[0] ''' print("esm embedding generated") esm_emb = F.pad(esm_emb.t(), (0, 1024 - len(esm_emb))).t() torch.save(esm_emb, 'example.pt') return gr.File.update(value="example.pt", visible=True) css = """ #output { height: 500px; overflow: auto; border: 1px solid #ccc; } """ with gr.Blocks(css=css) as demo: gr.Markdown(DESCRIPTION) with gr.Tab(label="Esm2 embedding generation"): with gr.Row(): with gr.Column(): input_protein = gr.Textbox(type="text", label="Upload sequence") model_selector = gr.Dropdown(choices=list(models.keys()), label="Model", value='microsoft/Florence-2-large') submit_btn = gr.Button(value="Submit") with gr.Column(): button = gr.Button("Export") pt = gr.File(interactive=False, visible=False) # gr.Examples( # examples=[ # ["image1.jpg", 'Object Detection'], # ], # inputs=[input_img, task_prompt], # outputs=[output_text, output_img], # fn=process_image, # cache_examples=True, # label='Try examples' # ) button.click(run_example, [input_protein, model_selector], pt) demo.launch(debug=True)