File size: 600 Bytes
859da2a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 |
import pandas as pd
def predict(data, task, model, tokenizer, config, **kwargs):
if isinstance(data, pd.DataFrame):
data = data[data.columns[0]].tolist()
is_df = True
results = []
addn_args = kwargs.get("addn_args", {})
for d in data:
inputs = tokenizer(d, return_tensors="pt", return_attention_mask=False)
outputs = model.generate(**inputs, **addn_args, max_length=50)
text = tokenizer.batch_decode(outputs)[0]
results.append(text)
if is_df:
return pd.DataFrame(results,columns =['output'])
return {"output": results}
|