File size: 4,303 Bytes
58b9de9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import unittest
from unittest.mock import patch

import pandas as pd

import src.backend.evaluate_model as evaluate_model
import src.envs as envs


class TestEvaluator(unittest.TestCase):

    def setUp(self):
        self.model_name = 'test_model'
        self.revision = 'test_revision'
        self.precision = 'test_precision'
        self.batch_size = 10
        self.device = 'test_device'
        self.no_cache = False
        self.limit = 10

    @patch('src.backend.evaluate_model.SummaryGenerator')
    @patch('src.backend.evaluate_model.EvaluationModel')
    def test_evaluator_initialization(self, mock_eval_model, mock_summary_generator):
        evaluator = evaluate_model.Evaluator(self.model_name, self.revision,
                                            self.precision, self.batch_size,
                                            self.device, self.no_cache, self.limit)

        mock_summary_generator.assert_called_once_with(self.model_name, self.revision)
        mock_eval_model.assert_called_once_with(envs.HEM_PATH)
        self.assertEqual(evaluator.model, self.model_name)

    @patch('src.backend.evaluate_model.EvaluationModel')
    @patch('src.backend.evaluate_model.SummaryGenerator')
    def test_evaluator_initialization_error(self, mock_summary_generator, mock_eval_model):
        mock_eval_model.side_effect = Exception('test_exception')
        with self.assertRaises(Exception):
            evaluate_model.Evaluator(self.model_name, self.revision,
                                    self.precision, self.batch_size,
                                    self.device, self.no_cache, self.limit)

    @patch('src.backend.evaluate_model.SummaryGenerator')
    @patch('src.backend.evaluate_model.EvaluationModel')
    @patch('src.backend.evaluate_model.pd.read_csv')
    @patch('src.backend.util.format_results')
    def test_evaluate_method(self, mock_format_results, mock_read_csv, mock_eval_model,
                            mock_summary_generator):
        evaluator = evaluate_model.Evaluator(self.model_name, self.revision,
                                            self.precision, self.batch_size,
                                            self.device, self.no_cache, self.limit)

        # Mock setup
        mock_format_results.return_value = {'test': 'result'}
        mock_read_csv.return_value = pd.DataFrame({'column1': ['data1', 'data2']})
        mock_summary_generator.return_value.generate_summaries.return_value = pd.DataFrame({'column1': ['summary1', 'summary2']})
        mock_summary_generator.return_value.avg_length = 100
        mock_summary_generator.return_value.answer_rate = 1.0
        mock_summary_generator.return_value.error_rate = 0.0
        mock_eval_model.return_value.compute_accuracy.return_value = 1.0
        mock_eval_model.return_value.hallucination_rate = 0.0
        mock_eval_model.return_value.evaluate_hallucination.return_value = [0.5]

        # Method call and assertions
        results = evaluator.evaluate()
        mock_format_results.assert_called_once_with(model_name=self.model_name,
                                                    revision=self.revision,
                                                    precision=self.precision,
                                                    accuracy=1.0, hallucination_rate=0.0,
                                                    answer_rate=1.0, avg_summary_len=100,
                                                    error_rate=0.0)
        mock_read_csv.assert_called_once_with(envs.SOURCE_PATH)

    @patch('src.backend.evaluate_model.SummaryGenerator')
    @patch('src.backend.evaluate_model.EvaluationModel')
    @patch('src.backend.evaluate_model.pd.read_csv')
    def test_evaluate_with_file_not_found(self, mock_read_csv, mock_eval_model,
                                        mock_summary_generator):
        mock_read_csv.side_effect = FileNotFoundError('test_exception')
        evaluator = evaluate_model.Evaluator(self.model_name, self.revision,
                                            self.precision, self.batch_size,
                                            self.device, self.no_cache, self.limit)

        with self.assertRaises(FileNotFoundError):
            evaluator.evaluate()


if __name__ == '__main__':
    unittest.main()