feat: Phase 1 handover - schema migration, correction mechanism, API fixes
Schema changes: dev.chunks->dev.chunk, remove old_chunk_id/chunk_index Correction: asr-1.json format, generate/apply scripts API: 37/37 endpoints fixed and tested Docs: HANDOVER_V2.0.md for M4
This commit is contained in:
204
scripts/split_asr_segments.py
Normal file
204
scripts/split_asr_segments.py
Normal file
@@ -0,0 +1,204 @@
|
||||
#!/opt/homebrew/bin/python3.11
|
||||
"""
|
||||
Split ASR segments at detected speaker change points.
|
||||
Uses ECAPA-TDNN sub-window classification against reference centroids.
|
||||
|
||||
Output: new asrx_fine.json with fine-grained segments + parent_asr_idx reference.
|
||||
"""
|
||||
import json, sys, os, time, argparse, subprocess, tempfile, shutil
|
||||
import numpy as np
|
||||
from collections import Counter
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), "asrx_self"))
|
||||
from main_fixed import SelfASRXFixed
|
||||
from speaker_encoder import extract_speaker_embedding, normalize_embeddings
|
||||
import torchaudio, psycopg2
|
||||
|
||||
SUB_WIN = 0.5
|
||||
SUB_STRIDE = 0.25
|
||||
CHANGE_CONFIRM = 2
|
||||
MIN_DUR = 0.7
|
||||
BATCH_SIZE = 500
|
||||
|
||||
def load_reference(uuid, db_url):
|
||||
conn = psycopg2.connect(db_url)
|
||||
cur = conn.cursor()
|
||||
cur.execute("SELECT chunk_index, metadata->>'new_speaker_name' FROM dev.chunks WHERE file_uuid=%s AND chunk_type='sentence' ORDER BY chunk_index", (uuid,))
|
||||
name_by_idx = dict(cur.fetchall())
|
||||
conn.close()
|
||||
|
||||
asrx_path = f"/Users/accusys/momentry/output_dev/{uuid}.asrx.json"
|
||||
asrx_full = json.load(open(asrx_path))
|
||||
ref = {"Cary Grant": [], "Audrey Hepburn": [], "Unknown": []}
|
||||
for i, seg in enumerate(asrx_full["segments"]):
|
||||
name = name_by_idx.get(i, "Unknown")
|
||||
if name in ref and i < len(asrx_full.get("embeddings", [])):
|
||||
ref[name].append(np.array(asrx_full["embeddings"][i]))
|
||||
|
||||
centroids = {}
|
||||
for name, el in ref.items():
|
||||
if el:
|
||||
c = np.mean(el, axis=0)
|
||||
centroids[name] = c / (np.linalg.norm(c) + 1e-10)
|
||||
name_to_speaker = {}
|
||||
for i, seg in enumerate(asrx_full["segments"]):
|
||||
name = name_by_idx.get(i, "Unknown")
|
||||
sid = seg["speaker_id"]
|
||||
name_to_speaker.setdefault(name, sid)
|
||||
return centroids, name_to_speaker
|
||||
|
||||
def extract_audio(video_path, sr=16000):
|
||||
tmp = tempfile.mkdtemp(prefix="asr_split_")
|
||||
wav = os.path.join(tmp, "audio.wav")
|
||||
subprocess.run(["ffmpeg", "-y", "-v", "quiet", "-i", video_path,
|
||||
"-ar", str(sr), "-ac", "1", "-sample_fmt", "s16", wav], check=True, capture_output=True, timeout=300)
|
||||
wav_data, sr_actual = torchaudio.load(wav)
|
||||
if wav_data.shape[0] > 1:
|
||||
wav_data = wav_data.mean(dim=0, keepdim=True)
|
||||
return wav_data, sr_actual, tmp
|
||||
|
||||
def classify(emb, centroids):
|
||||
return max(centroids, key=lambda n: float(np.dot(emb, centroids[n])))
|
||||
|
||||
def process_batch(asr_segs, wav, sr, centroids, encoder, offset_start=0):
|
||||
ws = int(SUB_WIN * sr)
|
||||
sw = int(SUB_STRIDE * sr)
|
||||
results = []
|
||||
for si, s in enumerate(asr_segs):
|
||||
st = s["start"] - offset_start
|
||||
et = s["end"] - offset_start
|
||||
dur = et - st
|
||||
|
||||
if dur < 1.0:
|
||||
a = wav[:, int(st*sr):int(et*sr)]
|
||||
e = extract_speaker_embedding(encoder, a.numpy(), sr)
|
||||
e /= np.linalg.norm(e) + 1e-10
|
||||
results.append((s["start"], s["end"], classify(e, centroids), si))
|
||||
continue
|
||||
|
||||
ss = int(st*sr); se = int(et*sr)
|
||||
sub_e, sub_t = [], []
|
||||
for wpos in range(ss, se-ws+1, sw):
|
||||
chunk = wav[:, wpos:wpos+ws]
|
||||
sub_e.append(extract_speaker_embedding(encoder, chunk.numpy(), sr))
|
||||
sub_t.append(wpos/sr + offset_start)
|
||||
|
||||
if len(sub_e) < 3:
|
||||
a = wav[:, ss:se]
|
||||
e = extract_speaker_embedding(encoder, a.numpy(), sr)
|
||||
e /= np.linalg.norm(e) + 1e-10
|
||||
results.append((s["start"], s["end"], classify(e, centroids), si))
|
||||
continue
|
||||
|
||||
sub_e = normalize_embeddings(np.array(sub_e))
|
||||
names = []
|
||||
for i in range(len(sub_e)):
|
||||
names.append(classify(sub_e[i], centroids))
|
||||
|
||||
# Smooth
|
||||
sm = list(names)
|
||||
for i in range(1, len(names)-1):
|
||||
sm[i] = Counter(names[max(0,i-1):min(len(names),i+2)]).most_common(1)[0][0]
|
||||
|
||||
# Find splits
|
||||
splits = []
|
||||
prev = sm[0]
|
||||
for i in range(1, len(sm)):
|
||||
if sm[i] != prev:
|
||||
if i+CHANGE_CONFIRM < len(sm) and all(sm[i]==sm[j] for j in range(i, i+CHANGE_CONFIRM+1)):
|
||||
splits.append(sub_t[i]); prev = sm[i]
|
||||
elif i+CHANGE_CONFIRM >= len(sm):
|
||||
splits.append(sub_t[i]); prev = sm[i]
|
||||
|
||||
if not splits:
|
||||
results.append((s["start"], s["end"], Counter(names).most_common(1)[0][0], si))
|
||||
else:
|
||||
boundaries = [s["start"]] + splits + [s["end"]]
|
||||
for pi in range(len(boundaries)-1):
|
||||
ps, pe = boundaries[pi], boundaries[pi+1]
|
||||
if pe-ps < MIN_DUR: continue
|
||||
sub_i = [i for i, t in enumerate(sub_t) if ps <= t < pe]
|
||||
lbl = Counter([names[i] for i in sub_i]).most_common(1)[0][0] if sub_i else Counter(names).most_common(1)[0][0]
|
||||
results.append((round(ps,2), round(pe,2), lbl, si))
|
||||
|
||||
return results
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--uuid", default="aeed71342a899fe4b4c57b7d41bcb692")
|
||||
parser.add_argument("--output", help="Output path for fine ASRX JSON")
|
||||
args = parser.parse_args()
|
||||
|
||||
UUID = args.uuid
|
||||
BASE = "/Users/accusys/momentry/output_dev"
|
||||
DB_URL = "postgresql://accusys@localhost:5432/momentry?host=/tmp"
|
||||
VIDEO = "/Users/accusys/momentry/var/sftpgo/data/demo/Charade (1963) Cary Grant & Audrey Hepburn \uff5c Comedy Mystery Romance Thriller \uff5c Full Movie.mp4"
|
||||
|
||||
print(f"Processing {UUID}")
|
||||
|
||||
centroids, name_to_speaker = load_reference(UUID, DB_URL)
|
||||
print(f"Centroids: {list(centroids.keys())}")
|
||||
|
||||
asr = json.load(open(f"{BASE}/{UUID}.asr.json"))
|
||||
asr_segs = asr["segments"]
|
||||
print(f"ASR segments: {len(asr_segs)}")
|
||||
|
||||
print("Extracting audio...")
|
||||
wav, sr, tmp_dir = extract_audio(VIDEO)
|
||||
print(f"Audio: {wav.shape[1]/sr:.0f}s")
|
||||
|
||||
inst = SelfASRXFixed()
|
||||
encoder = inst.speaker_encoder
|
||||
|
||||
all_results = []
|
||||
t0 = time.time()
|
||||
for batch_start in range(0, len(asr_segs), BATCH_SIZE):
|
||||
batch = asr_segs[batch_start:batch_start + BATCH_SIZE]
|
||||
segs = process_batch(batch, wav, sr, centroids, encoder)
|
||||
all_results.extend(segs)
|
||||
pct = (batch_start + len(batch)) * 100 // len(asr_segs)
|
||||
print(f" {batch_start+len(batch)}/{len(asr_segs)} ({pct}%) -> {len(all_results)} segments [{time.time()-t0:.0f}s]")
|
||||
|
||||
shutil.rmtree(tmp_dir, ignore_errors=True)
|
||||
|
||||
# Build output
|
||||
spk_stats = {}
|
||||
out_segs = []
|
||||
# Assign sequential SPEAKER_X IDs based on name order
|
||||
name_order = {name: i for i, name in enumerate(sorted(set(s[2] for s in all_results)))}
|
||||
|
||||
for start, end, name, asr_idx in all_results:
|
||||
sid = f"SPEAKER_{name_order[name]}"
|
||||
dur = end - start
|
||||
spk_stats.setdefault(sid, {"count": 0, "duration": 0})
|
||||
spk_stats[sid]["count"] += 1
|
||||
spk_stats[sid]["duration"] += dur
|
||||
out_segs.append({
|
||||
"start_time": start,
|
||||
"end_time": end,
|
||||
"speaker_id": sid,
|
||||
"speaker_name": name,
|
||||
"parent_asr_idx": asr_idx,
|
||||
})
|
||||
|
||||
output = {
|
||||
"uuid": UUID,
|
||||
"language": "en",
|
||||
"segments": out_segs,
|
||||
"speaker_stats": spk_stats,
|
||||
"total_asr_segments": len(asr_segs),
|
||||
"total_fine_segments": len(out_segs),
|
||||
}
|
||||
|
||||
output_path = args.output or f"{BASE}/{UUID}.asrx_fine.json"
|
||||
json.dump(output, open(output_path, "w"), indent=2)
|
||||
print(f"\nSaved: {output_path}")
|
||||
print(f"Segments: {len(out_segs)} (was {len(asr_segs)}, +{len(out_segs)-len(asr_segs)})")
|
||||
print(f"Speakers: {len(spk_stats)}")
|
||||
for sid, st in sorted(spk_stats.items()):
|
||||
print(f" {sid}: {st['count']} segs, {st['duration']:.0f}s")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user