|
using UnityEngine; |
|
using Microsoft.ML.Tokenizers; |
|
using Unity.Sentis; |
|
using System.IO; |
|
using System.Linq; |
|
using System.Collections.Generic; |
|
using System.Collections; |
|
|
|
public class Phi3Claude : MonoBehaviour |
|
{ |
|
Worker worker_model; |
|
Worker worker_decoding; |
|
LlamaTokenizer tokenizer; |
|
|
|
List<int> tokens = new(); |
|
Tensor<int> inputTensor, attentionMaskTensor, positionIdsTensor; |
|
Tensor<float> outputLogits; |
|
Tensor<int> argMaxTensor; |
|
|
|
int maxTokens = 100; |
|
List<int> eosTokens; |
|
|
|
private void Start() |
|
{ |
|
var tokenizerModelPath = Path.Combine(Application.streamingAssetsPath, "Phi35/tokenizer.model"); |
|
var sentisModelPath = Path.Combine(Application.streamingAssetsPath, "Phi35/model_Uint8.sentis"); |
|
var configPath = Path.Combine(Application.streamingAssetsPath, "Phi35/generation_config.json"); |
|
|
|
var model = ModelLoader.Load(sentisModelPath); |
|
var vocab_size = 32064; |
|
|
|
FunctionalGraph graph = new FunctionalGraph(); |
|
FunctionalTensor logits = graph.AddInput<float>(new DynamicTensorShape(1,-1,vocab_size)); |
|
FunctionalTensor argMax = Functional.ArgMax(logits, 2, false); |
|
Model greedyModel = graph.Compile(argMax); |
|
|
|
worker_model = new Worker(model, BackendType.GPUCompute); |
|
worker_decoding = new Worker(greedyModel, BackendType.GPUCompute); |
|
|
|
Dictionary<string, int> specialTokens = new() |
|
{ |
|
{ "<|assistant|>", 32001 }, |
|
{ "<|endoftext|>", 32000 }, |
|
{ "<|end|>", 32007 }, |
|
{ "<|placeholder1|>", 32002 }, |
|
{ "<|placeholder2|>", 32003 }, |
|
{ "<|placeholder3|>", 32004 }, |
|
{ "<|placeholder4|>", 32005 }, |
|
{ "<|placeholder5|>", 32008 }, |
|
{ "<|placeholder6|>", 32009 }, |
|
{ "<|system|>", 32006 }, |
|
{ "<|user|>", 32010 } |
|
}; |
|
|
|
|
|
using (Stream tokenizerModelStream = new FileStream(tokenizerModelPath, FileMode.Open, FileAccess.Read)) |
|
{ |
|
tokenizer = LlamaTokenizer.Create( |
|
tokenizerModelStream, |
|
addBeginOfSentence: true, |
|
addEndOfSentence: false, |
|
specialTokens: specialTokens |
|
); |
|
} |
|
|
|
|
|
eosTokens = new(){32007, 32001, 32000}; |
|
|
|
Generate("What is the capital of France?"); |
|
} |
|
|
|
public void Generate(string userPrompt, string systemPrompt = "You are a helpful assistant.") |
|
{ |
|
string completePrompt = $@"<|system|> |
|
{systemPrompt}<|end|> |
|
<|user|> |
|
{userPrompt}<|end|> |
|
<|assistant|>"; |
|
Debug.Log("Complete prompt : " + completePrompt); |
|
|
|
int[] inputIds = tokenizer.EncodeToIds(completePrompt).ToArray(); |
|
Debug.Log($"Tokenized input: [{string.Join(", ", inputIds)}]"); |
|
Debug.Log($"Decoded tokens: [{string.Join(", ", tokenizer.Decode(inputIds, true))}]"); |
|
|
|
tokens.Clear(); |
|
tokens.AddRange(inputIds); |
|
|
|
StartCoroutine(GenerateSequence()); |
|
} |
|
|
|
private IEnumerator GenerateSequence() |
|
{ |
|
for (int i = 0; i < maxTokens; i++) |
|
{ |
|
RefreshTensors(tokens.ToArray()); |
|
|
|
worker_model.SetInput("input_ids", inputTensor); |
|
worker_model.SetInput("attention_mask", attentionMaskTensor); |
|
worker_model.SetInput("position_ids", positionIdsTensor); |
|
worker_model.Schedule(); |
|
|
|
outputLogits = worker_model.PeekOutput("logits") as Tensor<float>; |
|
outputLogits.ReadbackRequest(); |
|
|
|
yield return outputLogits.IsReadbackRequestDone(); |
|
|
|
tokens.Add(ProcessLogits()); |
|
|
|
int nextToken = tokens[tokens.Count - 1]; |
|
|
|
CleanupTensors(); |
|
|
|
if (eosTokens.Contains(nextToken)) |
|
break; |
|
} |
|
|
|
string generatedText = tokenizer.Decode(tokens.ToArray(), true); |
|
Debug.Log($"Generated sequence: {generatedText}"); |
|
} |
|
|
|
|
|
private int ProcessLogits() |
|
{ |
|
worker_decoding.SetInput(0, outputLogits); |
|
worker_decoding.Schedule(); |
|
argMaxTensor = worker_decoding.PeekOutput() as Tensor<int>; |
|
argMaxTensor.ReadbackRequest(); |
|
argMaxTensor.IsReadbackRequestDone(); |
|
|
|
var argMaxTensorArray = argMaxTensor.DownloadToArray(); |
|
int nextToken = argMaxTensorArray[outputLogits.shape[1] - 1]; |
|
|
|
Debug.Log($"<color=orange>Next token: [ID = {nextToken}, STR = \"{tokenizer.Decode(new[] { nextToken }, true)}\"]</color>"); |
|
|
|
return nextToken; |
|
} |
|
|
|
private void RefreshTensors(int[] ids) |
|
{ |
|
|
|
inputTensor = new Tensor<int>(new TensorShape(1, ids.Length), ids); |
|
attentionMaskTensor = new Tensor<int>(new TensorShape(1, ids.Length), Enumerable.Repeat(1, ids.Length).ToArray()); |
|
positionIdsTensor = new Tensor<int>(new TensorShape(1, ids.Length), Enumerable.Range(0, ids.Length).ToArray()); |
|
} |
|
|
|
private void CleanupTensors() |
|
{ |
|
inputTensor?.Dispose(); |
|
attentionMaskTensor?.Dispose(); |
|
positionIdsTensor?.Dispose(); |
|
outputLogits?.Dispose(); |
|
argMaxTensor?.Dispose(); |
|
} |
|
|
|
private void OnDestroy() { |
|
CleanupTensors(); |
|
|
|
worker_model?.Dispose(); |
|
worker_decoding?.Dispose(); |
|
} |
|
} |