|
from transformers import TextGenerationPipeline |
|
from transformers.pipelines.text_generation import ReturnType |
|
|
|
STYLE = "<|prompt|>{instruction}</s><|answer|>" |
|
|
|
|
|
class MambaGPTTextGenerationPipeline(TextGenerationPipeline): |
|
def __init__(self, *args, **kwargs): |
|
super().__init__(*args, **kwargs) |
|
self.prompt = STYLE |
|
|
|
def preprocess( |
|
self, prompt_text, prefix="", handle_long_generation=None, **generate_kwargs |
|
): |
|
prompt_text = self.prompt.format(instruction=prompt_text) |
|
return super().preprocess( |
|
prompt_text, |
|
prefix=prefix, |
|
handle_long_generation=handle_long_generation, |
|
**generate_kwargs, |
|
) |
|
|
|
def postprocess( |
|
self, |
|
model_outputs, |
|
return_type=ReturnType.FULL_TEXT, |
|
clean_up_tokenization_spaces=True, |
|
): |
|
records = super().postprocess( |
|
model_outputs, |
|
return_type=return_type, |
|
clean_up_tokenization_spaces=clean_up_tokenization_spaces, |
|
) |
|
for rec in records: |
|
rec["generated_text"] = ( |
|
rec["generated_text"] |
|
.split("<|answer|>")[1] |
|
.strip() |
|
.split("<|prompt|>")[0] |
|
.strip() |
|
) |
|
return records |
|
|