File size: 1,804 Bytes
bfa1717
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
from label_studio_ml.model import LabelStudioMLBase
from sentiment_cnn import SentimentCNN

import torch
import torch.nn as nn
import torchtext
 
class SentimentModel(LabelStudioMLBase):
    def __init__(self, **kwargs):
        super(SentimentModel, self).__init__(**kwargs)

        self.sentiment_model = SentimentCNN(
                state_dict='data/cnn.pt',
                vocab='data/vocab_obj.pt')

        self.label_map = {
            1: "Positive",
            0: "Negative"}

    def predict(self, tasks, **kwargs):
        predictions = []
   
        # Get annotation tag first, and extract from_name/to_name keys from the labeling config
        #  to make predictions
        from_name, schema = list(self.parsed_label_config.items())[0]
        to_name = schema['to_name'][0]
        data_name = schema['inputs'][0]['value']

        for task in tasks:
            # load the data and make a prediction with the model
            text = task['data'][data_name]
            predicted_class, predicted_prob = self.sentiment_model.predict_sentiment(text)
            print("%s\nprediction: %s probability: %s" % (text, predicted_class, predicted_prob))

            label = self.label_map[predicted_class]

            # for each task, return classification results in the form of "choices" pre-annotations
            prediction = {
                'score': float(predicted_prob),
                'result': [{
                    'from_name': from_name,
                    'to_name': to_name,
                    'type': 'choices',
                    'value': {
                        'choices': [
                            label
                        ]
                    },
                }]
            }
            predictions.append(prediction)
        return predictions