feat: package docker deployment and publish flow
This commit is contained in:
277
tests/test_groq_transcribe_provider.py
Normal file
277
tests/test_groq_transcribe_provider.py
Normal file
@ -0,0 +1,277 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import tempfile
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import patch
|
||||
|
||||
from biliup_next.core.errors import ModuleError
|
||||
from biliup_next.core.models import Artifact, Task
|
||||
from biliup_next.modules.transcribe.providers.groq import GroqTranscribeProvider
|
||||
|
||||
|
||||
class _FakeResponse:
|
||||
def __init__(self, segments):
|
||||
self.segments = segments
|
||||
|
||||
|
||||
class _FakeTranscriptions:
|
||||
def __init__(self, outcomes: list[object]) -> None:
|
||||
self.outcomes = list(outcomes)
|
||||
self.calls: list[dict[str, object]] = []
|
||||
|
||||
def create(self, **kwargs): # noqa: ANN003
|
||||
self.calls.append(kwargs)
|
||||
outcome = self.outcomes.pop(0)
|
||||
if isinstance(outcome, Exception):
|
||||
raise outcome
|
||||
return outcome
|
||||
|
||||
|
||||
class _FakeGroqClient:
|
||||
def __init__(self, outcomes: list[object]) -> None:
|
||||
self.audio = SimpleNamespace(transcriptions=_FakeTranscriptions(outcomes))
|
||||
|
||||
|
||||
class GroqTranscribeProviderTests(unittest.TestCase):
|
||||
def test_transcribe_retries_timeout_and_writes_srt_atomically(self) -> None:
|
||||
provider = GroqTranscribeProvider()
|
||||
task = Task("task-1", "local_file", "/tmp/input.mp4", "demo", "created", "2026-01-01T00:00:00+00:00", "2026-01-01T00:00:00+00:00")
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
work_dir = Path(tmpdir)
|
||||
source_path = work_dir / "input.mp4"
|
||||
source_path.write_bytes(b"video")
|
||||
source_video = Artifact(None, task.id, "source_video", str(source_path), "{}", "2026-01-01T00:00:00+00:00")
|
||||
segment = work_dir / "temp_audio" / "part_000.mp3"
|
||||
|
||||
def fake_extract_audio_segments(**kwargs): # noqa: ANN003
|
||||
segment.parent.mkdir(parents=True, exist_ok=True)
|
||||
segment.write_bytes(b"audio")
|
||||
|
||||
client = _FakeGroqClient(
|
||||
[
|
||||
RuntimeError("Request timed out."),
|
||||
_FakeResponse([{"start": 0, "end": 1.2, "text": "hello"}]),
|
||||
]
|
||||
)
|
||||
|
||||
settings = {
|
||||
"groq_api_key": "gsk_test",
|
||||
"ffmpeg_bin": "ffmpeg",
|
||||
"max_file_size_mb": 23,
|
||||
"request_timeout_seconds": 33,
|
||||
"request_max_retries": 1,
|
||||
"request_retry_backoff_seconds": 0,
|
||||
"serialize_groq_requests": False,
|
||||
}
|
||||
|
||||
with patch("groq.Groq", return_value=client) as groq_ctor:
|
||||
with patch.object(provider, "_extract_audio_segments", side_effect=fake_extract_audio_segments):
|
||||
artifact = provider.transcribe(task, source_video, settings)
|
||||
|
||||
self.assertEqual(Path(artifact.path).read_text(encoding="utf-8"), "1\n00:00:00,000 --> 00:00:01,199\nhello\n\n")
|
||||
self.assertFalse((work_dir / ".demo.srt.tmp").exists())
|
||||
self.assertEqual(len(client.audio.transcriptions.calls), 2)
|
||||
self.assertEqual(client.audio.transcriptions.calls[0]["timeout"], 33)
|
||||
self.assertTrue((work_dir / "transcribe_segments" / "part_000.json").exists())
|
||||
groq_ctor.assert_called_once_with(api_key="gsk_test", timeout=33, max_retries=0)
|
||||
|
||||
def test_transcribe_reuses_completed_segment_checkpoints(self) -> None:
|
||||
provider = GroqTranscribeProvider()
|
||||
task = Task("task-1", "local_file", "/tmp/input.mp4", "demo", "created", "2026-01-01T00:00:00+00:00", "2026-01-01T00:00:00+00:00")
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
work_dir = Path(tmpdir)
|
||||
source_path = work_dir / "input.mp4"
|
||||
source_path.write_bytes(b"video")
|
||||
source_video = Artifact(None, task.id, "source_video", str(source_path), "{}", "2026-01-01T00:00:00+00:00")
|
||||
segments = [work_dir / "temp_audio" / "part_000.mp3", work_dir / "temp_audio" / "part_001.mp3"]
|
||||
checkpoint_dir = work_dir / "transcribe_segments"
|
||||
checkpoint_dir.mkdir()
|
||||
(checkpoint_dir / "part_000.json").write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"provider": "groq",
|
||||
"model": "whisper-large-v3-turbo",
|
||||
"language": "zh",
|
||||
"audio_file": "part_000.mp3",
|
||||
"segment_duration_seconds": 75,
|
||||
"segments": [{"start": 0, "end": 1, "text": "first"}],
|
||||
}
|
||||
),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
def fake_extract_audio_segments(**kwargs): # noqa: ANN003
|
||||
for segment in segments:
|
||||
segment.parent.mkdir(parents=True, exist_ok=True)
|
||||
segment.write_bytes(b"audio")
|
||||
|
||||
client = _FakeGroqClient([_FakeResponse([{"start": 0, "end": 1.5, "text": "second"}])])
|
||||
settings = {
|
||||
"groq_api_key": "gsk_test",
|
||||
"ffmpeg_bin": "ffmpeg",
|
||||
"max_file_size_mb": 23,
|
||||
"request_timeout_seconds": 33,
|
||||
"request_max_retries": 1,
|
||||
"request_retry_backoff_seconds": 0,
|
||||
"serialize_groq_requests": False,
|
||||
}
|
||||
|
||||
with patch("groq.Groq", return_value=client):
|
||||
with patch.object(provider, "_initial_segment_duration", return_value=75):
|
||||
with patch.object(provider, "_extract_audio_segments", side_effect=fake_extract_audio_segments):
|
||||
artifact = provider.transcribe(task, source_video, settings)
|
||||
|
||||
srt = Path(artifact.path).read_text(encoding="utf-8")
|
||||
self.assertIn("00:00:00,000 --> 00:00:01,000\nfirst", srt)
|
||||
self.assertIn("00:01:15,000 --> 00:01:16,500\nsecond", srt)
|
||||
self.assertEqual(len(client.audio.transcriptions.calls), 1)
|
||||
self.assertEqual(client.audio.transcriptions.calls[0]["file"][0], "part_001.mp3")
|
||||
self.assertTrue((checkpoint_dir / "part_001.json").exists())
|
||||
|
||||
def test_transcribe_switches_to_next_api_key_on_rate_limit(self) -> None:
|
||||
provider = GroqTranscribeProvider()
|
||||
task = Task("task-1", "local_file", "/tmp/input.mp4", "demo", "created", "2026-01-01T00:00:00+00:00", "2026-01-01T00:00:00+00:00")
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
work_dir = Path(tmpdir)
|
||||
source_path = work_dir / "input.mp4"
|
||||
source_path.write_bytes(b"video")
|
||||
source_video = Artifact(None, task.id, "source_video", str(source_path), "{}", "2026-01-01T00:00:00+00:00")
|
||||
segment = work_dir / "temp_audio" / "part_000.mp3"
|
||||
|
||||
def fake_extract_audio_segments(**kwargs): # noqa: ANN003
|
||||
segment.parent.mkdir(parents=True, exist_ok=True)
|
||||
segment.write_bytes(b"audio")
|
||||
|
||||
limited_client = _FakeGroqClient([RuntimeError("Error code: 429 rate_limit")])
|
||||
fallback_client = _FakeGroqClient([_FakeResponse([{"start": 0, "end": 1.2, "text": "fallback"}])])
|
||||
settings = {
|
||||
"groq_api_key": "",
|
||||
"groq_api_keys": ["gsk_first", "gsk_second"],
|
||||
"ffmpeg_bin": "ffmpeg",
|
||||
"max_file_size_mb": 23,
|
||||
"request_timeout_seconds": 20,
|
||||
"request_max_retries": 0,
|
||||
"request_retry_backoff_seconds": 0,
|
||||
"serialize_groq_requests": False,
|
||||
}
|
||||
|
||||
with patch("groq.Groq", side_effect=[limited_client, fallback_client]) as groq_ctor:
|
||||
with patch.object(provider, "_extract_audio_segments", side_effect=fake_extract_audio_segments):
|
||||
artifact = provider.transcribe(task, source_video, settings)
|
||||
|
||||
self.assertIn("fallback", Path(artifact.path).read_text(encoding="utf-8"))
|
||||
self.assertEqual(len(limited_client.audio.transcriptions.calls), 1)
|
||||
self.assertEqual(len(fallback_client.audio.transcriptions.calls), 1)
|
||||
self.assertEqual([call.kwargs["api_key"] for call in groq_ctor.call_args_list], ["gsk_first", "gsk_second"])
|
||||
|
||||
def test_transcribe_waits_after_all_api_keys_are_rate_limited(self) -> None:
|
||||
provider = GroqTranscribeProvider()
|
||||
task = Task("task-1", "local_file", "/tmp/input.mp4", "demo", "created", "2026-01-01T00:00:00+00:00", "2026-01-01T00:00:00+00:00")
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
work_dir = Path(tmpdir)
|
||||
source_path = work_dir / "input.mp4"
|
||||
source_path.write_bytes(b"video")
|
||||
source_video = Artifact(None, task.id, "source_video", str(source_path), "{}", "2026-01-01T00:00:00+00:00")
|
||||
segment = work_dir / "temp_audio" / "part_000.mp3"
|
||||
|
||||
def fake_extract_audio_segments(**kwargs): # noqa: ANN003
|
||||
segment.parent.mkdir(parents=True, exist_ok=True)
|
||||
segment.write_bytes(b"audio")
|
||||
|
||||
first_client = _FakeGroqClient([RuntimeError("429 rate_limit"), _FakeResponse([{"start": 0, "end": 1, "text": "retry ok"}])])
|
||||
second_client = _FakeGroqClient([RuntimeError("429 rate_limit")])
|
||||
settings = {
|
||||
"groq_api_key": "",
|
||||
"groq_api_keys": ["gsk_first", "gsk_second"],
|
||||
"ffmpeg_bin": "ffmpeg",
|
||||
"max_file_size_mb": 23,
|
||||
"request_timeout_seconds": 20,
|
||||
"request_max_retries": 1,
|
||||
"request_retry_backoff_seconds": 7,
|
||||
"serialize_groq_requests": False,
|
||||
}
|
||||
|
||||
with patch("groq.Groq", side_effect=[first_client, second_client]):
|
||||
with patch("time.sleep") as sleep_mock:
|
||||
with patch.object(provider, "_extract_audio_segments", side_effect=fake_extract_audio_segments):
|
||||
artifact = provider.transcribe(task, source_video, settings)
|
||||
|
||||
self.assertIn("retry ok", Path(artifact.path).read_text(encoding="utf-8"))
|
||||
sleep_mock.assert_called_once_with(7)
|
||||
self.assertEqual(len(first_client.audio.transcriptions.calls), 2)
|
||||
self.assertEqual(len(second_client.audio.transcriptions.calls), 1)
|
||||
|
||||
def test_transcribe_raises_after_retry_budget_is_exhausted(self) -> None:
|
||||
provider = GroqTranscribeProvider()
|
||||
task = Task("task-1", "local_file", "/tmp/input.mp4", "demo", "created", "2026-01-01T00:00:00+00:00", "2026-01-01T00:00:00+00:00")
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
work_dir = Path(tmpdir)
|
||||
source_path = work_dir / "input.mp4"
|
||||
source_path.write_bytes(b"video")
|
||||
source_video = Artifact(None, task.id, "source_video", str(source_path), "{}", "2026-01-01T00:00:00+00:00")
|
||||
segment = work_dir / "temp_audio" / "part_000.mp3"
|
||||
|
||||
def fake_extract_audio_segments(**kwargs): # noqa: ANN003
|
||||
segment.parent.mkdir(parents=True, exist_ok=True)
|
||||
segment.write_bytes(b"audio")
|
||||
|
||||
client = _FakeGroqClient([RuntimeError("Connection error."), RuntimeError("Connection error.")])
|
||||
settings = {
|
||||
"groq_api_key": "gsk_test",
|
||||
"ffmpeg_bin": "ffmpeg",
|
||||
"max_file_size_mb": 23,
|
||||
"request_timeout_seconds": 20,
|
||||
"request_max_retries": 1,
|
||||
"request_retry_backoff_seconds": 0,
|
||||
"serialize_groq_requests": False,
|
||||
}
|
||||
|
||||
with patch("groq.Groq", return_value=client):
|
||||
with patch.object(provider, "_extract_audio_segments", side_effect=fake_extract_audio_segments):
|
||||
with self.assertRaises(ModuleError) as exc_info:
|
||||
provider.transcribe(task, source_video, settings)
|
||||
|
||||
self.assertEqual(exc_info.exception.message, "Groq 转录失败: part_000.mp3")
|
||||
|
||||
def test_initial_segment_duration_keeps_safety_margin(self) -> None:
|
||||
self.assertLess(GroqTranscribeProvider._initial_segment_duration(12), 1536)
|
||||
|
||||
def test_extract_audio_segments_retries_when_segment_exceeds_size_limit(self) -> None:
|
||||
provider = GroqTranscribeProvider()
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
work_dir = Path(tmpdir)
|
||||
temp_audio_dir = work_dir / "temp_audio"
|
||||
temp_audio_dir.mkdir()
|
||||
output_pattern = temp_audio_dir / "part_%03d.mp3"
|
||||
durations: list[int] = []
|
||||
|
||||
def fake_extract_audio_segments(**kwargs): # noqa: ANN003
|
||||
durations.append(int(kwargs["segment_duration"]))
|
||||
size = 20 if len(durations) == 1 else 5
|
||||
(temp_audio_dir / "part_000.mp3").write_bytes(b"x" * size)
|
||||
|
||||
with patch.object(provider, "_extract_audio_segments", side_effect=fake_extract_audio_segments):
|
||||
result = provider._extract_audio_segments_with_size_guard(
|
||||
ffmpeg_bin="ffmpeg",
|
||||
source_path=work_dir / "input.mp4",
|
||||
output_pattern=output_pattern,
|
||||
temp_audio_dir=temp_audio_dir,
|
||||
initial_segment_duration=100,
|
||||
max_segment_bytes=10,
|
||||
)
|
||||
|
||||
self.assertEqual(durations, [100, 75])
|
||||
self.assertEqual(result, 75)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user