Last commit not found
from typing import Any, Dict, List, Optional | |
from .operator import StreamInstanceOperator | |
class Tasker: | |
pass | |
class FormTask(Tasker, StreamInstanceOperator): | |
inputs: List[str] | |
outputs: List[str] | |
metrics: List[str] | |
augmentable_inputs: List[str] = [] | |
def verify(self): | |
for augmentable_input in self.augmentable_inputs: | |
assert ( | |
augmentable_input in self.inputs | |
), f"augmentable_input f{augmentable_input} is not part of {self.inputs}" | |
def process( | |
self, instance: Dict[str, Any], stream_name: Optional[str] = None | |
) -> Dict[str, Any]: | |
try: | |
inputs = {key: instance[key] for key in self.inputs} | |
except KeyError as e: | |
raise KeyError( | |
f"Unexpected FormTask input column names ({[key for key in self.inputs if key not in instance]})." | |
f"The available input names: {list(instance.keys())}" | |
) from e | |
try: | |
outputs = {key: instance[key] for key in self.outputs} | |
except KeyError as e: | |
raise KeyError( | |
f"Unexpected FormTask output column names: {[key for key in self.outputs if key not in instance]}" | |
f" \n available names:{list(instance.keys())}\n given output names:{self.outputs}" | |
) from e | |
return { | |
"inputs": inputs, | |
"outputs": outputs, | |
"metrics": self.metrics, | |
} | |
class MultipleChoiceTask(FormTask): | |
choices_field: str = "choices" | |
choices_separator: str = "\n" | |
enumeration_suffix: str = ". " | |
use_text_in_target: bool = False | |
alphabet: str = "ABCDEFGHIJKLMNOPQRSTUVWXYZ" | |
def process_single_choice( | |
self, choice: str, index: int, use_text: bool = True | |
) -> str: | |
try: | |
processed_choice = f"{self.alphabet[index]}" | |
except IndexError as e: | |
raise ValueError( | |
f"Too many choices, the length of alphabet '{self.alphabet}': {len(self.alphabet)} is the limit" | |
) from e | |
if use_text: | |
processed_choice += f"{self.enumeration_suffix}{choice}" | |
return processed_choice | |
def process_choices(self, choices: List[str]) -> str: | |
processed_choices = [] | |
for index, choice in enumerate(choices): | |
processed_choices.append(self.process_single_choice(choice, index)) | |
return self.choices_separator.join(processed_choices) | |
def process_target(self, choices, target_index): | |
return self.process_single_choice( | |
choices[target_index], target_index, use_text=self.use_text_in_target | |
) | |
def process( | |
self, instance: Dict[str, Any], stream_name: Optional[str] = None | |
) -> Dict[str, Any]: | |
result = super().process(instance, stream_name) | |
target_key, target_value = next(iter(result["outputs"].items())) | |
choices = result["inputs"][self.choices_field] | |
target_index_in_choices = choices.index(target_value) | |
processed_choices = self.process_choices(choices) | |
processed_target = self.process_target(choices, target_index_in_choices) | |
result["inputs"][self.choices_field] = processed_choices | |
result["outputs"][target_key] = processed_target | |
return result | |