File size: 600 Bytes
859da2a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
import pandas as pd

def predict(data, task, model, tokenizer, config, **kwargs):
    if isinstance(data, pd.DataFrame):
        data = data[data.columns[0]].tolist()
        is_df = True
    results = []
    addn_args = kwargs.get("addn_args", {})
    for d in data:
        inputs = tokenizer(d, return_tensors="pt", return_attention_mask=False)
        outputs = model.generate(**inputs, **addn_args, max_length=50)
        text = tokenizer.batch_decode(outputs)[0]
        results.append(text)
    if is_df:
        return pd.DataFrame(results,columns =['output'])
    return {"output": results}