File size: 2,456 Bytes
d93a43e c573591 d93a43e 0d8eac0 bee36a2 1b279a5 c573591 6e75c39 1b279a5 6e75c39 15611ba 6e75c39 0d8eac0 6e75c39 ef2888d 15611ba 8f83392 ef2888d 0d8eac0 1b763d0 0d8eac0 6b4d3f3 428cf29 dd36091 7ef0d72 428cf29 0029d81 0d8eac0 6ad6514 aab0394 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 |
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from rdkit.Chem import Draw
from rdkit import Chem
import selfies as sf
sf_output="zju"
def greet1(name):
tokenizer = AutoTokenizer.from_pretrained("zjunlp/MolGen")
model = AutoModelForSeq2SeqLM.from_pretrained("zjunlp/MolGen")
sf_input = tokenizer(name, return_tensors="pt")
# beam search
molecules = model.generate(input_ids=sf_input["input_ids"],
attention_mask=sf_input["attention_mask"],
max_length=15,
min_length=5,
num_return_sequences=4,
num_beams=5)
sf_output = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True).replace(" ","") for g in molecules]
return sf_output
def greet2(name):
tokenizer = AutoTokenizer.from_pretrained("zjunlp/MolGen")
model = AutoModelForSeq2SeqLM.from_pretrained("zjunlp/MolGen")
sf_input = tokenizer(name, return_tensors="pt")
# beam search
molecules = model.generate(input_ids=sf_input["input_ids"],
attention_mask=sf_input["attention_mask"],
max_length=15,
min_length=5,
num_return_sequences=4,
num_beams=5)
sf_output = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True).replace(" ","") for g in molecules]
smis = [sf.decoder(i) for i in sf_output]
mols = []
for smi in smis:
mol = Chem.MolFromSmiles(smi)
mols.append(mol)
img = Draw.MolsToGridImage(
mols,
molsPerRow=4,
subImgSize=(200,200),
legends=['' for x in mols]
)
return img
def greet3(name):
return name
examples = [
['[C][=C][C][=C][C][=C][Ring1][=Branch1]'],['[C]']
]
greeter_1 = gr.Interface(greet1, inputs="textbox", outputs="text")
greeter_2 = gr.Interface(greet2 , inputs="textbox", outputs="image")
#greeter_2.launch()
demo = gr.Parallel(greeter_1, greeter_2,title="Molecular Language Model as Multi-task Generator",
examples=examples)
demo.launch()
#iface = gr.Interface(fn=greet2, inputs="text", outputs="image", title="Molecular Language Model as Multi-task Generator",
# )
#iface.launch() |