feat: ASRX hybrid pipeline, identity history, worker fixes, checkpoint system

This commit is contained in:
Accusys
2026-06-02 07:13:23 +08:00
parent e3066c3f49
commit e1572907ae
198 changed files with 43705 additions and 8910 deletions

View File

@@ -1,308 +1,728 @@
#!/opt/homebrew/bin/python3.11
"""
Self-implemented ASRX - Fixed Version
使用魯棒的聚類算法
SelfASRXFixed - 7 步 Hybrid Speaker Diarization Pipeline
Pipeline:
1. whisper.transcribe(full_audio) → rough segments + text + language
2. VAD scan each rough segment → refined segments
3. whisper per refined segment → {text, language, lang_prob}
4. ECAPA-TDNN per refined segment → 192-dim embeddings
5. AgglomerativeClustering → speaker_labels
6. Store all embeddings in Qdrant (payload: file_uuid, speaker_id, text, ...)
7. High-quality embeddings → gender classify + store reference in Qdrant
"""
import sys
import json
import time
import os
import numpy as np
from pathlib import Path
from urllib.request import Request, urlopen
from urllib.error import URLError
# 導入自定義模組
from vad import load_vad_model, extract_speech_segments
from speaker_encoder import (
load_speaker_encoder,
extract_speaker_embeddings_batch,
normalize_embeddings
)
from speaker_cluster_fixed import robust_speaker_clustering
def _load_audio(path):
"""載入音頻文件,回傳 (wav_numpy, sample_rate)"""
import soundfile as sf
wav, sr = sf.read(path)
if len(wav.shape) > 1:
wav = np.mean(wav, axis=1)
return wav, sr
def _load_whisper_model(size="small"):
from whisper_local import load_model
return load_model(size)
def _load_vad():
from vad import load_vad_model
return load_vad_model()
def _load_speaker_encoder():
from speaker_encoder import load_speaker_encoder
return load_speaker_encoder()
def _load_gender_classifier():
try:
from speechbrain.inference.classifiers import EncoderClassifier
classifier = EncoderClassifier.from_hparams(
source="speechbrain/gender-recognition-ecapa",
run_opts={"device": "cpu"},
)
print("[Gender] Classifier loaded: speechbrain/gender-recognition-ecapa")
return classifier
except Exception as e:
print(f"[Gender] Classifier not available: {e}")
return None
def _ensure_speaker_collection(qdrant_url, api_key, collection):
"""確認 Qdrant speaker collection 存在,不存在則建立 (dim=192, cosine)"""
try:
url = f"{qdrant_url}/collections/{collection}"
req = Request(url, method="GET",
headers={"api-key": api_key} if api_key else {})
try:
urlopen(req)
return True
except URLError as e:
if getattr(e, "code", None) == 404:
body = json.dumps({
"vectors": {
"size": 192,
"distance": "Cosine"
}
}).encode()
req = Request(url, data=body, method="PUT",
headers={"Content-Type": "application/json",
**({"api-key": api_key} if api_key else {})})
urlopen(req)
print(f"[Qdrant] Created collection: {collection} (dim=192)")
return True
raise
except Exception as e:
print(f"[Qdrant] Cannot access Qdrant: {e}")
return False
def _qdrant_upsert(qdrant_url, api_key, collection, points):
"""批量寫入 Qdrant points"""
try:
url = f"{qdrant_url}/collections/{collection}/points?wait=true"
body = json.dumps({"points": points}).encode()
headers = {"Content-Type": "application/json"}
if api_key:
headers["api-key"] = api_key
req = Request(url, data=body, headers=headers, method="PUT")
urlopen(req)
return True
except Exception as e:
print(f"[Qdrant] Upsert failed: {e}")
return False
def _hash_point_id(file_uuid, label):
"""產生一致的 point ID"""
s = f"{file_uuid}_{label}"
return hash(s) & 0x7FFFFFFFFFFFFFFF
def _save_checkpoint(path: str, data: dict):
"""原子寫入 checkpoint先 .tmp 再 rename"""
tmp = path + ".tmp"
Path(tmp).parent.mkdir(parents=True, exist_ok=True)
with open(tmp, "w", encoding="utf-8") as f:
json.dump(data, f, indent=2, ensure_ascii=False)
os.replace(tmp, path)
def compute_embedding_quality(embeddings, labels):
"""每個 embedding 到所屬 cluster centroid 的餘弦相似度"""
from sklearn.metrics.pairwise import cosine_similarity
unique_labels = set(labels)
centroids = {}
for label in unique_labels:
mask = labels == label
centroid = np.mean(embeddings[mask], axis=0)
norm = np.linalg.norm(centroid)
if norm > 0:
centroid = centroid / norm
centroids[label] = centroid
qualities = []
for emb, label in zip(embeddings, labels):
sim = cosine_similarity([emb], [centroids[label]])[0][0]
qualities.append(sim)
return np.array(qualities)
class SelfASRXFixed:
"""自實作說話人分離系統(修復版)"""
"""7 步 Hybrid Speaker Diarization Pipeline"""
def __init__(self):
print("[SelfASRX-Fixed] Initializing models...")
# 載入 VAD 模型
print("[SelfASRX-Fixed] Loading VAD model (Silero)...")
self.vad_model, self.vad_utils = load_vad_model()
# 載入聲紋模型
print("[SelfASRX-Fixed] Loading speaker encoder (ECAPA-TDNN)...")
self.speaker_encoder = load_speaker_encoder()
print("[SelfASRX-Fixed] Models loaded successfully")
def process(self, audio_path, output_path=None,
min_speech_duration_ms=500,
n_speakers=None,
max_speakers=10):
"""處理音頻文件"""
start_time = time.time()
print(f"\n[SelfASRX-Fixed] Processing: {audio_path}")
print("=" * 60)
# 步驟 1: VAD
print("\n[Step 1] Voice Activity Detection...")
step1_start = time.time()
speech_segments, wav, sample_rate = extract_speech_segments(
audio_path, self.vad_model, self.vad_utils,
min_speech_duration_ms=min_speech_duration_ms
)
step1_time = time.time() - step1_start
print(f" Speech segments: {len(speech_segments)}")
print(f" Total duration: {len(wav)/sample_rate:.2f}s")
print(f" VAD time: {step1_time:.2f}s")
if len(speech_segments) == 0:
print("[SelfASRX-Fixed] No speech detected!")
return {"error": "No speech detected", "segments": []}
# 步驟 2: 聲紋特徵提取
print("\n[Step 2] Speaker embedding extraction...")
step2_start = time.time()
# 提取語音片段音頻
audio_segments = []
for start_sec, end_sec in speech_segments:
start_sample = int(start_sec * sample_rate)
end_sample = int(end_sec * sample_rate)
audio_segments.append(wav[start_sample:end_sample])
# 批量提取嵌入
embeddings = extract_speaker_embeddings_batch(
self.speaker_encoder, audio_segments, sample_rate
)
# 正規化
embeddings = normalize_embeddings(embeddings)
step2_time = time.time() - step2_start
print(f" Embedding shape: {embeddings.shape}")
print(f" Embedding time: {step2_time:.2f}s")
# 步驟 3: 魯棒聚類
print("\n[Step 3] Robust speaker clustering...")
step3_start = time.time()
speaker_labels, estimated_n_speakers = robust_speaker_clustering(
embeddings,
n_speakers=n_speakers,
max_speakers=max_speakers
)
step3_time = time.time() - step3_start
print(f" Clustering time: {step3_time:.2f}s")
# 步驟 4: 建立輸出
print("\n[Step 4] Building output...")
result = {
"audio_path": str(audio_path),
"total_duration": len(wav) / sample_rate,
"n_speech_segments": len(speech_segments),
"n_speakers": int(estimated_n_speakers),
"segments": []
}
for i, ((start, end), label) in enumerate(zip(speech_segments, speaker_labels)):
result["segments"].append({
"index": i,
"start": round(start, 3),
"end": round(end, 3),
"duration": round(end - start, 3),
"speaker": f"SPEAKER_{int(label)}"
})
# 統計每個說話人的總時長
speaker_stats = {}
for seg in result["segments"]:
speaker = seg["speaker"]
if speaker not in speaker_stats:
speaker_stats[speaker] = {"count": 0, "duration": 0}
speaker_stats[speaker]["count"] += 1
speaker_stats[speaker]["duration"] += seg["duration"]
result["speaker_stats"] = speaker_stats
total_time = time.time() - start_time
result["processing_time"] = round(total_time, 2)
result["realtime_factor"] = round(result["total_duration"] / total_time, 2)
print("\n[SelfASRX-Fixed] Processing completed!")
print(f" Total time: {total_time:.2f}s")
print(f" Realtime factor: {result['realtime_factor']:.2f}x")
print(f" Detected speakers: {estimated_n_speakers}")
# 保存結果
if output_path:
output_path = Path(output_path)
output_path.parent.mkdir(parents=True, exist_ok=True)
with open(output_path, 'w', encoding='utf-8') as f:
json.dump(result, f, indent=2, ensure_ascii=False)
print(f" Results saved to: {output_path}")
print("=" * 60)
return result
print("[SelfASRX] Initializing models...")
print("[SelfASRX] Loading whisper model...")
self.whisper = _load_whisper_model("small")
print("[SelfASRX] Loading VAD model (Silero)...")
self.vad_model, self.vad_utils = _load_vad()
print("[SelfASRX] Loading speaker encoder (ECAPA-TDNN)...")
self.speaker_encoder = _load_speaker_encoder()
print("[SelfASRX] Loading gender classifier...")
self.gender_classifier = _load_gender_classifier()
# Qdrant 設定
self.qdrant_url = os.environ.get("QDRANT_URL", "http://localhost:6333")
self.qdrant_api_key = os.environ.get("QDRANT_API_KEY", "")
schema = os.environ.get("DATABASE_SCHEMA", "public")
self.qdrant_collection = os.environ.get(
"QDRANT_SPEAKER_COLLECTION",
f"momentry_{schema}_speaker"
)
self._qdrant_ok = False
print("[SelfASRX] Models loaded successfully")
def process(self, audio_path, output_path=None, file_uuid=None,
max_speakers=10, quality_threshold=0.85,
checkpoint_path=None):
"""7 步 speaker diarization pipeline
def process_with_segments(self, audio_path, asr_segments, output_path=None):
"""
使用 ASR segment 邊界進行 speaker diarization取代 VAD 步驟。
Args:
audio_path: 音頻文件路徑WAV
asr_segments: ASR segment 列表,每個包含 start/end
output_path: 輸出 JSON 路徑(可選)
audio_path: 音頻文件路徑 (WAV 16kHz mono)
output_path: 輸出 JSON 路徑 (可選)
file_uuid: 檔案 UUID (用於 Qdrant 儲存)
max_speakers: 最大說話人數
quality_threshold: 高品質聲紋門檻 (0-1)
checkpoint_path: Step 3 完成後儲存 checkpoint 路徑
Returns:
dict: segments, speaker_stats, n_speakers, total_duration, references
"""
start_time = time.time()
print(f"\n[SelfASRX-Fixed] Processing with {len(asr_segments)} ASR segments: {audio_path}")
print(f"\n[SelfASRX] Processing: {audio_path}")
print("=" * 60)
# 載入完整音頻
import soundfile as sf
wav, sample_rate = sf.read(audio_path)
if len(wav.shape) > 1:
wav = np.mean(wav, axis=1) # 轉 mono
print(f" Audio loaded: {len(wav)/sample_rate:.2f}s, {sample_rate}Hz")
# 載入音頻
wav, sample_rate = _load_audio(audio_path)
total_duration = len(wav) / sample_rate
print(f" Audio: {total_duration:.2f}s, {sample_rate}Hz")
# 使用 ASR segments 取代 VAD (audio处理用time)
speech_segments = [(s["start_time"], s["end_time"]) for s in asr_segments]
print(f" Speech segments from ASR: {len(speech_segments)}")
# ── Step 1: whisper 粗略定位 (faster-whisper) ──
print("\n[Step 1] Initial whisper transcription...")
t1 = time.time()
seg_gen, info = self.whisper.transcribe(audio_path)
rough_segments = []
for seg in seg_gen:
rough_segments.append({"start": seg.start, "end": seg.end, "text": seg.text})
language = info.language if info else None
print(f" Rough segments: {len(rough_segments)}")
print(f" Language: {language}")
print(f" Step 1 time: {time.time() - t1:.2f}s")
if len(speech_segments) == 0:
print("[SelfASRX-Fixed] No ASR segments provided!")
return {"error": "No ASR segments", "segments": []}
if not rough_segments:
print("[SelfASRX] No speech detected by whisper!")
return {"error": "No speech detected", "segments": []}
# 提取語音片段
audio_segments = []
for start_sec, end_sec in speech_segments:
start_sample = int(start_sec * sample_rate)
end_sample = int(end_sec * sample_rate)
if start_sample >= len(wav):
# ── Step 2: VAD scan 每個 rough segment 細切 ──
print("\n[Step 2] VAD scan for refined segmentation...")
t2 = time.time()
refined_segments = []
for seg in rough_segments:
s = seg["start"]
e = seg["end"]
sub = self._vad_scan_segment(wav, sample_rate, s, e)
if sub:
refined_segments.extend(sub)
else:
refined_segments.append((s, e))
print(f" Refined segments: {len(refined_segments)}")
print(f" Step 2 time: {time.time() - t2:.2f}s")
if not refined_segments:
return {"error": "No segments after VAD scan", "segments": []}
# ── Step 3: whisper per refined segment ──
print("\n[Step 3] Per-segment transcription...")
t3 = time.time()
CHECKPOINT_INTERVAL = 50
segment_texts = []
resume_from = 0
# 載入既有 partial checkpoint中斷續接
if checkpoint_path and os.path.exists(checkpoint_path):
try:
with open(checkpoint_path, "r") as f:
cp = json.load(f)
if cp.get("checkpoint_version") == 2 and not cp.get("step3_completed"):
saved = cp.get("segment_texts", [])
if saved:
resume_from = len(saved)
segment_texts = saved
print(f"[Step 3] Resuming from #{resume_from}/{len(refined_segments)}")
except Exception:
pass
for i, (start_sec, end_sec) in enumerate(refined_segments):
if i < resume_from:
continue
audio_segments.append(wav[start_sample:min(end_sample, len(wav))])
seg_text = self._transcribe_segment(wav, sample_rate, start_sec, end_sec)
segment_texts.append(seg_text)
print(f" Audio segments extracted: {len(audio_segments)}")
if checkpoint_path and (i + 1) % CHECKPOINT_INTERVAL == 0:
_save_checkpoint(checkpoint_path, {
"checkpoint_version": 2,
"step3_completed": False,
"step3_progress": i + 1,
"language": language,
"total_duration": total_duration,
"refined_segments": [[s, e] for s, e in refined_segments],
"segment_texts": [{
"text": st["text"],
"language": st["language"],
"lang_prob": st["lang_prob"],
} for st in segment_texts],
"file_uuid": file_uuid,
"max_speakers": max_speakers,
"quality_threshold": quality_threshold,
})
print(f"[Checkpoint] Step 3: {i+1}/{len(refined_segments)}")
# 批量提取聲紋嵌入
print("\n[Step 2] Speaker embedding extraction...")
step2_start = time.time()
print(f" Step 3 time: {time.time() - t3:.2f}s")
# ── Save final checkpoint after Step 3 ──
if checkpoint_path:
_save_checkpoint(checkpoint_path, {
"checkpoint_version": 2,
"step3_completed": True,
"language": language,
"total_duration": total_duration,
"refined_segments": [[s, e] for s, e in refined_segments],
"segment_texts": [{
"text": st["text"],
"language": st["language"],
"lang_prob": st["lang_prob"],
} for st in segment_texts],
"file_uuid": file_uuid,
"max_speakers": max_speakers,
"quality_threshold": quality_threshold,
})
print(f"[Checkpoint] Step 3 complete, saved to {checkpoint_path}")
# ── Step 4: ECAPA-TDNN per refined segment ──
print("\n[Step 4] Speaker embedding extraction...")
t4 = time.time()
audio_segments = []
for start_sec, end_sec in refined_segments:
s = int(start_sec * sample_rate)
e = int(end_sec * sample_rate)
audio_segments.append(wav[s:min(e, len(wav))])
from speaker_encoder import extract_speaker_embeddings_batch, normalize_embeddings
embeddings = extract_speaker_embeddings_batch(
self.speaker_encoder, audio_segments, sample_rate
)
embeddings = normalize_embeddings(embeddings)
step2_time = time.time() - step2_start
print(f" Embedding shape: {embeddings.shape}")
print(f" Embedding time: {step2_time:.2f}s")
print(f" Embeddings: {embeddings.shape}")
print(f" Step 4 time: {time.time() - t4:.2f}s")
# 聚類
print("\n[Step 3] Robust speaker clustering...")
step3_start = time.time()
# ── Step 5: AgglomerativeClustering ──
print("\n[Step 5] Speaker clustering...")
t5 = time.time()
from speaker_cluster_fixed import robust_speaker_clustering
speaker_labels, estimated_n_speakers = robust_speaker_clustering(
embeddings, n_speakers=None, max_speakers=10
embeddings, n_speakers=None, max_speakers=max_speakers
)
step3_time = time.time() - step3_start
print(f" Clustering time: {step3_time:.2f}s")
print(f" Speakers: {estimated_n_speakers}")
print(f" Step 5 time: {time.time() - t5:.2f}s")
# 建立輸出
result = {
"audio_path": str(audio_path),
"total_duration": len(wav) / sample_rate,
"n_speech_segments": len(speech_segments),
"n_speakers": int(estimated_n_speakers),
"segments": []
}
# 品質計算
qualities = compute_embedding_quality(embeddings, speaker_labels)
for i, ((start, end), label) in enumerate(zip(speech_segments, speaker_labels)):
result["segments"].append({
"index": i,
"start": round(start, 3),
"end": round(end, 3),
"duration": round(end - start, 3),
"speaker": f"SPEAKER_{int(label)}"
})
# 加入 embeddings每個 segment 對應的 192-D speaker embedding
result["embeddings"] = []
for emb in embeddings:
result["embeddings"].append(emb.tolist())
# 建立輸出 segments
segments = []
for i, ((start_sec, end_sec), label) in enumerate(
zip(refined_segments, speaker_labels)):
seg = {
"start": round(start_sec, 3),
"end": round(end_sec, 3),
"start_frame": int(start_sec * 30),
"end_frame": int(end_sec * 30),
"text": segment_texts[i]["text"],
"language": segment_texts[i]["language"],
"lang_prob": segment_texts[i]["lang_prob"],
"speaker": f"SPEAKER_{int(label)}",
"speaker_id": f"SPEAKER_{int(label)}",
"quality": float(qualities[i]),
}
segments.append(seg)
# 統計
speaker_stats = {}
for seg in result["segments"]:
speaker = seg["speaker"]
if speaker not in speaker_stats:
speaker_stats[speaker] = {"count": 0, "duration": 0}
speaker_stats[speaker]["count"] += 1
speaker_stats[speaker]["duration"] += seg["duration"]
result["speaker_stats"] = speaker_stats
for seg in segments:
spk = seg["speaker_id"]
dur = seg["end"] - seg["start"]
if spk not in speaker_stats:
speaker_stats[spk] = {"count": 0, "duration": 0}
speaker_stats[spk]["count"] += 1
speaker_stats[spk]["duration"] += dur
result = {
"language": language or "",
"segments": segments,
"n_speakers": int(estimated_n_speakers),
"speaker_stats": speaker_stats,
"total_duration": total_duration,
"n_segments": len(segments),
}
# ── Step 6: Store embeddings in Qdrant ──
if file_uuid:
print("\n[Step 6] Storing embeddings in Qdrant...")
t6 = time.time()
self._store_speaker_embeddings(segments, embeddings, speaker_labels,
file_uuid)
print(f" Step 6 time: {time.time() - t6:.2f}s")
# ── Step 7: High-quality classification ──
if file_uuid:
print("\n[Step 7] Classifying high-quality embeddings...")
t7 = time.time()
references = self._classify_high_quality_speakers(
segments, embeddings, speaker_labels, file_uuid,
wav, sample_rate, quality_threshold
)
if references:
result["references"] = references
print(f" Step 7 time: {time.time() - t7:.2f}s")
total_time = time.time() - start_time
result["processing_time"] = round(total_time, 2)
result["realtime_factor"] = round(result["total_duration"] / total_time, 2)
print("\n[SelfASRX-Fixed] Processing completed!")
print(f" Total time: {total_time:.2f}s")
print(f" Realtime factor: {result['realtime_factor']:.2f}x")
print(f" Detected speakers: {estimated_n_speakers}")
if total_duration > 0:
result["realtime_factor"] = round(total_duration / total_time, 2)
# 保存輸出
if output_path:
import json
with open(output_path, 'w', encoding='utf-8') as f:
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
with open(output_path, "w", encoding="utf-8") as f:
json.dump(result, f, indent=2, ensure_ascii=False)
print(f" Results saved to: {output_path}")
print(f"\n[SelfASRX] Saved to: {output_path}")
print(f"\n[SelfASRX] Done! {len(segments)} segments, "
f"{estimated_n_speakers} speakers, "
f"{total_time:.2f}s")
print("=" * 60)
return result
def resume_from_checkpoint(self, checkpoint_path, audio_path,
output_path=None):
"""從 checkpoint 載入 Steps 1-3 結果,執行 Steps 4-7"""
print(f"\n[SelfASRX] Resuming from checkpoint: {checkpoint_path}")
print("=" * 60)
with open(checkpoint_path, "r", encoding="utf-8") as f:
cp = json.load(f)
if not cp.get("step3_completed"):
error_msg = f"Checkpoint step3 not completed (progress: {cp.get('step3_progress', '?')})"
print(f"[SelfASRX] {error_msg}")
return {"error": error_msg, "segments": []}
wav, sample_rate = _load_audio(audio_path)
refined_segments = [tuple(s) for s in cp["refined_segments"]]
segment_texts = cp["segment_texts"]
language = cp.get("language", "")
total_duration = cp.get("total_duration", 0)
file_uuid = cp.get("file_uuid")
max_speakers = cp.get("max_speakers", 10)
quality_threshold = cp.get("quality_threshold", 0.85)
print(f" Loaded checkpoint: {len(refined_segments)} segments, "
f"language={language}, duration={total_duration:.2f}s")
start_time = time.time()
# ── Step 4: ECAPA-TDNN per refined segment ──
print("\n[Step 4] Speaker embedding extraction...")
t4 = time.time()
audio_segments = []
for start_sec, end_sec in refined_segments:
s = int(start_sec * sample_rate)
e = int(end_sec * sample_rate)
audio_segments.append(wav[s:min(e, len(wav))])
from speaker_encoder import extract_speaker_embeddings_batch, normalize_embeddings
embeddings = extract_speaker_embeddings_batch(
self.speaker_encoder, audio_segments, sample_rate
)
embeddings = normalize_embeddings(embeddings)
print(f" Embeddings: {embeddings.shape}")
print(f" Step 4 time: {time.time() - t4:.2f}s")
# ── Step 5: AgglomerativeClustering ──
print("\n[Step 5] Speaker clustering...")
t5 = time.time()
from speaker_cluster_fixed import robust_speaker_clustering
speaker_labels, estimated_n_speakers = robust_speaker_clustering(
embeddings, n_speakers=None, max_speakers=max_speakers
)
print(f" Speakers: {estimated_n_speakers}")
print(f" Step 5 time: {time.time() - t5:.2f}s")
# 品質計算
qualities = compute_embedding_quality(embeddings, speaker_labels)
# 建立輸出 segments
segments = []
for i, ((start_sec, end_sec), label) in enumerate(
zip(refined_segments, speaker_labels)):
seg = {
"start": round(start_sec, 3),
"end": round(end_sec, 3),
"start_frame": int(start_sec * 30),
"end_frame": int(end_sec * 30),
"text": segment_texts[i]["text"],
"language": segment_texts[i]["language"],
"lang_prob": segment_texts[i]["lang_prob"],
"speaker": f"SPEAKER_{int(label)}",
"speaker_id": f"SPEAKER_{int(label)}",
"quality": float(qualities[i]),
}
segments.append(seg)
# 統計
speaker_stats = {}
for seg in segments:
spk = seg["speaker_id"]
dur = seg["end"] - seg["start"]
if spk not in speaker_stats:
speaker_stats[spk] = {"count": 0, "duration": 0}
speaker_stats[spk]["count"] += 1
speaker_stats[spk]["duration"] += dur
result = {
"language": language or "",
"segments": segments,
"n_speakers": int(estimated_n_speakers),
"speaker_stats": speaker_stats,
"total_duration": total_duration,
"n_segments": len(segments),
}
# ── Step 6: Store embeddings in Qdrant ──
if file_uuid:
print("\n[Step 6] Storing embeddings in Qdrant...")
t6 = time.time()
self._store_speaker_embeddings(segments, embeddings, speaker_labels,
file_uuid)
print(f" Step 6 time: {time.time() - t6:.2f}s")
# ── Step 7: High-quality classification ──
if file_uuid:
print("\n[Step 7] Classifying high-quality embeddings...")
t7 = time.time()
references = self._classify_high_quality_speakers(
segments, embeddings, speaker_labels, file_uuid,
wav, sample_rate, quality_threshold
)
if references:
result["references"] = references
print(f" Step 7 time: {time.time() - t7:.2f}s")
total_time = time.time() - start_time
result["processing_time"] = round(total_time, 2)
if total_duration > 0:
result["realtime_factor"] = round(total_duration / total_time, 2)
# 保存輸出
if output_path:
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
with open(output_path, "w", encoding="utf-8") as f:
json.dump(result, f, indent=2, ensure_ascii=False)
print(f"\n[SelfASRX] Saved to: {output_path}")
print(f"\n[SelfASRX] Done! {len(segments)} segments, "
f"{estimated_n_speakers} speakers, "
f"{total_time:.2f}s")
return result
# ── Internal helpers ──
def _vad_scan_segment(self, wav, sample_rate, start_sec, end_sec):
"""VAD 細切單一段落"""
from vad import scan_within_segment
return scan_within_segment(
wav, sample_rate, start_sec, end_sec,
self.vad_model, self.vad_utils
)
def _transcribe_segment(self, wav, sample_rate, start_sec, end_sec):
"""轉錄單一段落"""
from whisper_local import transcribe_segment
return transcribe_segment(wav, sample_rate, start_sec, end_sec, self.whisper)
def _store_speaker_embeddings(self, segments, embeddings, labels, file_uuid):
"""Step 6: 所有 embedding 存入 Qdrant"""
if not self._ensure_qdrant():
return
points = []
for i, (seg, emb, label) in enumerate(
zip(segments, embeddings, labels)):
point_id = _hash_point_id(file_uuid, f"{i}")
points.append({
"id": point_id,
"vector": emb.tolist(),
"payload": {
"type": "speaker_embedding",
"file_uuid": file_uuid,
"speaker_id": seg["speaker_id"],
"text": seg["text"],
"language": seg["language"],
"start_time": seg["start"],
"end_time": seg["end"],
}
})
ok = _qdrant_upsert(self.qdrant_url, self.qdrant_api_key,
self.qdrant_collection, points)
if ok:
print(f" Stored {len(points)} speaker embeddings to Qdrant")
return ok
def _classify_high_quality_speakers(self, segments, embeddings, labels,
file_uuid, wav, sample_rate,
threshold=0.85):
"""Step 7: 高品質聲紋分級 + 性別分類 → Qdrant reference"""
qualities = compute_embedding_quality(embeddings, labels)
high_mask = qualities >= threshold
if not np.any(high_mask):
print(" No high-quality embeddings found")
return []
unique_labels = set(labels)
references = []
for label in unique_labels:
mask = (labels == label) & high_mask
if not np.any(mask):
continue
high_indices = [i for i in range(len(segments)) if mask[i]]
high_segs = [segments[i] for i in high_indices]
# 取品質最高的 segment index
best_idx = high_indices[int(np.argmax(qualities[mask]))]
best_seg = segments[best_idx]
centroid = np.mean(embeddings[mask], axis=0)
norm = np.linalg.norm(centroid)
if norm > 0:
centroid = centroid / norm
avg_quality = float(np.mean(qualities[mask]))
speaker_id = f"SPEAKER_{int(label)}"
text_samples = [s["text"] for s in high_segs[:5] if s["text"]]
total_dur = sum(s["end"] - s["start"] for s in high_segs)
ref_id = _hash_point_id(file_uuid, f"ref_{label}")
ref_payload = {
"type": "speaker_reference",
"file_uuid": file_uuid,
"speaker_id": speaker_id,
"n_segments": int(np.sum(mask)),
"avg_quality": avg_quality,
"total_duration": round(total_dur, 2),
"language": best_seg.get("language", ""),
"text_samples": text_samples,
}
# 性別分類:用最佳 segment 的音頻
if self.gender_classifier is not None:
try:
import torch
s = int(best_seg["start"] * sample_rate)
e = int(best_seg["end"] * sample_rate)
seg_wav = wav[s:min(e, len(wav))]
seg_tensor = torch.from_numpy(seg_wav).float().unsqueeze(0)
# SpeechBrain gender classifier 接受音頻
out = self.gender_classifier.classify_batch(seg_tensor)
probs = torch.softmax(out[0], dim=-1).squeeze().cpu().detach().numpy()
if len(probs) >= 2:
idx = int(np.argmax(probs))
ref_payload["gender"] = "male" if idx == 0 else "female"
ref_payload["gender_conf"] = float(probs[idx])
else:
ref_payload["gender"] = "unknown"
ref_payload["gender_conf"] = 0.0
except Exception as e:
print(f"[Gender] Classify error: {e}")
ref_payload["gender"] = "unknown"
ref_payload["gender_conf"] = 0.0
else:
ref_payload["gender"] = "unknown"
ref_payload["gender_conf"] = 0.0
_qdrant_upsert(self.qdrant_url, self.qdrant_api_key,
self.qdrant_collection, [{
"id": ref_id,
"vector": centroid.tolist(),
"payload": ref_payload,
}])
references.append({
"speaker_id": speaker_id,
"n_segments": int(np.sum(mask)),
"avg_quality": avg_quality,
"gender": ref_payload["gender"],
})
print(f" Ref: {speaker_id}, gender={ref_payload['gender']}"
f" ({ref_payload['gender_conf']:.2f}), q={avg_quality:.3f}")
return references
def _ensure_qdrant(self):
"""確保 Qdrant collection 可用"""
if not self._qdrant_ok:
ok = _ensure_speaker_collection(
self.qdrant_url, self.qdrant_api_key, self.qdrant_collection
)
self._qdrant_ok = ok
return self._qdrant_ok
def main():
import argparse
parser = argparse.ArgumentParser(description="Self-implemented ASRX (Fixed)")
parser.add_argument("audio_path", help="Path to audio file")
parser = argparse.ArgumentParser(description="SelfASRX - Hybrid Speaker Diarization")
parser.add_argument("audio_path", help="Path to audio file (WAV)")
parser.add_argument("-o", "--output", help="Output JSON path")
parser.add_argument("--min-speech-duration", type=int, default=500)
parser.add_argument("--n-speakers", type=int, default=None)
parser.add_argument("--file-uuid", help="File UUID for Qdrant storage")
parser.add_argument("--max-speakers", type=int, default=10)
parser.add_argument("--quality-threshold", type=float, default=0.85)
parser.add_argument("--resume", help="Checkpoint path to resume from")
parser.add_argument("--checkpoint", help="Save checkpoint path after Step 3")
args = parser.parse_args()
if not Path(args.audio_path).exists():
print(f"Error: Audio file not found: {args.audio_path}")
sys.exit(1)
asrx = SelfASRXFixed()
result = asrx.process(
args.audio_path,
args.output,
min_speech_duration_ms=args.min_speech_duration,
n_speakers=args.n_speakers,
max_speakers=args.max_speakers
)
if args.resume:
if not Path(args.resume).exists():
print(f"Error: Checkpoint not found: {args.resume}")
sys.exit(1)
result = asrx.resume_from_checkpoint(
args.resume, args.audio_path,
output_path=args.output,
)
else:
if not Path(args.audio_path).exists():
print(f"Error: Audio file not found: {args.audio_path}")
sys.exit(1)
result = asrx.process(
args.audio_path,
output_path=args.output,
file_uuid=args.file_uuid,
max_speakers=args.max_speakers,
quality_threshold=args.quality_threshold,
checkpoint_path=args.checkpoint,
)
if "error" not in result:
print("\n[Summary]")
print(f" Audio duration: {result['total_duration']:.2f}s")
print(f" Speech segments: {result['n_speech_segments']}")
print(f" Detected speakers: {result['n_speakers']}")
print(f" Processing time: {result['processing_time']:.2f}s")
print(f" Realtime factor: {result['realtime_factor']:.2f}x")
print("\n[Speaker Statistics]")
for speaker, stats in result['speaker_stats'].items():
pct = stats['duration'] / result['total_duration'] * 100
print(f" {speaker}: {stats['count']} segments, " +
f"{stats['duration']:.2f}s ({pct:.1f}%)")
print(f" Duration: {result['total_duration']:.2f}s")
print(f" Segments: {result['n_segments']}")
print(f" Speakers: {result['n_speakers']}")
if "references" in result:
for ref in result["references"]:
print(f" {ref['speaker_id']}: gender={ref['gender']}, "
f"quality={ref['avg_quality']:.3f}")
if __name__ == "__main__":