File size: 3,444 Bytes
a8b3f00
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
from typing import Optional

from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance
from core.model_runtime.entities.message_entities import PromptMessage
from core.model_runtime.entities.model_entities import ModelPropertyKey
from core.prompt.entities.advanced_prompt_entities import MemoryConfig


class PromptTransform:
    def _append_chat_histories(
        self,
        memory: TokenBufferMemory,
        memory_config: MemoryConfig,
        prompt_messages: list[PromptMessage],
        model_config: ModelConfigWithCredentialsEntity,
    ) -> list[PromptMessage]:
        rest_tokens = self._calculate_rest_token(prompt_messages, model_config)
        histories = self._get_history_messages_list_from_memory(memory, memory_config, rest_tokens)
        prompt_messages.extend(histories)

        return prompt_messages

    def _calculate_rest_token(
        self, prompt_messages: list[PromptMessage], model_config: ModelConfigWithCredentialsEntity
    ) -> int:
        rest_tokens = 2000

        model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
        if model_context_tokens:
            model_instance = ModelInstance(
                provider_model_bundle=model_config.provider_model_bundle, model=model_config.model
            )

            curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages)

            max_tokens = 0
            for parameter_rule in model_config.model_schema.parameter_rules:
                if parameter_rule.name == "max_tokens" or (
                    parameter_rule.use_template and parameter_rule.use_template == "max_tokens"
                ):
                    max_tokens = (
                        model_config.parameters.get(parameter_rule.name)
                        or model_config.parameters.get(parameter_rule.use_template)
                    ) or 0

            rest_tokens = model_context_tokens - max_tokens - curr_message_tokens
            rest_tokens = max(rest_tokens, 0)

        return rest_tokens

    def _get_history_messages_from_memory(
        self,
        memory: TokenBufferMemory,
        memory_config: MemoryConfig,
        max_token_limit: int,
        human_prefix: Optional[str] = None,
        ai_prefix: Optional[str] = None,
    ) -> str:
        """Get memory messages."""
        kwargs = {"max_token_limit": max_token_limit}

        if human_prefix:
            kwargs["human_prefix"] = human_prefix

        if ai_prefix:
            kwargs["ai_prefix"] = ai_prefix

        if memory_config.window.enabled and memory_config.window.size is not None and memory_config.window.size > 0:
            kwargs["message_limit"] = memory_config.window.size

        return memory.get_history_prompt_text(**kwargs)

    def _get_history_messages_list_from_memory(
        self, memory: TokenBufferMemory, memory_config: MemoryConfig, max_token_limit: int
    ) -> list[PromptMessage]:
        """Get memory messages."""
        return memory.get_history_prompt_messages(
            max_token_limit=max_token_limit,
            message_limit=memory_config.window.size
            if (
                memory_config.window.enabled and memory_config.window.size is not None and memory_config.window.size > 0
            )
            else None,
        )