File size: 2,425 Bytes
a8b3f00
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import os
from collections.abc import Callable
from typing import Literal

import pytest

# import monkeypatch
from _pytest.monkeypatch import MonkeyPatch
from openai.resources.audio.transcriptions import Transcriptions
from openai.resources.chat import Completions as ChatCompletions
from openai.resources.completions import Completions
from openai.resources.embeddings import Embeddings
from openai.resources.models import Models
from openai.resources.moderations import Moderations

from tests.integration_tests.model_runtime.__mock.openai_chat import MockChatClass
from tests.integration_tests.model_runtime.__mock.openai_completion import MockCompletionsClass
from tests.integration_tests.model_runtime.__mock.openai_embeddings import MockEmbeddingsClass
from tests.integration_tests.model_runtime.__mock.openai_moderation import MockModerationClass
from tests.integration_tests.model_runtime.__mock.openai_remote import MockModelClass
from tests.integration_tests.model_runtime.__mock.openai_speech2text import MockSpeech2TextClass


def mock_openai(
    monkeypatch: MonkeyPatch,
    methods: list[Literal["completion", "chat", "remote", "moderation", "speech2text", "text_embedding"]],
) -> Callable[[], None]:
    """
    mock openai module

    :param monkeypatch: pytest monkeypatch fixture
    :return: unpatch function
    """

    def unpatch() -> None:
        monkeypatch.undo()

    if "completion" in methods:
        monkeypatch.setattr(Completions, "create", MockCompletionsClass.completion_create)

    if "chat" in methods:
        monkeypatch.setattr(ChatCompletions, "create", MockChatClass.chat_create)

    if "remote" in methods:
        monkeypatch.setattr(Models, "list", MockModelClass.list)

    if "moderation" in methods:
        monkeypatch.setattr(Moderations, "create", MockModerationClass.moderation_create)

    if "speech2text" in methods:
        monkeypatch.setattr(Transcriptions, "create", MockSpeech2TextClass.speech2text_create)

    if "text_embedding" in methods:
        monkeypatch.setattr(Embeddings, "create", MockEmbeddingsClass.create_embeddings)

    return unpatch


MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true"


@pytest.fixture
def setup_openai_mock(request, monkeypatch):
    methods = request.param if hasattr(request, "param") else []
    if MOCK:
        unpatch = mock_openai(monkeypatch, methods=methods)

    yield

    if MOCK:
        unpatch()