|
using System.Collections.Generic; |
|
using Unity.Sentis; |
|
using UnityEngine; |
|
using UnityEngine.UI; |
|
using UnityEngine.Video; |
|
using Lays = Unity.Sentis.Layers; |
|
using System.IO; |
|
using FF = Unity.Sentis.Functional; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
public class RunYOLO8n : MonoBehaviour |
|
{ |
|
|
|
public ModelAsset asset; |
|
const string modelName = "yolov8n.sentis"; |
|
|
|
const string videoName = "giraffes.mp4"; |
|
|
|
public TextAsset labelsAsset; |
|
|
|
public RawImage displayImage; |
|
|
|
public Sprite borderSprite; |
|
public Texture2D borderTexture; |
|
|
|
public Font font; |
|
|
|
const BackendType backend = BackendType.GPUCompute; |
|
|
|
private Transform displayLocation; |
|
private IWorker engine; |
|
private string[] labels; |
|
private RenderTexture targetRT; |
|
|
|
|
|
|
|
private const int imageWidth = 640; |
|
private const int imageHeight = 640; |
|
|
|
|
|
private const int numClasses = 80; |
|
|
|
private VideoPlayer video; |
|
|
|
List<GameObject> boxPool = new(); |
|
|
|
[SerializeField, Range(0, 1)] float iouThreshold = 0.5f; |
|
[SerializeField, Range(0, 1)] float scoreThreshold = 0.5f; |
|
int maxOutputBoxes = 64; |
|
|
|
TensorFloat centersToCorners; |
|
|
|
public struct BoundingBox |
|
{ |
|
public float centerX; |
|
public float centerY; |
|
public float width; |
|
public float height; |
|
public string label; |
|
} |
|
|
|
|
|
void Start() |
|
{ |
|
Application.targetFrameRate = 60; |
|
Screen.orientation = ScreenOrientation.LandscapeLeft; |
|
|
|
|
|
labels = labelsAsset.text.Split('\n'); |
|
|
|
LoadModel(); |
|
|
|
targetRT = new RenderTexture(imageWidth, imageHeight, 0); |
|
|
|
|
|
displayLocation = displayImage.transform; |
|
|
|
SetupInput(); |
|
|
|
if (borderSprite == null) |
|
{ |
|
borderSprite = Sprite.Create(borderTexture, new Rect(0, 0, borderTexture.width, borderTexture.height), new Vector2(borderTexture.width / 2, borderTexture.height / 2)); |
|
} |
|
} |
|
void LoadModel() |
|
{ |
|
|
|
|
|
|
|
var model1 = ModelLoader.Load(asset); |
|
|
|
centersToCorners = new TensorFloat(new TensorShape(4, 4), |
|
new float[] |
|
{ |
|
1, 0, 1, 0, |
|
0, 1, 0, 1, |
|
-0.5f, 0, 0.5f, 0, |
|
0, -0.5f, 0, 0.5f |
|
}); |
|
|
|
|
|
var model2 = Functional.Compile( |
|
input => |
|
{ |
|
var modelOutput = model1.Forward(input)[0]; |
|
var boxCoords = modelOutput[0, 0..4, ..].Transpose(0, 1); |
|
var allScores = modelOutput[0, 4.., ..]; |
|
var scores = FF.ReduceMax(allScores, 0) - scoreThreshold; |
|
var classIDs = FF.ArgMax(allScores, 0); |
|
var boxCorners = FF.MatMul(boxCoords, FunctionalTensor.FromTensor(centersToCorners)); |
|
var indices = FF.NMS(boxCorners, scores, iouThreshold); |
|
var indices2 = indices.Unsqueeze(-1).BroadcastTo(new int[] { 4 }); |
|
var coords = FF.Gather(boxCoords, 0, indices2); |
|
var labelIDs = FF.Gather(classIDs, 0, indices); |
|
return (coords, labelIDs); |
|
}, |
|
InputDef.FromModel(model1)[0] |
|
); |
|
|
|
|
|
engine = WorkerFactory.CreateWorker(backend, model2); |
|
} |
|
|
|
void SetupInput() |
|
{ |
|
video = gameObject.AddComponent<VideoPlayer>(); |
|
video.renderMode = VideoRenderMode.APIOnly; |
|
video.source = VideoSource.Url; |
|
video.url = Path.Join(Application.streamingAssetsPath, videoName); |
|
video.isLooping = true; |
|
video.Play(); |
|
} |
|
|
|
private void Update() |
|
{ |
|
ExecuteML(); |
|
|
|
if (Input.GetKeyDown(KeyCode.Escape)) |
|
{ |
|
Application.Quit(); |
|
} |
|
} |
|
|
|
public void ExecuteML() |
|
{ |
|
ClearAnnotations(); |
|
|
|
if (video && video.texture) |
|
{ |
|
float aspect = video.width * 1f / video.height; |
|
Graphics.Blit(video.texture, targetRT, new Vector2(1f / aspect, 1), new Vector2(0, 0)); |
|
displayImage.texture = targetRT; |
|
} |
|
else return; |
|
|
|
using var input = TextureConverter.ToTensor(targetRT, imageWidth, imageHeight, 3); |
|
engine.Execute(input); |
|
|
|
var output = engine.PeekOutput("output_0") as TensorFloat; |
|
var labelIDs = engine.PeekOutput("output_1") as TensorInt; |
|
|
|
output.CompleteOperationsAndDownload(); |
|
labelIDs.CompleteOperationsAndDownload(); |
|
|
|
float displayWidth = displayImage.rectTransform.rect.width; |
|
float displayHeight = displayImage.rectTransform.rect.height; |
|
|
|
float scaleX = displayWidth / imageWidth; |
|
float scaleY = displayHeight / imageHeight; |
|
|
|
int boxesFound = output.shape[0]; |
|
|
|
for (int n = 0; n < Mathf.Min(boxesFound, 200); n++) |
|
{ |
|
var box = new BoundingBox |
|
{ |
|
centerX = output[n, 0] * scaleX - displayWidth / 2, |
|
centerY = output[n, 1] * scaleY - displayHeight / 2, |
|
width = output[n, 2] * scaleX, |
|
height = output[n, 3] * scaleY, |
|
label = labels[labelIDs[n]], |
|
}; |
|
DrawBox(box, n, displayHeight * 0.05f); |
|
} |
|
} |
|
|
|
public void DrawBox(BoundingBox box, int id, float fontSize) |
|
{ |
|
|
|
GameObject panel; |
|
if (id < boxPool.Count) |
|
{ |
|
panel = boxPool[id]; |
|
panel.SetActive(true); |
|
} |
|
else |
|
{ |
|
panel = CreateNewBox(Color.yellow); |
|
} |
|
|
|
panel.transform.localPosition = new Vector3(box.centerX, -box.centerY); |
|
|
|
|
|
RectTransform rt = panel.GetComponent<RectTransform>(); |
|
rt.sizeDelta = new Vector2(box.width, box.height); |
|
|
|
|
|
var label = panel.GetComponentInChildren<Text>(); |
|
label.text = box.label; |
|
label.fontSize = (int)fontSize; |
|
} |
|
|
|
public GameObject CreateNewBox(Color color) |
|
{ |
|
|
|
|
|
var panel = new GameObject("ObjectBox"); |
|
panel.AddComponent<CanvasRenderer>(); |
|
Image img = panel.AddComponent<Image>(); |
|
img.color = color; |
|
img.sprite = borderSprite; |
|
img.type = Image.Type.Sliced; |
|
panel.transform.SetParent(displayLocation, false); |
|
|
|
|
|
|
|
var text = new GameObject("ObjectLabel"); |
|
text.AddComponent<CanvasRenderer>(); |
|
text.transform.SetParent(panel.transform, false); |
|
Text txt = text.AddComponent<Text>(); |
|
txt.font = font; |
|
txt.color = color; |
|
txt.fontSize = 40; |
|
txt.horizontalOverflow = HorizontalWrapMode.Overflow; |
|
|
|
RectTransform rt2 = text.GetComponent<RectTransform>(); |
|
rt2.offsetMin = new Vector2(20, rt2.offsetMin.y); |
|
rt2.offsetMax = new Vector2(0, rt2.offsetMax.y); |
|
rt2.offsetMin = new Vector2(rt2.offsetMin.x, 0); |
|
rt2.offsetMax = new Vector2(rt2.offsetMax.x, 30); |
|
rt2.anchorMin = new Vector2(0, 0); |
|
rt2.anchorMax = new Vector2(1, 1); |
|
|
|
boxPool.Add(panel); |
|
return panel; |
|
} |
|
|
|
public void ClearAnnotations() |
|
{ |
|
foreach (var box in boxPool) |
|
{ |
|
box.SetActive(false); |
|
} |
|
} |
|
|
|
private void OnDestroy() |
|
{ |
|
centersToCorners?.Dispose(); |
|
engine?.Dispose(); |
|
} |
|
} |