|
from transformers import TextGenerationPipeline |
|
from transformers.pipelines.text_generation import ReturnType |
|
|
|
STYLE = "{{text_prompt_start}}{instruction}{{end_of_sentence}}{{text_answer_separator}}" |
|
|
|
|
|
class H2OTextGenerationPipeline(TextGenerationPipeline): |
|
def __init__(self, *args, **kwargs): |
|
super().__init__(*args, **kwargs) |
|
self.prompt = STYLE |
|
|
|
def preprocess( |
|
self, prompt_text, prefix="", handle_long_generation=None, **generate_kwargs |
|
): |
|
prompt_text = self.prompt.format(instruction=prompt_text) |
|
return super().preprocess( |
|
prompt_text, |
|
prefix=prefix, |
|
handle_long_generation=handle_long_generation, |
|
**generate_kwargs, |
|
) |
|
|
|
def postprocess( |
|
self, |
|
model_outputs, |
|
return_type=ReturnType.FULL_TEXT, |
|
clean_up_tokenization_spaces=True, |
|
): |
|
records = super().postprocess( |
|
model_outputs, |
|
return_type=return_type, |
|
clean_up_tokenization_spaces=clean_up_tokenization_spaces, |
|
) |
|
for rec in records: |
|
rec["generated_text"] = ( |
|
rec["generated_text"] |
|
.split("{{text_answer_separator}}")[1] |
|
.strip() |
|
.split("{{text_prompt_start}}")[0] |
|
.strip() |
|
) |
|
return records |
|
|