H2OTest / llm_studio /src /h2oai_pipeline_template.py
elineve's picture
Upload 301 files
07423df
raw
history blame
1.37 kB
from transformers import TextGenerationPipeline
from transformers.pipelines.text_generation import ReturnType
STYLE = "{{text_prompt_start}}{instruction}{{end_of_sentence}}{{text_answer_separator}}"
class H2OTextGenerationPipeline(TextGenerationPipeline):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.prompt = STYLE
def preprocess(
self, prompt_text, prefix="", handle_long_generation=None, **generate_kwargs
):
prompt_text = self.prompt.format(instruction=prompt_text)
return super().preprocess(
prompt_text,
prefix=prefix,
handle_long_generation=handle_long_generation,
**generate_kwargs,
)
def postprocess(
self,
model_outputs,
return_type=ReturnType.FULL_TEXT,
clean_up_tokenization_spaces=True,
):
records = super().postprocess(
model_outputs,
return_type=return_type,
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
)
for rec in records:
rec["generated_text"] = (
rec["generated_text"]
.split("{{text_answer_separator}}")[1]
.strip()
.split("{{text_prompt_start}}")[0]
.strip()
)
return records