DCWIR-Demo / textattack /models /wrappers /demo_model_wrapper.py
PFEemp2024's picture
add necessary file
63775f2
raw
history blame contribute delete
No virus
730 Bytes
from textattack.models.wrappers import HuggingFaceModelWrapper
class TADModelWrapper(HuggingFaceModelWrapper):
"""Transformers sentiment analysis pipeline returns a list of responses
like
[{'label': 'POSITIVE', 'score': 0.7817379832267761}]
We need to convert that to a format TextAttack understands, like
[[0.218262017, 0.7817379832267761]
"""
def __init__(self, model):
self.model = model # pipeline = pipeline
def __call__(self, text_inputs, **kwargs):
outputs = []
for text_input in text_inputs:
raw_outputs = self.model.infer(text_input, print_result=False, **kwargs)
outputs.append(raw_outputs["probs"])
return outputs