mawairon commited on
Commit
04805af
1 Parent(s): 40aaf6e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -5
app.py CHANGED
@@ -5,7 +5,11 @@ from transformers import AutoTokenizer, AutoModelForSequenceClassification
5
  # Load the Hugging Face model and tokenizer
6
  model_name = 'AIRI-Institute/gena-lm-bert-base-lastln-t2t' # Replace with the actual model name
7
  tokenizer = AutoTokenizer.from_pretrained(model_name)
8
- model = AutoModelForSequenceClassification.from_pretrained(model_name)
 
 
 
 
9
 
10
  # Define a function to process the DNA sequence
11
  def analyze_dna(sequence):
@@ -13,12 +17,21 @@ def analyze_dna(sequence):
13
  inputs = tokenizer(sequence, return_tensors='pt')
14
  # Get model predictions
15
  outputs = model(**inputs)
16
-
17
- predictions = outputs.logits.argmax(dim=-1).item()
18
- return f"Prediction: {predictions}"
 
 
 
 
 
 
 
 
 
19
 
20
  # Create a Gradio interface
21
- demo = gr.Interface(fn=analyze_dna, inputs="text", outputs="text")
22
 
23
  # Launch the interface
24
  demo.launch()
 
5
  # Load the Hugging Face model and tokenizer
6
  model_name = 'AIRI-Institute/gena-lm-bert-base-lastln-t2t' # Replace with the actual model name
7
  tokenizer = AutoTokenizer.from_pretrained(model_name)
8
+ model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels = 38)
9
+
10
+ # Ensure the model has the correct number of classes
11
+ num_classes = model.config.num_labels
12
+ assert num_classes == 38, f"The model has {num_classes} classes, but 38 were expected."
13
 
14
  # Define a function to process the DNA sequence
15
  def analyze_dna(sequence):
 
17
  inputs = tokenizer(sequence, return_tensors='pt')
18
  # Get model predictions
19
  outputs = model(**inputs)
20
+
21
+ # Convert logits to probabilities
22
+ probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1).squeeze().tolist()
23
+
24
+ # Get the top 5 most likely classes
25
+ top_5_indices = sorted(range(len(probabilities)), key=lambda i: probabilities[i], reverse=True)[:5]
26
+ top_5_probs = [probabilities[i] for i in top_5_indices]
27
+
28
+ # Prepare the output as a list of tuples (class_index, probability)
29
+ result = [(index, prob) for index, prob in zip(top_5_indices, top_5_probs)]
30
+
31
+ return result
32
 
33
  # Create a Gradio interface
34
+ demo = gr.Interface(fn=analyze_dna, inputs="text", outputs="json")
35
 
36
  # Launch the interface
37
  demo.launch()