#!/opt/homebrew/bin/python3.11 """ ASRX Processor - Hybrid Pipeline Wrapper Pipeline: 1. ffprobe → select best audio track → ffmpeg → 16kHz mono WAV 2. SelfASRXFixed.process() (7-step hybrid speaker diarization) 3. Convert to Rust-expected format """ import sys import json import argparse import os import subprocess import tempfile from pathlib import Path sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) sys.path.insert( 0, os.path.join(os.path.dirname(os.path.abspath(__file__)), "asrx_self") ) from redis_publisher import RedisPublisher def probe_audio_tracks(video_path: str) -> list: """ffprobe 列出所有音軌""" cmd = [ "ffprobe", "-v", "quiet", "-print_format", "json", "-show_streams", "-select_streams", "a", video_path, ] try: result = subprocess.run(cmd, capture_output=True, text=True, timeout=30) data = json.loads(result.stdout) tracks = [] for stream in data.get("streams", []): tracks.append({ "index": stream.get("index"), "codec": stream.get("codec_name"), "language": stream.get("tags", {}).get("language", "und"), "channels": stream.get("channels", 0), "sample_rate": stream.get("sample_rate", "0"), }) return tracks except Exception as e: print(f"[ASRX] ffprobe failed: {e}") return [] def select_best_track(tracks: list) -> int: """選最佳音軌: English > 最多channels > 0""" if not tracks: return 0 for i, t in enumerate(tracks): if t["language"] in ("eng", "en"): return i best = 0 for i, t in enumerate(tracks): if t["channels"] > tracks[best]["channels"]: best = i return best def extract_audio_to_wav(video_path: str, track_index: int, output_wav: str) -> bool: """ffmpeg 提取音軌為 16kHz mono WAV""" cmd = [ "ffmpeg", "-y", "-v", "quiet", "-i", video_path, "-map", f"0:{track_index}", "-ar", "16000", "-ac", "1", "-sample_fmt", "s16", output_wav, ] try: subprocess.run(cmd, check=True, capture_output=True, timeout=300) return True except Exception as e: print(f"[ASRX] ffmpeg extraction failed: {e}") return False def _cleanup(tmp_dir): if tmp_dir and os.path.exists(tmp_dir): import shutil shutil.rmtree(tmp_dir, ignore_errors=True) def _atomic_write(path: str, data: dict): tmp = path + ".tmp" with open(tmp, "w") as f: json.dump(data, f, indent=2) os.rename(tmp, path) def _shared_audio_setup(video_path): """提取音頻,回傳 (tmp_dir, wav_path)""" tracks = probe_audio_tracks(video_path) track_idx = select_best_track(tracks) if tracks else 0 actual_track_index = tracks[track_idx]["index"] if tracks else track_idx tmp_dir = tempfile.mkdtemp(prefix="asrx_") wav_path = os.path.join(tmp_dir, "audio.wav") if extract_audio_to_wav(video_path, actual_track_index, wav_path): return tmp_dir, wav_path print("[ASRX] Audio extraction failed, falling back to original file", file=sys.stderr) return tmp_dir, video_path def _convert_result(result, output_path): """Stage 3: 將 SelfASRXFixed result 轉為 Rust-expected format""" fps = 30.0 base_name = os.path.basename(output_path) uuid_part = base_name.split(".")[0] probe_path = os.path.join(os.path.dirname(output_path), f"{uuid_part}.probe.json") if os.path.exists(probe_path): try: with open(probe_path) as pf: probe_data = json.load(pf) if "fps" in probe_data: fps = float(probe_data["fps"]) except Exception: pass output_result = { "language": result.get("language"), "segments": [], "n_speakers": result.get("n_speakers", 0), "speaker_stats": result.get("speaker_stats", {}), } for seg in result.get("segments", []): start_sec = seg["start"] end_sec = seg["end"] output_result["segments"].append({ "start_time": start_sec, "end_time": end_sec, "start_frame": int(start_sec * fps), "end_frame": int(end_sec * fps), "text": seg.get("text", ""), "speaker_id": seg.get("speaker_id", seg.get("speaker", "")), "language": seg.get("language", ""), "lang_prob": seg.get("lang_prob", 0.0), "quality": seg.get("quality", 0.0), }) if "references" in result: output_result["references"] = result["references"] return output_result def process_asrx(video_path: str, output_path: str, uuid: str = "", file_uuid: str = "", resume: bool = False): """主處理函數""" publisher = RedisPublisher(uuid) if uuid else None if publisher: publisher.info("asrx", "ASRX_START") checkpoint_path = output_path + ".stage1.json" # ── Phase 2: Resume from checkpoint (Steps 4-7 only) ── if resume and os.path.exists(checkpoint_path): print(f"[ASRX] Found checkpoint, resuming from Step 4...") tmp_dir, audio_input = _shared_audio_setup(video_path) try: from asrx_self.main_fixed import SelfASRXFixed asrx = SelfASRXFixed() result = asrx.resume_from_checkpoint( checkpoint_path, audio_input, output_path=output_path, ) if "error" in result: if publisher: publisher.error("asrx", result["error"]) output_result = {"language": None, "segments": []} _atomic_write(output_path, output_result) if publisher: publisher.complete("asrx", "0 segments") _cleanup(tmp_dir) return output_result output_result = _convert_result(result, output_path) if publisher: publisher.info("asrx", f"ASRX_COMPLETE:{len(output_result['segments'])}") _atomic_write(output_path, output_result) if publisher: publisher.complete( "asrx", f"{len(output_result['segments'])} segments") print(f"[ASRX] Saved {len(output_result['segments'])} segments " f"to {output_path}", file=sys.stderr) # 刪除 checkpoint(完成後清理) try: os.remove(checkpoint_path) print(f"[ASRX] Removed checkpoint: {checkpoint_path}") except Exception: pass _cleanup(tmp_dir) return output_result except Exception as e: if publisher: publisher.error("asrx", str(e)) import traceback traceback.print_exc() output_result = {"language": None, "segments": []} _atomic_write(output_path, output_result) if publisher: publisher.complete("asrx", "0 segments") _cleanup(tmp_dir) return output_result # ── Phase 1: Full 7-step pipeline ── tmp_dir = None try: # Stage 1: Audio Track Preprocessing tmp_dir, audio_input = _shared_audio_setup(video_path) # Stage 2: SelfASRXFixed 7-step pipeline from asrx_self.main_fixed import SelfASRXFixed if publisher: publisher.info("asrx", "ASRX_LOADING_MODEL") asrx = SelfASRXFixed() if publisher: publisher.info("asrx", "ASRX_TRANSCRIBING") result = asrx.process( audio_input, output_path=None, file_uuid=file_uuid or None, max_speakers=10, quality_threshold=0.85, checkpoint_path=checkpoint_path, ) if "error" in result: if publisher: publisher.error("asrx", result["error"]) output_result = {"language": None, "segments": []} _atomic_write(output_path, output_result) if publisher: publisher.complete("asrx", "0 segments") _cleanup(tmp_dir) return output_result # Stage 3: Convert to Rust-expected format output_result = _convert_result(result, output_path) if publisher: publisher.info("asrx", f"ASRX_COMPLETE:{len(output_result['segments'])}") _atomic_write(output_path, output_result) if publisher: publisher.complete("asrx", f"{len(output_result['segments'])} segments") print(f"[ASRX] Saved {len(output_result['segments'])} segments " f"to {output_path}", file=sys.stderr) _cleanup(tmp_dir) return output_result except Exception as e: if publisher: publisher.error("asrx", str(e)) import traceback traceback.print_exc() output_result = {"language": None, "segments": []} _atomic_write(output_path, output_result) if publisher: publisher.complete("asrx", "0 segments") # 如果 checkpoint 已存在(Step 3 完成後 crash),保留 WAV 給 resume if not os.path.exists(checkpoint_path): _cleanup(tmp_dir) else: print(f"[ASRX] Checkpoint saved, keeping temp dir for resume: {tmp_dir}") return output_result if __name__ == "__main__": parser = argparse.ArgumentParser(description="ASRX Processor (Hybrid Pipeline)") parser.add_argument("video_path", help="Path to video/audio file") parser.add_argument("output_path", help="Path to output JSON file") parser.add_argument("--uuid", help="UUID for Redis publishing", default="") parser.add_argument("--file-uuid", help="File UUID for Qdrant storage", default="") parser.add_argument("--resume", action="store_true", help="Resume from checkpoint (skip Steps 1-3)") args = parser.parse_args() if not args.resume and not Path(args.video_path).exists(): print(f"Error: Video file not found: {args.video_path}") sys.exit(1) result = process_asrx(args.video_path, args.output_path, args.uuid, args.file_uuid, resume=args.resume) print("\n[Summary]") print(f" Total segments: {len(result.get('segments', []))}") if "speaker_stats" in result: print(f" Detected speakers: {len(result['speaker_stats'])}") for speaker, stats in result["speaker_stats"].items(): print(f" {speaker}: {stats['count']} segments")