""" SelfASRXFixed - 7 步 Hybrid Speaker Diarization Pipeline Pipeline: 1. whisper.transcribe(full_audio) → rough segments + text + language 2. VAD scan each rough segment → refined segments 3. whisper per refined segment → {text, language, lang_prob} 4. ECAPA-TDNN per refined segment → 192-dim embeddings 5. AgglomerativeClustering → speaker_labels 6. Store all embeddings in Qdrant (payload: file_uuid, speaker_id, text, ...) 7. High-quality embeddings → gender classify + store reference in Qdrant """ import sys import json import time import os import numpy as np from pathlib import Path from urllib.request import Request, urlopen from urllib.error import URLError def _load_audio(path): """載入音頻文件,回傳 (wav_numpy, sample_rate)""" import soundfile as sf wav, sr = sf.read(path) if len(wav.shape) > 1: wav = np.mean(wav, axis=1) return wav, sr def _load_whisper_model(size="small"): from whisper_local import load_model return load_model(size) def _load_vad(): from vad import load_vad_model return load_vad_model() def _load_speaker_encoder(): from speaker_encoder import load_speaker_encoder return load_speaker_encoder() def _load_gender_classifier(): try: from speechbrain.inference.classifiers import EncoderClassifier classifier = EncoderClassifier.from_hparams( source="speechbrain/gender-recognition-ecapa", run_opts={"device": "cpu"}, ) print("[Gender] Classifier loaded: speechbrain/gender-recognition-ecapa") return classifier except Exception as e: print(f"[Gender] Classifier not available: {e}") return None def _ensure_speaker_collection(qdrant_url, api_key, collection): """確認 Qdrant speaker collection 存在,不存在則建立 (dim=192, cosine)""" try: url = f"{qdrant_url}/collections/{collection}" req = Request(url, method="GET", headers={"api-key": api_key} if api_key else {}) try: urlopen(req) return True except URLError as e: if getattr(e, "code", None) == 404: body = json.dumps({ "vectors": { "size": 192, "distance": "Cosine" } }).encode() req = Request(url, data=body, method="PUT", headers={"Content-Type": "application/json", **({"api-key": api_key} if api_key else {})}) urlopen(req) print(f"[Qdrant] Created collection: {collection} (dim=192)") return True raise except Exception as e: print(f"[Qdrant] Cannot access Qdrant: {e}") return False def _qdrant_upsert(qdrant_url, api_key, collection, points): """批量寫入 Qdrant points""" try: url = f"{qdrant_url}/collections/{collection}/points?wait=true" body = json.dumps({"points": points}).encode() headers = {"Content-Type": "application/json"} if api_key: headers["api-key"] = api_key req = Request(url, data=body, headers=headers, method="PUT") urlopen(req) return True except Exception as e: print(f"[Qdrant] Upsert failed: {e}") return False def _hash_point_id(file_uuid, label): """產生一致的 point ID""" s = f"{file_uuid}_{label}" return hash(s) & 0x7FFFFFFFFFFFFFFF def _save_checkpoint(path: str, data: dict): """原子寫入 checkpoint(先 .tmp 再 rename)""" tmp = path + ".tmp" Path(tmp).parent.mkdir(parents=True, exist_ok=True) with open(tmp, "w", encoding="utf-8") as f: json.dump(data, f, indent=2, ensure_ascii=False) os.replace(tmp, path) def compute_embedding_quality(embeddings, labels): """每個 embedding 到所屬 cluster centroid 的餘弦相似度""" from sklearn.metrics.pairwise import cosine_similarity unique_labels = set(labels) centroids = {} for label in unique_labels: mask = labels == label centroid = np.mean(embeddings[mask], axis=0) norm = np.linalg.norm(centroid) if norm > 0: centroid = centroid / norm centroids[label] = centroid qualities = [] for emb, label in zip(embeddings, labels): sim = cosine_similarity([emb], [centroids[label]])[0][0] qualities.append(sim) return np.array(qualities) class SelfASRXFixed: """7 步 Hybrid Speaker Diarization Pipeline""" def __init__(self): print("[SelfASRX] Initializing models...") print("[SelfASRX] Loading whisper model...") self.whisper = _load_whisper_model("small") print("[SelfASRX] Loading VAD model (Silero)...") self.vad_model, self.vad_utils = _load_vad() print("[SelfASRX] Loading speaker encoder (ECAPA-TDNN)...") self.speaker_encoder = _load_speaker_encoder() print("[SelfASRX] Loading gender classifier...") self.gender_classifier = _load_gender_classifier() # Qdrant 設定 self.qdrant_url = os.environ.get("QDRANT_URL", "http://localhost:6333") self.qdrant_api_key = os.environ.get("QDRANT_API_KEY", "") schema = os.environ.get("DATABASE_SCHEMA", "public") self.qdrant_collection = os.environ.get( "QDRANT_SPEAKER_COLLECTION", f"momentry_{schema}_speaker" ) self._qdrant_ok = False print("[SelfASRX] Models loaded successfully") def process(self, audio_path, output_path=None, file_uuid=None, max_speakers=10, quality_threshold=0.85, checkpoint_path=None): """7 步 speaker diarization pipeline Args: audio_path: 音頻文件路徑 (WAV 16kHz mono) output_path: 輸出 JSON 路徑 (可選) file_uuid: 檔案 UUID (用於 Qdrant 儲存) max_speakers: 最大說話人數 quality_threshold: 高品質聲紋門檻 (0-1) checkpoint_path: Step 3 完成後儲存 checkpoint 路徑 Returns: dict: segments, speaker_stats, n_speakers, total_duration, references """ start_time = time.time() print(f"\n[SelfASRX] Processing: {audio_path}") print("=" * 60) # 載入音頻 wav, sample_rate = _load_audio(audio_path) total_duration = len(wav) / sample_rate print(f" Audio: {total_duration:.2f}s, {sample_rate}Hz") # ── Step 1: whisper 粗略定位 (faster-whisper) ── print("\n[Step 1] Initial whisper transcription...") t1 = time.time() seg_gen, info = self.whisper.transcribe(audio_path) rough_segments = [] for seg in seg_gen: rough_segments.append({"start": seg.start, "end": seg.end, "text": seg.text}) language = info.language if info else None print(f" Rough segments: {len(rough_segments)}") print(f" Language: {language}") print(f" Step 1 time: {time.time() - t1:.2f}s") if not rough_segments: print("[SelfASRX] No speech detected by whisper!") return {"error": "No speech detected", "segments": []} # ── Step 2: VAD scan 每個 rough segment 細切 ── print("\n[Step 2] VAD scan for refined segmentation...") t2 = time.time() refined_segments = [] for seg in rough_segments: s = seg["start"] e = seg["end"] sub = self._vad_scan_segment(wav, sample_rate, s, e) if sub: refined_segments.extend(sub) else: refined_segments.append((s, e)) print(f" Refined segments: {len(refined_segments)}") print(f" Step 2 time: {time.time() - t2:.2f}s") if not refined_segments: return {"error": "No segments after VAD scan", "segments": []} # ── Step 3: whisper per refined segment ── print("\n[Step 3] Per-segment transcription...") t3 = time.time() CHECKPOINT_INTERVAL = 50 segment_texts = [] resume_from = 0 # 載入既有 partial checkpoint(中斷續接) if checkpoint_path and os.path.exists(checkpoint_path): try: with open(checkpoint_path, "r") as f: cp = json.load(f) if cp.get("checkpoint_version") == 2 and not cp.get("step3_completed"): saved = cp.get("segment_texts", []) if saved: resume_from = len(saved) segment_texts = saved print(f"[Step 3] Resuming from #{resume_from}/{len(refined_segments)}") except Exception: pass for i, (start_sec, end_sec) in enumerate(refined_segments): if i < resume_from: continue seg_text = self._transcribe_segment(wav, sample_rate, start_sec, end_sec) segment_texts.append(seg_text) if checkpoint_path and (i + 1) % CHECKPOINT_INTERVAL == 0: _save_checkpoint(checkpoint_path, { "checkpoint_version": 2, "step3_completed": False, "step3_progress": i + 1, "language": language, "total_duration": total_duration, "refined_segments": [[s, e] for s, e in refined_segments], "segment_texts": [{ "text": st["text"], "language": st["language"], "lang_prob": st["lang_prob"], } for st in segment_texts], "file_uuid": file_uuid, "max_speakers": max_speakers, "quality_threshold": quality_threshold, }) print(f"[Checkpoint] Step 3: {i+1}/{len(refined_segments)}") print(f" Step 3 time: {time.time() - t3:.2f}s") # ── Save final checkpoint after Step 3 ── if checkpoint_path: _save_checkpoint(checkpoint_path, { "checkpoint_version": 2, "step3_completed": True, "language": language, "total_duration": total_duration, "refined_segments": [[s, e] for s, e in refined_segments], "segment_texts": [{ "text": st["text"], "language": st["language"], "lang_prob": st["lang_prob"], } for st in segment_texts], "file_uuid": file_uuid, "max_speakers": max_speakers, "quality_threshold": quality_threshold, }) print(f"[Checkpoint] Step 3 complete, saved to {checkpoint_path}") # ── Step 4: ECAPA-TDNN per refined segment ── print("\n[Step 4] Speaker embedding extraction...") t4 = time.time() audio_segments = [] for start_sec, end_sec in refined_segments: s = int(start_sec * sample_rate) e = int(end_sec * sample_rate) audio_segments.append(wav[s:min(e, len(wav))]) from speaker_encoder import extract_speaker_embeddings_batch, normalize_embeddings embeddings = extract_speaker_embeddings_batch( self.speaker_encoder, audio_segments, sample_rate ) embeddings = normalize_embeddings(embeddings) print(f" Embeddings: {embeddings.shape}") print(f" Step 4 time: {time.time() - t4:.2f}s") # ── Step 5: AgglomerativeClustering ── print("\n[Step 5] Speaker clustering...") t5 = time.time() from speaker_cluster_fixed import robust_speaker_clustering speaker_labels, estimated_n_speakers = robust_speaker_clustering( embeddings, n_speakers=None, max_speakers=max_speakers ) print(f" Speakers: {estimated_n_speakers}") print(f" Step 5 time: {time.time() - t5:.2f}s") # 品質計算 qualities = compute_embedding_quality(embeddings, speaker_labels) # 建立輸出 segments segments = [] for i, ((start_sec, end_sec), label) in enumerate( zip(refined_segments, speaker_labels)): seg = { "start": round(start_sec, 3), "end": round(end_sec, 3), "start_frame": int(start_sec * 30), "end_frame": int(end_sec * 30), "text": segment_texts[i]["text"], "language": segment_texts[i]["language"], "lang_prob": segment_texts[i]["lang_prob"], "speaker": f"SPEAKER_{int(label)}", "speaker_id": f"SPEAKER_{int(label)}", "quality": float(qualities[i]), } segments.append(seg) # 統計 speaker_stats = {} for seg in segments: spk = seg["speaker_id"] dur = seg["end"] - seg["start"] if spk not in speaker_stats: speaker_stats[spk] = {"count": 0, "duration": 0} speaker_stats[spk]["count"] += 1 speaker_stats[spk]["duration"] += dur result = { "language": language or "", "segments": segments, "n_speakers": int(estimated_n_speakers), "speaker_stats": speaker_stats, "total_duration": total_duration, "n_segments": len(segments), } # ── Step 6: Store embeddings in Qdrant ── if file_uuid: print("\n[Step 6] Storing embeddings in Qdrant...") t6 = time.time() self._store_speaker_embeddings(segments, embeddings, speaker_labels, file_uuid) print(f" Step 6 time: {time.time() - t6:.2f}s") # ── Step 7: High-quality classification ── if file_uuid: print("\n[Step 7] Classifying high-quality embeddings...") t7 = time.time() references = self._classify_high_quality_speakers( segments, embeddings, speaker_labels, file_uuid, wav, sample_rate, quality_threshold ) if references: result["references"] = references print(f" Step 7 time: {time.time() - t7:.2f}s") total_time = time.time() - start_time result["processing_time"] = round(total_time, 2) if total_duration > 0: result["realtime_factor"] = round(total_duration / total_time, 2) # 保存輸出 if output_path: Path(output_path).parent.mkdir(parents=True, exist_ok=True) with open(output_path, "w", encoding="utf-8") as f: json.dump(result, f, indent=2, ensure_ascii=False) print(f"\n[SelfASRX] Saved to: {output_path}") print(f"\n[SelfASRX] Done! {len(segments)} segments, " f"{estimated_n_speakers} speakers, " f"{total_time:.2f}s") return result def resume_from_checkpoint(self, checkpoint_path, audio_path, output_path=None): """從 checkpoint 載入 Steps 1-3 結果,執行 Steps 4-7""" print(f"\n[SelfASRX] Resuming from checkpoint: {checkpoint_path}") print("=" * 60) with open(checkpoint_path, "r", encoding="utf-8") as f: cp = json.load(f) if not cp.get("step3_completed"): error_msg = f"Checkpoint step3 not completed (progress: {cp.get('step3_progress', '?')})" print(f"[SelfASRX] {error_msg}") return {"error": error_msg, "segments": []} wav, sample_rate = _load_audio(audio_path) refined_segments = [tuple(s) for s in cp["refined_segments"]] segment_texts = cp["segment_texts"] language = cp.get("language", "") total_duration = cp.get("total_duration", 0) file_uuid = cp.get("file_uuid") max_speakers = cp.get("max_speakers", 10) quality_threshold = cp.get("quality_threshold", 0.85) print(f" Loaded checkpoint: {len(refined_segments)} segments, " f"language={language}, duration={total_duration:.2f}s") start_time = time.time() # ── Step 4: ECAPA-TDNN per refined segment ── print("\n[Step 4] Speaker embedding extraction...") t4 = time.time() audio_segments = [] for start_sec, end_sec in refined_segments: s = int(start_sec * sample_rate) e = int(end_sec * sample_rate) audio_segments.append(wav[s:min(e, len(wav))]) from speaker_encoder import extract_speaker_embeddings_batch, normalize_embeddings embeddings = extract_speaker_embeddings_batch( self.speaker_encoder, audio_segments, sample_rate ) embeddings = normalize_embeddings(embeddings) print(f" Embeddings: {embeddings.shape}") print(f" Step 4 time: {time.time() - t4:.2f}s") # ── Step 5: AgglomerativeClustering ── print("\n[Step 5] Speaker clustering...") t5 = time.time() from speaker_cluster_fixed import robust_speaker_clustering speaker_labels, estimated_n_speakers = robust_speaker_clustering( embeddings, n_speakers=None, max_speakers=max_speakers ) print(f" Speakers: {estimated_n_speakers}") print(f" Step 5 time: {time.time() - t5:.2f}s") # 品質計算 qualities = compute_embedding_quality(embeddings, speaker_labels) # 建立輸出 segments segments = [] for i, ((start_sec, end_sec), label) in enumerate( zip(refined_segments, speaker_labels)): seg = { "start": round(start_sec, 3), "end": round(end_sec, 3), "start_frame": int(start_sec * 30), "end_frame": int(end_sec * 30), "text": segment_texts[i]["text"], "language": segment_texts[i]["language"], "lang_prob": segment_texts[i]["lang_prob"], "speaker": f"SPEAKER_{int(label)}", "speaker_id": f"SPEAKER_{int(label)}", "quality": float(qualities[i]), } segments.append(seg) # 統計 speaker_stats = {} for seg in segments: spk = seg["speaker_id"] dur = seg["end"] - seg["start"] if spk not in speaker_stats: speaker_stats[spk] = {"count": 0, "duration": 0} speaker_stats[spk]["count"] += 1 speaker_stats[spk]["duration"] += dur result = { "language": language or "", "segments": segments, "n_speakers": int(estimated_n_speakers), "speaker_stats": speaker_stats, "total_duration": total_duration, "n_segments": len(segments), } # ── Step 6: Store embeddings in Qdrant ── if file_uuid: print("\n[Step 6] Storing embeddings in Qdrant...") t6 = time.time() self._store_speaker_embeddings(segments, embeddings, speaker_labels, file_uuid) print(f" Step 6 time: {time.time() - t6:.2f}s") # ── Step 7: High-quality classification ── if file_uuid: print("\n[Step 7] Classifying high-quality embeddings...") t7 = time.time() references = self._classify_high_quality_speakers( segments, embeddings, speaker_labels, file_uuid, wav, sample_rate, quality_threshold ) if references: result["references"] = references print(f" Step 7 time: {time.time() - t7:.2f}s") total_time = time.time() - start_time result["processing_time"] = round(total_time, 2) if total_duration > 0: result["realtime_factor"] = round(total_duration / total_time, 2) # 保存輸出 if output_path: Path(output_path).parent.mkdir(parents=True, exist_ok=True) with open(output_path, "w", encoding="utf-8") as f: json.dump(result, f, indent=2, ensure_ascii=False) print(f"\n[SelfASRX] Saved to: {output_path}") print(f"\n[SelfASRX] Done! {len(segments)} segments, " f"{estimated_n_speakers} speakers, " f"{total_time:.2f}s") return result # ── Internal helpers ── def _vad_scan_segment(self, wav, sample_rate, start_sec, end_sec): """VAD 細切單一段落""" from vad import scan_within_segment return scan_within_segment( wav, sample_rate, start_sec, end_sec, self.vad_model, self.vad_utils ) def _transcribe_segment(self, wav, sample_rate, start_sec, end_sec): """轉錄單一段落""" from whisper_local import transcribe_segment return transcribe_segment(wav, sample_rate, start_sec, end_sec, self.whisper) def _store_speaker_embeddings(self, segments, embeddings, labels, file_uuid): """Step 6: 所有 embedding 存入 Qdrant""" if not self._ensure_qdrant(): return points = [] for i, (seg, emb, label) in enumerate( zip(segments, embeddings, labels)): point_id = _hash_point_id(file_uuid, f"{i}") points.append({ "id": point_id, "vector": emb.tolist(), "payload": { "type": "speaker_embedding", "file_uuid": file_uuid, "speaker_id": seg["speaker_id"], "text": seg["text"], "language": seg["language"], "start_time": seg["start"], "end_time": seg["end"], } }) ok = _qdrant_upsert(self.qdrant_url, self.qdrant_api_key, self.qdrant_collection, points) if ok: print(f" Stored {len(points)} speaker embeddings to Qdrant") return ok def _classify_high_quality_speakers(self, segments, embeddings, labels, file_uuid, wav, sample_rate, threshold=0.85): """Step 7: 高品質聲紋分級 + 性別分類 → Qdrant reference""" qualities = compute_embedding_quality(embeddings, labels) high_mask = qualities >= threshold if not np.any(high_mask): print(" No high-quality embeddings found") return [] unique_labels = set(labels) references = [] for label in unique_labels: mask = (labels == label) & high_mask if not np.any(mask): continue high_indices = [i for i in range(len(segments)) if mask[i]] high_segs = [segments[i] for i in high_indices] # 取品質最高的 segment index best_idx = high_indices[int(np.argmax(qualities[mask]))] best_seg = segments[best_idx] centroid = np.mean(embeddings[mask], axis=0) norm = np.linalg.norm(centroid) if norm > 0: centroid = centroid / norm avg_quality = float(np.mean(qualities[mask])) speaker_id = f"SPEAKER_{int(label)}" text_samples = [s["text"] for s in high_segs[:5] if s["text"]] total_dur = sum(s["end"] - s["start"] for s in high_segs) ref_id = _hash_point_id(file_uuid, f"ref_{label}") ref_payload = { "type": "speaker_reference", "file_uuid": file_uuid, "speaker_id": speaker_id, "n_segments": int(np.sum(mask)), "avg_quality": avg_quality, "total_duration": round(total_dur, 2), "language": best_seg.get("language", ""), "text_samples": text_samples, } # 性別分類:用最佳 segment 的音頻 if self.gender_classifier is not None: try: import torch s = int(best_seg["start"] * sample_rate) e = int(best_seg["end"] * sample_rate) seg_wav = wav[s:min(e, len(wav))] seg_tensor = torch.from_numpy(seg_wav).float().unsqueeze(0) # SpeechBrain gender classifier 接受音頻 out = self.gender_classifier.classify_batch(seg_tensor) probs = torch.softmax(out[0], dim=-1).squeeze().cpu().detach().numpy() if len(probs) >= 2: idx = int(np.argmax(probs)) ref_payload["gender"] = "male" if idx == 0 else "female" ref_payload["gender_conf"] = float(probs[idx]) else: ref_payload["gender"] = "unknown" ref_payload["gender_conf"] = 0.0 except Exception as e: print(f"[Gender] Classify error: {e}") ref_payload["gender"] = "unknown" ref_payload["gender_conf"] = 0.0 else: ref_payload["gender"] = "unknown" ref_payload["gender_conf"] = 0.0 _qdrant_upsert(self.qdrant_url, self.qdrant_api_key, self.qdrant_collection, [{ "id": ref_id, "vector": centroid.tolist(), "payload": ref_payload, }]) references.append({ "speaker_id": speaker_id, "n_segments": int(np.sum(mask)), "avg_quality": avg_quality, "gender": ref_payload["gender"], }) print(f" Ref: {speaker_id}, gender={ref_payload['gender']}" f" ({ref_payload['gender_conf']:.2f}), q={avg_quality:.3f}") return references def _ensure_qdrant(self): """確保 Qdrant collection 可用""" if not self._qdrant_ok: ok = _ensure_speaker_collection( self.qdrant_url, self.qdrant_api_key, self.qdrant_collection ) self._qdrant_ok = ok return self._qdrant_ok def main(): import argparse parser = argparse.ArgumentParser(description="SelfASRX - Hybrid Speaker Diarization") parser.add_argument("audio_path", help="Path to audio file (WAV)") parser.add_argument("-o", "--output", help="Output JSON path") parser.add_argument("--file-uuid", help="File UUID for Qdrant storage") parser.add_argument("--max-speakers", type=int, default=10) parser.add_argument("--quality-threshold", type=float, default=0.85) parser.add_argument("--resume", help="Checkpoint path to resume from") parser.add_argument("--checkpoint", help="Save checkpoint path after Step 3") args = parser.parse_args() asrx = SelfASRXFixed() if args.resume: if not Path(args.resume).exists(): print(f"Error: Checkpoint not found: {args.resume}") sys.exit(1) result = asrx.resume_from_checkpoint( args.resume, args.audio_path, output_path=args.output, ) else: if not Path(args.audio_path).exists(): print(f"Error: Audio file not found: {args.audio_path}") sys.exit(1) result = asrx.process( args.audio_path, output_path=args.output, file_uuid=args.file_uuid, max_speakers=args.max_speakers, quality_threshold=args.quality_threshold, checkpoint_path=args.checkpoint, ) if "error" not in result: print("\n[Summary]") print(f" Duration: {result['total_duration']:.2f}s") print(f" Segments: {result['n_segments']}") print(f" Speakers: {result['n_speakers']}") if "references" in result: for ref in result["references"]: print(f" {ref['speaker_id']}: gender={ref['gender']}, " f"quality={ref['avg_quality']:.3f}") if __name__ == "__main__": main()