#!/opt/homebrew/bin/python3.11 """ Split ASR segments at detected speaker change points. Uses ECAPA-TDNN sub-window classification against reference centroids. Output: new asrx_fine.json with fine-grained segments + parent_asr_idx reference. """ import json, sys, os, time, argparse, subprocess, tempfile, shutil import numpy as np from collections import Counter from pathlib import Path sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), "asrx_self")) from main_fixed import SelfASRXFixed from speaker_encoder import extract_speaker_embedding, normalize_embeddings import torchaudio, psycopg2 SUB_WIN = 0.5 SUB_STRIDE = 0.25 CHANGE_CONFIRM = 2 MIN_DUR = 0.7 BATCH_SIZE = 500 def load_reference(uuid, db_url): conn = psycopg2.connect(db_url) cur = conn.cursor() cur.execute("SELECT chunk_index, metadata->>'new_speaker_name' FROM dev.chunks WHERE file_uuid=%s AND chunk_type='sentence' ORDER BY chunk_index", (uuid,)) name_by_idx = dict(cur.fetchall()) conn.close() asrx_path = f"/Users/accusys/momentry/output_dev/{uuid}.asrx.json" asrx_full = json.load(open(asrx_path)) ref = {"Cary Grant": [], "Audrey Hepburn": [], "Unknown": []} for i, seg in enumerate(asrx_full["segments"]): name = name_by_idx.get(i, "Unknown") if name in ref and i < len(asrx_full.get("embeddings", [])): ref[name].append(np.array(asrx_full["embeddings"][i])) centroids = {} for name, el in ref.items(): if el: c = np.mean(el, axis=0) centroids[name] = c / (np.linalg.norm(c) + 1e-10) name_to_speaker = {} for i, seg in enumerate(asrx_full["segments"]): name = name_by_idx.get(i, "Unknown") sid = seg["speaker_id"] name_to_speaker.setdefault(name, sid) return centroids, name_to_speaker def extract_audio(video_path, sr=16000): tmp = tempfile.mkdtemp(prefix="asr_split_") wav = os.path.join(tmp, "audio.wav") subprocess.run(["ffmpeg", "-y", "-v", "quiet", "-i", video_path, "-ar", str(sr), "-ac", "1", "-sample_fmt", "s16", wav], check=True, capture_output=True, timeout=300) wav_data, sr_actual = torchaudio.load(wav) if wav_data.shape[0] > 1: wav_data = wav_data.mean(dim=0, keepdim=True) return wav_data, sr_actual, tmp def classify(emb, centroids): return max(centroids, key=lambda n: float(np.dot(emb, centroids[n]))) def process_batch(asr_segs, wav, sr, centroids, encoder, offset_start=0): ws = int(SUB_WIN * sr) sw = int(SUB_STRIDE * sr) results = [] for si, s in enumerate(asr_segs): st = s["start"] - offset_start et = s["end"] - offset_start dur = et - st if dur < 1.0: a = wav[:, int(st*sr):int(et*sr)] e = extract_speaker_embedding(encoder, a.numpy(), sr) e /= np.linalg.norm(e) + 1e-10 results.append((s["start"], s["end"], classify(e, centroids), si)) continue ss = int(st*sr); se = int(et*sr) sub_e, sub_t = [], [] for wpos in range(ss, se-ws+1, sw): chunk = wav[:, wpos:wpos+ws] sub_e.append(extract_speaker_embedding(encoder, chunk.numpy(), sr)) sub_t.append(wpos/sr + offset_start) if len(sub_e) < 3: a = wav[:, ss:se] e = extract_speaker_embedding(encoder, a.numpy(), sr) e /= np.linalg.norm(e) + 1e-10 results.append((s["start"], s["end"], classify(e, centroids), si)) continue sub_e = normalize_embeddings(np.array(sub_e)) names = [] for i in range(len(sub_e)): names.append(classify(sub_e[i], centroids)) # Smooth sm = list(names) for i in range(1, len(names)-1): sm[i] = Counter(names[max(0,i-1):min(len(names),i+2)]).most_common(1)[0][0] # Find splits splits = [] prev = sm[0] for i in range(1, len(sm)): if sm[i] != prev: if i+CHANGE_CONFIRM < len(sm) and all(sm[i]==sm[j] for j in range(i, i+CHANGE_CONFIRM+1)): splits.append(sub_t[i]); prev = sm[i] elif i+CHANGE_CONFIRM >= len(sm): splits.append(sub_t[i]); prev = sm[i] if not splits: results.append((s["start"], s["end"], Counter(names).most_common(1)[0][0], si)) else: boundaries = [s["start"]] + splits + [s["end"]] for pi in range(len(boundaries)-1): ps, pe = boundaries[pi], boundaries[pi+1] if pe-ps < MIN_DUR: continue sub_i = [i for i, t in enumerate(sub_t) if ps <= t < pe] lbl = Counter([names[i] for i in sub_i]).most_common(1)[0][0] if sub_i else Counter(names).most_common(1)[0][0] results.append((round(ps,2), round(pe,2), lbl, si)) return results def main(): parser = argparse.ArgumentParser() parser.add_argument("--uuid", default="aeed71342a899fe4b4c57b7d41bcb692") parser.add_argument("--output", help="Output path for fine ASRX JSON") args = parser.parse_args() UUID = args.uuid BASE = "/Users/accusys/momentry/output_dev" DB_URL = "postgresql://accusys@localhost:5432/momentry?host=/tmp" VIDEO = "/Users/accusys/momentry/var/sftpgo/data/demo/Charade (1963) Cary Grant & Audrey Hepburn \uff5c Comedy Mystery Romance Thriller \uff5c Full Movie.mp4" print(f"Processing {UUID}") centroids, name_to_speaker = load_reference(UUID, DB_URL) print(f"Centroids: {list(centroids.keys())}") asr = json.load(open(f"{BASE}/{UUID}.asr.json")) asr_segs = asr["segments"] print(f"ASR segments: {len(asr_segs)}") print("Extracting audio...") wav, sr, tmp_dir = extract_audio(VIDEO) print(f"Audio: {wav.shape[1]/sr:.0f}s") inst = SelfASRXFixed() encoder = inst.speaker_encoder all_results = [] t0 = time.time() for batch_start in range(0, len(asr_segs), BATCH_SIZE): batch = asr_segs[batch_start:batch_start + BATCH_SIZE] segs = process_batch(batch, wav, sr, centroids, encoder) all_results.extend(segs) pct = (batch_start + len(batch)) * 100 // len(asr_segs) print(f" {batch_start+len(batch)}/{len(asr_segs)} ({pct}%) -> {len(all_results)} segments [{time.time()-t0:.0f}s]") shutil.rmtree(tmp_dir, ignore_errors=True) # Build output spk_stats = {} out_segs = [] # Assign sequential SPEAKER_X IDs based on name order name_order = {name: i for i, name in enumerate(sorted(set(s[2] for s in all_results)))} for start, end, name, asr_idx in all_results: sid = f"SPEAKER_{name_order[name]}" dur = end - start spk_stats.setdefault(sid, {"count": 0, "duration": 0}) spk_stats[sid]["count"] += 1 spk_stats[sid]["duration"] += dur out_segs.append({ "start_time": start, "end_time": end, "speaker_id": sid, "speaker_name": name, "parent_asr_idx": asr_idx, }) output = { "uuid": UUID, "language": "en", "segments": out_segs, "speaker_stats": spk_stats, "total_asr_segments": len(asr_segs), "total_fine_segments": len(out_segs), } output_path = args.output or f"{BASE}/{UUID}.asrx_fine.json" json.dump(output, open(output_path, "w"), indent=2) print(f"\nSaved: {output_path}") print(f"Segments: {len(out_segs)} (was {len(asr_segs)}, +{len(out_segs)-len(asr_segs)})") print(f"Speakers: {len(spk_stats)}") for sid, st in sorted(spk_stats.items()): print(f" {sid}: {st['count']} segs, {st['duration']:.0f}s") if __name__ == "__main__": main()