YsnHdn commited on
Commit
b313f2c
·
1 Parent(s): c6eb236

Add : Unit Testing

Browse files
Files changed (1) hide show
  1. testapp.py +61 -0
testapp.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+ from app import app
3
+ from helper_functions import predict_class, transform_list_of_texts, prepare_text, inference
4
+ import torch
5
+ from transformers import DistilBertForSequenceClassification, AutoTokenizer
6
+
7
+ @pytest.fixture
8
+ def client():
9
+ app.config['TESTING'] = True
10
+ with app.test_client() as client:
11
+ yield client
12
+
13
+ # Unit tests
14
+
15
+ def test_predict_class():
16
+ # Mock the model and tokenizer
17
+ model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased")
18
+ text = ["This is a sample text for testing."]
19
+
20
+ predicted_class, class_probabilities = predict_class(text, model)
21
+
22
+ assert isinstance(predicted_class, tuple)
23
+ assert isinstance(class_probabilities, dict)
24
+ assert len(class_probabilities) == 17 # Assuming 17 classes
25
+
26
+ def test_transform_list_of_texts():
27
+ tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
28
+ texts = ["This is a sample text.", "Another sample text."]
29
+
30
+ result = transform_list_of_texts(texts, tokenizer, 510, 510, 1, 2550)
31
+
32
+ assert isinstance(result, dict)
33
+ assert "input_ids" in result
34
+ assert "attention_mask" in result
35
+
36
+ # Integration tests
37
+
38
+ def test_pdf_upload(client):
39
+ # You'll need to create a sample PDF file for testing
40
+ with open('sample.pdf', 'rb') as pdf_file:
41
+ data = {'file': (pdf_file, 'sample.pdf')}
42
+ response = client.post('/pdf/upload', data=data, content_type='multipart/form-data')
43
+
44
+ assert response.status_code == 200
45
+ assert b'class_probabilities' in response.data
46
+
47
+ def test_sentence_endpoint(client):
48
+ data = {'text': 'This is a sample sentence for testing.'}
49
+ response = client.post('/sentence', data=data)
50
+
51
+ assert response.status_code == 200
52
+ assert b'predicted_class' in response.data
53
+
54
+ def test_voice_endpoint(client):
55
+ # You'll need to create a sample audio file for testing
56
+ with open('sample_audio.wav', 'rb') as audio_file:
57
+ data = {'audio': (audio_file, 'sample_audio.wav')}
58
+ response = client.post('/voice', data=data, content_type='multipart/form-data')
59
+
60
+ assert response.status_code == 200
61
+ assert b'extracted_text' in response.data