278 lines
13 KiB
Python
278 lines
13 KiB
Python
from __future__ import annotations
|
|
|
|
import json
|
|
import tempfile
|
|
import unittest
|
|
from pathlib import Path
|
|
from types import SimpleNamespace
|
|
from unittest.mock import patch
|
|
|
|
from biliup_next.core.errors import ModuleError
|
|
from biliup_next.core.models import Artifact, Task
|
|
from biliup_next.modules.transcribe.providers.groq import GroqTranscribeProvider
|
|
|
|
|
|
class _FakeResponse:
|
|
def __init__(self, segments):
|
|
self.segments = segments
|
|
|
|
|
|
class _FakeTranscriptions:
|
|
def __init__(self, outcomes: list[object]) -> None:
|
|
self.outcomes = list(outcomes)
|
|
self.calls: list[dict[str, object]] = []
|
|
|
|
def create(self, **kwargs): # noqa: ANN003
|
|
self.calls.append(kwargs)
|
|
outcome = self.outcomes.pop(0)
|
|
if isinstance(outcome, Exception):
|
|
raise outcome
|
|
return outcome
|
|
|
|
|
|
class _FakeGroqClient:
|
|
def __init__(self, outcomes: list[object]) -> None:
|
|
self.audio = SimpleNamespace(transcriptions=_FakeTranscriptions(outcomes))
|
|
|
|
|
|
class GroqTranscribeProviderTests(unittest.TestCase):
|
|
def test_transcribe_retries_timeout_and_writes_srt_atomically(self) -> None:
|
|
provider = GroqTranscribeProvider()
|
|
task = Task("task-1", "local_file", "/tmp/input.mp4", "demo", "created", "2026-01-01T00:00:00+00:00", "2026-01-01T00:00:00+00:00")
|
|
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
work_dir = Path(tmpdir)
|
|
source_path = work_dir / "input.mp4"
|
|
source_path.write_bytes(b"video")
|
|
source_video = Artifact(None, task.id, "source_video", str(source_path), "{}", "2026-01-01T00:00:00+00:00")
|
|
segment = work_dir / "temp_audio" / "part_000.mp3"
|
|
|
|
def fake_extract_audio_segments(**kwargs): # noqa: ANN003
|
|
segment.parent.mkdir(parents=True, exist_ok=True)
|
|
segment.write_bytes(b"audio")
|
|
|
|
client = _FakeGroqClient(
|
|
[
|
|
RuntimeError("Request timed out."),
|
|
_FakeResponse([{"start": 0, "end": 1.2, "text": "hello"}]),
|
|
]
|
|
)
|
|
|
|
settings = {
|
|
"groq_api_key": "gsk_test",
|
|
"ffmpeg_bin": "ffmpeg",
|
|
"max_file_size_mb": 23,
|
|
"request_timeout_seconds": 33,
|
|
"request_max_retries": 1,
|
|
"request_retry_backoff_seconds": 0,
|
|
"serialize_groq_requests": False,
|
|
}
|
|
|
|
with patch("groq.Groq", return_value=client) as groq_ctor:
|
|
with patch.object(provider, "_extract_audio_segments", side_effect=fake_extract_audio_segments):
|
|
artifact = provider.transcribe(task, source_video, settings)
|
|
|
|
self.assertEqual(Path(artifact.path).read_text(encoding="utf-8"), "1\n00:00:00,000 --> 00:00:01,199\nhello\n\n")
|
|
self.assertFalse((work_dir / ".demo.srt.tmp").exists())
|
|
self.assertEqual(len(client.audio.transcriptions.calls), 2)
|
|
self.assertEqual(client.audio.transcriptions.calls[0]["timeout"], 33)
|
|
self.assertTrue((work_dir / "transcribe_segments" / "part_000.json").exists())
|
|
groq_ctor.assert_called_once_with(api_key="gsk_test", timeout=33, max_retries=0)
|
|
|
|
def test_transcribe_reuses_completed_segment_checkpoints(self) -> None:
|
|
provider = GroqTranscribeProvider()
|
|
task = Task("task-1", "local_file", "/tmp/input.mp4", "demo", "created", "2026-01-01T00:00:00+00:00", "2026-01-01T00:00:00+00:00")
|
|
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
work_dir = Path(tmpdir)
|
|
source_path = work_dir / "input.mp4"
|
|
source_path.write_bytes(b"video")
|
|
source_video = Artifact(None, task.id, "source_video", str(source_path), "{}", "2026-01-01T00:00:00+00:00")
|
|
segments = [work_dir / "temp_audio" / "part_000.mp3", work_dir / "temp_audio" / "part_001.mp3"]
|
|
checkpoint_dir = work_dir / "transcribe_segments"
|
|
checkpoint_dir.mkdir()
|
|
(checkpoint_dir / "part_000.json").write_text(
|
|
json.dumps(
|
|
{
|
|
"provider": "groq",
|
|
"model": "whisper-large-v3-turbo",
|
|
"language": "zh",
|
|
"audio_file": "part_000.mp3",
|
|
"segment_duration_seconds": 75,
|
|
"segments": [{"start": 0, "end": 1, "text": "first"}],
|
|
}
|
|
),
|
|
encoding="utf-8",
|
|
)
|
|
|
|
def fake_extract_audio_segments(**kwargs): # noqa: ANN003
|
|
for segment in segments:
|
|
segment.parent.mkdir(parents=True, exist_ok=True)
|
|
segment.write_bytes(b"audio")
|
|
|
|
client = _FakeGroqClient([_FakeResponse([{"start": 0, "end": 1.5, "text": "second"}])])
|
|
settings = {
|
|
"groq_api_key": "gsk_test",
|
|
"ffmpeg_bin": "ffmpeg",
|
|
"max_file_size_mb": 23,
|
|
"request_timeout_seconds": 33,
|
|
"request_max_retries": 1,
|
|
"request_retry_backoff_seconds": 0,
|
|
"serialize_groq_requests": False,
|
|
}
|
|
|
|
with patch("groq.Groq", return_value=client):
|
|
with patch.object(provider, "_initial_segment_duration", return_value=75):
|
|
with patch.object(provider, "_extract_audio_segments", side_effect=fake_extract_audio_segments):
|
|
artifact = provider.transcribe(task, source_video, settings)
|
|
|
|
srt = Path(artifact.path).read_text(encoding="utf-8")
|
|
self.assertIn("00:00:00,000 --> 00:00:01,000\nfirst", srt)
|
|
self.assertIn("00:01:15,000 --> 00:01:16,500\nsecond", srt)
|
|
self.assertEqual(len(client.audio.transcriptions.calls), 1)
|
|
self.assertEqual(client.audio.transcriptions.calls[0]["file"][0], "part_001.mp3")
|
|
self.assertTrue((checkpoint_dir / "part_001.json").exists())
|
|
|
|
def test_transcribe_switches_to_next_api_key_on_rate_limit(self) -> None:
|
|
provider = GroqTranscribeProvider()
|
|
task = Task("task-1", "local_file", "/tmp/input.mp4", "demo", "created", "2026-01-01T00:00:00+00:00", "2026-01-01T00:00:00+00:00")
|
|
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
work_dir = Path(tmpdir)
|
|
source_path = work_dir / "input.mp4"
|
|
source_path.write_bytes(b"video")
|
|
source_video = Artifact(None, task.id, "source_video", str(source_path), "{}", "2026-01-01T00:00:00+00:00")
|
|
segment = work_dir / "temp_audio" / "part_000.mp3"
|
|
|
|
def fake_extract_audio_segments(**kwargs): # noqa: ANN003
|
|
segment.parent.mkdir(parents=True, exist_ok=True)
|
|
segment.write_bytes(b"audio")
|
|
|
|
limited_client = _FakeGroqClient([RuntimeError("Error code: 429 rate_limit")])
|
|
fallback_client = _FakeGroqClient([_FakeResponse([{"start": 0, "end": 1.2, "text": "fallback"}])])
|
|
settings = {
|
|
"groq_api_key": "",
|
|
"groq_api_keys": ["gsk_first", "gsk_second"],
|
|
"ffmpeg_bin": "ffmpeg",
|
|
"max_file_size_mb": 23,
|
|
"request_timeout_seconds": 20,
|
|
"request_max_retries": 0,
|
|
"request_retry_backoff_seconds": 0,
|
|
"serialize_groq_requests": False,
|
|
}
|
|
|
|
with patch("groq.Groq", side_effect=[limited_client, fallback_client]) as groq_ctor:
|
|
with patch.object(provider, "_extract_audio_segments", side_effect=fake_extract_audio_segments):
|
|
artifact = provider.transcribe(task, source_video, settings)
|
|
|
|
self.assertIn("fallback", Path(artifact.path).read_text(encoding="utf-8"))
|
|
self.assertEqual(len(limited_client.audio.transcriptions.calls), 1)
|
|
self.assertEqual(len(fallback_client.audio.transcriptions.calls), 1)
|
|
self.assertEqual([call.kwargs["api_key"] for call in groq_ctor.call_args_list], ["gsk_first", "gsk_second"])
|
|
|
|
def test_transcribe_waits_after_all_api_keys_are_rate_limited(self) -> None:
|
|
provider = GroqTranscribeProvider()
|
|
task = Task("task-1", "local_file", "/tmp/input.mp4", "demo", "created", "2026-01-01T00:00:00+00:00", "2026-01-01T00:00:00+00:00")
|
|
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
work_dir = Path(tmpdir)
|
|
source_path = work_dir / "input.mp4"
|
|
source_path.write_bytes(b"video")
|
|
source_video = Artifact(None, task.id, "source_video", str(source_path), "{}", "2026-01-01T00:00:00+00:00")
|
|
segment = work_dir / "temp_audio" / "part_000.mp3"
|
|
|
|
def fake_extract_audio_segments(**kwargs): # noqa: ANN003
|
|
segment.parent.mkdir(parents=True, exist_ok=True)
|
|
segment.write_bytes(b"audio")
|
|
|
|
first_client = _FakeGroqClient([RuntimeError("429 rate_limit"), _FakeResponse([{"start": 0, "end": 1, "text": "retry ok"}])])
|
|
second_client = _FakeGroqClient([RuntimeError("429 rate_limit")])
|
|
settings = {
|
|
"groq_api_key": "",
|
|
"groq_api_keys": ["gsk_first", "gsk_second"],
|
|
"ffmpeg_bin": "ffmpeg",
|
|
"max_file_size_mb": 23,
|
|
"request_timeout_seconds": 20,
|
|
"request_max_retries": 1,
|
|
"request_retry_backoff_seconds": 7,
|
|
"serialize_groq_requests": False,
|
|
}
|
|
|
|
with patch("groq.Groq", side_effect=[first_client, second_client]):
|
|
with patch("time.sleep") as sleep_mock:
|
|
with patch.object(provider, "_extract_audio_segments", side_effect=fake_extract_audio_segments):
|
|
artifact = provider.transcribe(task, source_video, settings)
|
|
|
|
self.assertIn("retry ok", Path(artifact.path).read_text(encoding="utf-8"))
|
|
sleep_mock.assert_called_once_with(7)
|
|
self.assertEqual(len(first_client.audio.transcriptions.calls), 2)
|
|
self.assertEqual(len(second_client.audio.transcriptions.calls), 1)
|
|
|
|
def test_transcribe_raises_after_retry_budget_is_exhausted(self) -> None:
|
|
provider = GroqTranscribeProvider()
|
|
task = Task("task-1", "local_file", "/tmp/input.mp4", "demo", "created", "2026-01-01T00:00:00+00:00", "2026-01-01T00:00:00+00:00")
|
|
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
work_dir = Path(tmpdir)
|
|
source_path = work_dir / "input.mp4"
|
|
source_path.write_bytes(b"video")
|
|
source_video = Artifact(None, task.id, "source_video", str(source_path), "{}", "2026-01-01T00:00:00+00:00")
|
|
segment = work_dir / "temp_audio" / "part_000.mp3"
|
|
|
|
def fake_extract_audio_segments(**kwargs): # noqa: ANN003
|
|
segment.parent.mkdir(parents=True, exist_ok=True)
|
|
segment.write_bytes(b"audio")
|
|
|
|
client = _FakeGroqClient([RuntimeError("Connection error."), RuntimeError("Connection error.")])
|
|
settings = {
|
|
"groq_api_key": "gsk_test",
|
|
"ffmpeg_bin": "ffmpeg",
|
|
"max_file_size_mb": 23,
|
|
"request_timeout_seconds": 20,
|
|
"request_max_retries": 1,
|
|
"request_retry_backoff_seconds": 0,
|
|
"serialize_groq_requests": False,
|
|
}
|
|
|
|
with patch("groq.Groq", return_value=client):
|
|
with patch.object(provider, "_extract_audio_segments", side_effect=fake_extract_audio_segments):
|
|
with self.assertRaises(ModuleError) as exc_info:
|
|
provider.transcribe(task, source_video, settings)
|
|
|
|
self.assertEqual(exc_info.exception.message, "Groq 转录失败: part_000.mp3")
|
|
|
|
def test_initial_segment_duration_keeps_safety_margin(self) -> None:
|
|
self.assertLess(GroqTranscribeProvider._initial_segment_duration(12), 1536)
|
|
|
|
def test_extract_audio_segments_retries_when_segment_exceeds_size_limit(self) -> None:
|
|
provider = GroqTranscribeProvider()
|
|
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
work_dir = Path(tmpdir)
|
|
temp_audio_dir = work_dir / "temp_audio"
|
|
temp_audio_dir.mkdir()
|
|
output_pattern = temp_audio_dir / "part_%03d.mp3"
|
|
durations: list[int] = []
|
|
|
|
def fake_extract_audio_segments(**kwargs): # noqa: ANN003
|
|
durations.append(int(kwargs["segment_duration"]))
|
|
size = 20 if len(durations) == 1 else 5
|
|
(temp_audio_dir / "part_000.mp3").write_bytes(b"x" * size)
|
|
|
|
with patch.object(provider, "_extract_audio_segments", side_effect=fake_extract_audio_segments):
|
|
result = provider._extract_audio_segments_with_size_guard(
|
|
ffmpeg_bin="ffmpeg",
|
|
source_path=work_dir / "input.mp4",
|
|
output_pattern=output_pattern,
|
|
temp_audio_dir=temp_audio_dir,
|
|
initial_segment_duration=100,
|
|
max_segment_bytes=10,
|
|
)
|
|
|
|
self.assertEqual(durations, [100, 75])
|
|
self.assertEqual(result, 75)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|