File size: 5,571 Bytes
7368ee6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c698696
 
 
 
 
 
ca685d7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7368ee6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ca685d7
 
7368ee6
 
 
ca685d7
 
7368ee6
ca685d7
7368ee6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ca685d7
7368ee6
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
using System;
using System.Collections.Generic;
using System.Linq;
using Unity.Sentis;
using UnityEngine;

public sealed class DebertaV3 : MonoBehaviour
{
    public ModelAsset model;
    public TextAsset vocabulary;
    public bool multipleTrueClasses;
    public string text = "Angela Merkel is a politician in Germany and leader of the CDU";
    public string hypothesisTemplate = "This example is about {}";
    public string[] classes = { "politics", "economy", "entertainment", "environment" };

    IWorker engine;
    string[] vocabularyTokens;

    const int padToken = 0;
    const int startToken = 1;
    const int separatorToken = 2;
    const int vocabToTokenOffset = 260;

    void Start()
    {
        if (classes.Length == 0)
        {
            Debug.LogError("There need to be more than 0 classes");
            return;
        }

        vocabularyTokens = vocabulary.text.Replace("\r", "").Split("\n");

        Model baseModel = ModelLoader.Load(model);
        Model modelWithScoring = Functional.Compile(
            input =>
            {
                // The logits represent the model's predictions for entailment and non-entailment for each example in the batch.
                // They are of shape [batch size, 2] i.e. with two values per example.
                // To obtain a single score per example, a softmax function is applied
                FunctionalTensor logits = baseModel.Forward(input)[0];

                if (multipleTrueClasses || classes.Length == 1)
                {
                    // Softmax over the entailment vs. contradiction dimension for each label independently
                    logits = Functional.Softmax(logits);
                }
                else
                {
                    // Softmax over all candidate labels
                    logits = Functional.Softmax(logits, 0);
                }

                // The scores are stored along the first column
                return new []{logits[.., 0]};
            },
            InputDef.FromModel(baseModel)
        );

        engine = WorkerFactory.CreateWorker(BackendType.GPUCompute, modelWithScoring);

        string[] hypotheses = classes.Select(x => hypothesisTemplate.Replace("{}", x)).ToArray();
        Batch batch = GetTokenizedBatch(text, hypotheses);
        float[] scores = GetBatchScores(batch);

        for (int i = 0; i < scores.Length; i++)
        {
            Debug.Log($"[{classes[i]}] Entailment Score: {scores[i]}");
        }
    }

    float[] GetBatchScores(Batch batch)
    {
        using var inputIds = new TensorInt(new TensorShape(batch.BatchCount, batch.BatchLength), batch.BatchedTokens);
        using var attentionMask = new TensorInt(new TensorShape(batch.BatchCount, batch.BatchLength), batch.BatchedMasks);

        Dictionary<string, Tensor> inputs = new()
        {
            {"input_0", inputIds},
            {"input_1", attentionMask}
        };

        engine.Execute(inputs);
        TensorFloat scores = (TensorFloat)engine.PeekOutput("output_0");
        scores.CompleteOperationsAndDownload();

        return scores.ToReadOnlyArray();
    }

    Batch GetTokenizedBatch(string prompt, string[] hypotheses)
    {
        Batch batch = new Batch();

        List<int> promptTokens = Tokenize(prompt);
        promptTokens.Insert(0, startToken);

        List<int>[] tokenizedHypotheses = hypotheses.Select(Tokenize).ToArray();
        int maxTokenLength = tokenizedHypotheses.Max(x => x.Count);

        // Each example in the batch follows this format:
        // Start Prompt Separator Hypothesis Separator Padding

        int[] batchedTokens = tokenizedHypotheses.SelectMany(hypothesis => promptTokens
                .Append(separatorToken)
                .Concat(hypothesis)
                .Append(separatorToken)
                .Concat(Enumerable.Repeat(padToken, maxTokenLength - hypothesis.Count)))
            .ToArray();


        // The attention masks have the same length as the tokens.
        // Each attention mask contains repeating 1s for each token, except for padding tokens.

        int[] batchedMasks = tokenizedHypotheses.SelectMany(hypothesis => Enumerable.Repeat(1, promptTokens.Count + 1)
                .Concat(Enumerable.Repeat(1, hypothesis.Count + 1))
                .Concat(Enumerable.Repeat(0, maxTokenLength - hypothesis.Count)))
            .ToArray();

        batch.BatchCount = hypotheses.Length;
        batch.BatchLength = batchedTokens.Length / hypotheses.Length;
        batch.BatchedTokens = batchedTokens;
        batch.BatchedMasks = batchedMasks;

        return batch;
    }

    List<int> Tokenize(string input)
    {
        string[] words = input.Split(null);

        List<int> ids = new();

        foreach (string word in words)
        {
            int start = 0;
            for(int i = word.Length; i >= 0;i--)
            {
                string subWord = start == 0 ? "▁" + word.Substring(start, i) : word.Substring(start, i-start);
                int index = Array.IndexOf(vocabularyTokens, subWord);
                if (index >= 0)
                {
                    ids.Add(index + vocabToTokenOffset);
                    if (i == word.Length) break;
                    start = i;
                    i = word.Length + 1;
                }
            }
        }

        return ids;
    }

    void OnDestroy() => engine?.Dispose();

    struct Batch
    {
        public int BatchCount;
        public int BatchLength;
        public int[] BatchedTokens;
        public int[] BatchedMasks;
    }
}