using UnityEngine; using Microsoft.ML.Tokenizers; using Unity.Sentis; using System.IO; using System.Linq; using System.Collections.Generic; using System.Collections; public class Phi3Claude : MonoBehaviour { Worker worker_model; Worker worker_decoding; LlamaTokenizer tokenizer; List tokens = new(); Tensor inputTensor, attentionMaskTensor, positionIdsTensor; Tensor outputLogits; Tensor argMaxTensor; int maxTokens = 100; // Maximum number of tokens to generate List eosTokens; // End of sequence tokens private void Start() { var tokenizerModelPath = Path.Combine(Application.streamingAssetsPath, "Phi35/tokenizer.model"); var sentisModelPath = Path.Combine(Application.streamingAssetsPath, "Phi35/model_Uint8.sentis"); var configPath = Path.Combine(Application.streamingAssetsPath, "Phi35/generation_config.json"); var model = ModelLoader.Load(sentisModelPath); var vocab_size = 32064; // Create a model that does greedy decoding FunctionalGraph graph = new FunctionalGraph(); FunctionalTensor logits = graph.AddInput(new DynamicTensorShape(1,-1,vocab_size)); FunctionalTensor argMax = Functional.ArgMax(logits, 2, false); Model greedyModel = graph.Compile(argMax); worker_model = new Worker(model, BackendType.GPUCompute); worker_decoding = new Worker(greedyModel, BackendType.GPUCompute); // Manually set from added_tokens.json Dictionary specialTokens = new() { { "<|assistant|>", 32001 }, { "<|endoftext|>", 32000 }, { "<|end|>", 32007 }, { "<|placeholder1|>", 32002 }, { "<|placeholder2|>", 32003 }, { "<|placeholder3|>", 32004 }, { "<|placeholder4|>", 32005 }, { "<|placeholder5|>", 32008 }, { "<|placeholder6|>", 32009 }, { "<|system|>", 32006 }, { "<|user|>", 32010 } }; using (Stream tokenizerModelStream = new FileStream(tokenizerModelPath, FileMode.Open, FileAccess.Read)) { tokenizer = LlamaTokenizer.Create( tokenizerModelStream, addBeginOfSentence: true, addEndOfSentence: false, specialTokens: specialTokens ); } // Manually set from generation_config.json eosTokens = new(){32007, 32001, 32000}; Generate("What is the capital of France?"); } public void Generate(string userPrompt, string systemPrompt = "You are a helpful assistant.") { string completePrompt = $@"<|system|> {systemPrompt}<|end|> <|user|> {userPrompt}<|end|> <|assistant|>"; Debug.Log("Complete prompt : " + completePrompt); int[] inputIds = tokenizer.EncodeToIds(completePrompt).ToArray(); Debug.Log($"Tokenized input: [{string.Join(", ", inputIds)}]"); Debug.Log($"Decoded tokens: [{string.Join(", ", tokenizer.Decode(inputIds, true))}]"); tokens.Clear(); tokens.AddRange(inputIds); StartCoroutine(GenerateSequence()); } private IEnumerator GenerateSequence() { for (int i = 0; i < maxTokens; i++) { RefreshTensors(tokens.ToArray()); worker_model.SetInput("input_ids", inputTensor); worker_model.SetInput("attention_mask", attentionMaskTensor); worker_model.SetInput("position_ids", positionIdsTensor); worker_model.Schedule(); // > 15ms (/!\ should be async) outputLogits = worker_model.PeekOutput("logits") as Tensor; // Async outputLogits.ReadbackRequest(); // Async yield return outputLogits.IsReadbackRequestDone(); // 236 ms tokens.Add(ProcessLogits()); // > 200ms int nextToken = tokens[tokens.Count - 1]; CleanupTensors(); if (eosTokens.Contains(nextToken)) break; } string generatedText = tokenizer.Decode(tokens.ToArray(), true); // 0 ms Debug.Log($"Generated sequence: {generatedText}"); } private int ProcessLogits() { worker_decoding.SetInput(0, outputLogits); worker_decoding.Schedule(); argMaxTensor = worker_decoding.PeekOutput() as Tensor; argMaxTensor.ReadbackRequest(); argMaxTensor.IsReadbackRequestDone(); var argMaxTensorArray = argMaxTensor.DownloadToArray(); // TODO : investigate on why it's long to process int nextToken = argMaxTensorArray[outputLogits.shape[1] - 1]; Debug.Log($"Next token: [ID = {nextToken}, STR = \"{tokenizer.Decode(new[] { nextToken }, true)}\"]"); return nextToken; } private void RefreshTensors(int[] ids) { // Update input tensors with the full context inputTensor = new Tensor(new TensorShape(1, ids.Length), ids); attentionMaskTensor = new Tensor(new TensorShape(1, ids.Length), Enumerable.Repeat(1, ids.Length).ToArray()); positionIdsTensor = new Tensor(new TensorShape(1, ids.Length), Enumerable.Range(0, ids.Length).ToArray()); } private void CleanupTensors() { inputTensor?.Dispose(); attentionMaskTensor?.Dispose(); positionIdsTensor?.Dispose(); outputLogits?.Dispose(); argMaxTensor?.Dispose(); } private void OnDestroy() { CleanupTensors(); worker_model?.Dispose(); worker_decoding?.Dispose(); } }