from __future__ import annotations import json import tempfile import unittest from pathlib import Path from types import SimpleNamespace from unittest.mock import patch from biliup_next.core.errors import ModuleError from biliup_next.core.models import Artifact, Task from biliup_next.modules.transcribe.providers.groq import GroqTranscribeProvider class _FakeResponse: def __init__(self, segments): self.segments = segments class _FakeTranscriptions: def __init__(self, outcomes: list[object]) -> None: self.outcomes = list(outcomes) self.calls: list[dict[str, object]] = [] def create(self, **kwargs): # noqa: ANN003 self.calls.append(kwargs) outcome = self.outcomes.pop(0) if isinstance(outcome, Exception): raise outcome return outcome class _FakeGroqClient: def __init__(self, outcomes: list[object]) -> None: self.audio = SimpleNamespace(transcriptions=_FakeTranscriptions(outcomes)) class GroqTranscribeProviderTests(unittest.TestCase): def test_transcribe_retries_timeout_and_writes_srt_atomically(self) -> None: provider = GroqTranscribeProvider() task = Task("task-1", "local_file", "/tmp/input.mp4", "demo", "created", "2026-01-01T00:00:00+00:00", "2026-01-01T00:00:00+00:00") with tempfile.TemporaryDirectory() as tmpdir: work_dir = Path(tmpdir) source_path = work_dir / "input.mp4" source_path.write_bytes(b"video") source_video = Artifact(None, task.id, "source_video", str(source_path), "{}", "2026-01-01T00:00:00+00:00") segment = work_dir / "temp_audio" / "part_000.mp3" def fake_extract_audio_segments(**kwargs): # noqa: ANN003 segment.parent.mkdir(parents=True, exist_ok=True) segment.write_bytes(b"audio") client = _FakeGroqClient( [ RuntimeError("Request timed out."), _FakeResponse([{"start": 0, "end": 1.2, "text": "hello"}]), ] ) settings = { "groq_api_key": "gsk_test", "ffmpeg_bin": "ffmpeg", "max_file_size_mb": 23, "request_timeout_seconds": 33, "request_max_retries": 1, "request_retry_backoff_seconds": 0, "serialize_groq_requests": False, } with patch("groq.Groq", return_value=client) as groq_ctor: with patch.object(provider, "_extract_audio_segments", side_effect=fake_extract_audio_segments): artifact = provider.transcribe(task, source_video, settings) self.assertEqual(Path(artifact.path).read_text(encoding="utf-8"), "1\n00:00:00,000 --> 00:00:01,199\nhello\n\n") self.assertFalse((work_dir / ".demo.srt.tmp").exists()) self.assertEqual(len(client.audio.transcriptions.calls), 2) self.assertEqual(client.audio.transcriptions.calls[0]["timeout"], 33) self.assertTrue((work_dir / "transcribe_segments" / "part_000.json").exists()) groq_ctor.assert_called_once_with(api_key="gsk_test", timeout=33, max_retries=0) def test_transcribe_reuses_completed_segment_checkpoints(self) -> None: provider = GroqTranscribeProvider() task = Task("task-1", "local_file", "/tmp/input.mp4", "demo", "created", "2026-01-01T00:00:00+00:00", "2026-01-01T00:00:00+00:00") with tempfile.TemporaryDirectory() as tmpdir: work_dir = Path(tmpdir) source_path = work_dir / "input.mp4" source_path.write_bytes(b"video") source_video = Artifact(None, task.id, "source_video", str(source_path), "{}", "2026-01-01T00:00:00+00:00") segments = [work_dir / "temp_audio" / "part_000.mp3", work_dir / "temp_audio" / "part_001.mp3"] checkpoint_dir = work_dir / "transcribe_segments" checkpoint_dir.mkdir() (checkpoint_dir / "part_000.json").write_text( json.dumps( { "provider": "groq", "model": "whisper-large-v3-turbo", "language": "zh", "audio_file": "part_000.mp3", "segment_duration_seconds": 75, "segments": [{"start": 0, "end": 1, "text": "first"}], } ), encoding="utf-8", ) def fake_extract_audio_segments(**kwargs): # noqa: ANN003 for segment in segments: segment.parent.mkdir(parents=True, exist_ok=True) segment.write_bytes(b"audio") client = _FakeGroqClient([_FakeResponse([{"start": 0, "end": 1.5, "text": "second"}])]) settings = { "groq_api_key": "gsk_test", "ffmpeg_bin": "ffmpeg", "max_file_size_mb": 23, "request_timeout_seconds": 33, "request_max_retries": 1, "request_retry_backoff_seconds": 0, "serialize_groq_requests": False, } with patch("groq.Groq", return_value=client): with patch.object(provider, "_initial_segment_duration", return_value=75): with patch.object(provider, "_extract_audio_segments", side_effect=fake_extract_audio_segments): artifact = provider.transcribe(task, source_video, settings) srt = Path(artifact.path).read_text(encoding="utf-8") self.assertIn("00:00:00,000 --> 00:00:01,000\nfirst", srt) self.assertIn("00:01:15,000 --> 00:01:16,500\nsecond", srt) self.assertEqual(len(client.audio.transcriptions.calls), 1) self.assertEqual(client.audio.transcriptions.calls[0]["file"][0], "part_001.mp3") self.assertTrue((checkpoint_dir / "part_001.json").exists()) def test_transcribe_switches_to_next_api_key_on_rate_limit(self) -> None: provider = GroqTranscribeProvider() task = Task("task-1", "local_file", "/tmp/input.mp4", "demo", "created", "2026-01-01T00:00:00+00:00", "2026-01-01T00:00:00+00:00") with tempfile.TemporaryDirectory() as tmpdir: work_dir = Path(tmpdir) source_path = work_dir / "input.mp4" source_path.write_bytes(b"video") source_video = Artifact(None, task.id, "source_video", str(source_path), "{}", "2026-01-01T00:00:00+00:00") segment = work_dir / "temp_audio" / "part_000.mp3" def fake_extract_audio_segments(**kwargs): # noqa: ANN003 segment.parent.mkdir(parents=True, exist_ok=True) segment.write_bytes(b"audio") limited_client = _FakeGroqClient([RuntimeError("Error code: 429 rate_limit")]) fallback_client = _FakeGroqClient([_FakeResponse([{"start": 0, "end": 1.2, "text": "fallback"}])]) settings = { "groq_api_key": "", "groq_api_keys": ["gsk_first", "gsk_second"], "ffmpeg_bin": "ffmpeg", "max_file_size_mb": 23, "request_timeout_seconds": 20, "request_max_retries": 0, "request_retry_backoff_seconds": 0, "serialize_groq_requests": False, } with patch("groq.Groq", side_effect=[limited_client, fallback_client]) as groq_ctor: with patch.object(provider, "_extract_audio_segments", side_effect=fake_extract_audio_segments): artifact = provider.transcribe(task, source_video, settings) self.assertIn("fallback", Path(artifact.path).read_text(encoding="utf-8")) self.assertEqual(len(limited_client.audio.transcriptions.calls), 1) self.assertEqual(len(fallback_client.audio.transcriptions.calls), 1) self.assertEqual([call.kwargs["api_key"] for call in groq_ctor.call_args_list], ["gsk_first", "gsk_second"]) def test_transcribe_waits_after_all_api_keys_are_rate_limited(self) -> None: provider = GroqTranscribeProvider() task = Task("task-1", "local_file", "/tmp/input.mp4", "demo", "created", "2026-01-01T00:00:00+00:00", "2026-01-01T00:00:00+00:00") with tempfile.TemporaryDirectory() as tmpdir: work_dir = Path(tmpdir) source_path = work_dir / "input.mp4" source_path.write_bytes(b"video") source_video = Artifact(None, task.id, "source_video", str(source_path), "{}", "2026-01-01T00:00:00+00:00") segment = work_dir / "temp_audio" / "part_000.mp3" def fake_extract_audio_segments(**kwargs): # noqa: ANN003 segment.parent.mkdir(parents=True, exist_ok=True) segment.write_bytes(b"audio") first_client = _FakeGroqClient([RuntimeError("429 rate_limit"), _FakeResponse([{"start": 0, "end": 1, "text": "retry ok"}])]) second_client = _FakeGroqClient([RuntimeError("429 rate_limit")]) settings = { "groq_api_key": "", "groq_api_keys": ["gsk_first", "gsk_second"], "ffmpeg_bin": "ffmpeg", "max_file_size_mb": 23, "request_timeout_seconds": 20, "request_max_retries": 1, "request_retry_backoff_seconds": 7, "serialize_groq_requests": False, } with patch("groq.Groq", side_effect=[first_client, second_client]): with patch("time.sleep") as sleep_mock: with patch.object(provider, "_extract_audio_segments", side_effect=fake_extract_audio_segments): artifact = provider.transcribe(task, source_video, settings) self.assertIn("retry ok", Path(artifact.path).read_text(encoding="utf-8")) sleep_mock.assert_called_once_with(7) self.assertEqual(len(first_client.audio.transcriptions.calls), 2) self.assertEqual(len(second_client.audio.transcriptions.calls), 1) def test_transcribe_raises_after_retry_budget_is_exhausted(self) -> None: provider = GroqTranscribeProvider() task = Task("task-1", "local_file", "/tmp/input.mp4", "demo", "created", "2026-01-01T00:00:00+00:00", "2026-01-01T00:00:00+00:00") with tempfile.TemporaryDirectory() as tmpdir: work_dir = Path(tmpdir) source_path = work_dir / "input.mp4" source_path.write_bytes(b"video") source_video = Artifact(None, task.id, "source_video", str(source_path), "{}", "2026-01-01T00:00:00+00:00") segment = work_dir / "temp_audio" / "part_000.mp3" def fake_extract_audio_segments(**kwargs): # noqa: ANN003 segment.parent.mkdir(parents=True, exist_ok=True) segment.write_bytes(b"audio") client = _FakeGroqClient([RuntimeError("Connection error."), RuntimeError("Connection error.")]) settings = { "groq_api_key": "gsk_test", "ffmpeg_bin": "ffmpeg", "max_file_size_mb": 23, "request_timeout_seconds": 20, "request_max_retries": 1, "request_retry_backoff_seconds": 0, "serialize_groq_requests": False, } with patch("groq.Groq", return_value=client): with patch.object(provider, "_extract_audio_segments", side_effect=fake_extract_audio_segments): with self.assertRaises(ModuleError) as exc_info: provider.transcribe(task, source_video, settings) self.assertEqual(exc_info.exception.message, "Groq 转录失败: part_000.mp3") def test_initial_segment_duration_keeps_safety_margin(self) -> None: self.assertLess(GroqTranscribeProvider._initial_segment_duration(12), 1536) def test_extract_audio_segments_retries_when_segment_exceeds_size_limit(self) -> None: provider = GroqTranscribeProvider() with tempfile.TemporaryDirectory() as tmpdir: work_dir = Path(tmpdir) temp_audio_dir = work_dir / "temp_audio" temp_audio_dir.mkdir() output_pattern = temp_audio_dir / "part_%03d.mp3" durations: list[int] = [] def fake_extract_audio_segments(**kwargs): # noqa: ANN003 durations.append(int(kwargs["segment_duration"])) size = 20 if len(durations) == 1 else 5 (temp_audio_dir / "part_000.mp3").write_bytes(b"x" * size) with patch.object(provider, "_extract_audio_segments", side_effect=fake_extract_audio_segments): result = provider._extract_audio_segments_with_size_guard( ffmpeg_bin="ffmpeg", source_path=work_dir / "input.mp4", output_pattern=output_pattern, temp_audio_dir=temp_audio_dir, initial_segment_duration=100, max_segment_bytes=10, ) self.assertEqual(durations, [100, 75]) self.assertEqual(result, 75) if __name__ == "__main__": unittest.main()