feat: ASRX hybrid pipeline, identity history, worker fixes, checkpoint system
This commit is contained in:
@@ -1,308 +1,728 @@
|
||||
#!/opt/homebrew/bin/python3.11
|
||||
"""
|
||||
Self-implemented ASRX - Fixed Version
|
||||
使用魯棒的聚類算法
|
||||
SelfASRXFixed - 7 步 Hybrid Speaker Diarization Pipeline
|
||||
|
||||
Pipeline:
|
||||
1. whisper.transcribe(full_audio) → rough segments + text + language
|
||||
2. VAD scan each rough segment → refined segments
|
||||
3. whisper per refined segment → {text, language, lang_prob}
|
||||
4. ECAPA-TDNN per refined segment → 192-dim embeddings
|
||||
5. AgglomerativeClustering → speaker_labels
|
||||
6. Store all embeddings in Qdrant (payload: file_uuid, speaker_id, text, ...)
|
||||
7. High-quality embeddings → gender classify + store reference in Qdrant
|
||||
"""
|
||||
|
||||
import sys
|
||||
import json
|
||||
import time
|
||||
import os
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from urllib.request import Request, urlopen
|
||||
from urllib.error import URLError
|
||||
|
||||
# 導入自定義模組
|
||||
from vad import load_vad_model, extract_speech_segments
|
||||
from speaker_encoder import (
|
||||
load_speaker_encoder,
|
||||
extract_speaker_embeddings_batch,
|
||||
normalize_embeddings
|
||||
)
|
||||
from speaker_cluster_fixed import robust_speaker_clustering
|
||||
|
||||
def _load_audio(path):
|
||||
"""載入音頻文件,回傳 (wav_numpy, sample_rate)"""
|
||||
import soundfile as sf
|
||||
wav, sr = sf.read(path)
|
||||
if len(wav.shape) > 1:
|
||||
wav = np.mean(wav, axis=1)
|
||||
return wav, sr
|
||||
|
||||
|
||||
def _load_whisper_model(size="small"):
|
||||
from whisper_local import load_model
|
||||
return load_model(size)
|
||||
|
||||
|
||||
def _load_vad():
|
||||
from vad import load_vad_model
|
||||
return load_vad_model()
|
||||
|
||||
|
||||
def _load_speaker_encoder():
|
||||
from speaker_encoder import load_speaker_encoder
|
||||
return load_speaker_encoder()
|
||||
|
||||
|
||||
def _load_gender_classifier():
|
||||
try:
|
||||
from speechbrain.inference.classifiers import EncoderClassifier
|
||||
classifier = EncoderClassifier.from_hparams(
|
||||
source="speechbrain/gender-recognition-ecapa",
|
||||
run_opts={"device": "cpu"},
|
||||
)
|
||||
print("[Gender] Classifier loaded: speechbrain/gender-recognition-ecapa")
|
||||
return classifier
|
||||
except Exception as e:
|
||||
print(f"[Gender] Classifier not available: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def _ensure_speaker_collection(qdrant_url, api_key, collection):
|
||||
"""確認 Qdrant speaker collection 存在,不存在則建立 (dim=192, cosine)"""
|
||||
try:
|
||||
url = f"{qdrant_url}/collections/{collection}"
|
||||
req = Request(url, method="GET",
|
||||
headers={"api-key": api_key} if api_key else {})
|
||||
try:
|
||||
urlopen(req)
|
||||
return True
|
||||
except URLError as e:
|
||||
if getattr(e, "code", None) == 404:
|
||||
body = json.dumps({
|
||||
"vectors": {
|
||||
"size": 192,
|
||||
"distance": "Cosine"
|
||||
}
|
||||
}).encode()
|
||||
req = Request(url, data=body, method="PUT",
|
||||
headers={"Content-Type": "application/json",
|
||||
**({"api-key": api_key} if api_key else {})})
|
||||
urlopen(req)
|
||||
print(f"[Qdrant] Created collection: {collection} (dim=192)")
|
||||
return True
|
||||
raise
|
||||
except Exception as e:
|
||||
print(f"[Qdrant] Cannot access Qdrant: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def _qdrant_upsert(qdrant_url, api_key, collection, points):
|
||||
"""批量寫入 Qdrant points"""
|
||||
try:
|
||||
url = f"{qdrant_url}/collections/{collection}/points?wait=true"
|
||||
body = json.dumps({"points": points}).encode()
|
||||
headers = {"Content-Type": "application/json"}
|
||||
if api_key:
|
||||
headers["api-key"] = api_key
|
||||
req = Request(url, data=body, headers=headers, method="PUT")
|
||||
urlopen(req)
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"[Qdrant] Upsert failed: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def _hash_point_id(file_uuid, label):
|
||||
"""產生一致的 point ID"""
|
||||
s = f"{file_uuid}_{label}"
|
||||
return hash(s) & 0x7FFFFFFFFFFFFFFF
|
||||
|
||||
|
||||
def _save_checkpoint(path: str, data: dict):
|
||||
"""原子寫入 checkpoint(先 .tmp 再 rename)"""
|
||||
tmp = path + ".tmp"
|
||||
Path(tmp).parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(tmp, "w", encoding="utf-8") as f:
|
||||
json.dump(data, f, indent=2, ensure_ascii=False)
|
||||
os.replace(tmp, path)
|
||||
|
||||
|
||||
def compute_embedding_quality(embeddings, labels):
|
||||
"""每個 embedding 到所屬 cluster centroid 的餘弦相似度"""
|
||||
from sklearn.metrics.pairwise import cosine_similarity
|
||||
unique_labels = set(labels)
|
||||
centroids = {}
|
||||
for label in unique_labels:
|
||||
mask = labels == label
|
||||
centroid = np.mean(embeddings[mask], axis=0)
|
||||
norm = np.linalg.norm(centroid)
|
||||
if norm > 0:
|
||||
centroid = centroid / norm
|
||||
centroids[label] = centroid
|
||||
qualities = []
|
||||
for emb, label in zip(embeddings, labels):
|
||||
sim = cosine_similarity([emb], [centroids[label]])[0][0]
|
||||
qualities.append(sim)
|
||||
return np.array(qualities)
|
||||
|
||||
|
||||
class SelfASRXFixed:
|
||||
"""自實作說話人分離系統(修復版)"""
|
||||
|
||||
"""7 步 Hybrid Speaker Diarization Pipeline"""
|
||||
|
||||
def __init__(self):
|
||||
print("[SelfASRX-Fixed] Initializing models...")
|
||||
|
||||
# 載入 VAD 模型
|
||||
print("[SelfASRX-Fixed] Loading VAD model (Silero)...")
|
||||
self.vad_model, self.vad_utils = load_vad_model()
|
||||
|
||||
# 載入聲紋模型
|
||||
print("[SelfASRX-Fixed] Loading speaker encoder (ECAPA-TDNN)...")
|
||||
self.speaker_encoder = load_speaker_encoder()
|
||||
|
||||
print("[SelfASRX-Fixed] Models loaded successfully")
|
||||
|
||||
def process(self, audio_path, output_path=None,
|
||||
min_speech_duration_ms=500,
|
||||
n_speakers=None,
|
||||
max_speakers=10):
|
||||
"""處理音頻文件"""
|
||||
start_time = time.time()
|
||||
print(f"\n[SelfASRX-Fixed] Processing: {audio_path}")
|
||||
print("=" * 60)
|
||||
|
||||
# 步驟 1: VAD
|
||||
print("\n[Step 1] Voice Activity Detection...")
|
||||
step1_start = time.time()
|
||||
|
||||
speech_segments, wav, sample_rate = extract_speech_segments(
|
||||
audio_path, self.vad_model, self.vad_utils,
|
||||
min_speech_duration_ms=min_speech_duration_ms
|
||||
)
|
||||
|
||||
step1_time = time.time() - step1_start
|
||||
print(f" Speech segments: {len(speech_segments)}")
|
||||
print(f" Total duration: {len(wav)/sample_rate:.2f}s")
|
||||
print(f" VAD time: {step1_time:.2f}s")
|
||||
|
||||
if len(speech_segments) == 0:
|
||||
print("[SelfASRX-Fixed] No speech detected!")
|
||||
return {"error": "No speech detected", "segments": []}
|
||||
|
||||
# 步驟 2: 聲紋特徵提取
|
||||
print("\n[Step 2] Speaker embedding extraction...")
|
||||
step2_start = time.time()
|
||||
|
||||
# 提取語音片段音頻
|
||||
audio_segments = []
|
||||
for start_sec, end_sec in speech_segments:
|
||||
start_sample = int(start_sec * sample_rate)
|
||||
end_sample = int(end_sec * sample_rate)
|
||||
audio_segments.append(wav[start_sample:end_sample])
|
||||
|
||||
# 批量提取嵌入
|
||||
embeddings = extract_speaker_embeddings_batch(
|
||||
self.speaker_encoder, audio_segments, sample_rate
|
||||
)
|
||||
|
||||
# 正規化
|
||||
embeddings = normalize_embeddings(embeddings)
|
||||
|
||||
step2_time = time.time() - step2_start
|
||||
print(f" Embedding shape: {embeddings.shape}")
|
||||
print(f" Embedding time: {step2_time:.2f}s")
|
||||
|
||||
# 步驟 3: 魯棒聚類
|
||||
print("\n[Step 3] Robust speaker clustering...")
|
||||
step3_start = time.time()
|
||||
|
||||
speaker_labels, estimated_n_speakers = robust_speaker_clustering(
|
||||
embeddings,
|
||||
n_speakers=n_speakers,
|
||||
max_speakers=max_speakers
|
||||
)
|
||||
|
||||
step3_time = time.time() - step3_start
|
||||
print(f" Clustering time: {step3_time:.2f}s")
|
||||
|
||||
# 步驟 4: 建立輸出
|
||||
print("\n[Step 4] Building output...")
|
||||
|
||||
result = {
|
||||
"audio_path": str(audio_path),
|
||||
"total_duration": len(wav) / sample_rate,
|
||||
"n_speech_segments": len(speech_segments),
|
||||
"n_speakers": int(estimated_n_speakers),
|
||||
"segments": []
|
||||
}
|
||||
|
||||
for i, ((start, end), label) in enumerate(zip(speech_segments, speaker_labels)):
|
||||
result["segments"].append({
|
||||
"index": i,
|
||||
"start": round(start, 3),
|
||||
"end": round(end, 3),
|
||||
"duration": round(end - start, 3),
|
||||
"speaker": f"SPEAKER_{int(label)}"
|
||||
})
|
||||
|
||||
# 統計每個說話人的總時長
|
||||
speaker_stats = {}
|
||||
for seg in result["segments"]:
|
||||
speaker = seg["speaker"]
|
||||
if speaker not in speaker_stats:
|
||||
speaker_stats[speaker] = {"count": 0, "duration": 0}
|
||||
speaker_stats[speaker]["count"] += 1
|
||||
speaker_stats[speaker]["duration"] += seg["duration"]
|
||||
|
||||
result["speaker_stats"] = speaker_stats
|
||||
|
||||
total_time = time.time() - start_time
|
||||
result["processing_time"] = round(total_time, 2)
|
||||
result["realtime_factor"] = round(result["total_duration"] / total_time, 2)
|
||||
|
||||
print("\n[SelfASRX-Fixed] Processing completed!")
|
||||
print(f" Total time: {total_time:.2f}s")
|
||||
print(f" Realtime factor: {result['realtime_factor']:.2f}x")
|
||||
print(f" Detected speakers: {estimated_n_speakers}")
|
||||
|
||||
# 保存結果
|
||||
if output_path:
|
||||
output_path = Path(output_path)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with open(output_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(result, f, indent=2, ensure_ascii=False)
|
||||
|
||||
print(f" Results saved to: {output_path}")
|
||||
|
||||
print("=" * 60)
|
||||
|
||||
return result
|
||||
print("[SelfASRX] Initializing models...")
|
||||
|
||||
print("[SelfASRX] Loading whisper model...")
|
||||
self.whisper = _load_whisper_model("small")
|
||||
|
||||
print("[SelfASRX] Loading VAD model (Silero)...")
|
||||
self.vad_model, self.vad_utils = _load_vad()
|
||||
|
||||
print("[SelfASRX] Loading speaker encoder (ECAPA-TDNN)...")
|
||||
self.speaker_encoder = _load_speaker_encoder()
|
||||
|
||||
print("[SelfASRX] Loading gender classifier...")
|
||||
self.gender_classifier = _load_gender_classifier()
|
||||
|
||||
# Qdrant 設定
|
||||
self.qdrant_url = os.environ.get("QDRANT_URL", "http://localhost:6333")
|
||||
self.qdrant_api_key = os.environ.get("QDRANT_API_KEY", "")
|
||||
schema = os.environ.get("DATABASE_SCHEMA", "public")
|
||||
self.qdrant_collection = os.environ.get(
|
||||
"QDRANT_SPEAKER_COLLECTION",
|
||||
f"momentry_{schema}_speaker"
|
||||
)
|
||||
self._qdrant_ok = False
|
||||
|
||||
print("[SelfASRX] Models loaded successfully")
|
||||
|
||||
def process(self, audio_path, output_path=None, file_uuid=None,
|
||||
max_speakers=10, quality_threshold=0.85,
|
||||
checkpoint_path=None):
|
||||
"""7 步 speaker diarization pipeline
|
||||
|
||||
def process_with_segments(self, audio_path, asr_segments, output_path=None):
|
||||
"""
|
||||
使用 ASR segment 邊界進行 speaker diarization,取代 VAD 步驟。
|
||||
|
||||
Args:
|
||||
audio_path: 音頻文件路徑(WAV)
|
||||
asr_segments: ASR segment 列表,每個包含 start/end(秒)
|
||||
output_path: 輸出 JSON 路徑(可選)
|
||||
audio_path: 音頻文件路徑 (WAV 16kHz mono)
|
||||
output_path: 輸出 JSON 路徑 (可選)
|
||||
file_uuid: 檔案 UUID (用於 Qdrant 儲存)
|
||||
max_speakers: 最大說話人數
|
||||
quality_threshold: 高品質聲紋門檻 (0-1)
|
||||
checkpoint_path: Step 3 完成後儲存 checkpoint 路徑
|
||||
|
||||
Returns:
|
||||
dict: segments, speaker_stats, n_speakers, total_duration, references
|
||||
"""
|
||||
start_time = time.time()
|
||||
print(f"\n[SelfASRX-Fixed] Processing with {len(asr_segments)} ASR segments: {audio_path}")
|
||||
print(f"\n[SelfASRX] Processing: {audio_path}")
|
||||
print("=" * 60)
|
||||
|
||||
# 載入完整音頻
|
||||
import soundfile as sf
|
||||
wav, sample_rate = sf.read(audio_path)
|
||||
if len(wav.shape) > 1:
|
||||
wav = np.mean(wav, axis=1) # 轉 mono
|
||||
print(f" Audio loaded: {len(wav)/sample_rate:.2f}s, {sample_rate}Hz")
|
||||
# 載入音頻
|
||||
wav, sample_rate = _load_audio(audio_path)
|
||||
total_duration = len(wav) / sample_rate
|
||||
print(f" Audio: {total_duration:.2f}s, {sample_rate}Hz")
|
||||
|
||||
# 使用 ASR segments 取代 VAD (audio处理用time)
|
||||
speech_segments = [(s["start_time"], s["end_time"]) for s in asr_segments]
|
||||
print(f" Speech segments from ASR: {len(speech_segments)}")
|
||||
# ── Step 1: whisper 粗略定位 (faster-whisper) ──
|
||||
print("\n[Step 1] Initial whisper transcription...")
|
||||
t1 = time.time()
|
||||
seg_gen, info = self.whisper.transcribe(audio_path)
|
||||
rough_segments = []
|
||||
for seg in seg_gen:
|
||||
rough_segments.append({"start": seg.start, "end": seg.end, "text": seg.text})
|
||||
language = info.language if info else None
|
||||
print(f" Rough segments: {len(rough_segments)}")
|
||||
print(f" Language: {language}")
|
||||
print(f" Step 1 time: {time.time() - t1:.2f}s")
|
||||
|
||||
if len(speech_segments) == 0:
|
||||
print("[SelfASRX-Fixed] No ASR segments provided!")
|
||||
return {"error": "No ASR segments", "segments": []}
|
||||
if not rough_segments:
|
||||
print("[SelfASRX] No speech detected by whisper!")
|
||||
return {"error": "No speech detected", "segments": []}
|
||||
|
||||
# 提取語音片段
|
||||
audio_segments = []
|
||||
for start_sec, end_sec in speech_segments:
|
||||
start_sample = int(start_sec * sample_rate)
|
||||
end_sample = int(end_sec * sample_rate)
|
||||
if start_sample >= len(wav):
|
||||
# ── Step 2: VAD scan 每個 rough segment 細切 ──
|
||||
print("\n[Step 2] VAD scan for refined segmentation...")
|
||||
t2 = time.time()
|
||||
refined_segments = []
|
||||
for seg in rough_segments:
|
||||
s = seg["start"]
|
||||
e = seg["end"]
|
||||
sub = self._vad_scan_segment(wav, sample_rate, s, e)
|
||||
if sub:
|
||||
refined_segments.extend(sub)
|
||||
else:
|
||||
refined_segments.append((s, e))
|
||||
print(f" Refined segments: {len(refined_segments)}")
|
||||
print(f" Step 2 time: {time.time() - t2:.2f}s")
|
||||
|
||||
if not refined_segments:
|
||||
return {"error": "No segments after VAD scan", "segments": []}
|
||||
|
||||
# ── Step 3: whisper per refined segment ──
|
||||
print("\n[Step 3] Per-segment transcription...")
|
||||
t3 = time.time()
|
||||
CHECKPOINT_INTERVAL = 50
|
||||
|
||||
segment_texts = []
|
||||
resume_from = 0
|
||||
|
||||
# 載入既有 partial checkpoint(中斷續接)
|
||||
if checkpoint_path and os.path.exists(checkpoint_path):
|
||||
try:
|
||||
with open(checkpoint_path, "r") as f:
|
||||
cp = json.load(f)
|
||||
if cp.get("checkpoint_version") == 2 and not cp.get("step3_completed"):
|
||||
saved = cp.get("segment_texts", [])
|
||||
if saved:
|
||||
resume_from = len(saved)
|
||||
segment_texts = saved
|
||||
print(f"[Step 3] Resuming from #{resume_from}/{len(refined_segments)}")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
for i, (start_sec, end_sec) in enumerate(refined_segments):
|
||||
if i < resume_from:
|
||||
continue
|
||||
audio_segments.append(wav[start_sample:min(end_sample, len(wav))])
|
||||
seg_text = self._transcribe_segment(wav, sample_rate, start_sec, end_sec)
|
||||
segment_texts.append(seg_text)
|
||||
|
||||
print(f" Audio segments extracted: {len(audio_segments)}")
|
||||
if checkpoint_path and (i + 1) % CHECKPOINT_INTERVAL == 0:
|
||||
_save_checkpoint(checkpoint_path, {
|
||||
"checkpoint_version": 2,
|
||||
"step3_completed": False,
|
||||
"step3_progress": i + 1,
|
||||
"language": language,
|
||||
"total_duration": total_duration,
|
||||
"refined_segments": [[s, e] for s, e in refined_segments],
|
||||
"segment_texts": [{
|
||||
"text": st["text"],
|
||||
"language": st["language"],
|
||||
"lang_prob": st["lang_prob"],
|
||||
} for st in segment_texts],
|
||||
"file_uuid": file_uuid,
|
||||
"max_speakers": max_speakers,
|
||||
"quality_threshold": quality_threshold,
|
||||
})
|
||||
print(f"[Checkpoint] Step 3: {i+1}/{len(refined_segments)}")
|
||||
|
||||
# 批量提取聲紋嵌入
|
||||
print("\n[Step 2] Speaker embedding extraction...")
|
||||
step2_start = time.time()
|
||||
print(f" Step 3 time: {time.time() - t3:.2f}s")
|
||||
|
||||
# ── Save final checkpoint after Step 3 ──
|
||||
if checkpoint_path:
|
||||
_save_checkpoint(checkpoint_path, {
|
||||
"checkpoint_version": 2,
|
||||
"step3_completed": True,
|
||||
"language": language,
|
||||
"total_duration": total_duration,
|
||||
"refined_segments": [[s, e] for s, e in refined_segments],
|
||||
"segment_texts": [{
|
||||
"text": st["text"],
|
||||
"language": st["language"],
|
||||
"lang_prob": st["lang_prob"],
|
||||
} for st in segment_texts],
|
||||
"file_uuid": file_uuid,
|
||||
"max_speakers": max_speakers,
|
||||
"quality_threshold": quality_threshold,
|
||||
})
|
||||
print(f"[Checkpoint] Step 3 complete, saved to {checkpoint_path}")
|
||||
|
||||
# ── Step 4: ECAPA-TDNN per refined segment ──
|
||||
print("\n[Step 4] Speaker embedding extraction...")
|
||||
t4 = time.time()
|
||||
audio_segments = []
|
||||
for start_sec, end_sec in refined_segments:
|
||||
s = int(start_sec * sample_rate)
|
||||
e = int(end_sec * sample_rate)
|
||||
audio_segments.append(wav[s:min(e, len(wav))])
|
||||
|
||||
from speaker_encoder import extract_speaker_embeddings_batch, normalize_embeddings
|
||||
embeddings = extract_speaker_embeddings_batch(
|
||||
self.speaker_encoder, audio_segments, sample_rate
|
||||
)
|
||||
embeddings = normalize_embeddings(embeddings)
|
||||
step2_time = time.time() - step2_start
|
||||
print(f" Embedding shape: {embeddings.shape}")
|
||||
print(f" Embedding time: {step2_time:.2f}s")
|
||||
print(f" Embeddings: {embeddings.shape}")
|
||||
print(f" Step 4 time: {time.time() - t4:.2f}s")
|
||||
|
||||
# 聚類
|
||||
print("\n[Step 3] Robust speaker clustering...")
|
||||
step3_start = time.time()
|
||||
# ── Step 5: AgglomerativeClustering ──
|
||||
print("\n[Step 5] Speaker clustering...")
|
||||
t5 = time.time()
|
||||
from speaker_cluster_fixed import robust_speaker_clustering
|
||||
speaker_labels, estimated_n_speakers = robust_speaker_clustering(
|
||||
embeddings, n_speakers=None, max_speakers=10
|
||||
embeddings, n_speakers=None, max_speakers=max_speakers
|
||||
)
|
||||
step3_time = time.time() - step3_start
|
||||
print(f" Clustering time: {step3_time:.2f}s")
|
||||
print(f" Speakers: {estimated_n_speakers}")
|
||||
print(f" Step 5 time: {time.time() - t5:.2f}s")
|
||||
|
||||
# 建立輸出
|
||||
result = {
|
||||
"audio_path": str(audio_path),
|
||||
"total_duration": len(wav) / sample_rate,
|
||||
"n_speech_segments": len(speech_segments),
|
||||
"n_speakers": int(estimated_n_speakers),
|
||||
"segments": []
|
||||
}
|
||||
# 品質計算
|
||||
qualities = compute_embedding_quality(embeddings, speaker_labels)
|
||||
|
||||
for i, ((start, end), label) in enumerate(zip(speech_segments, speaker_labels)):
|
||||
result["segments"].append({
|
||||
"index": i,
|
||||
"start": round(start, 3),
|
||||
"end": round(end, 3),
|
||||
"duration": round(end - start, 3),
|
||||
"speaker": f"SPEAKER_{int(label)}"
|
||||
})
|
||||
|
||||
# 加入 embeddings(每個 segment 對應的 192-D speaker embedding)
|
||||
result["embeddings"] = []
|
||||
for emb in embeddings:
|
||||
result["embeddings"].append(emb.tolist())
|
||||
# 建立輸出 segments
|
||||
segments = []
|
||||
for i, ((start_sec, end_sec), label) in enumerate(
|
||||
zip(refined_segments, speaker_labels)):
|
||||
seg = {
|
||||
"start": round(start_sec, 3),
|
||||
"end": round(end_sec, 3),
|
||||
"start_frame": int(start_sec * 30),
|
||||
"end_frame": int(end_sec * 30),
|
||||
"text": segment_texts[i]["text"],
|
||||
"language": segment_texts[i]["language"],
|
||||
"lang_prob": segment_texts[i]["lang_prob"],
|
||||
"speaker": f"SPEAKER_{int(label)}",
|
||||
"speaker_id": f"SPEAKER_{int(label)}",
|
||||
"quality": float(qualities[i]),
|
||||
}
|
||||
segments.append(seg)
|
||||
|
||||
# 統計
|
||||
speaker_stats = {}
|
||||
for seg in result["segments"]:
|
||||
speaker = seg["speaker"]
|
||||
if speaker not in speaker_stats:
|
||||
speaker_stats[speaker] = {"count": 0, "duration": 0}
|
||||
speaker_stats[speaker]["count"] += 1
|
||||
speaker_stats[speaker]["duration"] += seg["duration"]
|
||||
result["speaker_stats"] = speaker_stats
|
||||
for seg in segments:
|
||||
spk = seg["speaker_id"]
|
||||
dur = seg["end"] - seg["start"]
|
||||
if spk not in speaker_stats:
|
||||
speaker_stats[spk] = {"count": 0, "duration": 0}
|
||||
speaker_stats[spk]["count"] += 1
|
||||
speaker_stats[spk]["duration"] += dur
|
||||
|
||||
result = {
|
||||
"language": language or "",
|
||||
"segments": segments,
|
||||
"n_speakers": int(estimated_n_speakers),
|
||||
"speaker_stats": speaker_stats,
|
||||
"total_duration": total_duration,
|
||||
"n_segments": len(segments),
|
||||
}
|
||||
|
||||
# ── Step 6: Store embeddings in Qdrant ──
|
||||
if file_uuid:
|
||||
print("\n[Step 6] Storing embeddings in Qdrant...")
|
||||
t6 = time.time()
|
||||
self._store_speaker_embeddings(segments, embeddings, speaker_labels,
|
||||
file_uuid)
|
||||
print(f" Step 6 time: {time.time() - t6:.2f}s")
|
||||
|
||||
# ── Step 7: High-quality classification ──
|
||||
if file_uuid:
|
||||
print("\n[Step 7] Classifying high-quality embeddings...")
|
||||
t7 = time.time()
|
||||
references = self._classify_high_quality_speakers(
|
||||
segments, embeddings, speaker_labels, file_uuid,
|
||||
wav, sample_rate, quality_threshold
|
||||
)
|
||||
if references:
|
||||
result["references"] = references
|
||||
print(f" Step 7 time: {time.time() - t7:.2f}s")
|
||||
|
||||
total_time = time.time() - start_time
|
||||
result["processing_time"] = round(total_time, 2)
|
||||
result["realtime_factor"] = round(result["total_duration"] / total_time, 2)
|
||||
|
||||
print("\n[SelfASRX-Fixed] Processing completed!")
|
||||
print(f" Total time: {total_time:.2f}s")
|
||||
print(f" Realtime factor: {result['realtime_factor']:.2f}x")
|
||||
print(f" Detected speakers: {estimated_n_speakers}")
|
||||
if total_duration > 0:
|
||||
result["realtime_factor"] = round(total_duration / total_time, 2)
|
||||
|
||||
# 保存輸出
|
||||
if output_path:
|
||||
import json
|
||||
with open(output_path, 'w', encoding='utf-8') as f:
|
||||
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(output_path, "w", encoding="utf-8") as f:
|
||||
json.dump(result, f, indent=2, ensure_ascii=False)
|
||||
print(f" Results saved to: {output_path}")
|
||||
print(f"\n[SelfASRX] Saved to: {output_path}")
|
||||
|
||||
print(f"\n[SelfASRX] Done! {len(segments)} segments, "
|
||||
f"{estimated_n_speakers} speakers, "
|
||||
f"{total_time:.2f}s")
|
||||
|
||||
print("=" * 60)
|
||||
return result
|
||||
|
||||
def resume_from_checkpoint(self, checkpoint_path, audio_path,
|
||||
output_path=None):
|
||||
"""從 checkpoint 載入 Steps 1-3 結果,執行 Steps 4-7"""
|
||||
print(f"\n[SelfASRX] Resuming from checkpoint: {checkpoint_path}")
|
||||
print("=" * 60)
|
||||
|
||||
with open(checkpoint_path, "r", encoding="utf-8") as f:
|
||||
cp = json.load(f)
|
||||
|
||||
if not cp.get("step3_completed"):
|
||||
error_msg = f"Checkpoint step3 not completed (progress: {cp.get('step3_progress', '?')})"
|
||||
print(f"[SelfASRX] {error_msg}")
|
||||
return {"error": error_msg, "segments": []}
|
||||
|
||||
wav, sample_rate = _load_audio(audio_path)
|
||||
refined_segments = [tuple(s) for s in cp["refined_segments"]]
|
||||
segment_texts = cp["segment_texts"]
|
||||
language = cp.get("language", "")
|
||||
total_duration = cp.get("total_duration", 0)
|
||||
file_uuid = cp.get("file_uuid")
|
||||
max_speakers = cp.get("max_speakers", 10)
|
||||
quality_threshold = cp.get("quality_threshold", 0.85)
|
||||
|
||||
print(f" Loaded checkpoint: {len(refined_segments)} segments, "
|
||||
f"language={language}, duration={total_duration:.2f}s")
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
# ── Step 4: ECAPA-TDNN per refined segment ──
|
||||
print("\n[Step 4] Speaker embedding extraction...")
|
||||
t4 = time.time()
|
||||
audio_segments = []
|
||||
for start_sec, end_sec in refined_segments:
|
||||
s = int(start_sec * sample_rate)
|
||||
e = int(end_sec * sample_rate)
|
||||
audio_segments.append(wav[s:min(e, len(wav))])
|
||||
|
||||
from speaker_encoder import extract_speaker_embeddings_batch, normalize_embeddings
|
||||
embeddings = extract_speaker_embeddings_batch(
|
||||
self.speaker_encoder, audio_segments, sample_rate
|
||||
)
|
||||
embeddings = normalize_embeddings(embeddings)
|
||||
print(f" Embeddings: {embeddings.shape}")
|
||||
print(f" Step 4 time: {time.time() - t4:.2f}s")
|
||||
|
||||
# ── Step 5: AgglomerativeClustering ──
|
||||
print("\n[Step 5] Speaker clustering...")
|
||||
t5 = time.time()
|
||||
from speaker_cluster_fixed import robust_speaker_clustering
|
||||
speaker_labels, estimated_n_speakers = robust_speaker_clustering(
|
||||
embeddings, n_speakers=None, max_speakers=max_speakers
|
||||
)
|
||||
print(f" Speakers: {estimated_n_speakers}")
|
||||
print(f" Step 5 time: {time.time() - t5:.2f}s")
|
||||
|
||||
# 品質計算
|
||||
qualities = compute_embedding_quality(embeddings, speaker_labels)
|
||||
|
||||
# 建立輸出 segments
|
||||
segments = []
|
||||
for i, ((start_sec, end_sec), label) in enumerate(
|
||||
zip(refined_segments, speaker_labels)):
|
||||
seg = {
|
||||
"start": round(start_sec, 3),
|
||||
"end": round(end_sec, 3),
|
||||
"start_frame": int(start_sec * 30),
|
||||
"end_frame": int(end_sec * 30),
|
||||
"text": segment_texts[i]["text"],
|
||||
"language": segment_texts[i]["language"],
|
||||
"lang_prob": segment_texts[i]["lang_prob"],
|
||||
"speaker": f"SPEAKER_{int(label)}",
|
||||
"speaker_id": f"SPEAKER_{int(label)}",
|
||||
"quality": float(qualities[i]),
|
||||
}
|
||||
segments.append(seg)
|
||||
|
||||
# 統計
|
||||
speaker_stats = {}
|
||||
for seg in segments:
|
||||
spk = seg["speaker_id"]
|
||||
dur = seg["end"] - seg["start"]
|
||||
if spk not in speaker_stats:
|
||||
speaker_stats[spk] = {"count": 0, "duration": 0}
|
||||
speaker_stats[spk]["count"] += 1
|
||||
speaker_stats[spk]["duration"] += dur
|
||||
|
||||
result = {
|
||||
"language": language or "",
|
||||
"segments": segments,
|
||||
"n_speakers": int(estimated_n_speakers),
|
||||
"speaker_stats": speaker_stats,
|
||||
"total_duration": total_duration,
|
||||
"n_segments": len(segments),
|
||||
}
|
||||
|
||||
# ── Step 6: Store embeddings in Qdrant ──
|
||||
if file_uuid:
|
||||
print("\n[Step 6] Storing embeddings in Qdrant...")
|
||||
t6 = time.time()
|
||||
self._store_speaker_embeddings(segments, embeddings, speaker_labels,
|
||||
file_uuid)
|
||||
print(f" Step 6 time: {time.time() - t6:.2f}s")
|
||||
|
||||
# ── Step 7: High-quality classification ──
|
||||
if file_uuid:
|
||||
print("\n[Step 7] Classifying high-quality embeddings...")
|
||||
t7 = time.time()
|
||||
references = self._classify_high_quality_speakers(
|
||||
segments, embeddings, speaker_labels, file_uuid,
|
||||
wav, sample_rate, quality_threshold
|
||||
)
|
||||
if references:
|
||||
result["references"] = references
|
||||
print(f" Step 7 time: {time.time() - t7:.2f}s")
|
||||
|
||||
total_time = time.time() - start_time
|
||||
result["processing_time"] = round(total_time, 2)
|
||||
if total_duration > 0:
|
||||
result["realtime_factor"] = round(total_duration / total_time, 2)
|
||||
|
||||
# 保存輸出
|
||||
if output_path:
|
||||
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(output_path, "w", encoding="utf-8") as f:
|
||||
json.dump(result, f, indent=2, ensure_ascii=False)
|
||||
print(f"\n[SelfASRX] Saved to: {output_path}")
|
||||
|
||||
print(f"\n[SelfASRX] Done! {len(segments)} segments, "
|
||||
f"{estimated_n_speakers} speakers, "
|
||||
f"{total_time:.2f}s")
|
||||
|
||||
return result
|
||||
|
||||
# ── Internal helpers ──
|
||||
|
||||
def _vad_scan_segment(self, wav, sample_rate, start_sec, end_sec):
|
||||
"""VAD 細切單一段落"""
|
||||
from vad import scan_within_segment
|
||||
return scan_within_segment(
|
||||
wav, sample_rate, start_sec, end_sec,
|
||||
self.vad_model, self.vad_utils
|
||||
)
|
||||
|
||||
def _transcribe_segment(self, wav, sample_rate, start_sec, end_sec):
|
||||
"""轉錄單一段落"""
|
||||
from whisper_local import transcribe_segment
|
||||
return transcribe_segment(wav, sample_rate, start_sec, end_sec, self.whisper)
|
||||
|
||||
def _store_speaker_embeddings(self, segments, embeddings, labels, file_uuid):
|
||||
"""Step 6: 所有 embedding 存入 Qdrant"""
|
||||
if not self._ensure_qdrant():
|
||||
return
|
||||
|
||||
points = []
|
||||
for i, (seg, emb, label) in enumerate(
|
||||
zip(segments, embeddings, labels)):
|
||||
point_id = _hash_point_id(file_uuid, f"{i}")
|
||||
points.append({
|
||||
"id": point_id,
|
||||
"vector": emb.tolist(),
|
||||
"payload": {
|
||||
"type": "speaker_embedding",
|
||||
"file_uuid": file_uuid,
|
||||
"speaker_id": seg["speaker_id"],
|
||||
"text": seg["text"],
|
||||
"language": seg["language"],
|
||||
"start_time": seg["start"],
|
||||
"end_time": seg["end"],
|
||||
}
|
||||
})
|
||||
|
||||
ok = _qdrant_upsert(self.qdrant_url, self.qdrant_api_key,
|
||||
self.qdrant_collection, points)
|
||||
if ok:
|
||||
print(f" Stored {len(points)} speaker embeddings to Qdrant")
|
||||
return ok
|
||||
|
||||
def _classify_high_quality_speakers(self, segments, embeddings, labels,
|
||||
file_uuid, wav, sample_rate,
|
||||
threshold=0.85):
|
||||
"""Step 7: 高品質聲紋分級 + 性別分類 → Qdrant reference"""
|
||||
qualities = compute_embedding_quality(embeddings, labels)
|
||||
high_mask = qualities >= threshold
|
||||
|
||||
if not np.any(high_mask):
|
||||
print(" No high-quality embeddings found")
|
||||
return []
|
||||
|
||||
unique_labels = set(labels)
|
||||
references = []
|
||||
for label in unique_labels:
|
||||
mask = (labels == label) & high_mask
|
||||
if not np.any(mask):
|
||||
continue
|
||||
high_indices = [i for i in range(len(segments)) if mask[i]]
|
||||
high_segs = [segments[i] for i in high_indices]
|
||||
|
||||
# 取品質最高的 segment index
|
||||
best_idx = high_indices[int(np.argmax(qualities[mask]))]
|
||||
best_seg = segments[best_idx]
|
||||
|
||||
centroid = np.mean(embeddings[mask], axis=0)
|
||||
norm = np.linalg.norm(centroid)
|
||||
if norm > 0:
|
||||
centroid = centroid / norm
|
||||
|
||||
avg_quality = float(np.mean(qualities[mask]))
|
||||
speaker_id = f"SPEAKER_{int(label)}"
|
||||
text_samples = [s["text"] for s in high_segs[:5] if s["text"]]
|
||||
total_dur = sum(s["end"] - s["start"] for s in high_segs)
|
||||
|
||||
ref_id = _hash_point_id(file_uuid, f"ref_{label}")
|
||||
ref_payload = {
|
||||
"type": "speaker_reference",
|
||||
"file_uuid": file_uuid,
|
||||
"speaker_id": speaker_id,
|
||||
"n_segments": int(np.sum(mask)),
|
||||
"avg_quality": avg_quality,
|
||||
"total_duration": round(total_dur, 2),
|
||||
"language": best_seg.get("language", ""),
|
||||
"text_samples": text_samples,
|
||||
}
|
||||
|
||||
# 性別分類:用最佳 segment 的音頻
|
||||
if self.gender_classifier is not None:
|
||||
try:
|
||||
import torch
|
||||
s = int(best_seg["start"] * sample_rate)
|
||||
e = int(best_seg["end"] * sample_rate)
|
||||
seg_wav = wav[s:min(e, len(wav))]
|
||||
seg_tensor = torch.from_numpy(seg_wav).float().unsqueeze(0)
|
||||
# SpeechBrain gender classifier 接受音頻
|
||||
out = self.gender_classifier.classify_batch(seg_tensor)
|
||||
probs = torch.softmax(out[0], dim=-1).squeeze().cpu().detach().numpy()
|
||||
if len(probs) >= 2:
|
||||
idx = int(np.argmax(probs))
|
||||
ref_payload["gender"] = "male" if idx == 0 else "female"
|
||||
ref_payload["gender_conf"] = float(probs[idx])
|
||||
else:
|
||||
ref_payload["gender"] = "unknown"
|
||||
ref_payload["gender_conf"] = 0.0
|
||||
except Exception as e:
|
||||
print(f"[Gender] Classify error: {e}")
|
||||
ref_payload["gender"] = "unknown"
|
||||
ref_payload["gender_conf"] = 0.0
|
||||
else:
|
||||
ref_payload["gender"] = "unknown"
|
||||
ref_payload["gender_conf"] = 0.0
|
||||
|
||||
_qdrant_upsert(self.qdrant_url, self.qdrant_api_key,
|
||||
self.qdrant_collection, [{
|
||||
"id": ref_id,
|
||||
"vector": centroid.tolist(),
|
||||
"payload": ref_payload,
|
||||
}])
|
||||
|
||||
references.append({
|
||||
"speaker_id": speaker_id,
|
||||
"n_segments": int(np.sum(mask)),
|
||||
"avg_quality": avg_quality,
|
||||
"gender": ref_payload["gender"],
|
||||
})
|
||||
|
||||
print(f" Ref: {speaker_id}, gender={ref_payload['gender']}"
|
||||
f" ({ref_payload['gender_conf']:.2f}), q={avg_quality:.3f}")
|
||||
|
||||
return references
|
||||
|
||||
def _ensure_qdrant(self):
|
||||
"""確保 Qdrant collection 可用"""
|
||||
if not self._qdrant_ok:
|
||||
ok = _ensure_speaker_collection(
|
||||
self.qdrant_url, self.qdrant_api_key, self.qdrant_collection
|
||||
)
|
||||
self._qdrant_ok = ok
|
||||
return self._qdrant_ok
|
||||
|
||||
|
||||
def main():
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description="Self-implemented ASRX (Fixed)")
|
||||
parser.add_argument("audio_path", help="Path to audio file")
|
||||
parser = argparse.ArgumentParser(description="SelfASRX - Hybrid Speaker Diarization")
|
||||
parser.add_argument("audio_path", help="Path to audio file (WAV)")
|
||||
parser.add_argument("-o", "--output", help="Output JSON path")
|
||||
parser.add_argument("--min-speech-duration", type=int, default=500)
|
||||
parser.add_argument("--n-speakers", type=int, default=None)
|
||||
parser.add_argument("--file-uuid", help="File UUID for Qdrant storage")
|
||||
parser.add_argument("--max-speakers", type=int, default=10)
|
||||
|
||||
parser.add_argument("--quality-threshold", type=float, default=0.85)
|
||||
parser.add_argument("--resume", help="Checkpoint path to resume from")
|
||||
parser.add_argument("--checkpoint", help="Save checkpoint path after Step 3")
|
||||
args = parser.parse_args()
|
||||
|
||||
if not Path(args.audio_path).exists():
|
||||
print(f"Error: Audio file not found: {args.audio_path}")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
asrx = SelfASRXFixed()
|
||||
result = asrx.process(
|
||||
args.audio_path,
|
||||
args.output,
|
||||
min_speech_duration_ms=args.min_speech_duration,
|
||||
n_speakers=args.n_speakers,
|
||||
max_speakers=args.max_speakers
|
||||
)
|
||||
|
||||
|
||||
if args.resume:
|
||||
if not Path(args.resume).exists():
|
||||
print(f"Error: Checkpoint not found: {args.resume}")
|
||||
sys.exit(1)
|
||||
result = asrx.resume_from_checkpoint(
|
||||
args.resume, args.audio_path,
|
||||
output_path=args.output,
|
||||
)
|
||||
else:
|
||||
if not Path(args.audio_path).exists():
|
||||
print(f"Error: Audio file not found: {args.audio_path}")
|
||||
sys.exit(1)
|
||||
|
||||
result = asrx.process(
|
||||
args.audio_path,
|
||||
output_path=args.output,
|
||||
file_uuid=args.file_uuid,
|
||||
max_speakers=args.max_speakers,
|
||||
quality_threshold=args.quality_threshold,
|
||||
checkpoint_path=args.checkpoint,
|
||||
)
|
||||
|
||||
if "error" not in result:
|
||||
print("\n[Summary]")
|
||||
print(f" Audio duration: {result['total_duration']:.2f}s")
|
||||
print(f" Speech segments: {result['n_speech_segments']}")
|
||||
print(f" Detected speakers: {result['n_speakers']}")
|
||||
print(f" Processing time: {result['processing_time']:.2f}s")
|
||||
print(f" Realtime factor: {result['realtime_factor']:.2f}x")
|
||||
|
||||
print("\n[Speaker Statistics]")
|
||||
for speaker, stats in result['speaker_stats'].items():
|
||||
pct = stats['duration'] / result['total_duration'] * 100
|
||||
print(f" {speaker}: {stats['count']} segments, " +
|
||||
f"{stats['duration']:.2f}s ({pct:.1f}%)")
|
||||
print(f" Duration: {result['total_duration']:.2f}s")
|
||||
print(f" Segments: {result['n_segments']}")
|
||||
print(f" Speakers: {result['n_speakers']}")
|
||||
if "references" in result:
|
||||
for ref in result["references"]:
|
||||
print(f" {ref['speaker_id']}: gender={ref['gender']}, "
|
||||
f"quality={ref['avg_quality']:.3f}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user