Pippe's picture
Upload 4 files
7368ee6 verified
raw
history blame
5.74 kB
using System;
using System.Collections.Generic;
using System.Linq;
using Unity.Sentis;
using UnityEngine;
public sealed class DebertaV3 : MonoBehaviour
{
public ModelAsset model;
public TextAsset vocabulary;
public bool multipleTrueClasses;
public string text = "Angela Merkel is a politician in Germany and leader of the CDU";
public string hypothesisTemplate = "This example is about {}";
public string[] classes = { "politics", "economy", "entertainment", "environment" };
Ops ops;
IWorker engine;
ITensorAllocator allocator;
string[] vocabularyTokens;
const int padToken = 0;
const int startToken = 1;
const int separatorToken = 2;
const int vocabToTokenOffset = 260;
const BackendType backend = BackendType.GPUCompute;
void Start()
{
vocabularyTokens = vocabulary.text.Split("\n");
allocator = new TensorCachingAllocator();
ops = WorkerFactory.CreateOps(backend, allocator);
Model loadedModel = ModelLoader.Load(model);
engine = WorkerFactory.CreateWorker(backend, loadedModel);
string[] hypotheses = classes.Select(x => hypothesisTemplate.Replace("{}", x)).ToArray();
Batch batch = GetTokenizedBatch(text, hypotheses);
float[] scores = GetBatchScores(batch);
for (int i = 0; i < scores.Length; i++)
{
Debug.Log($"[{classes[i]}] Entailment Score: {scores[i]}");
}
}
float[] GetBatchScores(Batch batch)
{
using var inputIds = new TensorInt(new TensorShape(batch.BatchCount, batch.BatchLength), batch.BatchedTokens);
using var attentionMask = new TensorInt(new TensorShape(batch.BatchCount, batch.BatchLength), batch.BatchedMasks);
Dictionary<string, Tensor> inputs = new()
{
{"input_ids", inputIds},
{"attention_mask", attentionMask}
};
engine.Execute(inputs);
TensorFloat logits = (TensorFloat)engine.PeekOutput("logits");
float[] scores = ScoresFromLogits(logits);
return scores;
}
Batch GetTokenizedBatch(string prompt, string[] hypotheses)
{
Batch batch = new Batch();
List<int> promptTokens = Tokenize(prompt);
promptTokens.Insert(0, startToken);
List<int>[] tokenizedHypotheses = hypotheses.Select(Tokenize).ToArray();
int maxTokenLength = tokenizedHypotheses.Max(x => x.Count);
// Each example in the batch follows this format:
// Start Prompt Separator Hypothesis Separator Padding
int[] batchedTokens = tokenizedHypotheses.SelectMany(hypothesis => promptTokens
.Append(separatorToken)
.Concat(hypothesis)
.Append(separatorToken)
.Concat(Enumerable.Repeat(padToken, maxTokenLength - hypothesis.Count)))
.ToArray();
// The attention masks have the same length as the tokens.
// Each attention mask contains repeating 1s for each token, except for padding tokens.
int[] batchedMasks = tokenizedHypotheses.SelectMany(hypothesis => Enumerable.Repeat(1, promptTokens.Count + 1)
.Concat(Enumerable.Repeat(1, hypothesis.Count + 1))
.Concat(Enumerable.Repeat(0, maxTokenLength - hypothesis.Count)))
.ToArray();
batch.BatchCount = hypotheses.Length;
batch.BatchLength = batchedTokens.Length / hypotheses.Length;
batch.BatchedTokens = batchedTokens;
batch.BatchedMasks = batchedMasks;
return batch;
}
float[] ScoresFromLogits(TensorFloat logits)
{
// The logits represent the model's predictions for entailment and non-entailment for each example in the batch.
// They are of shape [batch size, 2], with two values per example.
// To obtain a single value (score) per example, a softmax function is applied
TensorFloat tensorScores;
if (multipleTrueClasses || logits.shape.length == 1)
{
// Softmax over the entailment vs. contradiction dimension for each label independently
tensorScores = ops.Softmax(logits, -1);
}
else
{
// Softmax over all candidate labels
tensorScores = ops.Softmax(logits, 0);
}
tensorScores.MakeReadable();
float[] tensorArray = tensorScores.ToReadOnlyArray();
tensorScores.Dispose();
// Select the first column which is the column where the scores are stored
float[] scores = new float[tensorArray.Length / 2];
for (int i = 0; i < scores.Length; i++)
{
scores[i] = tensorArray[i * 2];
}
return scores;
}
List<int> Tokenize(string input)
{
string[] words = input.Split(null);
List<int> ids = new();
foreach (string word in words)
{
int start = 0;
for(int i = word.Length; i >= 0;i--)
{
string subWord = start == 0 ? "▁" + word.Substring(start, i) : word.Substring(start, i-start);
int index = Array.IndexOf(vocabularyTokens, subWord);
if (index >= 0)
{
ids.Add(index + vocabToTokenOffset);
if (i == word.Length) break;
start = i;
i = word.Length + 1;
}
}
}
return ids;
}
void OnDestroy()
{
engine?.Dispose();
allocator?.Dispose();
ops?.Dispose();
}
struct Batch
{
public int BatchCount;
public int BatchLength;
public int[] BatchedTokens;
public int[] BatchedMasks;
}
}