231 lines
8.4 KiB
Python
231 lines
8.4 KiB
Python
#!/opt/homebrew/bin/python3.11
|
|
"""Unit tests for face_tracker.py"""
|
|
|
|
import sys
|
|
import json
|
|
import math
|
|
import unittest
|
|
import numpy as np
|
|
from typing import Dict, List, Set
|
|
|
|
sys.path.insert(0, "scripts/utils")
|
|
from face_tracker import (
|
|
calculate_bbox_iou,
|
|
calculate_bbox_distance,
|
|
calculate_embedding_similarity,
|
|
match_faces,
|
|
track_faces,
|
|
)
|
|
|
|
|
|
class TestBboxIoU(unittest.TestCase):
|
|
def test_identical_bboxes(self):
|
|
bbox = {"x": 100, "y": 100, "width": 200, "height": 200}
|
|
self.assertAlmostEqual(calculate_bbox_iou(bbox, bbox), 1.0)
|
|
|
|
def test_no_overlap(self):
|
|
b1 = {"x": 0, "y": 0, "width": 100, "height": 100}
|
|
b2 = {"x": 200, "y": 200, "width": 100, "height": 100}
|
|
self.assertEqual(calculate_bbox_iou(b1, b2), 0.0)
|
|
|
|
def test_partial_overlap(self):
|
|
b1 = {"x": 0, "y": 0, "width": 100, "height": 100}
|
|
b2 = {"x": 50, "y": 50, "width": 100, "height": 100}
|
|
iou = calculate_bbox_iou(b1, b2)
|
|
self.assertGreater(iou, 0.0)
|
|
self.assertLess(iou, 1.0)
|
|
self.assertAlmostEqual(iou, 2500 / (10000 + 10000 - 2500))
|
|
|
|
def test_contained(self):
|
|
b1 = {"x": 0, "y": 0, "width": 200, "height": 200}
|
|
b2 = {"x": 50, "y": 50, "width": 50, "height": 50}
|
|
iou = calculate_bbox_iou(b1, b2)
|
|
self.assertGreater(iou, 0.0)
|
|
self.assertLess(iou, 1.0)
|
|
|
|
def test_zero_area(self):
|
|
b1 = {"x": 0, "y": 0, "width": 0, "height": 100}
|
|
b2 = {"x": 0, "y": 0, "width": 100, "height": 100}
|
|
self.assertEqual(calculate_bbox_iou(b1, b2), 0.0)
|
|
|
|
|
|
class TestBboxDistance(unittest.TestCase):
|
|
def test_same_center(self):
|
|
bbox = {"x": 100, "y": 100, "width": 100, "height": 100}
|
|
self.assertAlmostEqual(calculate_bbox_distance(bbox, bbox), 0.0)
|
|
|
|
def test_horizontal_shift(self):
|
|
b1 = {"x": 0, "y": 0, "width": 100, "height": 100}
|
|
b2 = {"x": 100, "y": 0, "width": 100, "height": 100}
|
|
self.assertAlmostEqual(calculate_bbox_distance(b1, b2), 100.0)
|
|
|
|
def test_diagonal_shift(self):
|
|
b1 = {"x": 0, "y": 0, "width": 100, "height": 100}
|
|
b2 = {"x": 100, "y": 100, "width": 100, "height": 100}
|
|
expected = math.sqrt(100**2 + 100**2)
|
|
self.assertAlmostEqual(calculate_bbox_distance(b1, b2), expected)
|
|
|
|
|
|
class TestEmbeddingSimilarity(unittest.TestCase):
|
|
def test_identical(self):
|
|
emb = [1.0, 0.0, 0.0]
|
|
self.assertAlmostEqual(calculate_embedding_similarity(emb, emb), 1.0)
|
|
|
|
def test_opposite(self):
|
|
self.assertAlmostEqual(
|
|
calculate_embedding_similarity([1.0, 0.0], [-1.0, 0.0]), -1.0
|
|
)
|
|
|
|
def test_orthogonal(self):
|
|
self.assertAlmostEqual(
|
|
calculate_embedding_similarity([1.0, 0.0], [0.0, 1.0]), 0.0
|
|
)
|
|
|
|
def test_partial_match(self):
|
|
sim = calculate_embedding_similarity([1.0, 0.0, 0.0], [0.5, 0.5, 0.0])
|
|
self.assertGreater(sim, 0.5)
|
|
self.assertLess(sim, 1.0)
|
|
|
|
def test_none_embedding(self):
|
|
self.assertEqual(calculate_embedding_similarity(None, [1.0, 0.0]), 0.0)
|
|
self.assertEqual(calculate_embedding_similarity([1.0, 0.0], None), 0.0)
|
|
|
|
def test_zero_norm(self):
|
|
self.assertEqual(calculate_embedding_similarity([0.0, 0.0], [1.0, 0.0]), 0.0)
|
|
|
|
|
|
class TestMatchFaces(unittest.TestCase):
|
|
def make_face(self, x, y, w, h, embedding=None, c=1.0):
|
|
f = {"x": x, "y": y, "width": w, "height": h, "confidence": c}
|
|
if embedding:
|
|
f["embedding"] = embedding
|
|
return f
|
|
|
|
def test_no_previous(self):
|
|
curr = [self.make_face(0, 0, 100, 100)]
|
|
result = match_faces(curr, [])
|
|
self.assertEqual(result, {0: -1})
|
|
|
|
def test_same_position_match(self):
|
|
curr = [self.make_face(0, 0, 100, 100, embedding=[1.0, 0.0])]
|
|
prev = [self.make_face(0, 0, 100, 100, embedding=[1.0, 0.0])]
|
|
result = match_faces(curr, prev)
|
|
self.assertEqual(result, {0: 0})
|
|
|
|
def test_reject_low_sim_low_iou(self):
|
|
curr = [self.make_face(300, 300, 100, 100, embedding=[0.0, 1.0])]
|
|
prev = [self.make_face(0, 0, 100, 100, embedding=[1.0, 0.0])]
|
|
result = match_faces(curr, prev)
|
|
self.assertEqual(result, {0: -1})
|
|
|
|
def test_match_by_position_only(self):
|
|
curr = [self.make_face(0, 0, 100, 100, embedding=[0.0, 1.0])]
|
|
prev = [self.make_face(0, 0, 100, 100, embedding=[1.0, 0.0])]
|
|
result = match_faces(curr, prev)
|
|
self.assertEqual(result, {0: 0})
|
|
|
|
def test_reject_size_change(self):
|
|
curr = [self.make_face(0, 0, 10, 10, embedding=[0.3, 0.7])]
|
|
prev = [self.make_face(0, 0, 100, 100, embedding=[0.7, 0.3])]
|
|
result = match_faces(curr, prev)
|
|
self.assertEqual(result, {0: -1})
|
|
|
|
def test_cut_boundary_split(self):
|
|
curr = [self.make_face(0, 0, 100, 100)]
|
|
prev = [self.make_face(0, 0, 100, 100)]
|
|
|
|
def test_edge_exit_reject(self):
|
|
curr = [self.make_face(200, 200, 100, 100, embedding=[0.3, 0.7])]
|
|
prev = [self.make_face(0, 0, 100, 100, embedding=[0.7, 0.3])]
|
|
result = match_faces(curr, prev)
|
|
self.assertEqual(result, {0: -1})
|
|
|
|
def test_cut_boundary_split(self):
|
|
curr = [self.make_face(0, 0, 100, 100)]
|
|
prev = [self.make_face(0, 0, 100, 100)]
|
|
result = match_faces(curr, prev, cut_boundaries={5}, prev_frame=4, curr_frame=6)
|
|
self.assertEqual(result, {0: -1})
|
|
|
|
def test_edge_exit_reject(self):
|
|
curr = [self.make_face(200, 200, 100, 100, embedding=[0.3, 0.7])]
|
|
prev = [self.make_face(0, 0, 100, 100, embedding=[0.7, 0.3])]
|
|
result = match_faces(curr, prev)
|
|
self.assertEqual(result, {0: -1})
|
|
|
|
def test_multiple_faces_no_conflict(self):
|
|
curr = [
|
|
self.make_face(0, 0, 100, 100, embedding=[1.0, 0.0, 0.0]),
|
|
self.make_face(200, 200, 100, 100, embedding=[0.0, 1.0, 0.0]),
|
|
]
|
|
prev = [
|
|
self.make_face(0, 0, 100, 100, embedding=[1.0, 0.0, 0.0]),
|
|
self.make_face(200, 200, 100, 100, embedding=[0.0, 1.0, 0.0]),
|
|
]
|
|
result = match_faces(curr, prev)
|
|
self.assertEqual(result, {0: 0, 1: 1})
|
|
|
|
|
|
class TestTrackFaces(unittest.TestCase):
|
|
def make_face_data(self, frames_data: Dict[int, List[Dict]]) -> Dict:
|
|
frames = {}
|
|
for fnum, face_list in frames_data.items():
|
|
frames[str(fnum)] = {"faces": face_list}
|
|
return {"frames": frames, "metadata": {"fps": 25.0}}
|
|
|
|
def test_single_frame(self):
|
|
data = self.make_face_data({
|
|
0: [{"x": 0, "y": 0, "width": 100, "height": 100, "confidence": 1.0}]
|
|
})
|
|
result = track_faces(data)
|
|
traces = result.get("traces", {})
|
|
self.assertEqual(len(traces), 1)
|
|
t = traces.get("0", {})
|
|
self.assertEqual(t["start_frame"], 0)
|
|
self.assertEqual(t["end_frame"], 0)
|
|
|
|
def test_face_appears_disappears(self):
|
|
data = self.make_face_data({
|
|
0: [{"x": 0, "y": 0, "width": 100, "height": 100, "confidence": 1.0}],
|
|
1: [],
|
|
2: [{"x": 100, "y": 100, "width": 100, "height": 100, "confidence": 1.0}],
|
|
})
|
|
result = track_faces(data)
|
|
traces = result.get("traces", {})
|
|
self.assertEqual(len(traces), 2)
|
|
|
|
def test_same_face_stable(self):
|
|
face = {"x": 50, "y": 50, "width": 100, "height": 100, "confidence": 1.0}
|
|
data = self.make_face_data({
|
|
0: [dict(face)],
|
|
1: [dict(face)],
|
|
2: [dict(face)],
|
|
})
|
|
result = track_faces(data)
|
|
traces = result.get("traces", {})
|
|
self.assertEqual(len(traces), 1)
|
|
t = list(traces.values())[0]
|
|
self.assertEqual(t["start_frame"], 0)
|
|
self.assertEqual(t["end_frame"], 2)
|
|
|
|
def test_cut_splits_trace(self):
|
|
data = self.make_face_data({
|
|
0: [{"x": 50, "y": 50, "width": 100, "height": 100, "confidence": 1.0}],
|
|
1: [{"x": 50, "y": 50, "width": 100, "height": 100, "confidence": 1.0}],
|
|
})
|
|
result = track_faces(data, cut_boundaries={1})
|
|
self.assertEqual(len(result["traces"]), 2)
|
|
|
|
def test_trace_stats_present(self):
|
|
data = self.make_face_data({
|
|
0: [{"x": 0, "y": 0, "width": 100, "height": 100, "confidence": 0.9}],
|
|
})
|
|
result = track_faces(data)
|
|
stats = result["metadata"]["trace_stats"]
|
|
self.assertIn("total_traces", stats)
|
|
self.assertIn("active_traces", stats)
|
|
self.assertIn("long_traces", stats)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|