File size: 5,127 Bytes
084fe8e
acb3380
084fe8e
 
 
acb3380
 
 
084fe8e
acb3380
 
084fe8e
 
 
 
 
 
acb3380
 
 
 
 
084fe8e
 
 
acb3380
 
 
 
 
 
 
 
084fe8e
 
 
 
 
 
 
acb3380
084fe8e
acb3380
 
084fe8e
 
 
 
acb3380
084fe8e
 
 
 
acb3380
084fe8e
 
 
 
 
 
 
 
 
 
 
 
acb3380
084fe8e
acb3380
 
084fe8e
acb3380
084fe8e
acb3380
 
 
 
084fe8e
 
acb3380
 
 
084fe8e
 
 
 
 
acb3380
 
084fe8e
acb3380
084fe8e
acb3380
 
 
 
084fe8e
 
acb3380
 
 
084fe8e
 
 
 
 
acb3380
 
084fe8e
acb3380
084fe8e
acb3380
084fe8e
acb3380
 
 
 
084fe8e
 
acb3380
 
 
084fe8e
 
 
 
 
acb3380
 
084fe8e
 
 
 
 
 
 
 
acb3380
 
 
 
 
 
 
 
084fe8e
 
 
 
acb3380
084fe8e
acb3380
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
from typing import Any, Callable, Dict, Optional, Tuple, Type

from openai import OpenAI

from ..utils.decorator import score_exponential_backoff


class BaseProcessor(object):
    _processor_registry: Dict[str, Type["BaseProcessor"]] = {}

    @classmethod
    def register_processor(
        cls, processor_name: str
    ) -> Callable[[Type["BaseProcessor"]], Type["BaseProcessor"]]:
        def decorator(
            subclass: Type["BaseProcessor"],
        ) -> Type["BaseProcessor"]:
            cls._processor_registry[processor_name] = subclass
            return subclass

        return decorator

    def __new__(
        cls, processor_name: str, *args: Any, **kwargs: Any
    ) -> "BaseProcessor":
        if processor_name not in cls._processor_registry:
            raise ValueError(
                f"No processor registered with name '{processor_name}'"
            )
        return super(BaseProcessor, cls).__new__(
            cls._processor_registry[processor_name]
        )

    def __init__(self, *args: Any, **kwargs: Any) -> None:
        self.init_scorer()
        self.init_executor()
        self.init_messenger()
        self.init_task_info()

    def init_executor(self) -> None:
        raise NotImplementedError(
            "The 'init_executor' method must be implemented in derived classes."
        )

    def init_messenger(self) -> None:
        raise NotImplementedError(
            "The 'init_messenger' method must be implemented in derived classes."
        )

    def init_task_info(self) -> None:
        raise NotImplementedError(
            "The 'init_task_info' method must be implemented in derived classes."
        )

    def init_scorer(self) -> None:
        self.scorer = OpenAI()

    def ask(
        self, query: str, text: str, image: str, audio: str, video_frames: str
    ) -> Tuple[str, float]:
        gist = self.ask_info(
            query=query,
            text=text,
            image=image,
            audio=audio,
            video_frames=video_frames,
        )
        score = self.ask_score(query, gist, verbose=True)
        return gist, score

    @score_exponential_backoff(retries=5, base_wait_time=1)
    def ask_relevance(self, query: str, gist: str) -> float:
        response = self.scorer.chat.completions.create(
            model="gpt-4-0125-preview",
            messages=[
                {
                    "role": "user",
                    "content": f"How related is the information ({gist}) with the query ({query})? Answer with a number from 0 to 5 and do not add any other thing.",
                }
            ],
            max_tokens=50,
        )
        score = (
            float(response.choices[0].message.content.strip()) / 5
            if response.choices[0].message.content
            else 0.0
        )
        return score

    @score_exponential_backoff(retries=5, base_wait_time=1)
    def ask_confidence(self, query: str, gist: str) -> float:
        response = self.scorer.chat.completions.create(
            model="gpt-4-0125-preview",
            messages=[
                {
                    "role": "user",
                    "content": f"How confident do you think the information ({gist}) is a must-know? Answer with a number from 0 to 5 and do not add any other thing.",
                }
            ],
            max_tokens=50,
        )
        score = (
            float(response.choices[0].message.content.strip()) / 5
            if response.choices[0].message.content
            else 0.0
        )
        return score

    @score_exponential_backoff(retries=5, base_wait_time=1)
    def ask_surprise(
        self, query: str, gist: str, history_gists: Optional[str] = None
    ) -> float:
        response = self.scorer.chat.completions.create(
            model="gpt-4-0125-preview",
            messages=[
                {
                    "role": "user",
                    "content": f"How surprising do you think the information ({gist}) is as an output of the processor? Answer with a number from 0 to 5 and do not add any other thing.",
                }
            ],
            max_tokens=50,
        )
        score = (
            float(response.choices[0].message.content.strip()) / 5
            if response.choices[0].message.content
            else 0.0
        )
        return score

    def ask_score(
        self,
        query: str,
        gist: str,
        verbose: bool = False,
        *args: Any,
        **kwargs: Any,
    ) -> float:
        relevance = self.ask_relevance(query, gist, *args, **kwargs)
        confidence = self.ask_confidence(query, gist, *args, **kwargs)
        surprise = self.ask_surprise(query, gist, *args, **kwargs)
        if verbose:
            print(
                f"Relevance: {relevance}, Confidence: {confidence}, Surprise: {surprise}"
            )

        final_score = relevance * confidence * surprise
        return final_score

    def ask_info(self, *args: Any, **kwargs: Any) -> str:
        raise NotImplementedError(
            "The 'ask_info' method must be implemented in derived classes."
        )