Edit model card

ESM-2 Sequence Classifier

This is a small sequence classifier trained on synthetic data generated by GPT-4 which classifies protein sequences into three categories enzymes (class 0), receptor_proteins (class 1), and structural_proteins (class 2). This is trained using facebook/esm2_t6_8M_UR50D, one of the ESM-2 models.

This model is not well tested, and is for experimental and eductaional purposes. Use with caution.

Using the Model

To use the model, try running:

# Load the trained model and tokenizer
model = EsmForSequenceClassification.from_pretrained("AmelieSchreiber/esm2_t6_8M_UR50D_sequence_classifier_v1")
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D")

# Suppose these are your new sequences that you want to classify
# Additional Family 0: Enzymes
new_sequences_0 = [
    "ACGYLKTPKLADPPVLRGDSSVTKAICKPDPVLEK",
    "GVALDECKALDYLPGKPLPMDGKVCQCGSKTPLRP",
    "VLPGYTCGELDCKPGKPLPKCGADKTQVATPFLRG",
    "TCGALVQYPSCADPPVLRGSDSSVKACKKLDPQDK",
    "GALCEECKLCPGADYKPMDGDRLPAAATSKTRPVG",
    "PAVDCKKALVYLPKPLPMDGKVCRGSKTPKTRPYG",
    "VLGYTCGALDCKPGKPLPKCGADKTQVATPFLRGA",
    "CGALVQYPSCADPPVLRGSDSSVKACKKLDPQDKT",
    "ALCEECKLCPGADYKPMDGDRLPAAATSKTRPVGK",
    "AVDCKKALVYLPKPLPMDGKVCRGSKTPKTRPYGR",
]

# Additional Family 1: Receptor Proteins
new_sequences_1 = [
    "VGQRFYGGRQKNRHCELSPLPSACRGSVQGALYTD",
    "KDQVLTVPTYACRCCPKMDSKGRVPSTLRVKSARS",
    "PLAGVACGRGLDYRCPRKMVPGDLQVTPATQRPYG",
    "CGVRLGYPGCADVPLRGRSSFAPRACMKKDPRVTR",
    "RKGVAYLYECRKLRCRADYKPRGMDGRRLPKASTT",
    "RPTGAVNCKQAKVYRGLPLPMMGKVPRVCRSRRPY",
    "RLDGGYTCGQALDCKPGRKPPKMGCADLKSTVATP",
    "LGTCRKLVRYPQCADPPVMGRSSFRPKACCRQDPV",
    "RVGYAMCSPKLCSCRADYKPPMGDGDRLPKAATSK",
    "QPKAVNCRKAMVYRPKPLPMDKGVPVCRSKRPRPY",
]

# Additional Family 2: Structural Proteins
new_sequences_2 = [
    "VGKGFRYGSSQKRYLHCQKSALPPSCRRGKGQGSAT",
    "KDPTVMTVGTYSCQCPKQDSRGSVQPTSRVKTSRSK",
    "PLVGKACGRSSDYKCPGQMVSGGSKQTPASQRPSYD",
    "CGKKLVGYPSSKADVPLQGRSSFSPKACKKDPQMTS",
    "RKGVASLYCSSKLSCKAQYSKGMSDGRSPKASSTTS",
    "RPKSAASCEQAKSYRSLSLPSMKGKVPSKCSRSKRP",
    "RSDVSYTSCSQSKDCKPSKPPKMSGSKDSSTVATPS",
    "LSTCSKKVAYPSSKADPPSSGRSSFSMKACKKQDPPV",
    "RVGSASSEPKSSCSVQSYSKPSMSGDSSPKASSTSK",
    "QPSASNCEKMSSYRPSLPSMSKGVPSSRSKSSPPYQ",
]

# Tokenize the sequences and convert to tensors
# Merge all sequences
new_sequences = new_sequences_0 + new_sequences_1 + new_sequences_2
inputs = tokenizer(new_sequences, return_tensors="pt", padding=True, truncation=True)

# Use the model to get the logits
with torch.no_grad():
    logits = model(**inputs).logits

# Get the predicted class for each sequence
predicted_class_ids = torch.argmax(logits, dim=-1)

# Print the predicted class for each sequence
for sequence, predicted_class in zip(new_sequences, predicted_class_ids):
    print(f"Sequence: {sequence}, Predicted class: {predicted_class.item()}")
Downloads last month
108
Safetensors
Model size
7.84M params
Tensor type
F32
·
Inference Examples
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social visibility and check back later, or deploy to Inference Endpoints (dedicated) instead.