|
import abc |
|
import ast |
|
import collections |
|
from typing import ( |
|
Any, |
|
AsyncGenerator, |
|
Dict, |
|
Generator, |
|
List, |
|
Optional, |
|
OrderedDict, |
|
Union, |
|
) |
|
|
|
from h2ogpt_client._gradio_client import GradioClientWrapper |
|
from h2ogpt_client._h2ogpt_enums import ( |
|
DocumentSubset, |
|
LangChainAction, |
|
LangChainMode, |
|
PromptType, |
|
) |
|
from h2ogpt_client._models import Model |
|
|
|
_H2OGPT_PARAMETERS_TO_CLIENT = collections.OrderedDict( |
|
instruction="instruction", |
|
iinput="input", |
|
context="system_pre_context", |
|
stream_output="stream_output", |
|
prompt_type="prompt_type", |
|
prompt_dict="prompt_dict", |
|
temperature="temperature", |
|
top_p="top_p", |
|
top_k="top_k", |
|
penalty_alpha="penalty_alpha", |
|
num_beams="beams", |
|
max_new_tokens="max_output_length", |
|
min_new_tokens="min_output_length", |
|
early_stopping="early_stopping", |
|
max_time="max_time", |
|
repetition_penalty="repetition_penalty", |
|
num_return_sequences="number_returns", |
|
do_sample="enable_sampler", |
|
chat="chat", |
|
instruction_nochat="instruction_nochat", |
|
iinput_nochat="input_context_for_instruction", |
|
langchain_mode="langchain_mode", |
|
add_chat_history_to_context="add_chat_history_to_context", |
|
langchain_action="langchain_action", |
|
langchain_agents="langchain_agents", |
|
top_k_docs="langchain_top_k_docs", |
|
chunk="langchain_enable_chunk", |
|
chunk_size="langchain_chunk_size", |
|
document_subset="langchain_document_subset", |
|
document_choice="langchain_document_choice", |
|
document_source_substrings="langchain_document_source_substrings", |
|
document_source_substrings_op="langchain_document_source_substrings_op", |
|
document_content_substrings="langchain_document_content_substrings", |
|
document_content_substrings_op="langchain_document_content_substrings_op", |
|
pre_prompt_query="pre_prompt_query", |
|
prompt_query="prompt_query", |
|
pre_prompt_summary="pre_prompt_summary", |
|
prompt_summary="prompt_summary", |
|
hyde_llm_prompt="hyde_llm_prompt", |
|
system_prompt="system_prompt", |
|
image_audio_loaders="image_audio_loaders", |
|
pdf_loaders="pdf_loaders", |
|
url_loaders="url_loaders", |
|
jq_schema="jq_schema", |
|
visible_models="model", |
|
h2ogpt_key="h2ogpt_key", |
|
add_search_to_context="add_search_to_context", |
|
chat_conversation="chat_conversation", |
|
text_context_list="text_context_list", |
|
docs_ordering_type="docs_ordering_type", |
|
min_max_new_tokens="min_max_new_tokens", |
|
max_input_tokens="max_input_tokens", |
|
max_total_input_tokens="max_total_input_tokens", |
|
docs_token_handling="docs_token_handling", |
|
docs_joiner="docs_joiner", |
|
hyde_level="hyde_level", |
|
hyde_template="hyde_template", |
|
hyde_show_only_final="hyde_show_only_final", |
|
doc_json_mode="doc_json_mode", |
|
chatbot_role="chatbot_role", |
|
speaker="speaker", |
|
tts_language="tts_language", |
|
tts_speed="tts_speed", |
|
) |
|
|
|
|
|
def _to_h2ogpt_params(client_params: Dict[str, Any]) -> OrderedDict[str, Any]: |
|
"""Convert given params to the order of params in h2oGPT.""" |
|
|
|
h2ogpt_params: OrderedDict[str, Any] = collections.OrderedDict() |
|
for h2ogpt_param_name, client_param_name in _H2OGPT_PARAMETERS_TO_CLIENT.items(): |
|
if client_param_name in client_params: |
|
h2ogpt_params[h2ogpt_param_name] = client_params[client_param_name] |
|
return h2ogpt_params |
|
|
|
|
|
_DEFAULT_PARAMETERS: Dict[str, Any] = dict( |
|
instruction="", |
|
input="", |
|
system_pre_context="", |
|
stream_output=False, |
|
prompt_type=PromptType.plain.value, |
|
prompt_dict="", |
|
temperature=0.1, |
|
top_p=1.0, |
|
top_k=40, |
|
penalty_alpha=0.0, |
|
beams=1.0, |
|
max_output_length=1024, |
|
min_output_length=0, |
|
early_stopping=False, |
|
max_time=360, |
|
repetition_penalty=1.07, |
|
number_returns=1, |
|
enable_sampler=False, |
|
chat=False, |
|
instruction_nochat="", |
|
input_context_for_instruction="", |
|
langchain_mode=LangChainMode.DISABLED.value, |
|
add_chat_history_to_context=False, |
|
langchain_action=LangChainAction.QUERY.value, |
|
langchain_agents=[], |
|
langchain_top_k_docs=4, |
|
langchain_enable_chunk=True, |
|
langchain_chunk_size=512, |
|
langchain_document_subset=DocumentSubset.Relevant.name, |
|
langchain_document_choice=[], |
|
langchain_document_source_substrings=[], |
|
langchain_document_source_substrings_op='and', |
|
langchain_document_content_substrings=[], |
|
langchain_document_content_substrings_op='and', |
|
pre_prompt_query=[], |
|
prompt_query="", |
|
pre_prompt_summary="", |
|
prompt_summary="", |
|
hyde_llm_prompt="", |
|
system_prompt="", |
|
image_audio_loaders=[], |
|
pdf_loaders=[], |
|
url_loaders=[], |
|
jq_schema=".[]", |
|
model=None, |
|
h2ogpt_key=None, |
|
add_search_to_context=False, |
|
chat_conversation=None, |
|
text_context_list=[], |
|
docs_ordering_type="reverse_ucurve_sort", |
|
min_max_new_tokens=256, |
|
max_input_tokens=-1, |
|
max_total_input_tokens=-1, |
|
docs_token_handling="split_or_merge", |
|
docs_joiner="\n\n", |
|
hyde_level=0, |
|
hyde_template=None, |
|
hyde_show_only_final=None, |
|
doc_json_mode=False, |
|
chatbot_role="None", |
|
speaker="None", |
|
tts_language="autodetect", |
|
tts_speed=1.0, |
|
) |
|
|
|
|
|
class _Completion(abc.ABC): |
|
_API_NAME = "/submit_nochat_api" |
|
|
|
def __init__(self, client: GradioClientWrapper, parameters: OrderedDict[str, Any]): |
|
self._client = client |
|
self._parameters = dict(parameters) |
|
|
|
def _get_parameters(self, prompt: str) -> Dict[str, Any]: |
|
self._parameters["instruction_nochat"] = prompt |
|
return self._parameters |
|
|
|
@staticmethod |
|
def _get_reply(response: str) -> str: |
|
return ast.literal_eval(response)["response"] |
|
|
|
def _predict(self, prompt: str) -> str: |
|
response = self._client.predict( |
|
str(self._get_parameters(prompt)), api_name=self._API_NAME |
|
) |
|
return self._get_reply(response) |
|
|
|
def _predict_and_stream(self, prompt: str) -> Generator[str, None, None]: |
|
generator = self._client.predict_and_stream( |
|
str(self._get_parameters(prompt)), api_name=self._API_NAME |
|
) |
|
reply_size_so_far = 0 |
|
for response in generator: |
|
current_reply = self._get_reply(response) |
|
new_reply_chunk = current_reply[reply_size_so_far:] |
|
if not new_reply_chunk: |
|
continue |
|
reply_size_so_far += len(new_reply_chunk) |
|
yield new_reply_chunk |
|
|
|
async def _submit(self, prompt: str) -> str: |
|
response = await self._client.submit( |
|
str(self._get_parameters(prompt)), api_name=self._API_NAME |
|
) |
|
return self._get_reply(response) |
|
|
|
async def _submit_and_stream(self, prompt: str) -> AsyncGenerator[str, None]: |
|
generator = self._client.submit_and_stream( |
|
str(self._get_parameters(prompt)), api_name=self._API_NAME |
|
) |
|
reply_size_so_far = 0 |
|
async for response in generator: |
|
current_reply = self._get_reply(response) |
|
new_reply_chunk = current_reply[reply_size_so_far:] |
|
if not new_reply_chunk: |
|
continue |
|
reply_size_so_far += len(new_reply_chunk) |
|
yield new_reply_chunk |
|
|
|
|
|
class TextCompletionCreator: |
|
"""Builder that can create text completions.""" |
|
|
|
def __init__(self, client: GradioClientWrapper): |
|
self._client = client |
|
|
|
def create( |
|
self, |
|
model: Union[None, Model, str] = None, |
|
prompt_type: PromptType = PromptType.plain, |
|
input_context_for_instruction: str = "", |
|
enable_sampler=False, |
|
temperature: float = 0.1, |
|
top_p: float = 1.0, |
|
top_k: int = 40, |
|
penalty_alpha: float = 0.0, |
|
beams: float = 1.0, |
|
early_stopping: bool = False, |
|
min_output_length: int = 0, |
|
max_output_length: int = 1024, |
|
max_time: int = 360, |
|
repetition_penalty: float = 1.07, |
|
number_returns: int = 1, |
|
system_pre_context: str = "", |
|
langchain_mode: LangChainMode = LangChainMode.DISABLED, |
|
system_prompt: str = "", |
|
add_search_to_context: bool = False, |
|
text_context_list: List[str] = [], |
|
docs_ordering_type: str = "reverse_ucurve_sort", |
|
min_max_new_tokens: int = 256, |
|
max_input_tokens: int = -1, |
|
max_total_input_tokens: int = -1, |
|
docs_token_handling: str = "split_or_merge", |
|
docs_joiner: str = "\n\n", |
|
hyde_level: int = 0, |
|
hyde_template: Optional[str] = None, |
|
hyde_show_only_final: bool = False, |
|
doc_json_mode: bool = False, |
|
chatbot_role="None", |
|
speaker="None", |
|
tts_language="autodetect", |
|
tts_speed=1.0, |
|
) -> "TextCompletion": |
|
""" |
|
Creates a new text completion. |
|
|
|
:param model: model to be used, `None` means used the default model. |
|
:param prompt_type: type of the prompt |
|
:param input_context_for_instruction: input context for instruction |
|
:param enable_sampler: enable or disable the sampler, required for use of |
|
temperature, top_p, top_k |
|
:param temperature: What sampling temperature to use, between 0 and 3. |
|
Lower values will make it more focused and deterministic, but may lead |
|
to repeat. Higher values will make the output more creative, but may |
|
lead to hallucinations. |
|
:param top_p: cumulative probability of tokens to sample from |
|
:param top_k: number of tokens to sample from |
|
:param penalty_alpha: >0 and top_k>1 enable contrastive search (not all models support) |
|
:param beams: Number of searches for optimal overall probability. |
|
Higher values uses more GPU memory and compute. |
|
:param early_stopping: whether to stop early or not in beam search |
|
:param min_output_length: minimum output length |
|
:param max_output_length: maximum output length |
|
:param max_time: maximum time to search optimal output |
|
:param repetition_penalty: penalty for repetition |
|
:param number_returns: |
|
:param system_pre_context: directly pre-appended without prompt processing |
|
:param langchain_mode: LangChain mode |
|
:param system_prompt: Universal system prompt to override prompt_type's system |
|
prompt |
|
If pass 'None' or 'auto' or None, then automatic per-model value used |
|
:param add_search_to_context: Whether to add web search of query to context |
|
:param text_context_list: list of strings to use as context (up to allowed max_seq_len of model) |
|
:param docs_ordering_type: By default uses 'reverse_ucurve_sort' for optimal retrieval |
|
:param min_max_new_tokens: minimum value for max_new_tokens when auto-adjusting for content of prompt, docs, etc. |
|
:param max_input_tokens: Max input tokens to place into model context for each LLM call |
|
-1 means auto, fully fill context for query, and fill by original document chunk for summarization |
|
>=0 means use that to limit context filling to that many tokens |
|
:param max_total_input_tokens: like max_input_tokens but instead of per LLM call, applies across all LLM calls for single summarization/extraction action |
|
:param docs_token_handling: 'chunk' means fill context with top_k_docs (limited by max_input_tokens or model_max_len) chunks for query |
|
or top_k_docs original document chunks summarization |
|
None or 'split_or_merge' means same as 'chunk' for query, while for summarization merges documents to fill up to max_input_tokens or model_max_len tokens |
|
:param docs_joiner: string to join lists of text when doing split_or_merge. None means '\n\n' |
|
:param hyde_level: HYDE level for HYDE approach (https://arxiv.org/abs/2212.10496) |
|
0: No HYDE |
|
1: Use non-document-based LLM response and original query for embedding query |
|
2: Use document-based LLM response and original query for embedding query |
|
3+: Continue iterations of embedding prior answer and getting new response |
|
:param hyde_template: |
|
None, 'None', 'auto' uses internal value and enable |
|
'{query}' is minimal template one can pass |
|
:param hyde_show_only_final: See h2oGPT server docs |
|
:param doc_json_mode: whether to give JSON to LLM and get JSON response back |
|
:param chatbot_role: See h2oGPT server docs |
|
:param speaker: See h2oGPT server docs |
|
:param tts_language: See h2oGPT server docs |
|
:param tts_speed: See h2oGPT server docs |
|
""" |
|
args = locals().copy() |
|
args["prompt_type"] = prompt_type.value |
|
args["langchain_mode"] = langchain_mode.value |
|
params = _to_h2ogpt_params({**_DEFAULT_PARAMETERS, **args}) |
|
params["instruction_nochat"] = None |
|
params["h2ogpt_key"] = self._client.h2ogpt_key |
|
return TextCompletion(self._client, params) |
|
|
|
|
|
class TextCompletion(_Completion): |
|
"""Text completion.""" |
|
|
|
async def complete( |
|
self, prompt: str, enable_streaming: bool = False |
|
) -> Union[str, AsyncGenerator[str, None]]: |
|
""" |
|
Complete this text completion. |
|
|
|
:param prompt: text prompt to generate completion for |
|
:param enable_streaming: whether to enable or disable streaming the response |
|
:return: response from the model |
|
""" |
|
if enable_streaming: |
|
params = self._get_parameters(prompt) |
|
params["stream_output"] = True |
|
return self._submit_and_stream(prompt) |
|
else: |
|
return await self._submit(prompt) |
|
|
|
def complete_sync( |
|
self, prompt: str, enable_streaming: bool = False |
|
) -> Union[str, Generator[str, None, None]]: |
|
""" |
|
Complete this text completion synchronously. |
|
|
|
:param prompt: text prompt to generate completion for |
|
:param enable_streaming: whether to enable or disable streaming the response |
|
:return: response from the model |
|
""" |
|
if enable_streaming: |
|
params = self._get_parameters(prompt) |
|
params["stream_output"] = True |
|
return self._predict_and_stream(prompt) |
|
else: |
|
return self._predict(prompt) |
|
|
|
|
|
class ChatCompletionCreator: |
|
"""Chat completion.""" |
|
|
|
def __init__(self, client: GradioClientWrapper): |
|
self._client = client |
|
|
|
def create( |
|
self, |
|
model: Union[None, Model, str] = None, |
|
prompt_type: PromptType = PromptType.plain, |
|
input_context_for_instruction: str = "", |
|
enable_sampler=False, |
|
temperature: float = 0.1, |
|
top_p: float = 1.0, |
|
top_k: int = 40, |
|
penalty_alpha: float = 0.0, |
|
beams: float = 1.0, |
|
early_stopping: bool = False, |
|
min_output_length: int = 0, |
|
max_output_length: int = 1024, |
|
max_time: int = 360, |
|
repetition_penalty: float = 1.07, |
|
number_returns: int = 1, |
|
system_pre_context: str = "", |
|
langchain_mode: LangChainMode = LangChainMode.DISABLED, |
|
system_prompt: str = "", |
|
add_search_to_context: bool = False, |
|
text_context_list: List[str] = [], |
|
docs_ordering_type: str = "reverse_ucurve_sort", |
|
min_max_new_tokens: int = 256, |
|
max_input_tokens: int = -1, |
|
max_total_input_tokens: int = -1, |
|
docs_token_handling: str = "split_or_merge", |
|
docs_joiner: str = "\n\n", |
|
hyde_level: int = 0, |
|
hyde_template: Optional[str] = None, |
|
hyde_show_only_final: bool = False, |
|
doc_json_mode: bool = False, |
|
chatbot_role="None", |
|
speaker="None", |
|
tts_language="autodetect", |
|
tts_speed=1.0, |
|
) -> "ChatCompletion": |
|
""" |
|
Creates a new chat completion. |
|
|
|
:param model: model to be used, `None` means used the default model. |
|
:param prompt_type: type of the prompt |
|
:param input_context_for_instruction: input context for instruction |
|
:param enable_sampler: enable or disable the sampler, required for use of |
|
temperature, top_p, top_k |
|
:param temperature: What sampling temperature to use, between 0 and 3. |
|
Lower values will make it more focused and deterministic, but may lead |
|
to repeat. Higher values will make the output more creative, but may |
|
lead to hallucinations. |
|
:param top_p: cumulative probability of tokens to sample from |
|
:param top_k: number of tokens to sample from |
|
:param penalty_alpha: >0 and top_k>1 enable contrastive search (not all models support) |
|
:param beams: Number of searches for optimal overall probability. |
|
Higher values uses more GPU memory and compute. |
|
:param early_stopping: whether to stop early or not in beam search |
|
:param min_output_length: minimum output length |
|
:param max_output_length: maximum output length |
|
:param max_time: maximum time to search optimal output |
|
:param repetition_penalty: penalty for repetition |
|
:param number_returns: |
|
:param system_pre_context: directly pre-appended without prompt processing |
|
:param langchain_mode: LangChain mode |
|
:param system_prompt: Universal system prompt to override prompt_type's system |
|
prompt |
|
:param add_search_to_context: Whether to add web search of query to context |
|
:param text_context_list: list of strings to use as context (up to allowed max_seq_len of model) |
|
:param docs_ordering_type: By default uses 'reverse_ucurve_sort' for optimal retrieval |
|
:param min_max_new_tokens: minimum value for max_new_tokens when auto-adjusting for content of prompt, docs, etc. |
|
:param max_input_tokens: Max input tokens to place into model context for each LLM call |
|
-1 means auto, fully fill context for query, and fill by original document chunk for summarization |
|
>=0 means use that to limit context filling to that many tokens |
|
:param max_total_input_tokens: like max_input_tokens but instead of per LLM call, applies across all LLM calls for single summarization/extraction action |
|
:param docs_token_handling: 'chunk' means fill context with top_k_docs (limited by max_input_tokens or model_max_len) chunks for query |
|
or top_k_docs original document chunks summarization |
|
None or 'split_or_merge' means same as 'chunk' for query, while for summarization merges documents to fill up to max_input_tokens or model_max_len tokens |
|
:param docs_joiner: string to join lists of text when doing split_or_merge. None means '\n\n' |
|
:param hyde_level: HYDE level for HYDE approach (https://arxiv.org/abs/2212.10496) |
|
0: No HYDE |
|
1: Use non-document-based LLM response and original query for embedding query |
|
2: Use document-based LLM response and original query for embedding query |
|
3+: Continue iterations of embedding prior answer and getting new response |
|
:param hyde_template: |
|
None, 'None', 'auto' uses internal value and enable |
|
'{query}' is minimal template one can pass |
|
:param hyde_show_only_final: See h2oGPT server docs |
|
:param doc_json_mode: whether to give JSON to LLM and get JSON response back |
|
:param chatbot_role: See h2oGPT server docs |
|
:param speaker: See h2oGPT server docs |
|
:param tts_language: See h2oGPT server docs |
|
:param tts_speed: See h2oGPT server docs |
|
""" |
|
args = locals().copy() |
|
args["prompt_type"] = prompt_type.value |
|
args["langchain_mode"] = langchain_mode.value |
|
params = _to_h2ogpt_params({**_DEFAULT_PARAMETERS, **args}) |
|
params["instruction_nochat"] = None |
|
params["add_chat_history_to_context"] = True |
|
params["h2ogpt_key"] = self._client.h2ogpt_key |
|
params["chat_conversation"] = [] |
|
return ChatCompletion(self._client, params) |
|
|
|
|
|
class ChatCompletion(_Completion): |
|
"""Chat completion.""" |
|
|
|
def _update_history(self, prompt: str, reply: str) -> None: |
|
self._parameters["chat_conversation"].append((prompt, reply)) |
|
|
|
async def chat(self, prompt: str) -> Dict[str, str]: |
|
""" |
|
Complete this chat completion. |
|
|
|
:param prompt: text prompt to generate completions for |
|
:returns chat reply |
|
""" |
|
reply = await self._submit(prompt) |
|
self._update_history(prompt, reply) |
|
return {"user": prompt, "gpt": reply} |
|
|
|
def chat_sync(self, prompt: str) -> Dict[str, str]: |
|
""" |
|
Complete this chat completion. |
|
|
|
:param prompt: text prompt to generate completions for |
|
:returns chat reply |
|
""" |
|
reply = self._predict(prompt) |
|
self._update_history(prompt, reply) |
|
return {"user": prompt, "gpt": reply} |
|
|
|
def chat_history(self) -> List[Dict[str, str]]: |
|
"""Returns the full chat history.""" |
|
return [ |
|
{"user": i[0], "gpt": i[1]} for i in self._parameters["chat_conversation"] |
|
] |
|
|