Files
momentry_core/scripts/asrx_processor.py

321 lines
11 KiB
Python
Executable File
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/opt/homebrew/bin/python3.11
"""
ASRX Processor - Hybrid Pipeline Wrapper
Pipeline:
1. ffprobe → select best audio track → ffmpeg → 16kHz mono WAV
2. SelfASRXFixed.process() (7-step hybrid speaker diarization)
3. Convert to Rust-expected format
"""
import sys
import json
import argparse
import os
import subprocess
import tempfile
from pathlib import Path
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(
0, os.path.join(os.path.dirname(os.path.abspath(__file__)), "asrx_self")
)
from redis_publisher import RedisPublisher
def probe_audio_tracks(video_path: str) -> list:
"""ffprobe 列出所有音軌"""
cmd = [
"ffprobe", "-v", "quiet", "-print_format", "json",
"-show_streams", "-select_streams", "a", video_path,
]
try:
result = subprocess.run(cmd, capture_output=True, text=True, timeout=30)
data = json.loads(result.stdout)
tracks = []
for stream in data.get("streams", []):
tracks.append({
"index": stream.get("index"),
"codec": stream.get("codec_name"),
"language": stream.get("tags", {}).get("language", "und"),
"channels": stream.get("channels", 0),
"sample_rate": stream.get("sample_rate", "0"),
})
return tracks
except Exception as e:
print(f"[ASRX] ffprobe failed: {e}")
return []
def select_best_track(tracks: list) -> int:
"""選最佳音軌: English > 最多channels > 0"""
if not tracks:
return 0
for i, t in enumerate(tracks):
if t["language"] in ("eng", "en"):
return i
best = 0
for i, t in enumerate(tracks):
if t["channels"] > tracks[best]["channels"]:
best = i
return best
def extract_audio_to_wav(video_path: str, track_index: int, output_wav: str) -> bool:
"""ffmpeg 提取音軌為 16kHz mono WAV"""
cmd = [
"ffmpeg", "-y", "-v", "quiet",
"-i", video_path,
"-map", f"0:{track_index}",
"-ar", "16000",
"-ac", "1",
"-sample_fmt", "s16",
output_wav,
]
try:
subprocess.run(cmd, check=True, capture_output=True, timeout=300)
return True
except Exception as e:
print(f"[ASRX] ffmpeg extraction failed: {e}")
return False
def _cleanup(tmp_dir):
if tmp_dir and os.path.exists(tmp_dir):
import shutil
shutil.rmtree(tmp_dir, ignore_errors=True)
def _atomic_write(path: str, data: dict):
tmp = path + ".tmp"
with open(tmp, "w") as f:
json.dump(data, f, indent=2)
os.rename(tmp, path)
def _shared_audio_setup(video_path):
"""提取音頻,回傳 (tmp_dir, wav_path)"""
tracks = probe_audio_tracks(video_path)
track_idx = select_best_track(tracks) if tracks else 0
actual_track_index = tracks[track_idx]["index"] if tracks else track_idx
tmp_dir = tempfile.mkdtemp(prefix="asrx_")
wav_path = os.path.join(tmp_dir, "audio.wav")
if extract_audio_to_wav(video_path, actual_track_index, wav_path):
return tmp_dir, wav_path
print("[ASRX] Audio extraction failed, falling back to original file",
file=sys.stderr)
return tmp_dir, video_path
def _convert_result(result, output_path):
"""Stage 3: 將 SelfASRXFixed result 轉為 Rust-expected format"""
fps = 30.0
base_name = os.path.basename(output_path)
uuid_part = base_name.split(".")[0]
probe_path = os.path.join(os.path.dirname(output_path),
f"{uuid_part}.probe.json")
if os.path.exists(probe_path):
try:
with open(probe_path) as pf:
probe_data = json.load(pf)
if "fps" in probe_data:
fps = float(probe_data["fps"])
except Exception:
pass
output_result = {
"language": result.get("language"),
"segments": [],
"n_speakers": result.get("n_speakers", 0),
"speaker_stats": result.get("speaker_stats", {}),
}
for seg in result.get("segments", []):
start_sec = seg["start"]
end_sec = seg["end"]
output_result["segments"].append({
"start_time": start_sec,
"end_time": end_sec,
"start_frame": int(start_sec * fps),
"end_frame": int(end_sec * fps),
"text": seg.get("text", ""),
"speaker_id": seg.get("speaker_id", seg.get("speaker", "")),
"language": seg.get("language", ""),
"lang_prob": seg.get("lang_prob", 0.0),
"quality": seg.get("quality", 0.0),
})
if "references" in result:
output_result["references"] = result["references"]
return output_result
def process_asrx(video_path: str, output_path: str, uuid: str = "",
file_uuid: str = "", resume: bool = False):
"""主處理函數"""
publisher = RedisPublisher(uuid) if uuid else None
if publisher:
publisher.info("asrx", "ASRX_START")
checkpoint_path = output_path + ".stage1.json"
# ── Phase 2: Resume from checkpoint (Steps 4-7 only) ──
if resume and os.path.exists(checkpoint_path):
print(f"[ASRX] Found checkpoint, resuming from Step 4...")
tmp_dir, audio_input = _shared_audio_setup(video_path)
try:
from asrx_self.main_fixed import SelfASRXFixed
asrx = SelfASRXFixed()
result = asrx.resume_from_checkpoint(
checkpoint_path, audio_input, output_path=output_path,
)
if "error" in result:
if publisher:
publisher.error("asrx", result["error"])
output_result = {"language": None, "segments": []}
_atomic_write(output_path, output_result)
if publisher:
publisher.complete("asrx", "0 segments")
_cleanup(tmp_dir)
return output_result
output_result = _convert_result(result, output_path)
if publisher:
publisher.info("asrx",
f"ASRX_COMPLETE:{len(output_result['segments'])}")
_atomic_write(output_path, output_result)
if publisher:
publisher.complete(
"asrx", f"{len(output_result['segments'])} segments")
print(f"[ASRX] Saved {len(output_result['segments'])} segments "
f"to {output_path}", file=sys.stderr)
# 刪除 checkpoint完成後清理
try:
os.remove(checkpoint_path)
print(f"[ASRX] Removed checkpoint: {checkpoint_path}")
except Exception:
pass
_cleanup(tmp_dir)
return output_result
except Exception as e:
if publisher:
publisher.error("asrx", str(e))
import traceback
traceback.print_exc()
output_result = {"language": None, "segments": []}
_atomic_write(output_path, output_result)
if publisher:
publisher.complete("asrx", "0 segments")
_cleanup(tmp_dir)
return output_result
# ── Phase 1: Full 7-step pipeline ──
tmp_dir = None
try:
# Stage 1: Audio Track Preprocessing
tmp_dir, audio_input = _shared_audio_setup(video_path)
# Stage 2: SelfASRXFixed 7-step pipeline
from asrx_self.main_fixed import SelfASRXFixed
if publisher:
publisher.info("asrx", "ASRX_LOADING_MODEL")
asrx = SelfASRXFixed()
if publisher:
publisher.info("asrx", "ASRX_TRANSCRIBING")
result = asrx.process(
audio_input,
output_path=None,
file_uuid=file_uuid or None,
max_speakers=10,
quality_threshold=0.85,
checkpoint_path=checkpoint_path,
)
if "error" in result:
if publisher:
publisher.error("asrx", result["error"])
output_result = {"language": None, "segments": []}
_atomic_write(output_path, output_result)
if publisher:
publisher.complete("asrx", "0 segments")
_cleanup(tmp_dir)
return output_result
# Stage 3: Convert to Rust-expected format
output_result = _convert_result(result, output_path)
if publisher:
publisher.info("asrx", f"ASRX_COMPLETE:{len(output_result['segments'])}")
_atomic_write(output_path, output_result)
if publisher:
publisher.complete("asrx",
f"{len(output_result['segments'])} segments")
print(f"[ASRX] Saved {len(output_result['segments'])} segments "
f"to {output_path}", file=sys.stderr)
_cleanup(tmp_dir)
return output_result
except Exception as e:
if publisher:
publisher.error("asrx", str(e))
import traceback
traceback.print_exc()
output_result = {"language": None, "segments": []}
_atomic_write(output_path, output_result)
if publisher:
publisher.complete("asrx", "0 segments")
# 如果 checkpoint 已存在Step 3 完成後 crash保留 WAV 給 resume
if not os.path.exists(checkpoint_path):
_cleanup(tmp_dir)
else:
print(f"[ASRX] Checkpoint saved, keeping temp dir for resume: {tmp_dir}")
return output_result
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="ASRX Processor (Hybrid Pipeline)")
parser.add_argument("video_path", help="Path to video/audio file")
parser.add_argument("output_path", help="Path to output JSON file")
parser.add_argument("--uuid", help="UUID for Redis publishing", default="")
parser.add_argument("--file-uuid", help="File UUID for Qdrant storage", default="")
parser.add_argument("--resume", action="store_true",
help="Resume from checkpoint (skip Steps 1-3)")
args = parser.parse_args()
if not args.resume and not Path(args.video_path).exists():
print(f"Error: Video file not found: {args.video_path}")
sys.exit(1)
result = process_asrx(args.video_path, args.output_path, args.uuid,
args.file_uuid, resume=args.resume)
print("\n[Summary]")
print(f" Total segments: {len(result.get('segments', []))}")
if "speaker_stats" in result:
print(f" Detected speakers: {len(result['speaker_stats'])}")
for speaker, stats in result["speaker_stats"].items():
print(f" {speaker}: {stats['count']} segments")