File size: 1,845 Bytes
a8b3f00
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import os
from collections.abc import Callable
from typing import Literal

import httpx
import pytest
from _pytest.monkeypatch import MonkeyPatch


def mock_get(*args, **kwargs):
    if kwargs.get("headers", {}).get("Authorization") != "Bearer test":
        raise httpx.HTTPStatusError(
            "Invalid API key",
            request=httpx.Request("GET", ""),
            response=httpx.Response(401),
        )

    return httpx.Response(
        200,
        json={
            "items": [
                {"title": "Model 1", "_id": "model1"},
                {"title": "Model 2", "_id": "model2"},
            ]
        },
        request=httpx.Request("GET", ""),
    )


def mock_stream(*args, **kwargs):
    class MockStreamResponse:
        def __init__(self):
            self.status_code = 200

        def __enter__(self):
            return self

        def __exit__(self, exc_type, exc_val, exc_tb):
            pass

        def iter_bytes(self):
            yield b"Mocked audio data"

    return MockStreamResponse()


def mock_fishaudio(
    monkeypatch: MonkeyPatch,
    methods: list[Literal["list-models", "tts"]],
) -> Callable[[], None]:
    """
    mock fishaudio module

    :param monkeypatch: pytest monkeypatch fixture
    :return: unpatch function
    """

    def unpatch() -> None:
        monkeypatch.undo()

    if "list-models" in methods:
        monkeypatch.setattr(httpx, "get", mock_get)

    if "tts" in methods:
        monkeypatch.setattr(httpx, "stream", mock_stream)

    return unpatch


MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true"


@pytest.fixture
def setup_fishaudio_mock(request, monkeypatch):
    methods = request.param if hasattr(request, "param") else []
    if MOCK:
        unpatch = mock_fishaudio(monkeypatch, methods=methods)

    yield

    if MOCK:
        unpatch()