feat: ASRX hybrid pipeline, identity history, worker fixes, checkpoint system
This commit is contained in:
@@ -1,124 +1,320 @@
|
||||
#!/opt/homebrew/bin/python3.11
|
||||
"""
|
||||
ASRX Processor - Speaker Diarization
|
||||
Uses whisperx for speaker diarization (local model)
|
||||
ASRX Processor - Hybrid Pipeline Wrapper
|
||||
|
||||
Pipeline:
|
||||
1. ffprobe → select best audio track → ffmpeg → 16kHz mono WAV
|
||||
2. SelfASRXFixed.process() (7-step hybrid speaker diarization)
|
||||
3. Convert to Rust-expected format
|
||||
"""
|
||||
|
||||
import sys
|
||||
import json
|
||||
import argparse
|
||||
import os
|
||||
import subprocess
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||||
sys.path.insert(
|
||||
0, os.path.join(os.path.dirname(os.path.abspath(__file__)), "asrx_self")
|
||||
)
|
||||
|
||||
from redis_publisher import RedisPublisher
|
||||
|
||||
|
||||
def process_asrx(video_path: str, output_path: str, uuid: str = ""):
|
||||
"""Process video for speaker diarization using whisperx"""
|
||||
def probe_audio_tracks(video_path: str) -> list:
|
||||
"""ffprobe 列出所有音軌"""
|
||||
cmd = [
|
||||
"ffprobe", "-v", "quiet", "-print_format", "json",
|
||||
"-show_streams", "-select_streams", "a", video_path,
|
||||
]
|
||||
try:
|
||||
result = subprocess.run(cmd, capture_output=True, text=True, timeout=30)
|
||||
data = json.loads(result.stdout)
|
||||
tracks = []
|
||||
for stream in data.get("streams", []):
|
||||
tracks.append({
|
||||
"index": stream.get("index"),
|
||||
"codec": stream.get("codec_name"),
|
||||
"language": stream.get("tags", {}).get("language", "und"),
|
||||
"channels": stream.get("channels", 0),
|
||||
"sample_rate": stream.get("sample_rate", "0"),
|
||||
})
|
||||
return tracks
|
||||
except Exception as e:
|
||||
print(f"[ASRX] ffprobe failed: {e}")
|
||||
return []
|
||||
|
||||
|
||||
def select_best_track(tracks: list) -> int:
|
||||
"""選最佳音軌: English > 最多channels > 0"""
|
||||
if not tracks:
|
||||
return 0
|
||||
for i, t in enumerate(tracks):
|
||||
if t["language"] in ("eng", "en"):
|
||||
return i
|
||||
best = 0
|
||||
for i, t in enumerate(tracks):
|
||||
if t["channels"] > tracks[best]["channels"]:
|
||||
best = i
|
||||
return best
|
||||
|
||||
|
||||
def extract_audio_to_wav(video_path: str, track_index: int, output_wav: str) -> bool:
|
||||
"""ffmpeg 提取音軌為 16kHz mono WAV"""
|
||||
cmd = [
|
||||
"ffmpeg", "-y", "-v", "quiet",
|
||||
"-i", video_path,
|
||||
"-map", f"0:{track_index}",
|
||||
"-ar", "16000",
|
||||
"-ac", "1",
|
||||
"-sample_fmt", "s16",
|
||||
output_wav,
|
||||
]
|
||||
try:
|
||||
subprocess.run(cmd, check=True, capture_output=True, timeout=300)
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"[ASRX] ffmpeg extraction failed: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def _cleanup(tmp_dir):
|
||||
if tmp_dir and os.path.exists(tmp_dir):
|
||||
import shutil
|
||||
shutil.rmtree(tmp_dir, ignore_errors=True)
|
||||
|
||||
|
||||
def _atomic_write(path: str, data: dict):
|
||||
tmp = path + ".tmp"
|
||||
with open(tmp, "w") as f:
|
||||
json.dump(data, f, indent=2)
|
||||
os.rename(tmp, path)
|
||||
|
||||
|
||||
def _shared_audio_setup(video_path):
|
||||
"""提取音頻,回傳 (tmp_dir, wav_path)"""
|
||||
tracks = probe_audio_tracks(video_path)
|
||||
track_idx = select_best_track(tracks) if tracks else 0
|
||||
actual_track_index = tracks[track_idx]["index"] if tracks else track_idx
|
||||
|
||||
tmp_dir = tempfile.mkdtemp(prefix="asrx_")
|
||||
wav_path = os.path.join(tmp_dir, "audio.wav")
|
||||
|
||||
if extract_audio_to_wav(video_path, actual_track_index, wav_path):
|
||||
return tmp_dir, wav_path
|
||||
print("[ASRX] Audio extraction failed, falling back to original file",
|
||||
file=sys.stderr)
|
||||
return tmp_dir, video_path
|
||||
|
||||
|
||||
def _convert_result(result, output_path):
|
||||
"""Stage 3: 將 SelfASRXFixed result 轉為 Rust-expected format"""
|
||||
fps = 30.0
|
||||
base_name = os.path.basename(output_path)
|
||||
uuid_part = base_name.split(".")[0]
|
||||
probe_path = os.path.join(os.path.dirname(output_path),
|
||||
f"{uuid_part}.probe.json")
|
||||
if os.path.exists(probe_path):
|
||||
try:
|
||||
with open(probe_path) as pf:
|
||||
probe_data = json.load(pf)
|
||||
if "fps" in probe_data:
|
||||
fps = float(probe_data["fps"])
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
output_result = {
|
||||
"language": result.get("language"),
|
||||
"segments": [],
|
||||
"n_speakers": result.get("n_speakers", 0),
|
||||
"speaker_stats": result.get("speaker_stats", {}),
|
||||
}
|
||||
|
||||
for seg in result.get("segments", []):
|
||||
start_sec = seg["start"]
|
||||
end_sec = seg["end"]
|
||||
output_result["segments"].append({
|
||||
"start_time": start_sec,
|
||||
"end_time": end_sec,
|
||||
"start_frame": int(start_sec * fps),
|
||||
"end_frame": int(end_sec * fps),
|
||||
"text": seg.get("text", ""),
|
||||
"speaker_id": seg.get("speaker_id", seg.get("speaker", "")),
|
||||
"language": seg.get("language", ""),
|
||||
"lang_prob": seg.get("lang_prob", 0.0),
|
||||
"quality": seg.get("quality", 0.0),
|
||||
})
|
||||
|
||||
if "references" in result:
|
||||
output_result["references"] = result["references"]
|
||||
|
||||
return output_result
|
||||
|
||||
|
||||
def process_asrx(video_path: str, output_path: str, uuid: str = "",
|
||||
file_uuid: str = "", resume: bool = False):
|
||||
"""主處理函數"""
|
||||
publisher = RedisPublisher(uuid) if uuid else None
|
||||
if publisher:
|
||||
publisher.info("asrx", "ASRX_START")
|
||||
|
||||
try:
|
||||
import whisperx
|
||||
import torch
|
||||
except ImportError:
|
||||
if publisher:
|
||||
publisher.error("asrx", "whisperx not installed")
|
||||
result = {"language": None, "segments": []}
|
||||
if publisher:
|
||||
publisher.complete("asrx", "0 segments")
|
||||
with open(output_path, "w") as f:
|
||||
json.dump(result, f, indent=2)
|
||||
return result
|
||||
checkpoint_path = output_path + ".stage1.json"
|
||||
|
||||
if publisher:
|
||||
publisher.info("asrx", "ASRX_LOADING_MODEL")
|
||||
# ── Phase 2: Resume from checkpoint (Steps 4-7 only) ──
|
||||
if resume and os.path.exists(checkpoint_path):
|
||||
print(f"[ASRX] Found checkpoint, resuming from Step 4...")
|
||||
tmp_dir, audio_input = _shared_audio_setup(video_path)
|
||||
try:
|
||||
from asrx_self.main_fixed import SelfASRXFixed
|
||||
asrx = SelfASRXFixed()
|
||||
|
||||
result = asrx.resume_from_checkpoint(
|
||||
checkpoint_path, audio_input, output_path=output_path,
|
||||
)
|
||||
|
||||
if "error" in result:
|
||||
if publisher:
|
||||
publisher.error("asrx", result["error"])
|
||||
output_result = {"language": None, "segments": []}
|
||||
_atomic_write(output_path, output_result)
|
||||
if publisher:
|
||||
publisher.complete("asrx", "0 segments")
|
||||
_cleanup(tmp_dir)
|
||||
return output_result
|
||||
|
||||
output_result = _convert_result(result, output_path)
|
||||
|
||||
if publisher:
|
||||
publisher.info("asrx",
|
||||
f"ASRX_COMPLETE:{len(output_result['segments'])}")
|
||||
|
||||
_atomic_write(output_path, output_result)
|
||||
|
||||
if publisher:
|
||||
publisher.complete(
|
||||
"asrx", f"{len(output_result['segments'])} segments")
|
||||
|
||||
print(f"[ASRX] Saved {len(output_result['segments'])} segments "
|
||||
f"to {output_path}", file=sys.stderr)
|
||||
|
||||
# 刪除 checkpoint(完成後清理)
|
||||
try:
|
||||
os.remove(checkpoint_path)
|
||||
print(f"[ASRX] Removed checkpoint: {checkpoint_path}")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
_cleanup(tmp_dir)
|
||||
return output_result
|
||||
except Exception as e:
|
||||
if publisher:
|
||||
publisher.error("asrx", str(e))
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
output_result = {"language": None, "segments": []}
|
||||
_atomic_write(output_path, output_result)
|
||||
if publisher:
|
||||
publisher.complete("asrx", "0 segments")
|
||||
_cleanup(tmp_dir)
|
||||
return output_result
|
||||
|
||||
# ── Phase 1: Full 7-step pipeline ──
|
||||
tmp_dir = None
|
||||
|
||||
try:
|
||||
# Fix for PyTorch 2.6+ compatibility
|
||||
# Allow omegaconf types in torch.load
|
||||
import omegaconf
|
||||
# Stage 1: Audio Track Preprocessing
|
||||
tmp_dir, audio_input = _shared_audio_setup(video_path)
|
||||
|
||||
torch.serialization.add_safe_globals(
|
||||
[omegaconf.listconfig.ListConfig, omegaconf.dictconfig.DictConfig]
|
||||
)
|
||||
# Stage 2: SelfASRXFixed 7-step pipeline
|
||||
from asrx_self.main_fixed import SelfASRXFixed
|
||||
|
||||
# Load model - using faster-whisper for better performance
|
||||
# You can also use: "large-v3", "medium", "small", "base", "tiny"
|
||||
model = whisperx.load_model("base", device="cpu", compute_type="int8")
|
||||
if publisher:
|
||||
publisher.info("asrx", "ASRX_LOADING_MODEL")
|
||||
|
||||
asrx = SelfASRXFixed()
|
||||
|
||||
if publisher:
|
||||
publisher.info("asrx", "ASRX_TRANSCRIBING")
|
||||
|
||||
# Transcribe audio
|
||||
result = model.transcribe(video_path, language="en")
|
||||
|
||||
# Align timestamps
|
||||
model_a, metadata = whisperx.load_align_model(language_code=result["language"])
|
||||
result = whisperx.align(
|
||||
result["segments"], model_a, metadata, video_path, device="cpu"
|
||||
result = asrx.process(
|
||||
audio_input,
|
||||
output_path=None,
|
||||
file_uuid=file_uuid or None,
|
||||
max_speakers=10,
|
||||
quality_threshold=0.85,
|
||||
checkpoint_path=checkpoint_path,
|
||||
)
|
||||
|
||||
# Diarization (speaker segmentation)
|
||||
try:
|
||||
from whisperx.diarize import DiarizationPipeline
|
||||
|
||||
# DiarizationPipeline parameters: model_name, token, device, cache_dir
|
||||
diarize_model = DiarizationPipeline(
|
||||
model_name="pyannote/speaker-diarization",
|
||||
token=None, # HuggingFace token (None for public models)
|
||||
device="cpu",
|
||||
)
|
||||
diarize_segments = diarize_model(video_path)
|
||||
|
||||
# Assign speaker labels
|
||||
result = whisperx.assign_word_speakers(diarize_segments, result)
|
||||
except Exception as e:
|
||||
if "error" in result:
|
||||
if publisher:
|
||||
publisher.info("asrx", f"Diarization skipped: {e}")
|
||||
publisher.error("asrx", result["error"])
|
||||
output_result = {"language": None, "segments": []}
|
||||
_atomic_write(output_path, output_result)
|
||||
if publisher:
|
||||
publisher.complete("asrx", "0 segments")
|
||||
_cleanup(tmp_dir)
|
||||
return output_result
|
||||
|
||||
# Build output
|
||||
segments = []
|
||||
for seg in result.get("segments", []):
|
||||
text = seg.get("text", "").strip()
|
||||
if text:
|
||||
segments.append(
|
||||
{
|
||||
"start": seg.get("start", 0.0),
|
||||
"end": seg.get("end", 0.0),
|
||||
"text": text,
|
||||
"speaker_id": seg.get("speaker", None),
|
||||
}
|
||||
)
|
||||
|
||||
output_result = {"language": result.get("language"), "segments": segments}
|
||||
# Stage 3: Convert to Rust-expected format
|
||||
output_result = _convert_result(result, output_path)
|
||||
|
||||
if publisher:
|
||||
publisher.complete("asrx", f"{len(segments)} segments")
|
||||
publisher.info("asrx", f"ASRX_COMPLETE:{len(output_result['segments'])}")
|
||||
|
||||
with open(output_path, "w") as f:
|
||||
json.dump(output_result, f, indent=2)
|
||||
_atomic_write(output_path, output_result)
|
||||
|
||||
if publisher:
|
||||
publisher.complete("asrx",
|
||||
f"{len(output_result['segments'])} segments")
|
||||
|
||||
print(f"[ASRX] Saved {len(output_result['segments'])} segments "
|
||||
f"to {output_path}", file=sys.stderr)
|
||||
|
||||
_cleanup(tmp_dir)
|
||||
return output_result
|
||||
|
||||
except Exception as e:
|
||||
if publisher:
|
||||
publisher.error("asrx", f"Error: {e}")
|
||||
result = {"language": None, "segments": []}
|
||||
publisher.error("asrx", str(e))
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
output_result = {"language": None, "segments": []}
|
||||
_atomic_write(output_path, output_result)
|
||||
if publisher:
|
||||
publisher.complete("asrx", "0 segments")
|
||||
with open(output_path, "w") as f:
|
||||
json.dump(result, f, indent=2)
|
||||
return result
|
||||
# 如果 checkpoint 已存在(Step 3 完成後 crash),保留 WAV 給 resume
|
||||
if not os.path.exists(checkpoint_path):
|
||||
_cleanup(tmp_dir)
|
||||
else:
|
||||
print(f"[ASRX] Checkpoint saved, keeping temp dir for resume: {tmp_dir}")
|
||||
return output_result
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="ASRX Speaker Diarization")
|
||||
parser.add_argument("video_path", help="Path to video file")
|
||||
parser.add_argument("output_path", help="Output JSON path")
|
||||
parser.add_argument("--uuid", "-u", help="UUID for Redis progress", default="")
|
||||
parser = argparse.ArgumentParser(description="ASRX Processor (Hybrid Pipeline)")
|
||||
parser.add_argument("video_path", help="Path to video/audio file")
|
||||
parser.add_argument("output_path", help="Path to output JSON file")
|
||||
parser.add_argument("--uuid", help="UUID for Redis publishing", default="")
|
||||
parser.add_argument("--file-uuid", help="File UUID for Qdrant storage", default="")
|
||||
parser.add_argument("--resume", action="store_true",
|
||||
help="Resume from checkpoint (skip Steps 1-3)")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
process_asrx(args.video_path, args.output_path, args.uuid)
|
||||
if not args.resume and not Path(args.video_path).exists():
|
||||
print(f"Error: Video file not found: {args.video_path}")
|
||||
sys.exit(1)
|
||||
|
||||
result = process_asrx(args.video_path, args.output_path, args.uuid,
|
||||
args.file_uuid, resume=args.resume)
|
||||
|
||||
print("\n[Summary]")
|
||||
print(f" Total segments: {len(result.get('segments', []))}")
|
||||
if "speaker_stats" in result:
|
||||
print(f" Detected speakers: {len(result['speaker_stats'])}")
|
||||
for speaker, stats in result["speaker_stats"].items():
|
||||
print(f" {speaker}: {stats['count']} segments")
|
||||
|
||||
Reference in New Issue
Block a user