File size: 5,649 Bytes
05cf642
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
using UnityEngine;
using Microsoft.ML.Tokenizers;
using Unity.Sentis;
using System.IO;
using System.Linq;
using System.Collections.Generic;
using System.Collections;

public class Phi3Claude : MonoBehaviour
{
    Worker worker_model;
    Worker worker_decoding;
    LlamaTokenizer tokenizer;

    List<int> tokens = new();
    Tensor<int> inputTensor, attentionMaskTensor, positionIdsTensor;
    Tensor<float> outputLogits;
    Tensor<int> argMaxTensor;

    int maxTokens = 100; // Maximum number of tokens to generate
    List<int> eosTokens; // End of sequence tokens

    private void Start()
    {
        var tokenizerModelPath = Path.Combine(Application.streamingAssetsPath, "Phi35/tokenizer.model");
        var sentisModelPath = Path.Combine(Application.streamingAssetsPath, "Phi35/model_Uint8.sentis");
        var configPath = Path.Combine(Application.streamingAssetsPath, "Phi35/generation_config.json");
        
        var model = ModelLoader.Load(sentisModelPath);
        var vocab_size = 32064;
        // Create a model that does greedy decoding
        FunctionalGraph graph = new FunctionalGraph();
        FunctionalTensor logits = graph.AddInput<float>(new DynamicTensorShape(1,-1,vocab_size));
        FunctionalTensor argMax = Functional.ArgMax(logits, 2, false);
        Model greedyModel = graph.Compile(argMax);
        
        worker_model = new Worker(model, BackendType.GPUCompute);
        worker_decoding = new Worker(greedyModel, BackendType.GPUCompute);
        // Manually set from added_tokens.json
        Dictionary<string, int> specialTokens = new()
        {
            { "<|assistant|>", 32001 },
            { "<|endoftext|>", 32000 },
            { "<|end|>", 32007 },
            { "<|placeholder1|>", 32002 },
            { "<|placeholder2|>", 32003 },
            { "<|placeholder3|>", 32004 },
            { "<|placeholder4|>", 32005 },
            { "<|placeholder5|>", 32008 },
            { "<|placeholder6|>", 32009 },
            { "<|system|>", 32006 },
            { "<|user|>", 32010 }
        };


        using (Stream tokenizerModelStream = new FileStream(tokenizerModelPath, FileMode.Open, FileAccess.Read))
        {
            tokenizer = LlamaTokenizer.Create(
                tokenizerModelStream,
                addBeginOfSentence: true,
                addEndOfSentence: false,
                specialTokens: specialTokens
            );
        }

        // Manually set from generation_config.json
        eosTokens = new(){32007, 32001, 32000};

        Generate("What is the capital of France?");
    }

    public void Generate(string userPrompt, string systemPrompt = "You are a helpful assistant.") 
    {
        string completePrompt = $@"<|system|>
{systemPrompt}<|end|>
<|user|>
{userPrompt}<|end|>
<|assistant|>";
        Debug.Log("Complete prompt : " + completePrompt);

        int[] inputIds = tokenizer.EncodeToIds(completePrompt).ToArray();
        Debug.Log($"Tokenized input: [{string.Join(", ", inputIds)}]");
        Debug.Log($"Decoded tokens: [{string.Join(", ", tokenizer.Decode(inputIds, true))}]");

        tokens.Clear();
        tokens.AddRange(inputIds);

        StartCoroutine(GenerateSequence());
    }
    
    private IEnumerator GenerateSequence()
    {
        for (int i = 0; i < maxTokens; i++)
        {
            RefreshTensors(tokens.ToArray());

            worker_model.SetInput("input_ids", inputTensor);
            worker_model.SetInput("attention_mask", attentionMaskTensor);
            worker_model.SetInput("position_ids", positionIdsTensor);
            worker_model.Schedule(); // > 15ms (/!\ should be async)

            outputLogits = worker_model.PeekOutput("logits") as Tensor<float>; // Async
            outputLogits.ReadbackRequest(); // Async

            yield return outputLogits.IsReadbackRequestDone(); // 236 ms

            tokens.Add(ProcessLogits()); // > 200ms

            int nextToken = tokens[tokens.Count - 1];

            CleanupTensors();

            if (eosTokens.Contains(nextToken))
                break;
        }

        string generatedText = tokenizer.Decode(tokens.ToArray(), true); // 0 ms
        Debug.Log($"Generated sequence: {generatedText}");
    }


    private int ProcessLogits()
    {
        worker_decoding.SetInput(0, outputLogits);
        worker_decoding.Schedule();
        argMaxTensor = worker_decoding.PeekOutput() as Tensor<int>;
        argMaxTensor.ReadbackRequest();
        argMaxTensor.IsReadbackRequestDone();

        var argMaxTensorArray = argMaxTensor.DownloadToArray(); // TODO : investigate on why it's long to process
        int nextToken = argMaxTensorArray[outputLogits.shape[1] - 1];

        Debug.Log($"<color=orange>Next token: [ID = {nextToken}, STR = \"{tokenizer.Decode(new[] { nextToken }, true)}\"]</color>");

        return nextToken;
    }

    private void RefreshTensors(int[] ids) 
    {
        // Update input tensors with the full context
        inputTensor = new Tensor<int>(new TensorShape(1, ids.Length), ids);
        attentionMaskTensor = new Tensor<int>(new TensorShape(1, ids.Length), Enumerable.Repeat(1, ids.Length).ToArray());
        positionIdsTensor = new Tensor<int>(new TensorShape(1, ids.Length), Enumerable.Range(0, ids.Length).ToArray());
    }

    private void CleanupTensors()
    {
        inputTensor?.Dispose();
        attentionMaskTensor?.Dispose();
        positionIdsTensor?.Dispose();
        outputLogits?.Dispose();
        argMaxTensor?.Dispose();
    }

    private void OnDestroy() {
        CleanupTensors();
        
        worker_model?.Dispose();
        worker_decoding?.Dispose();
    }
}