Files
momentry_core/scripts/asrx_processor_custom.py
Warren e75c4d6f07 cleanup: remove dead code and duplicate docs
- Remove session-ses_2f27.md (161KB raw session log)
- Remove 49 ROOT_* duplicate files across REFERENCE/
- Remove 14 duplicate files between REFERENCE/ root and history/
- Remove asr_legacy.rs (dead code, replaced by asr.rs)
- Remove src/core/worker/ (duplicate JobWorker)
- Remove src/core/layers/ (empty directory)
- Remove 4 .bak files in src/
- Remove 7 dead private methods in worker/processor.rs
- Remove backup directory from git tracking
2026-05-04 01:31:21 +08:00

329 lines
12 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/opt/homebrew/bin/python3.11
"""
ASRX Processor - Custom Implementation Wrapper
Uses SpeechBrain ECAPA-TDNN (no HuggingFace token required)
Pipeline:
1. Preprocess: ffprobe audio tracks → select best track → extract WAV
2. Process: VAD (Silero) → Speaker embedding (ECAPA-TDNN) → Spectral clustering
3. Output: segments with speaker_id
"""
import sys
import json
import argparse
import os
import subprocess
import tempfile
from pathlib import Path
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(
0, os.path.join(os.path.dirname(os.path.abspath(__file__)), "asrx_self")
)
from redis_publisher import RedisPublisher
def probe_audio_tracks(video_path: str) -> list:
"""Use ffprobe to list all audio tracks in the video file."""
cmd = [
"ffprobe", "-v", "quiet", "-print_format", "json",
"-show_streams", "-select_streams", "a", video_path,
]
try:
result = subprocess.run(cmd, capture_output=True, text=True, timeout=30)
data = json.loads(result.stdout)
tracks = []
for stream in data.get("streams", []):
track = {
"index": stream.get("index"),
"codec": stream.get("codec_name"),
"language": stream.get("tags", {}).get("language", "und"),
"channels": stream.get("channels", 0),
"sample_rate": stream.get("sample_rate", "0"),
}
tracks.append(track)
return tracks
except Exception as e:
print(f"[ASRX] ffprobe failed: {e}")
return []
def select_best_track(tracks: list) -> int:
"""Select the best audio track: English > first available > fallback to 0."""
if not tracks:
return 0
# Priority 1: English track
for i, t in enumerate(tracks):
if t["language"] == "eng" or t["language"] == "en":
print(f"[ASRX] Selected English track (index {t['index']})")
return i
# Priority 2: First track with the most channels
best = 0
for i, t in enumerate(tracks):
if t["channels"] > tracks[best]["channels"]:
best = i
print(f"[ASRX] Selected track {best} (lang={tracks[best]['language']}, ch={tracks[best]['channels']})")
return best
def extract_audio_to_wav(video_path: str, track_index: int, output_wav: str) -> bool:
"""Extract selected audio track to 16kHz mono WAV using ffmpeg."""
cmd = [
"ffmpeg", "-y", "-v", "quiet",
"-i", video_path,
"-map", f"0:{track_index}",
"-ar", "16000",
"-ac", "1",
"-sample_fmt", "s16",
output_wav,
]
try:
subprocess.run(cmd, check=True, capture_output=True, timeout=300)
return True
except Exception as e:
print(f"[ASRX] ffmpeg extraction failed: {e}")
return False
def _cleanup(tmp_dir):
"""Clean up temporary directory."""
if tmp_dir and os.path.exists(tmp_dir):
import shutil
shutil.rmtree(tmp_dir, ignore_errors=True)
def process_asrx_custom(video_path: str, output_path: str, uuid: str = ""):
"""Process video for speaker diarization using custom implementation"""
publisher = RedisPublisher(uuid) if uuid else None
if publisher:
publisher.info("asrx", "ASRX_START")
tmp_dir = None
try:
# Ensure working directory is the scripts dir for model loading
script_dir = os.path.dirname(os.path.abspath(__file__))
os.chdir(script_dir)
# Debug: check ffmpeg availability
import shutil
ffmpeg_path = shutil.which("ffmpeg")
print(f"[ASRX] ffmpeg: {ffmpeg_path}", file=sys.stderr)
print(f"[ASRX] CWD: {os.getcwd()}", file=sys.stderr)
# ---- Stage 1: Audio Track Preprocessing ----
print("\n[ASRX] ===== Stage 1: Audio Track Analysis =====", file=sys.stderr)
print(f"[ASRX] Input: {video_path}", file=sys.stderr)
tracks = probe_audio_tracks(video_path)
if tracks:
print(f"[ASRX] Found {len(tracks)} audio track(s):", file=sys.stderr)
for t in tracks:
print(f" Track {t['index']}: {t['codec']} {t['channels']}ch {t['sample_rate']}Hz lang={t['language']}", file=sys.stderr)
else:
print("[ASRX] No audio tracks found via ffprobe, using raw file", file=sys.stderr)
# Select best track
track_idx = select_best_track(tracks) if tracks else 0
actual_track_index = tracks[track_idx]["index"] if tracks else track_idx
# Extract audio to WAV
tmp_dir = tempfile.mkdtemp(prefix="asrx_")
wav_path = os.path.join(tmp_dir, "audio.wav")
if extract_audio_to_wav(video_path, actual_track_index, wav_path):
wav_size = os.path.getsize(wav_path)
print(f"[ASRX] Audio extracted: {wav_path} ({wav_size / 1024 / 1024:.1f}MB)", file=sys.stderr)
audio_input = wav_path
else:
print("[ASRX] Audio extraction failed, falling back to original file", file=sys.stderr)
audio_input = video_path
# ---- Stage 2: Load ASR segments for time alignment ----
# Try multiple paths to find ASR JSON
asr_segments = []
asr_fallback_reason = ""
asr_candidates = [
output_path.replace(".asrx.json", ".asr.json") if output_path else "",
os.path.join(os.path.dirname(output_path) if output_path else ".", os.path.basename(video_path).rsplit(".", 1)[0] + ".asr.json"),
os.path.join(os.path.dirname(output_path) if output_path else ".", "dd61fda85fee441fdd00ab5528213ff7.asr.json"),
]
asr_path = ""
for candidate in asr_candidates:
if candidate and os.path.exists(candidate):
asr_path = candidate
break
if asr_path:
try:
with open(asr_path) as f:
asr_data = json.load(f)
asr_segments = asr_data.get("segments", [])
print(f"[ASRX] Loaded {len(asr_segments)} ASR segments from {asr_path}", file=sys.stderr)
asr_fallback_reason = f"loaded_{len(asr_segments)}_segments"
except Exception as e:
asr_fallback_reason = f"load_error_{e}"
print(f"[ASRX] Failed to load ASR segments: {e}", file=sys.stderr)
else:
asr_fallback_reason = f"asr_json_not_found_tried_{len(asr_candidates)}_paths"
print(f"[ASRX] ASR output not found, tried {len(asr_candidates)} paths. First candidate: {asr_candidates[0]}", file=sys.stderr)
# ---- Stage 3: ASRX Processing ----
from asrx_self.main_fixed import SelfASRXFixed
if publisher:
publisher.info("asrx", "ASRX_LOADING_MODEL")
asrx = SelfASRXFixed()
if publisher:
publisher.info("asrx", "ASRX_TRANSCRIBING")
if asr_segments:
# Use ASR segment boundaries for speaker embedding extraction
print(f"[ASRX] Using {len(asr_segments)} ASR segments for diarization", file=sys.stderr)
result = asrx.process_with_segments(
audio_input,
asr_segments,
output_path=None,
)
else:
# Fallback: VAD-based diarization
result = asrx.process(
audio_input,
output_path=None,
min_speech_duration_ms=500,
max_speakers=10,
)
if "error" in result:
if publisher:
publisher.error("asrx", result["error"])
# Return empty result
output_result = {"language": None, "segments": []}
with open(output_path, "w") as f:
json.dump(output_result, f, indent=2)
if publisher:
publisher.complete("asrx", "0 segments")
_cleanup(tmp_dir)
return output_result
# Convert to Rust-expected format (start_frame/end_frame/speaker)
# Read fps from probe json ({file_uuid}.probe.json)
_debug = {"asr_fallback": asr_fallback_reason, "asr_path": asr_path}
fps = 30.0
output_dir = os.path.dirname(output_path) if output_path else "."
base_name = os.path.basename(output_path) if output_path else ""
# Extract uuid from {uuid}.{type}.json format
uuid_part = base_name.split(".")[0] if base_name else ""
probe_candidates = [
os.path.join(output_dir, f"{uuid_part}.probe.json"),
]
for p in probe_candidates:
if os.path.exists(p):
try:
with open(p) as pf:
probe_data = json.load(pf)
if "fps" in probe_data:
fps = float(probe_data["fps"])
print(f"[ASRX] FPS from probe: {fps}", file=sys.stderr)
break
except:
pass
output_result = {
"language": None,
"segments": [],
}
# Convert segments
for seg in result["segments"]:
start_sec = seg["start"]
end_sec = seg["end"]
output_result["segments"].append(
{
"start_time": start_sec,
"end_time": end_sec,
"start_frame": int(start_sec * fps),
"end_frame": int(end_sec * fps),
"text": "",
"speaker_id": seg["speaker"],
}
)
# Add speaker_stats as optional metadata
if "speaker_stats" in result:
output_result["speaker_stats"] = result["speaker_stats"]
# 傳遞 embeddings每個 segment 對應的 192-D speaker embedding
if "embeddings" in result:
output_result["embeddings"] = result["embeddings"]
if publisher:
publisher.info("asrx", f"ASRX_COMPLETE:{len(output_result['segments'])}")
# Save output
output_result["_debug"] = _debug
with open(output_path, "w") as f:
json.dump(output_result, f, indent=2)
if publisher:
publisher.complete("asrx", f"{len(output_result['segments'])} segments")
print(f"[ASRX-Custom] Saved {len(output_result['segments'])} segments to {output_path}", file=sys.stderr)
_cleanup(tmp_dir)
return output_result
except Exception as e:
if publisher:
publisher.error("asrx", str(e))
import traceback
traceback.print_exc()
# Return empty result on error
output_result = {"language": None, "segments": []}
with open(output_path, "w") as f:
json.dump(output_result, f, indent=2)
if publisher:
publisher.complete("asrx", "0 segments")
_cleanup(tmp_dir)
return output_result
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="ASRX Processor (Custom Implementation)"
)
parser.add_argument("video_path", help="Path to video/audio file")
parser.add_argument("output_path", help="Path to output JSON file")
parser.add_argument("--uuid", help="UUID for Redis publishing", default="")
args = parser.parse_args()
if not Path(args.video_path).exists():
print(f"Error: Video file not found: {args.video_path}")
sys.exit(1)
result = process_asrx_custom(args.video_path, args.output_path, args.uuid)
print("\n[Summary]")
print(f" Total segments: {len(result['segments'])}")
if "speaker_stats" in result:
print(f" Detected speakers: {len(result['speaker_stats'])}")
for speaker, stats in result["speaker_stats"].items():
print(f" {speaker}: {stats['count']} segments")