File size: 3,752 Bytes
82a161e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import uuid
import g4f
from g4f import ChatCompletion

TEST_PROMPT = "Generate a sentence with 'ocean'"
EXPECTED_RESPONSE_CONTAINS = "ocean"


class Provider:
    def __init__(self, name, models):
        """
        Initialize the provider with its name and models.
        """
        self.name = name
        self.models = models if isinstance(models, list) else [models]

    def __str__(self):
        return self.name


class ModelProviderManager:
    def __init__(self):
        """
        Initialize the manager that manages the working (active) providers for each model.
        """
        self._working_model_providers = {}

    def add_provider(self, model, provider_name):
        """
        Add a provider to the working provider list of the specified model.
        """
        if model not in self._working_model_providers:
            self._working_model_providers[model] = []
        self._working_model_providers[model].append(provider_name)

    def get_working_providers(self):
        """
        Return the currently active providers for each model.
        """
        return self._working_model_providers


def _fetch_providers_having_models():
    """
    Get providers that have models from g4f.Providers.
    """
    model_providers = []

    for provider_name in dir(g4f.Provider):
        provider = getattr(g4f.Provider, provider_name)

        if _is_provider_applicable(provider):
            model_providers.append(Provider(provider_name, provider.model))

    return model_providers


def _is_provider_applicable(provider):
    """
    Check if the provider has a model and doesn't require authentication.
    """
    return (hasattr(provider, 'model') and
            hasattr(provider, '_create_completion') and
            hasattr(provider, 'needs_auth') and
            not provider.needs_auth)


def _generate_test_messages():
    """
    Generate messages for testing.
    """
    return [{"role": "system", "content": "You are a trained AI assistant."},
            {"role": "user", "content": TEST_PROMPT}]


def _manage_chat_completion(manager, model_providers, test_messages):
    """
    Generate chat completion for each provider's models and handle positive and negative results.
    """
    for provider in model_providers:
        for model in provider.models:
            try:
                response = _generate_chat_response(
                    provider.name, model, test_messages)
                if EXPECTED_RESPONSE_CONTAINS in response.lower():
                    _print_success_response(provider, model)
                    manager.add_provider(model, provider.name)
                else:
                    raise Exception(f"Unexpected response: {response}")
            except Exception as error:
                _print_error_response(provider, model, error)


def _generate_chat_response(provider_name, model, test_messages):
    """
    Generate a chat response given a provider name, a model, and test messages.
    """
    return ChatCompletion.create(
        model=model,
        messages=test_messages,
        chatId=str(uuid.uuid4()),
        provider=getattr(g4f.Provider, provider_name)
    )


def _print_success_response(provider, model):
    print(f"\u2705 [{provider}] - [{model}]: Success")


def _print_error_response(provider, model, error):
    print(f"\u26D4 [{provider}] - [{model}]: Error - {str(error)}")


def get_active_model_providers():
    """
    Get providers that are currently working (active).
    """
    model_providers = _fetch_providers_having_models()
    test_messages = _generate_test_messages()
    manager = ModelProviderManager()

    _manage_chat_completion(manager, model_providers, test_messages)

    return manager.get_working_providers()