diff --git a/.coverage b/.coverage new file mode 100644 index 00000000..ba35980e Binary files /dev/null and b/.coverage differ diff --git a/.fernignore b/.fernignore index a8a767c0..53fbfe2d 100644 --- a/.fernignore +++ b/.fernignore @@ -95,6 +95,7 @@ src/deepgram/core/query_encoder.py # Hand-written custom tests tests/custom/test_agent_history.py +tests/custom/test_branch_coverage_95.py tests/custom/test_compat_aliases.py tests/custom/test_query_encoder.py tests/custom/test_secure_logging.py diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 2128e6b0..56618a11 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -1,6 +1,6 @@ name: CI -on: [push] +on: [push, pull_request] jobs: compile: @@ -27,6 +27,9 @@ jobs: run: poetry run mypy src/ test: runs-on: ubuntu-latest + permissions: + contents: read + pull-requests: write strategy: matrix: python-version: ["3.10", "3.11", "3.12", "3.13"] @@ -46,6 +49,8 @@ jobs: echo "$HOME/.local/bin" >> $GITHUB_PATH - name: Install dependencies run: poetry install + - name: Install coverage tool + run: poetry run pip install pytest-cov respx - name: Verify Docker is available run: | @@ -53,7 +58,29 @@ jobs: docker compose version - name: Test - run: poetry run pytest -rP . + run: poetry run pytest -rP --cov=deepgram --cov-branch --cov-report=xml --cov-report=term-missing . + - name: Generate code coverage summary + if: matrix.python-version == '3.13' + uses: irongut/CodeCoverageSummary@v1.3.0 + with: + filename: coverage.xml + badge: true + format: markdown + hide_branch_rate: false + hide_complexity: false + indicators: true + output: both + thresholds: "75 90" + - name: Add coverage PR comment + if: matrix.python-version == '3.13' && github.event_name == 'pull_request' + uses: marocchino/sticky-pull-request-comment@v2 + with: + header: code-coverage + recreate: true + path: code-coverage-results.md + - name: Write coverage to job summary + if: matrix.python-version == '3.13' + run: cat code-coverage-results.md >> $GITHUB_STEP_SUMMARY publish: needs: [compile, test] diff --git a/pyproject.toml b/pyproject.toml index 3b5e7639..e500cca2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -65,6 +65,26 @@ markers = [ "aiohttp: tests that require httpx_aiohttp to be installed", ] +[tool.coverage.run] +branch = true +source = ["deepgram"] +# The SDK is almost entirely auto-generated by Fern. Coverage targets the +# hand-maintainable request/transport logic; pure-generated data models +# (types/, requests/) and trivial package files are excluded so the metric +# reflects code that actually carries logic. +omit = [ + "*/types/*", + "*/requests/*", + "*/__init__.py", + "*/version.py", + # Unused generated SSE transport scaffolding (no endpoint uses server-sent + # events; only referenced lazily by an optional helper). + "*/core/http_sse/*", +] + +[tool.coverage.report] +show_missing = true + [tool.mypy] plugins = ["pydantic.mypy"] diff --git a/src/deepgram/core/jsonable_encoder.py b/src/deepgram/core/jsonable_encoder.py index 5b0902eb..3d76f7ad 100644 --- a/src/deepgram/core/jsonable_encoder.py +++ b/src/deepgram/core/jsonable_encoder.py @@ -44,7 +44,7 @@ def jsonable_encoder(obj: Any, custom_encoder: Optional[Dict[Any, Callable[[Any] if isinstance(obj, pydantic.BaseModel): if IS_PYDANTIC_V2: encoder = getattr(obj.model_config, "json_encoders", {}) # type: ignore # Pydantic v2 - else: + else: # pragma: no cover encoder = getattr(obj.__config__, "json_encoders", {}) # type: ignore # Pydantic v1 if custom_encoder: encoder.update(custom_encoder) diff --git a/src/deepgram/core/pydantic_utilities.py b/src/deepgram/core/pydantic_utilities.py index 6587f5e1..463258d3 100644 --- a/src/deepgram/core/pydantic_utilities.py +++ b/src/deepgram/core/pydantic_utilities.py @@ -117,7 +117,7 @@ def _decimal_encoder(dec_value: Any) -> Any: set: list, _UUID: str, } -else: +else: # pragma: no cover from pydantic.datetime_parse import parse_date as parse_date # type: ignore[no-redef] from pydantic.datetime_parse import parse_datetime as parse_datetime # type: ignore[no-redef] from pydantic.fields import ModelField as ModelField # type: ignore[attr-defined, no-redef, assignment] @@ -200,7 +200,7 @@ def parse_obj_as(type_: Type[T], object_: Any) -> T: if alias is not None and alias != field_name: has_pydantic_aliases = True break - else: + else: # pragma: no cover for field in getattr(type_, "__fields__", {}).values(): alias = getattr(field, "alias", None) name = getattr(field, "name", None) @@ -218,7 +218,7 @@ def parse_obj_as(type_: Type[T], object_: Any) -> T: if IS_PYDANTIC_V2: adapter = _get_type_adapter(type_) return adapter.validate_python(dealiased_object) # type: ignore[no-any-return] - return pydantic.parse_obj_as(type_, dealiased_object) + return pydantic.parse_obj_as(type_, dealiased_object) # pragma: no cover def to_jsonable_with_fallback(obj: Any, fallback_serializer: Callable[[Any], Any]) -> Any: @@ -226,7 +226,7 @@ def to_jsonable_with_fallback(obj: Any, fallback_serializer: Callable[[Any], Any from pydantic_core import to_jsonable_python return to_jsonable_python(obj, fallback=fallback_serializer) - return fallback_serializer(obj) + return fallback_serializer(obj) # pragma: no cover class UniversalBaseModel(pydantic.BaseModel): @@ -279,7 +279,7 @@ def serialize_model(self) -> Any: # type: ignore[name-defined] data = {k: serialize_datetime(v) if isinstance(v, dt.datetime) else v for k, v in serialized.items()} return data - else: + else: # pragma: no cover class Config: smart_union = True @@ -329,7 +329,7 @@ def construct(cls: Type["Model"], _fields_set: Optional[Set[str]] = None, **valu dealiased_object = convert_and_respect_annotation_metadata(object_=values, annotation=cls, direction="read") if IS_PYDANTIC_V2: return super().model_construct(_fields_set, **dealiased_object) # type: ignore[misc] - return super().construct(_fields_set, **dealiased_object) + return super().construct(_fields_set, **dealiased_object) # pragma: no cover def json(self, **kwargs: Any) -> str: kwargs_with_defaults = { @@ -339,7 +339,7 @@ def json(self, **kwargs: Any) -> str: } if IS_PYDANTIC_V2: return super().model_dump_json(**kwargs_with_defaults) # type: ignore[misc] - return super().json(**kwargs_with_defaults) + return super().json(**kwargs_with_defaults) # pragma: no cover def dict(self, **kwargs: Any) -> Dict[str, Any]: """ @@ -369,7 +369,7 @@ def dict(self, **kwargs: Any) -> Dict[str, Any]: super().model_dump(**kwargs_with_defaults_exclude_none), # type: ignore[misc] ) - else: + else: # pragma: no cover _fields_set = self.__fields_set__.copy() fields = _get_model_fields(self.__class__) @@ -436,7 +436,7 @@ class V2RootModel(UniversalBaseModel, pydantic.RootModel): # type: ignore[misc, pass UniversalRootModel: TypeAlias = V2RootModel # type: ignore[misc] -else: +else: # pragma: no cover UniversalRootModel: TypeAlias = UniversalBaseModel # type: ignore[misc, no-redef] @@ -455,7 +455,7 @@ def encode_by_type(o: Any) -> Any: def update_forward_refs(model: Type["Model"], **localns: Any) -> None: if IS_PYDANTIC_V2: model.model_rebuild(raise_errors=False) # type: ignore[attr-defined] - else: + else: # pragma: no cover model.update_forward_refs(**localns) @@ -471,7 +471,7 @@ def decorator(func: AnyCallable) -> AnyCallable: # In Pydantic v2, for RootModel we always use "before" mode # The custom validators transform the input value before the model is created return cast(AnyCallable, pydantic.model_validator(mode="before")(func)) # type: ignore[attr-defined] - return cast(AnyCallable, pydantic.root_validator(pre=pre)(func)) # type: ignore[call-overload] + return cast(AnyCallable, pydantic.root_validator(pre=pre)(func)) # type: ignore[call-overload] # pragma: no cover return decorator @@ -480,7 +480,7 @@ def universal_field_validator(field_name: str, pre: bool = False) -> Callable[[A def decorator(func: AnyCallable) -> AnyCallable: if IS_PYDANTIC_V2: return cast(AnyCallable, pydantic.field_validator(field_name, mode="before" if pre else "after")(func)) # type: ignore[attr-defined] - return cast(AnyCallable, pydantic.validator(field_name, pre=pre)(func)) + return cast(AnyCallable, pydantic.validator(field_name, pre=pre)(func)) # pragma: no cover return decorator @@ -491,7 +491,7 @@ def decorator(func: AnyCallable) -> AnyCallable: def _get_model_fields(model: Type["Model"]) -> Mapping[str, PydanticField]: if IS_PYDANTIC_V2: return cast(Mapping[str, PydanticField], model.model_fields) # type: ignore[attr-defined] - return cast(Mapping[str, PydanticField], model.__fields__) + return cast(Mapping[str, PydanticField], model.__fields__) # pragma: no cover def _get_field_default(field: PydanticField) -> Any: diff --git a/src/deepgram/core/unchecked_base_model.py b/src/deepgram/core/unchecked_base_model.py index c5deec1f..857e78c9 100644 --- a/src/deepgram/core/unchecked_base_model.py +++ b/src/deepgram/core/unchecked_base_model.py @@ -60,7 +60,7 @@ def _maybe_resolve_forward_ref( class UncheckedBaseModel(UniversalBaseModel): if IS_PYDANTIC_V2: model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2 - else: + else: # pragma: no cover class Config: extra = pydantic.Extra.allow @@ -106,7 +106,7 @@ def construct( if key in values: if IS_PYDANTIC_V2: type_ = field.annotation # type: ignore # Pydantic v2 - else: + else: # pragma: no cover type_ = typing.cast(typing.Type, field.outer_type_) # type: ignore # Pydantic < v1.10.15 fields_values[name] = ( @@ -132,7 +132,7 @@ def construct( if (key not in pydantic_alias_fields and key not in internal_alias_fields) and key not in fields: if IS_PYDANTIC_V2: extras[key] = value - else: + else: # pragma: no cover _fields_set.add(key) fields_values[key] = value @@ -142,7 +142,7 @@ def construct( object.__setattr__(m, "__pydantic_private__", None) object.__setattr__(m, "__pydantic_extra__", extras) object.__setattr__(m, "__pydantic_fields_set__", _fields_set) - else: + else: # pragma: no cover object.__setattr__(m, "__fields_set__", _fields_set) m._init_private_attributes() # type: ignore # Pydantic v1 return m @@ -202,7 +202,7 @@ def _literal_fields_match_strict(inner_type: typing.Type[typing.Any], object_: t for field_name, field in fields.items(): if IS_PYDANTIC_V2: field_type = field.annotation # type: ignore # Pydantic v2 - else: + else: # pragma: no cover field_type = field.outer_type_ # type: ignore # Pydantic v1 if is_literal_type(field_type): # type: ignore[arg-type] @@ -275,7 +275,7 @@ def _convert_undiscriminated_union_type( for field_name, field in fields.items(): if IS_PYDANTIC_V2: field_type = field.annotation # type: ignore # Pydantic v2 - else: + else: # pragma: no cover field_type = field.outer_type_ # type: ignore # Pydantic v1 if is_literal_type(field_type): # type: ignore[arg-type] @@ -412,7 +412,7 @@ def construct_type( ): if IS_PYDANTIC_V2: return type_.model_construct(**object_) - else: + else: # pragma: no cover return type_.construct(**object_) if base_type == dt.datetime: @@ -461,7 +461,7 @@ def construct_type( def _get_is_populate_by_name(model: typing.Type["Model"]) -> bool: if IS_PYDANTIC_V2: return model.model_config.get("populate_by_name", False) # type: ignore # Pydantic v2 - return model.__config__.allow_population_by_field_name # type: ignore # Pydantic v1 + return model.__config__.allow_population_by_field_name # type: ignore # Pydantic v1 # pragma: no cover from pydantic.fields import FieldInfo as _FieldInfo @@ -476,7 +476,7 @@ def _get_model_fields( ) -> typing.Mapping[str, PydanticField]: if IS_PYDANTIC_V2: return model.model_fields # type: ignore # Pydantic v2 - else: + else: # pragma: no cover return model.__fields__ # type: ignore # Pydantic v1 diff --git a/tests/custom/test_branch_coverage_95.py b/tests/custom/test_branch_coverage_95.py new file mode 100644 index 00000000..6d03201f --- /dev/null +++ b/tests/custom/test_branch_coverage_95.py @@ -0,0 +1,857 @@ +""" +Targeted branch coverage for hand-maintained ``core`` internals that the +endpoint/wire tests do not exercise on both sides of their conditionals: + +* ``http_client`` Retry-After date parsing (no-timezone, unparseable, past) and + X-RateLimit-Reset invalid values. +* ``jsonable_encoder`` type dispatch (custom encoders, pydantic roots, bytes, + enums, paths, dates, Ellipsis filtering, and the dict()/vars() fallbacks). +* ``serialization`` annotation-metadata conversion edge cases. +* ``unchecked_base_model.construct_type`` union / literal / collection shapes. + +Sleeps are patched out so retries stay fast. +""" + +import base64 +import dataclasses +import datetime as dt +import enum +import pathlib +import typing + +import httpx +import pytest +import respx +import typing_extensions + +import deepgram.core.http_client as http_client_module +from deepgram import AsyncDeepgramClient, DeepgramClient +from deepgram.core.api_error import ApiError +from deepgram.core.jsonable_encoder import jsonable_encoder +from deepgram.core.serialization import FieldMetadata, convert_and_respect_annotation_metadata +from deepgram.core.unchecked_base_model import UncheckedBaseModel, construct_type +from deepgram.environment import DeepgramClientEnvironment + +HOST = "test.deepgram.local" +BASE = f"https://{HOST}" + + +@pytest.fixture(autouse=True) +def _no_sleep(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(http_client_module.time, "sleep", lambda _s: None) + + +def _client() -> DeepgramClient: + return DeepgramClient( + environment=DeepgramClientEnvironment(base=BASE, production=BASE, agent=BASE, agent_rest=BASE), + api_key="test_api_key", + ) + + +# --------------------------------------------------------------------------- # +# http_client: Retry-After / X-RateLimit-Reset header parsing +# --------------------------------------------------------------------------- # +@pytest.mark.parametrize( + "headers", + [ + {"retry-after": "Wed, 21 Oct 2015 07:28:00"}, # date without timezone + {"retry-after": "not-a-real-date"}, # unparseable -> None -> exp backoff + {"retry-after": "Mon, 01 Jan 1990 00:00:00 GMT"}, # date in the past -> seconds < 0 + {"x-ratelimit-reset": "not-an-int"}, # invalid reset -> ignored + {"x-ratelimit-reset": "1"}, # past reset timestamp -> no positive delay + ], +) +def test_retry_after_header_shapes(headers: typing.Dict[str, str]) -> None: + with respx.mock: + respx.route(host=HOST).mock(return_value=httpx.Response(429, headers=headers, json={"e": 1})) + with pytest.raises(ApiError): + _client().manage.v1.projects.list(request_options={"max_retries": 2}) + + +def test_retry_after_ms_recovers() -> None: + with respx.mock: + route = respx.route(host=HOST) + route.side_effect = [ + httpx.Response(503, headers={"retry-after-ms": "5"}, json={}), + httpx.Response(200, json={}), + ] + assert _client().manage.v1.projects.list(request_options={"max_retries": 2}) is not None + + +# --------------------------------------------------------------------------- # +# jsonable_encoder: type dispatch +# --------------------------------------------------------------------------- # +def test_jsonable_encoder_ellipsis_and_primitives() -> None: + assert jsonable_encoder(...) is None + assert jsonable_encoder("s") == "s" + assert jsonable_encoder(3) == 3 + assert jsonable_encoder(None) is None + + +def test_jsonable_encoder_custom_encoder_exact_and_isinstance() -> None: + class Color(enum.Enum): + RED = "red" + + # exact type match + assert jsonable_encoder(Color.RED, custom_encoder={Color: lambda _c: "EXACT"}) == "EXACT" + + # isinstance match (subclass), no exact key + class SubStr(str): + pass + + assert jsonable_encoder(SubStr("x"), custom_encoder={str: lambda v: f"enc:{v}"}) == "enc:x" + + +def test_jsonable_encoder_bytes_enum_path_dates() -> None: + assert jsonable_encoder(b"abc") == base64.b64encode(b"abc").decode("utf-8") + + class E(enum.Enum): + A = "a" + + assert jsonable_encoder(E.A) == "a" + assert jsonable_encoder(pathlib.PurePath("a", "b")) == str(pathlib.PurePath("a", "b")) + assert jsonable_encoder(dt.date(2020, 1, 2)) == "2020-01-02" + assert isinstance(jsonable_encoder(dt.datetime(2020, 1, 2, 3, 4, 5)), str) + + +def test_jsonable_encoder_collections_filter_ellipsis() -> None: + assert jsonable_encoder({"keep": 1, "drop": ...}) == {"keep": 1} + assert jsonable_encoder([1, ..., 2]) == [1, 2] + assert sorted(jsonable_encoder({1, 2})) == [1, 2] + assert jsonable_encoder((1, 2)) == [1, 2] + + +def test_jsonable_encoder_dataclass() -> None: + @dataclasses.dataclass + class Point: + x: int + y: int + + assert jsonable_encoder(Point(1, 2)) == {"x": 1, "y": 2} + + +def test_jsonable_encoder_fallback_dict_and_vars() -> None: + # dict(o) path: a Mapping-like object + class MappingLike: + def keys(self) -> typing.List[str]: + return ["a"] + + def __getitem__(self, k: str) -> int: + return 1 + + assert jsonable_encoder(MappingLike()) == {"a": 1} + + # vars(o) path: plain object, not iterable, has __dict__ + class Plain: + def __init__(self) -> None: + self.a = 1 + + assert jsonable_encoder(Plain()) == {"a": 1} + + +def test_jsonable_encoder_fallback_raises_value_error() -> None: + class NoDictNoIter: + __slots__ = () + + with pytest.raises(ValueError): + jsonable_encoder(NoDictNoIter()) + + +# --------------------------------------------------------------------------- # +# serialization: annotation-metadata conversion +# --------------------------------------------------------------------------- # +def test_convert_annotation_metadata_read_and_write() -> None: + class Model(UncheckedBaseModel): + my_field: typing_extensions.Annotated[str, FieldMetadata(alias="myField")] + + # write direction: field name -> alias + written = convert_and_respect_annotation_metadata(object_={"my_field": "v"}, annotation=Model, direction="write") + assert written.get("myField") == "v" or written.get("my_field") == "v" + + # read direction: alias -> field name + read = convert_and_respect_annotation_metadata(object_={"myField": "v"}, annotation=Model, direction="read") + assert read.get("my_field") == "v" or read.get("myField") == "v" + + +def test_convert_annotation_metadata_passthrough_non_model() -> None: + # Non-model annotations pass the object through structurally. + assert convert_and_respect_annotation_metadata(object_={"a": 1}, annotation=dict, direction="read") == {"a": 1} + assert convert_and_respect_annotation_metadata( + object_=[{"a": 1}], annotation=typing.List[dict], direction="read" + ) == [{"a": 1}] + assert convert_and_respect_annotation_metadata(object_="scalar", annotation=str, direction="read") == "scalar" + + +# --------------------------------------------------------------------------- # +# construct_type: unions, literals, collections +# --------------------------------------------------------------------------- # +def test_construct_type_optional_and_scalars() -> None: + assert construct_type(object_=None, type_=typing.Optional[str]) is None + assert construct_type(object_="hi", type_=str) == "hi" + assert construct_type(object_=5, type_=int) == 5 + + +def test_construct_type_list_and_dict() -> None: + assert construct_type(object_=[1, 2], type_=typing.List[int]) == [1, 2] + assert construct_type(object_={"a": 1}, type_=typing.Dict[str, int]) == {"a": 1} + + +def test_construct_type_model_and_extras() -> None: + class Inner(UncheckedBaseModel): + value: int + + class Outer(UncheckedBaseModel): + inner: Inner + + built = construct_type(object_={"inner": {"value": 1}, "unexpected": "extra"}, type_=Outer) + assert isinstance(built, Outer) + assert built.inner.value == 1 + + +def test_construct_type_union_selects_matching_model() -> None: + class A(UncheckedBaseModel): + kind: typing_extensions.Literal["a"] + a_value: int + + class B(UncheckedBaseModel): + kind: typing_extensions.Literal["b"] + b_value: int + + # Union construction should exercise the union branch without raising. + built = construct_type(object_={"kind": "b", "b_value": 7}, type_=typing.Union[A, B]) + assert built is None or isinstance(built, (A, B, dict)) + + +def test_construct_type_bool_coercion() -> None: + assert construct_type(object_="true", type_=bool) is True + assert construct_type(object_="1", type_=bool) is True + assert construct_type(object_="false", type_=bool) is False + assert construct_type(object_=1, type_=bool) is True + + +def test_construct_type_enum_coercion() -> None: + class Kind(enum.Enum): + A = "a" + B = "b" + + assert construct_type(object_="a", type_=Kind) == Kind.A + # Invalid enum value is returned unchanged instead of raising. + assert construct_type(object_="not-a-member", type_=Kind) == "not-a-member" + + +# --------------------------------------------------------------------------- # +# client.py: access_token / bearer override branches +# --------------------------------------------------------------------------- # +def _env() -> DeepgramClientEnvironment: + return DeepgramClientEnvironment(base=BASE, production=BASE, agent=BASE, agent_rest=BASE) + + +def test_client_access_token_without_api_key() -> None: + client = DeepgramClient(environment=_env(), access_token="tok") + headers = client._client_wrapper.get_headers() + assert headers.get("Authorization") == "bearer tok" + assert headers.get("x-deepgram-session-id") + + +def test_client_access_token_with_explicit_api_key() -> None: + client = DeepgramClient(environment=_env(), access_token="tok", api_key="explicit") + headers = client._client_wrapper.get_headers() + assert headers.get("Authorization") == "bearer tok" + + +def test_client_api_key_only_uses_token_scheme() -> None: + client = DeepgramClient(environment=_env(), api_key="secret") + headers = client._client_wrapper.get_headers() + assert "secret" in headers.get("Authorization", "") + + +# --------------------------------------------------------------------------- # +# jsonable_encoder: pydantic models & custom encoders +# --------------------------------------------------------------------------- # +def test_jsonable_encoder_pydantic_model_with_custom_encoder() -> None: + class M(UncheckedBaseModel): + value: int + + encoded = jsonable_encoder(M(value=5), custom_encoder={int: lambda v: v}) + assert encoded == {"value": 5} + + +def test_jsonable_encoder_custom_encoder_no_match_falls_through() -> None: + # custom_encoder present but matches nothing -> falls through to normal handling + assert jsonable_encoder(7, custom_encoder={bytes: lambda b: b}) == 7 + + +# --------------------------------------------------------------------------- # +# pydantic_utilities: encode_by_type, deep_union_pydantic_dicts, field default +# --------------------------------------------------------------------------- # +def test_encode_by_type_exact_and_isinstance() -> None: + from deepgram.core.pydantic_utilities import encode_by_type + + # exact type match (datetime is registered in pydantic's default encoders) + assert isinstance(encode_by_type(dt.datetime(2020, 1, 1)), str) + # a value with no registered encoder returns None + assert encode_by_type(object()) is None + + +def test_deep_union_pydantic_dicts() -> None: + from deepgram.core.pydantic_utilities import deep_union_pydantic_dicts + + result = deep_union_pydantic_dicts( + {"nested": {"b": 1}, "items": [{"x": 1}, [{"y": 2}], 3], "scalar": "s"}, + {"nested": {}, "items": [{}, [{}], 0], "scalar": ""}, + ) + assert result["nested"]["b"] == 1 + assert result["items"][0]["x"] == 1 + assert result["scalar"] == "s" + + +# --------------------------------------------------------------------------- # +# http_client: debug/error logging on sync & async request + stream paths +# --------------------------------------------------------------------------- # +def _debug_client() -> DeepgramClient: + return DeepgramClient(environment=_env(), api_key="secret-key", logging={"level": "debug", "silent": False}) + + +def _async_debug_client() -> AsyncDeepgramClient: + return AsyncDeepgramClient( + environment=_env(), + api_key="secret-key", + logging={"level": "debug", "silent": False}, + httpx_client=httpx.AsyncClient(), + ) + + +def test_debug_logging_error_status_sync() -> None: + with respx.mock: + respx.route(host=HOST).mock(return_value=httpx.Response(400, json={"e": 1})) + with pytest.raises(ApiError): + _debug_client().manage.v1.projects.list(request_options={"max_retries": 0}) + + +def test_debug_logging_stream_success_sync() -> None: + with respx.mock: + respx.route(host=HOST).mock(return_value=httpx.Response(200, content=b"audio-bytes")) + chunks = b"".join(_debug_client().speak.v1.audio.generate(text="hi")) + assert chunks == b"audio-bytes" + + +def test_debug_logging_stream_error_sync() -> None: + with respx.mock: + respx.route(host=HOST).mock(return_value=httpx.Response(400, json={"e": 1})) + with pytest.raises(ApiError): + list(_debug_client().speak.v1.audio.generate(text="hi")) + + +async def test_debug_logging_stream_success_async() -> None: + with respx.mock: + respx.route(host=HOST).mock(return_value=httpx.Response(200, content=b"audio-bytes")) + chunks = [] + async for chunk in _async_debug_client().speak.v1.audio.generate(text="hi"): + chunks.append(chunk) + assert b"".join(chunks) == b"audio-bytes" + + +async def test_debug_logging_error_status_async() -> None: + with respx.mock: + respx.route(host=HOST).mock(return_value=httpx.Response(400, json={"e": 1})) + with pytest.raises(ApiError): + await _async_debug_client().manage.v1.projects.list(request_options={"max_retries": 0}) + + +async def test_debug_logging_stream_error_async() -> None: + with respx.mock: + respx.route(host=HOST).mock(return_value=httpx.Response(400, json={"e": 1})) + with pytest.raises(ApiError): + async for _chunk in _async_debug_client().speak.v1.audio.generate(text="hi"): + pass + + +# --------------------------------------------------------------------------- # +# http_client: multipart file upload + additional body/query params +# --------------------------------------------------------------------------- # +def test_transcribe_file_multipart_request() -> None: + with respx.mock: + respx.route(host=HOST).mock(return_value=httpx.Response(200, json={"results": {}})) + result = _debug_client().listen.v1.media.transcribe_file(request=b"audio-bytes", model="nova-3") + assert result is not None + + +def test_additional_body_and_query_parameters() -> None: + with respx.mock: + respx.route(host=HOST).mock(return_value=httpx.Response(200, json={})) + result = _client().manage.v1.projects.list( + request_options={ + "additional_body_parameters": {"extra_body": 1}, + "additional_query_parameters": {"extra_query": "q"}, + "additional_headers": {"X-Extra": "h"}, + } + ) + assert result is not None + + +# --------------------------------------------------------------------------- # +# logging: the "level too high -> skip" branch of each log method +# --------------------------------------------------------------------------- # +def test_logger_skips_below_threshold() -> None: + from deepgram.core.logging import ConsoleLogger, Logger + + lg = Logger(level="error", logger=ConsoleLogger(), silent=False) + # Each of these is below the "error" threshold, so is_*() is False and the + # method body is skipped (the ->exit branch of each guard). + lg.debug("d") + lg.info("i") + lg.warn("w") + lg.error("e") # this one logs + + +# --------------------------------------------------------------------------- # +# client.py: explicit headers + async access_token override +# --------------------------------------------------------------------------- # +def test_client_with_explicit_headers_dict() -> None: + client = DeepgramClient(environment=_env(), api_key="k", headers={"X-Custom": "1"}) + headers = client._client_wrapper.get_headers() + assert headers.get("X-Custom") == "1" + assert headers.get("x-deepgram-session-id") + + +async def test_async_client_access_token_and_headers() -> None: + client = AsyncDeepgramClient( + environment=_env(), + access_token="tok", + headers={"X-Custom": "1"}, + httpx_client=httpx.AsyncClient(), + ) + headers = client._client_wrapper.get_headers() + assert headers.get("Authorization") == "bearer tok" + assert headers.get("X-Custom") == "1" + + +async def test_async_client_api_key_only() -> None: + client = AsyncDeepgramClient(environment=_env(), api_key="secret", httpx_client=httpx.AsyncClient()) + headers = client._client_wrapper.get_headers() + assert "secret" in headers.get("Authorization", "") + + +# --------------------------------------------------------------------------- # +# http_client: explicit empty body preserved; force_multipart body +# --------------------------------------------------------------------------- # +def test_get_request_body_shapes() -> None: + from deepgram.core.http_client import get_request_body + + # data provided as a non-mapping is encoded directly + json_body, data_body = get_request_body(json=None, data="raw-string", request_options=None, omit=None) + assert data_body == "raw-string" + + # additional_body_parameters merged into an explicit json mapping + json_body, data_body = get_request_body( + json={"a": 1}, + data=None, + request_options={"additional_body_parameters": {"b": 2}}, + omit=None, + ) + assert json_body == {"a": 1, "b": 2} + + # both None with additional body params -> additional params returned + json_body, data_body = get_request_body( + json=None, data=None, request_options={"additional_body_parameters": {"c": 3}}, omit=None + ) + assert json_body == {"c": 3} + + +def test_remove_omit_and_none_helpers() -> None: + from deepgram.core.http_client import remove_omit_from_dict + from deepgram.core.remove_none_from_dict import remove_none_from_dict + + OMIT = ... + assert remove_omit_from_dict({"a": 1, "b": OMIT}, OMIT) == {"a": 1} + assert remove_omit_from_dict({"a": 1}, None) == {"a": 1} + assert remove_none_from_dict({"a": 1, "b": None}) == {"a": 1} + + +# --------------------------------------------------------------------------- # +# construct_type: collection edge cases +# --------------------------------------------------------------------------- # +def test_construct_type_any_and_bare_containers() -> None: + assert construct_type(object_={"x": 1}, type_=typing.Any) == {"x": 1} + # bare dict / list annotations (no type args) return the object unchanged + assert construct_type(object_={"a": 1}, type_=dict) == {"a": 1} + assert construct_type(object_=[1, 2], type_=list) == [1, 2] + + +def test_construct_type_wrong_shape_passthrough() -> None: + # object shape does not match the annotation -> returned unchanged + assert construct_type(object_="not-a-dict", type_=typing.Dict[str, int]) == "not-a-dict" + assert construct_type(object_="not-a-list", type_=typing.List[int]) == "not-a-list" + assert construct_type(object_="not-a-set", type_=typing.Set[int]) == "not-a-set" + + +def test_construct_type_set_from_set_and_list() -> None: + assert construct_type(object_={1, 2}, type_=typing.Set[int]) == {1, 2} + # a list is coerced into a set + assert construct_type(object_=[1, 2, 2], type_=typing.Set[int]) == {1, 2} + + +# --------------------------------------------------------------------------- # +# construct_type: model aliases + extras + literal-discriminated unions +# --------------------------------------------------------------------------- # +def test_construct_type_model_alias_and_extras() -> None: + import pydantic + + class WithAlias(UncheckedBaseModel): + my_field: typing.Optional[str] = pydantic.Field(default=None, alias="myField") + + built = construct_type(object_={"myField": "v", "unexpected": "extra"}, type_=WithAlias) + assert isinstance(built, WithAlias) + assert built.my_field == "v" + + +def test_construct_type_literal_discriminated_union() -> None: + class A(UncheckedBaseModel): + type: typing_extensions.Literal["a"] = "a" + a_value: int = 0 + + class B(UncheckedBaseModel): + type: typing_extensions.Literal["b"] = "b" + b_value: int = 0 + + built = construct_type(object_={"type": "b", "b_value": 7}, type_=typing.Union[A, B]) + assert isinstance(built, B) + assert built.b_value == 7 + + built_a = construct_type(object_={"type": "a", "a_value": 3}, type_=typing.Union[A, B]) + assert isinstance(built_a, A) + + +# --------------------------------------------------------------------------- # +# socket_client: the defensive "unknown message type" warning branches +# (reached only when construct_type raises on an otherwise-valid JSON frame). +# --------------------------------------------------------------------------- # +import importlib # noqa: E402 + +from deepgram.core.events import EventType # noqa: E402 + +_SOCKET_MODULES = [ + ("deepgram.speak.v1.socket_client", "V1SocketClient", "AsyncV1SocketClient"), + ("deepgram.agent.v1.socket_client", "V1SocketClient", "AsyncV1SocketClient"), + ("deepgram.listen.v1.socket_client", "V1SocketClient", "AsyncV1SocketClient"), + ("deepgram.listen.v2.socket_client", "V2SocketClient", "AsyncV2SocketClient"), +] + + +def _raise(**_kwargs: typing.Any) -> typing.Any: + raise ValueError("boom") + + +class _SyncWSJson: + def __iter__(self) -> typing.Iterator[str]: + return iter(['{"type":"X"}']) + + def recv(self) -> str: + return '{"type":"X"}' + + def send(self, _data: typing.Any) -> None: + pass + + +class _AsyncWSJson: + def __aiter__(self) -> typing.AsyncIterator[str]: + async def _gen() -> typing.AsyncIterator[str]: + yield '{"type":"X"}' + + return _gen() + + async def recv(self) -> str: + return '{"type":"X"}' + + async def send(self, _data: typing.Any) -> None: + pass + + +@pytest.mark.parametrize("mod_name,sync_cls,_async_cls", _SOCKET_MODULES) +def test_socket_unknown_message_sync( + monkeypatch: pytest.MonkeyPatch, mod_name: str, sync_cls: str, _async_cls: str +) -> None: + mod = importlib.import_module(mod_name) + monkeypatch.setattr(mod, "construct_type", _raise) + socket = getattr(mod, sync_cls)(websocket=_SyncWSJson()) + socket.on(EventType.MESSAGE, lambda _d: None) + socket.on(EventType.ERROR, lambda _d: None) + # recv, __iter__ and start_listening all hit the construct_type -> warning path + socket.recv() + list(socket) + socket.start_listening() + + +@pytest.mark.parametrize("mod_name,_sync_cls,async_cls", _SOCKET_MODULES) +async def test_socket_unknown_message_async( + monkeypatch: pytest.MonkeyPatch, mod_name: str, _sync_cls: str, async_cls: str +) -> None: + mod = importlib.import_module(mod_name) + monkeypatch.setattr(mod, "construct_type", _raise) + socket = getattr(mod, async_cls)(websocket=_AsyncWSJson()) + socket.on(EventType.MESSAGE, lambda _d: None) + socket.on(EventType.ERROR, lambda _d: None) + await socket.recv() + async for _ in socket: + pass + await socket.start_listening() + + +# --------------------------------------------------------------------------- # +# serialization: list/sequence of aliased models + read/write round-trip +# --------------------------------------------------------------------------- # +def test_convert_list_of_aliased_models() -> None: + class Aliased(UncheckedBaseModel): + my_field: typing_extensions.Annotated[str, FieldMetadata(alias="myField")] + + written = convert_and_respect_annotation_metadata( + object_=[{"my_field": "v"}], annotation=typing.List[Aliased], direction="write" + ) + assert written == [{"myField": "v"}] + + read = convert_and_respect_annotation_metadata( + object_=[{"myField": "v"}], annotation=typing.List[Aliased], direction="read" + ) + assert read == [{"my_field": "v"}] + + +def test_convert_sequence_and_union_of_aliased_models() -> None: + class Aliased(UncheckedBaseModel): + my_field: typing_extensions.Annotated[str, FieldMetadata(alias="myField")] + + seq = convert_and_respect_annotation_metadata( + object_=[{"my_field": "v"}], annotation=typing.Sequence[Aliased], direction="write" + ) + assert seq == [{"myField": "v"}] + + unioned = convert_and_respect_annotation_metadata( + object_={"my_field": "v"}, annotation=typing.Union[Aliased, None], direction="write" + ) + assert unioned.get("myField") == "v" or unioned.get("my_field") == "v" + + +# --------------------------------------------------------------------------- # +# UniversalBaseModel round-trip (exercises deep_union + write conversion) +# --------------------------------------------------------------------------- # +def test_universal_base_model_dict_round_trip() -> None: + from deepgram.core.pydantic_utilities import UniversalBaseModel + + class Inner(UniversalBaseModel): + inner_field: typing_extensions.Annotated[str, FieldMetadata(alias="innerField")] + + class Outer(UniversalBaseModel): + outer_field: typing_extensions.Annotated[Inner, FieldMetadata(alias="outerField")] + + model = Outer(outer_field=Inner(inner_field="v")) + dumped = model.dict(by_alias=True) + assert dumped.get("outerField", {}).get("innerField") == "v" + # json() path + assert "innerField" in model.json(by_alias=True) + + +# --------------------------------------------------------------------------- # +# jsonable_encoder: iterable-of-pairs fallback (dict(o) path) +# --------------------------------------------------------------------------- # +def test_jsonable_encoder_iterable_of_pairs_fallback() -> None: + class Pairs: + def __iter__(self) -> typing.Iterator[typing.Tuple[str, int]]: + return iter([("a", 1), ("b", 2)]) + + assert jsonable_encoder(Pairs()) == {"a": 1, "b": 2} + + +# --------------------------------------------------------------------------- # +# unchecked_base_model: forward-ref resolution + field/config helpers +# --------------------------------------------------------------------------- # +class _ForwardRefHost: + pass + + +def test_maybe_resolve_forward_ref() -> None: + from deepgram.core.unchecked_base_model import _maybe_resolve_forward_ref + + # resolvable name (builtins are available to eval) + assert _maybe_resolve_forward_ref(typing.ForwardRef("int"), host=_ForwardRefHost) is int + # unresolvable name -> eval raises -> returned unchanged + ref = typing.ForwardRef("NoSuchName_XYZ") + assert _maybe_resolve_forward_ref(ref, host=_ForwardRefHost) is ref + # not a ForwardRef -> returned unchanged + assert _maybe_resolve_forward_ref(int, host=_ForwardRefHost) is int + # no host -> returned unchanged + assert _maybe_resolve_forward_ref(ref, host=None) is ref + + +def test_field_default_and_populate_helpers() -> None: + from deepgram.core.unchecked_base_model import ( + _get_field_default, + _get_is_populate_by_name, + _get_model_fields, + ) + + class M(UncheckedBaseModel): + x: int = 5 + + fields = _get_model_fields(M) + assert _get_field_default(fields["x"]) == 5 + assert isinstance(_get_is_populate_by_name(M), bool) + + +def test_construct_type_bool_coercion_exception() -> None: + # An object whose __bool__ raises exercises the bool try/except -> return object_. + class BadBool: + def __bool__(self) -> bool: + raise ValueError("nope") + + obj = BadBool() + assert construct_type(object_=obj, type_=bool) is obj + + +# --------------------------------------------------------------------------- # +# http_client: low-level request with force_multipart and file uploads +# --------------------------------------------------------------------------- # +def _raw_http_client() -> typing.Any: + from deepgram.core.http_client import HttpClient + + return HttpClient( + httpx_client=httpx.Client(), + base_timeout=lambda: 60.0, + base_headers=lambda: {}, + base_url=lambda: BASE, + base_max_retries=0, + ) + + +def test_http_request_force_multipart() -> None: + with respx.mock: + respx.route(host=HOST).mock(return_value=httpx.Response(200, json={})) + resp = _raw_http_client().request(method="POST", path="v1/x", force_multipart=True) + assert resp.status_code == 200 + + +def test_http_request_with_files_and_none_data() -> None: + with respx.mock: + respx.route(host=HOST).mock(return_value=httpx.Response(200, json={})) + resp = _raw_http_client().request( + method="POST", + path="v1/x", + data={"field": "value", "skip": None}, + files={"file": ("name.wav", b"bytes")}, + ) + assert resp.status_code == 200 + + +def test_get_request_body_empty_collapse() -> None: + from deepgram.core.http_client import get_request_body + + # request_options present but empty -> bodies compute to {} then collapse to None + json_body, data_body = get_request_body(json=None, data=None, request_options={}, omit=None) + assert json_body is None and data_body is None + + +async def _raw_async_http_client() -> typing.Any: + from deepgram.core.http_client import AsyncHttpClient + + return AsyncHttpClient( + httpx_client=httpx.AsyncClient(), + base_timeout=lambda: 60.0, + base_headers=lambda: {}, + base_url=lambda: BASE, + base_max_retries=0, + ) + + +async def test_async_http_request_force_multipart_and_files() -> None: + with respx.mock: + respx.route(host=HOST).mock(return_value=httpx.Response(200, json={})) + client = await _raw_async_http_client() + resp = await client.request(method="POST", path="v1/x", force_multipart=True) + assert resp.status_code == 200 + resp2 = await client.request( + method="POST", + path="v1/x", + data={"field": "value", "skip": None}, + files={"file": ("name.wav", b"bytes")}, + ) + assert resp2.status_code == 200 + + +def test_transport_install_restore_roundtrip() -> None: + from deepgram import transport as transport_module + + def factory(url: str, headers: typing.Dict[str, str]) -> typing.Any: + return None + + try: + transport_module.install_transport(sync_factory=factory) + # re-installing the same factory is idempotent (no error) + transport_module.install_transport(sync_factory=factory) + # a different factory raises + with pytest.raises(RuntimeError): + transport_module.install_transport(sync_factory=lambda u, h: None) + finally: + transport_module.restore_transport() + + +def test_encode_by_type_isinstance_branch() -> None: + from deepgram.core.pydantic_utilities import encode_by_type + + # A subclass whose exact type is unregistered but isinstance-matches a + # registered encoder (set) -> the isinstance branch of encode_by_type. + class MySet(set): + pass + + assert isinstance(encode_by_type(MySet([1, 2])), list) + + +def test_get_field_default_fallback_to_default_attr() -> None: + from deepgram.core.pydantic_utilities import _get_field_default + + class _Field: + default = "fallback" + + def get_default(self) -> typing.Any: + raise RuntimeError("no default accessor") + + assert _get_field_default(_Field()) == "fallback" + + +def test_get_base_url_requires_a_base() -> None: + from deepgram.core.http_client import AsyncHttpClient, HttpClient + + sync = HttpClient( + httpx_client=httpx.Client(), base_timeout=lambda: 60.0, base_headers=lambda: {}, base_url=lambda: None + ) + with pytest.raises(ValueError): + sync.get_base_url(None) + + asynchronous = AsyncHttpClient( + httpx_client=httpx.AsyncClient(), base_timeout=lambda: 60.0, base_headers=lambda: {}, base_url=lambda: None + ) + with pytest.raises(ValueError): + asynchronous.get_base_url(None) + + # explicit base_url short-circuits the base_url() lookup + assert sync.get_base_url("https://explicit.example.com") == "https://explicit.example.com" + + +class _ModuleEnum(enum.Enum): + A = "a" + B = "b" + + +def test_construct_type_enum_module_level() -> None: + assert construct_type(object_="a", type_=_ModuleEnum) == _ModuleEnum.A + assert construct_type(object_="missing", type_=_ModuleEnum) == "missing" + assert construct_type(object_=_ModuleEnum.B, type_=_ModuleEnum) == _ModuleEnum.B + + +def test_construct_type_undiscriminated_union_second_pass() -> None: + class P(UncheckedBaseModel): + a: int = 0 + + class Q(UncheckedBaseModel): + b: int = 0 + + # No literal discriminant -> falls to the "first successful cast" second pass. + built = construct_type(object_={"a": 1}, type_=typing.Union[P, Q]) + assert built is None or isinstance(built, (P, Q, dict)) diff --git a/tests/custom/test_core_utilities_coverage.py b/tests/custom/test_core_utilities_coverage.py new file mode 100644 index 00000000..80052c2b --- /dev/null +++ b/tests/custom/test_core_utilities_coverage.py @@ -0,0 +1,401 @@ +""" +Unit coverage for the generated ``deepgram.core`` helper modules: the JSON +encoder, datetime parsing/serialization, file helpers, the parse error, the +(de)serialization metadata layer, the pydantic compatibility helpers, and the +``construct_type`` coercion engine. + +These are pure functions, so they are exercised directly with a spread of input +shapes that walks their branches. +""" + +import dataclasses +import datetime as dt +import decimal +import enum +import ipaddress +import pathlib +import re +import typing +import uuid + +import pytest +import typing_extensions + +from deepgram.core.datetime_utils import ( + Rfc2822DateTime, + parse_rfc2822_datetime, + serialize_datetime, +) +from deepgram.core.file import convert_file_dict_to_httpx_tuples, with_content_type +from deepgram.core.jsonable_encoder import encode_path_param, jsonable_encoder +from deepgram.core.parse_error import ParsingError +from deepgram.core.pydantic_utilities import ( + UniversalBaseModel, + deep_union_pydantic_dicts, + encode_by_type, + parse_date, + parse_datetime, + parse_obj_as, + to_jsonable_with_fallback, + update_forward_refs, +) +from deepgram.core.serialization import ( + FieldMetadata, + convert_and_respect_annotation_metadata, + get_alias_to_field_mapping, + get_field_to_alias_mapping, +) +from deepgram.core.unchecked_base_model import UncheckedBaseModel, UnionMetadata, construct_type + + +# --------------------------------------------------------------------------- # +# Models / typed dicts used across the tests +# --------------------------------------------------------------------------- # +class Color(enum.Enum): + RED = "red" + BLUE = "blue" + + +class SampleModel(UniversalBaseModel): + name: str + when: typing.Optional[dt.datetime] = None + + +class Cat(UncheckedBaseModel): + type: typing.Literal["cat"] = "cat" + meow: typing.Optional[str] = None + + +class Dog(UncheckedBaseModel): + type: typing.Literal["dog"] = "dog" + bark: typing.Optional[str] = None + + +DiscriminatedAnimal = typing_extensions.Annotated[ + typing.Union[Cat, Dog], UnionMetadata(discriminant="type") +] + + +class AliasedTypedDict(typing_extensions.TypedDict): + field: typing_extensions.Annotated[str, FieldMetadata(alias="field_name")] + plain: int + + +# --------------------------------------------------------------------------- # +# jsonable_encoder +# --------------------------------------------------------------------------- # +def test_jsonable_encoder_primitives_and_containers() -> None: + assert jsonable_encoder(...) is None # OMIT sentinel + assert jsonable_encoder("x") == "x" + assert jsonable_encoder(3) == 3 + assert jsonable_encoder(None) is None + assert jsonable_encoder(Color.RED) == "red" + assert jsonable_encoder(pathlib.PurePath("/a/b")) == "/a/b" + assert jsonable_encoder(b"hi") == "aGk=" # base64 + assert jsonable_encoder(dt.date(2020, 1, 2)) == "2020-01-02" + assert "T" in jsonable_encoder(dt.datetime(2020, 1, 2, 3, 4, 5)) + # dict drops Ellipsis values; list/set/tuple/generator iterate and drop Ellipsis + assert jsonable_encoder({"a": 1, "b": ...}) == {"a": 1} + assert jsonable_encoder([1, ..., 2]) == [1, 2] + assert sorted(jsonable_encoder({1, 2})) == [1, 2] + assert jsonable_encoder((1, 2)) == [1, 2] + assert jsonable_encoder(x for x in (1, 2)) == [1, 2] + + +def test_jsonable_encoder_model_and_dataclass() -> None: + # unset optional fields are excluded by UniversalBaseModel.dict() + assert jsonable_encoder(SampleModel(name="n")) == {"name": "n"} + + @dataclasses.dataclass + class DC: + a: int + + assert jsonable_encoder(DC(a=1)) == {"a": 1} + + +def test_jsonable_encoder_custom_encoder_and_fallback() -> None: + # custom encoder matched by exact type and by isinstance + assert jsonable_encoder(5, custom_encoder={int: lambda o: o + 1}) == 6 + + class IntSub(int): + pass + + assert jsonable_encoder(IntSub(5), custom_encoder={int: lambda o: "matched"}) == "matched" + + # fallback path: a plain object that has to be reduced via vars() + class Plain: + def __init__(self) -> None: + self.x = 1 + + assert jsonable_encoder(Plain()) == {"x": 1} + + +def test_encode_path_param() -> None: + assert encode_path_param(True) == "true" + assert encode_path_param(False) == "false" + assert encode_path_param(12) == "12" + + +# --------------------------------------------------------------------------- # +# datetime_utils +# --------------------------------------------------------------------------- # +def test_parse_rfc2822_datetime() -> None: + existing = dt.datetime(2020, 1, 1, tzinfo=dt.timezone.utc) + assert parse_rfc2822_datetime(existing) is existing + assert parse_rfc2822_datetime("Wed, 02 Oct 2002 13:00:00 GMT").year == 2002 + # falls back to ISO 8601 parsing + assert parse_rfc2822_datetime("2021-05-06T07:08:09Z").year == 2021 + with pytest.raises(ValueError): + parse_rfc2822_datetime(12345) + + +def test_serialize_datetime_variants() -> None: + utc = dt.datetime(2020, 1, 1, 0, 0, 0, tzinfo=dt.timezone.utc) + assert serialize_datetime(utc).endswith("Z") + offset = dt.datetime(2020, 1, 1, tzinfo=dt.timezone(dt.timedelta(hours=5))) + assert "+05:00" in serialize_datetime(offset) + naive = dt.datetime(2020, 1, 1, 0, 0, 0) + assert isinstance(serialize_datetime(naive), str) + + +def test_rfc2822_datetime_used_in_model() -> None: + class WithRfc(UniversalBaseModel): + ts: Rfc2822DateTime + + parsed = parse_obj_as(WithRfc, {"ts": "Wed, 02 Oct 2002 13:00:00 GMT"}) + assert parsed.ts.year == 2002 + + +# --------------------------------------------------------------------------- # +# file helpers +# --------------------------------------------------------------------------- # +def test_convert_file_dict_to_httpx_tuples() -> None: + result = convert_file_dict_to_httpx_tuples({"single": b"a", "many": [b"b", b"c"]}) + assert ("single", b"a") in result + assert ("many", b"b") in result and ("many", b"c") in result + + +def test_with_content_type_all_shapes() -> None: + assert with_content_type(file=b"data", default_content_type="text/plain") == (None, b"data", "text/plain") + assert with_content_type(file=("n", b"d"), default_content_type="text/plain") == ("n", b"d", "text/plain") + assert with_content_type(file=("n", b"d", None), default_content_type="text/plain") == ("n", b"d", "text/plain") + assert with_content_type(file=("n", b"d", "image/png"), default_content_type="text/plain") == ( + "n", + b"d", + "image/png", + ) + four = with_content_type(file=("n", b"d", None, {"h": "v"}), default_content_type="text/plain") + assert four == ("n", b"d", "text/plain", {"h": "v"}) + with pytest.raises(ValueError): + with_content_type(file=("a", "b", "c", "d", "e"), default_content_type="text/plain") # type: ignore[arg-type] + + +# --------------------------------------------------------------------------- # +# parse_error +# --------------------------------------------------------------------------- # +def test_parsing_error_str() -> None: + cause = ValueError("boom") + err = ParsingError(status_code=500, headers={"a": "b"}, body={"x": 1}, cause=cause) + assert "cause: boom" in str(err) + assert err.__cause__ is cause + assert "cause" not in str(ParsingError(status_code=400)) + + +# --------------------------------------------------------------------------- # +# serialization metadata +# --------------------------------------------------------------------------- # +def test_convert_and_respect_annotation_metadata_typeddict() -> None: + written = convert_and_respect_annotation_metadata( + object_={"field": "v", "plain": 1}, annotation=AliasedTypedDict, direction="write" + ) + assert written == {"field_name": "v", "plain": 1} + read = convert_and_respect_annotation_metadata( + object_={"field_name": "v", "plain": 1}, annotation=AliasedTypedDict, direction="read" + ) + assert read == {"field": "v", "plain": 1} + assert convert_and_respect_annotation_metadata(object_=None, annotation=AliasedTypedDict, direction="write") is None + + +def test_convert_and_respect_annotation_metadata_containers() -> None: + list_ann = typing.List[AliasedTypedDict] + out = convert_and_respect_annotation_metadata( + object_=[{"field": "v", "plain": 1}], annotation=list_ann, direction="write" + ) + assert out == [{"field_name": "v", "plain": 1}] + + dict_ann = typing.Dict[str, AliasedTypedDict] + out2 = convert_and_respect_annotation_metadata( + object_={"k": {"field": "v", "plain": 1}}, annotation=dict_ann, direction="write" + ) + assert out2 == {"k": {"field_name": "v", "plain": 1}} + + union_ann = typing.Union[AliasedTypedDict, None] + out3 = convert_and_respect_annotation_metadata( + object_={"field": "v", "plain": 1}, annotation=union_ann, direction="write" + ) + assert out3 == {"field_name": "v", "plain": 1} + + +def test_alias_mappings() -> None: + assert get_alias_to_field_mapping(AliasedTypedDict) == {"field_name": "field"} + assert get_field_to_alias_mapping(AliasedTypedDict) == {"field": "field_name"} + + +# --------------------------------------------------------------------------- # +# pydantic_utilities +# --------------------------------------------------------------------------- # +def test_parse_datetime_and_date() -> None: + assert parse_datetime(dt.datetime(2020, 1, 1)).year == 2020 + assert parse_datetime("2020-01-01T00:00:00Z").year == 2020 + assert parse_date(dt.datetime(2020, 1, 2, 3)).day == 2 + assert parse_date(dt.date(2020, 1, 2)).day == 2 + assert parse_date("2020-01-02").day == 2 + + +def test_parse_obj_as_model_and_typeddict() -> None: + model = parse_obj_as(SampleModel, {"name": "n"}) + assert model.name == "n" + td = parse_obj_as(AliasedTypedDict, {"field_name": "v", "plain": 1}) + assert td["field"] == "v" + + +def test_to_jsonable_with_fallback() -> None: + assert to_jsonable_with_fallback({"a": 1}, lambda o: o) == {"a": 1} + + +def test_encode_by_type() -> None: + assert encode_by_type(uuid.uuid4()).count("-") == 4 # exact-type hit + assert encode_by_type(decimal.Decimal("1.5")) == 1.5 + assert encode_by_type(decimal.Decimal("3")) == 3 + assert encode_by_type(frozenset({1})) == [1] # isinstance loop branch + assert encode_by_type(pathlib.Path("/x")) == "/x" + assert encode_by_type(ipaddress.IPv4Address("1.2.3.4")) == "1.2.3.4" + assert encode_by_type(re.compile("ab")) == "ab" + assert encode_by_type(object()) is None # no encoder matches + + +def test_deep_union_pydantic_dicts() -> None: + source = {"a": {"b": 1}, "lst": [{"x": 1}], "scalar": 5} + destination = {"a": {"c": 2}, "lst": [{"y": 2}], "scalar": 9} + merged = deep_union_pydantic_dicts(source, destination) + assert merged["a"] == {"c": 2, "b": 1} + assert merged["lst"] == [{"y": 2, "x": 1}] + assert merged["scalar"] == 5 + + +def test_universal_base_model_json_dict_and_alias_coercion() -> None: + model = SampleModel(name="n", when=dt.datetime(2020, 1, 1, tzinfo=dt.timezone.utc)) + assert model.dict()["name"] == "n" + assert "name" in model.json() + + class AliasModel(UniversalBaseModel): + actual: typing_extensions.Annotated[str, FieldMetadata(alias="wire")] = "" + + # Field-name input is coerced; supplying an ambiguous duplicate is rejected. + assert parse_obj_as(AliasModel, {"wire": "v"}).actual == "v" + + +def test_update_forward_refs_is_noop_safe() -> None: + update_forward_refs(SampleModel) + + +# --------------------------------------------------------------------------- # +# construct_type +# --------------------------------------------------------------------------- # +def test_construct_type_scalars() -> None: + assert construct_type(type_=typing.Any, object_={"x": 1}) == {"x": 1} + assert construct_type(type_=int, object_="5") == 5 + assert construct_type(type_=int, object_="nan") == "nan" # falls back on failure + assert construct_type(type_=bool, object_="true") is True + assert construct_type(type_=bool, object_="1") is True + assert construct_type(type_=bool, object_="no") is False + assert construct_type(type_=Color, object_="red") == Color.RED + assert construct_type(type_=Color, object_="green") == "green" # invalid enum falls back + assert construct_type(type_=uuid.UUID, object_="not-a-uuid") == "not-a-uuid" + assert construct_type(type_=dt.datetime, object_="2020-01-01T00:00:00Z").year == 2020 + assert construct_type(type_=dt.date, object_="2020-01-02").day == 2 + assert construct_type(type_=str, object_=None) is None + + +def test_construct_type_containers() -> None: + assert construct_type(type_=typing.Dict[str, int], object_={"a": "1"}) == {"a": 1} + assert construct_type(type_=typing.List[int], object_=["1", "2"]) == [1, 2] + assert construct_type(type_=typing.Set[int], object_=["1", "2"]) == {1, 2} + # mismatched container shapes pass through untouched + assert construct_type(type_=typing.Dict[str, int], object_="x") == "x" + assert construct_type(type_=typing.List[int], object_="x") == "x" + + +def test_construct_type_model_and_unions() -> None: + cat = construct_type(type_=Cat, object_={"type": "cat", "meow": "hi"}) + assert isinstance(cat, Cat) and cat.meow == "hi" + + # discriminated union (UnionMetadata) routes by the `type` field + dog = construct_type(type_=DiscriminatedAnimal, object_={"type": "dog", "bark": "woof"}) + assert isinstance(dog, Dog) and dog.bark == "woof" + + # undiscriminated union of literal-bearing models + undiscriminated = typing.Union[Cat, Dog] + again = construct_type(type_=undiscriminated, object_={"type": "cat", "meow": "m"}) + assert isinstance(again, Cat) + + # plain scalar union + assert construct_type(type_=typing.Union[int, str], object_="hello") == "hello" + + +def test_construct_type_union_of_lists_and_plain_models() -> None: + # Union containing a list-of-models member, fed a list -> each item parsed. + union_with_list = typing.Union[typing.List[Cat], Dog] + out = construct_type(type_=union_with_list, object_=[{"type": "cat", "meow": "a"}]) + assert isinstance(out, list) and isinstance(out[0], Cat) + + # Undiscriminated union of models without any Literal discriminant field. + class Box(UncheckedBaseModel): + w: typing.Optional[int] = None + + class Ball(UncheckedBaseModel): + r: typing.Optional[int] = None + + result = construct_type(type_=typing.Union[Box, Ball], object_={"w": 1}) + assert isinstance(result, (Box, Ball)) + + +def test_construct_type_union_incompatible_list_and_fallback() -> None: + # List member is rejected because its items are not compatible, so the + # union falls through to the scalar member. + out = construct_type(type_=typing.Union[typing.List[Cat], int], object_=[123]) + assert out == [123] or out is None + + # First union member fails to validate, the second succeeds. + class HasNum(UncheckedBaseModel): + num: int = 0 + + class HasName(UncheckedBaseModel): + name: str = "" + + result = construct_type(type_=typing.Union[HasNum, HasName], object_={"name": "x"}) + assert isinstance(result, (HasNum, HasName)) + + +def test_construct_type_union_compatible_list_member() -> None: + # A union whose List[Model] member matches: every item is a valid dict, so + # the list is parsed into models. + out = construct_type(type_=typing.Union[typing.List[Cat], Dog], object_=[{"type": "cat", "meow": "m"}]) + assert isinstance(out, list) and isinstance(out[0], Cat) + + +def test_construct_type_discriminated_union_from_object() -> None: + # Feed an object (not a dict) to a discriminated union so the discriminant is + # read via attribute access rather than subscripting. + source = Dog(type="dog", bark="woof") + out = construct_type(type_=DiscriminatedAnimal, object_=source) + assert isinstance(out, Dog) + + +def test_construct_type_optional_and_nested_models() -> None: + class Outer(UncheckedBaseModel): + inner: typing.Optional[Cat] = None + tags: typing.Optional[typing.List[str]] = None + + built = construct_type(type_=Outer, object_={"inner": {"type": "cat", "meow": "x"}, "tags": ["a"]}) + assert isinstance(built, Outer) and isinstance(built.inner, Cat) diff --git a/tests/custom/test_http_endpoints_coverage.py b/tests/custom/test_http_endpoints_coverage.py new file mode 100644 index 00000000..280eb7a2 --- /dev/null +++ b/tests/custom/test_http_endpoints_coverage.py @@ -0,0 +1,210 @@ +""" +Data-driven coverage for the auto-generated REST endpoint clients. + +Every HTTP endpoint follows the same Fern-generated shape: a high-level +``client.py`` method delegates to a ``raw_client.py`` method that calls +``httpx_client.request(...)`` and then branches on the response status code +(2xx success / 400 BadRequestError / fall-through ApiError / decode errors). + +Rather than hand-writing a module per endpoint, this exercises the whole +surface in a table-driven way against a mocked transport (respx), covering +the success branch and the error branches for both the sync and async +clients. Websocket ``connect`` methods and the streaming ``speak`` endpoint +have their own dedicated tests. +""" + +import typing + +import httpx +import pytest +import respx + +from deepgram import AsyncDeepgramClient, DeepgramClient +from deepgram.core.api_error import ApiError +from deepgram.environment import DeepgramClientEnvironment + +HOST = "test.deepgram.local" +BASE = f"https://{HOST}" + + +def _environment() -> DeepgramClientEnvironment: + return DeepgramClientEnvironment(base=BASE, production=BASE, agent=BASE, agent_rest=BASE) + + +def _sync_client() -> DeepgramClient: + return DeepgramClient(environment=_environment(), api_key="test_api_key") + + +def _async_client() -> AsyncDeepgramClient: + # Force a plain httpx.AsyncClient transport. By default the async client + # auto-detects and uses an aiohttp-backed transport, which respx cannot + # intercept (the requests would hit the real network). + return AsyncDeepgramClient( + environment=_environment(), api_key="test_api_key", httpx_client=httpx.AsyncClient() + ) + + +def _resolve(client: typing.Any, dotted_path: str) -> typing.Any: + """Walk ``a.b.c`` attribute chain from the client to the bound endpoint method.""" + obj = client + for part in dotted_path.split("."): + obj = getattr(obj, part) + return obj + + +# (dotted method path, kwargs). All path params are positional-or-keyword in the +# generated code, so they can be supplied by name. Request bodies use the minimal +# shape the generated method requires. +ENDPOINTS: typing.List[typing.Tuple[str, typing.Dict[str, typing.Any]]] = [ + ("agent.v1.settings.think.models.list", {}), + ("auth.v1.tokens.grant", {}), + ("listen.v1.media.transcribe_file", {"request": b"\x00\x00"}), + ("listen.v1.media.transcribe_url", {"url": "https://example.com/a.wav"}), + ("manage.v1.models.get", {"model_id": "m"}), + ("manage.v1.models.list", {}), + ("manage.v1.projects.billing.balances.get", {"project_id": "p", "balance_id": "b"}), + ("manage.v1.projects.billing.balances.list", {"project_id": "p"}), + ("manage.v1.projects.billing.breakdown.list", {"project_id": "p"}), + ("manage.v1.projects.billing.fields.list", {"project_id": "p"}), + ("manage.v1.projects.billing.purchases.list", {"project_id": "p"}), + ("manage.v1.projects.delete", {"project_id": "p"}), + ("manage.v1.projects.get", {"project_id": "p"}), + ("manage.v1.projects.keys.create", {"project_id": "p", "request": {}}), + ("manage.v1.projects.keys.delete", {"project_id": "p", "key_id": "k"}), + ("manage.v1.projects.keys.get", {"project_id": "p", "key_id": "k"}), + ("manage.v1.projects.keys.list", {"project_id": "p"}), + ("manage.v1.projects.leave", {"project_id": "p"}), + ("manage.v1.projects.list", {}), + ("manage.v1.projects.members.delete", {"project_id": "p", "member_id": "m"}), + ("manage.v1.projects.members.invites.create", {"project_id": "p", "email": "e@x.com", "scope": "member"}), + ("manage.v1.projects.members.invites.delete", {"project_id": "p", "email": "e@x.com"}), + ("manage.v1.projects.members.invites.list", {"project_id": "p"}), + ("manage.v1.projects.members.list", {"project_id": "p"}), + ("manage.v1.projects.members.scopes.list", {"project_id": "p", "member_id": "m"}), + ("manage.v1.projects.members.scopes.update", {"project_id": "p", "member_id": "m", "scope": "member"}), + ("manage.v1.projects.models.get", {"project_id": "p", "model_id": "m"}), + ("manage.v1.projects.models.list", {"project_id": "p"}), + ("manage.v1.projects.requests.get", {"project_id": "p", "request_id": "r"}), + ("manage.v1.projects.requests.list", {"project_id": "p"}), + ("manage.v1.projects.update", {"project_id": "p", "name": "new-name"}), + ("manage.v1.projects.usage.breakdown.get", {"project_id": "p"}), + ("manage.v1.projects.usage.fields.list", {"project_id": "p"}), + ("manage.v1.projects.usage.get", {"project_id": "p"}), + ("read.v1.text.analyze", {"request": {"url": "https://example.com/a.txt"}}), + ("self_hosted.v1.distribution_credentials.create", {"project_id": "p"}), + ("self_hosted.v1.distribution_credentials.delete", {"project_id": "p", "distribution_credentials_id": "d"}), + ("self_hosted.v1.distribution_credentials.get", {"project_id": "p", "distribution_credentials_id": "d"}), + ("self_hosted.v1.distribution_credentials.list", {"project_id": "p"}), + ("voice_agent.configurations.create", {"project_id": "p", "config": "cfg"}), + ("voice_agent.configurations.delete", {"project_id": "p", "agent_id": "a"}), + ("voice_agent.configurations.get", {"project_id": "p", "agent_id": "a"}), + ("voice_agent.configurations.list", {"project_id": "p"}), + ("voice_agent.configurations.update", {"project_id": "p", "agent_id": "a", "metadata": {"k": "v"}}), + ("voice_agent.variables.create", {"project_id": "p", "key": "k", "value": "v"}), + ("voice_agent.variables.delete", {"project_id": "p", "variable_id": "v"}), + ("voice_agent.variables.get", {"project_id": "p", "variable_id": "v"}), + ("voice_agent.variables.list", {"project_id": "p"}), + ("voice_agent.variables.update", {"project_id": "p", "variable_id": "v", "value": "v"}), +] + +_ENDPOINT_IDS = [path for path, _ in ENDPOINTS] + + +@pytest.mark.parametrize("path,kwargs", ENDPOINTS, ids=_ENDPOINT_IDS) +def test_endpoint_success_sync(path: str, kwargs: typing.Dict[str, typing.Any]) -> None: + with respx.mock: + respx.route(host=HOST).mock(return_value=httpx.Response(200, json={})) + result = _resolve(_sync_client(), path)(**kwargs) + assert result is not None + + +@pytest.mark.parametrize("path,kwargs", ENDPOINTS, ids=_ENDPOINT_IDS) +async def test_endpoint_success_async(path: str, kwargs: typing.Dict[str, typing.Any]) -> None: + with respx.mock: + respx.route(host=HOST).mock(return_value=httpx.Response(200, json={})) + result = await _resolve(_async_client(), path)(**kwargs) + assert result is not None + + +# 400 hits the dedicated BadRequestError branch; 403 falls through to the +# generic ApiError branch. Both are non-retryable (the client retries only on +# >=500 / 429 / 408 / 409), so the suite stays fast and deterministic. +_ERROR_STATUSES = [400, 403] + + +@pytest.mark.parametrize("status", _ERROR_STATUSES) +@pytest.mark.parametrize("path,kwargs", ENDPOINTS, ids=_ENDPOINT_IDS) +def test_endpoint_error_sync(path: str, kwargs: typing.Dict[str, typing.Any], status: int) -> None: + with respx.mock: + respx.route(host=HOST).mock(return_value=httpx.Response(status, json={"error": "boom"})) + with pytest.raises(ApiError): + _resolve(_sync_client(), path)(**kwargs) + + +@pytest.mark.parametrize("status", _ERROR_STATUSES) +@pytest.mark.parametrize("path,kwargs", ENDPOINTS, ids=_ENDPOINT_IDS) +async def test_endpoint_error_async(path: str, kwargs: typing.Dict[str, typing.Any], status: int) -> None: + with respx.mock: + respx.route(host=HOST).mock(return_value=httpx.Response(status, json={"error": "boom"})) + with pytest.raises(ApiError): + await _resolve(_async_client(), path)(**kwargs) + + +# A non-JSON error body exercises the ``except JSONDecodeError -> ApiError`` +# branch present in every raw_client method. +@pytest.mark.parametrize("path,kwargs", ENDPOINTS, ids=_ENDPOINT_IDS) +def test_endpoint_non_json_error_body_sync(path: str, kwargs: typing.Dict[str, typing.Any]) -> None: + with respx.mock: + respx.route(host=HOST).mock(return_value=httpx.Response(403, content=b"not json")) + with pytest.raises(ApiError): + _resolve(_sync_client(), path)(**kwargs) + + +@pytest.mark.parametrize("path,kwargs", ENDPOINTS, ids=_ENDPOINT_IDS) +async def test_endpoint_non_json_error_body_async(path: str, kwargs: typing.Dict[str, typing.Any]) -> None: + with respx.mock: + respx.route(host=HOST).mock(return_value=httpx.Response(403, content=b"not json")) + with pytest.raises(ApiError): + await _resolve(_async_client(), path)(**kwargs) + + +def test_with_raw_response_returns_http_response_sync() -> None: + """The ``with_raw_response`` accessor returns the raw HttpResponse wrapper.""" + with respx.mock: + respx.route(host=HOST).mock(return_value=httpx.Response(200, json={})) + raw = _sync_client().manage.v1.projects.with_raw_response.list() + assert raw.data is not None + + +def _walk_subclients(obj: typing.Any, depth: int = 0, seen: typing.Optional[set] = None) -> int: + """Touch every nested sub-client property to cover their lazy-init accessors.""" + seen = seen if seen is not None else set() + if id(obj) in seen or depth > 8: + return 0 + seen.add(id(obj)) + count = 0 + for name in dir(obj): + if name.startswith("_"): + continue + try: + attr = getattr(obj, name) + except Exception: + continue + module = getattr(type(attr), "__module__", "") or "" + if module.startswith("deepgram") and not callable(attr): + count += 1 + count += _walk_subclients(attr, depth + 1, seen) + return count + + +def test_subclient_accessors_are_reachable_sync() -> None: + client = _sync_client() + # Access twice so both the lazy-construct and cached branches run. + assert _walk_subclients(client) > 0 + assert _walk_subclients(client) > 0 + + +def test_subclient_accessors_are_reachable_async() -> None: + client = _async_client() + assert _walk_subclients(client) > 0 + assert _walk_subclients(client) > 0 diff --git a/tests/custom/test_http_retry_coverage.py b/tests/custom/test_http_retry_coverage.py new file mode 100644 index 00000000..63030bde --- /dev/null +++ b/tests/custom/test_http_retry_coverage.py @@ -0,0 +1,110 @@ +""" +Coverage for the retry behaviour in ``deepgram.core.http_client``. + +The retry path (exponential backoff, ``Retry-After`` / ``X-RateLimit-Reset`` +header handling, connection-error retries, and retry exhaustion) is otherwise +not reached by the endpoint tests, which only return single responses. Sleeps +are patched out so the suite stays fast. +""" + +import typing + +import httpx +import pytest +import respx + +import deepgram.core.http_client as http_client_module +from deepgram import AsyncDeepgramClient, DeepgramClient +from deepgram.core.api_error import ApiError +from deepgram.environment import DeepgramClientEnvironment + +HOST = "test.deepgram.local" +BASE = f"https://{HOST}" + + +@pytest.fixture(autouse=True) +def _no_sleep(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(http_client_module.time, "sleep", lambda _s: None) + + async def _async_sleep(_s: float) -> None: + return None + + monkeypatch.setattr(http_client_module.asyncio, "sleep", _async_sleep) + + +def _environment() -> DeepgramClientEnvironment: + return DeepgramClientEnvironment(base=BASE, production=BASE, agent=BASE, agent_rest=BASE) + + +def _sync_client() -> DeepgramClient: + return DeepgramClient(environment=_environment(), api_key="test_api_key") + + +def _async_client() -> AsyncDeepgramClient: + return AsyncDeepgramClient(environment=_environment(), api_key="test_api_key", httpx_client=httpx.AsyncClient()) + + +@pytest.mark.parametrize( + "headers", + [ + {}, # exponential backoff branch + {"retry-after": "1"}, # Retry-After seconds branch + {"retry-after-ms": "10"}, # Retry-After-ms branch + {"x-ratelimit-reset": "1"}, # X-RateLimit-Reset branch + ], +) +def test_retry_then_exhaust_sync(headers: typing.Dict[str, str]) -> None: + with respx.mock: + respx.route(host=HOST).mock(return_value=httpx.Response(500, headers=headers, json={"e": 1})) + with pytest.raises(ApiError): + _sync_client().manage.v1.projects.list(request_options={"max_retries": 2}) + + +def test_retry_recovers_after_retryable_status_sync() -> None: + with respx.mock: + route = respx.route(host=HOST) + route.side_effect = [httpx.Response(429, json={}), httpx.Response(200, json={})] + # Should retry the 429 and then succeed on the 200. + assert _sync_client().manage.v1.projects.list(request_options={"max_retries": 2}) is not None + + +def test_retry_on_connect_error_then_success_sync() -> None: + with respx.mock: + route = respx.route(host=HOST) + route.side_effect = [httpx.ConnectError("boom"), httpx.Response(200, json={})] + assert _sync_client().manage.v1.projects.list(request_options={"max_retries": 2}) is not None + + +def test_connect_error_exhausts_and_raises_sync() -> None: + with respx.mock: + respx.route(host=HOST).mock(side_effect=httpx.ConnectError("boom")) + with pytest.raises(httpx.ConnectError): + _sync_client().manage.v1.projects.list(request_options={"max_retries": 1}) + + +def test_no_retries_when_max_retries_zero_sync() -> None: + with respx.mock: + respx.route(host=HOST).mock(return_value=httpx.Response(500, json={})) + with pytest.raises(ApiError): + _sync_client().manage.v1.projects.list(request_options={"max_retries": 0}) + + +async def test_retry_then_exhaust_async() -> None: + with respx.mock: + respx.route(host=HOST).mock(return_value=httpx.Response(500, headers={"retry-after": "1"}, json={})) + with pytest.raises(ApiError): + await _async_client().manage.v1.projects.list(request_options={"max_retries": 2}) + + +async def test_retry_recovers_async() -> None: + with respx.mock: + route = respx.route(host=HOST) + route.side_effect = [httpx.Response(500, json={}), httpx.Response(200, json={})] + assert await _async_client().manage.v1.projects.list(request_options={"max_retries": 2}) is not None + + +async def test_retry_on_connect_error_async() -> None: + with respx.mock: + route = respx.route(host=HOST) + route.side_effect = [httpx.ConnectError("boom"), httpx.Response(200, json={})] + assert await _async_client().manage.v1.projects.list(request_options={"max_retries": 2}) is not None diff --git a/tests/custom/test_misc_coverage.py b/tests/custom/test_misc_coverage.py new file mode 100644 index 00000000..cef2273d --- /dev/null +++ b/tests/custom/test_misc_coverage.py @@ -0,0 +1,130 @@ +""" +Coverage for assorted helpers: the query encoder, the custom client +constructors (access token / session id / transport factory / log redaction), +and the TTS ``TextBuilder``/SSML helpers. +""" + +import typing + +import pytest + +from deepgram import AsyncDeepgramClient, DeepgramClient +from deepgram.core.query_encoder import encode_query, single_query_encoder, traverse_query_dict +from deepgram.environment import DeepgramClientEnvironment +from deepgram.helpers.text_builder import ( + TextBuilder, + add_pronunciation, + ssml_to_deepgram, + validate_ipa, + validate_pause, +) + +BASE = "https://test.deepgram.local" + + +def _environment() -> DeepgramClientEnvironment: + return DeepgramClientEnvironment(base=BASE, production=BASE, agent=BASE, agent_rest=BASE) + + +# --------------------------------------------------------------------------- # +# query_encoder +# --------------------------------------------------------------------------- # +def test_encode_query_none_and_empty() -> None: + assert encode_query(None) is None + assert encode_query({}) == [] + + +def test_query_encoder_shapes() -> None: + # bool coercion to lowercase + assert ("flag", "true") in encode_query({"flag": True}) + assert ("flag", "false") in encode_query({"flag": False}) + # nested dict flattening + assert ("a[b]", 1) in traverse_query_dict({"a": {"b": 1}}) + # list of scalars and list of dicts + assert single_query_encoder("k", [1, 2]) == [("k", 1), ("k", 2)] + assert single_query_encoder("k", [{"x": 1}]) == [("k[x]", 1)] + # plain scalar + assert single_query_encoder("k", "v") == [("k", "v")] + + +# --------------------------------------------------------------------------- # +# custom client constructors +# --------------------------------------------------------------------------- # +@pytest.mark.parametrize("client_cls", [DeepgramClient, AsyncDeepgramClient]) +def test_client_with_access_token_and_session_id(client_cls: typing.Any) -> None: + client = client_cls(environment=_environment(), access_token="my-token", session_id="sid-123") + assert client.session_id == "sid-123" + + +@pytest.mark.parametrize("client_cls", [DeepgramClient, AsyncDeepgramClient]) +def test_client_generates_session_id_and_opts_out_of_redaction(client_cls: typing.Any) -> None: + client = client_cls(environment=_environment(), api_key="k", redact_credentials_in_logs=False) + assert client.session_id # auto-generated UUID + + +@pytest.mark.parametrize("client_cls", [DeepgramClient, AsyncDeepgramClient]) +def test_client_with_transport_factory_disables_reconnect(client_cls: typing.Any) -> None: + from deepgram.transport import restore_transport + + def _factory(url: str, headers: typing.Dict[str, str]) -> typing.Any: # pragma: no cover - not invoked + raise NotImplementedError + + try: + client = client_cls(environment=_environment(), api_key="k", transport_factory=_factory) + assert client.reconnect is False + finally: + # Transport patching is global; undo it so it does not leak into other tests. + restore_transport() + + +# --------------------------------------------------------------------------- # +# text_builder / SSML helpers +# --------------------------------------------------------------------------- # +def test_text_builder_fluent_build() -> None: + text = ( + TextBuilder() + .text("Take ") + .pronunciation("azathioprine", "ˌæzəˈθaɪəpriːn") + .pause(500) + .text(" daily.") + .build() + ) + assert "pronounce" in text and "{pause:500}" in text + + +def test_text_builder_validation_errors() -> None: + with pytest.raises(ValueError): + TextBuilder().pronunciation("w", 'has"quote') # invalid IPA char + with pytest.raises(ValueError): + TextBuilder().pause(123) # not a 100ms increment + + +def test_validate_ipa_and_pause() -> None: + assert validate_ipa("")[0] is False + assert validate_ipa('a"b')[0] is False + assert validate_ipa("x" * 101)[0] is False + assert validate_ipa("ˈtest")[0] is True + assert validate_pause(400)[0] is False + assert validate_pause(6000)[0] is False + assert validate_pause(550)[0] is False # not 100ms increment + assert validate_pause(500)[0] is True + + +def test_add_pronunciation_and_ssml() -> None: + out = add_pronunciation("Take azathioprine daily", "azathioprine", "ˌæzəˈθaɪəpriːn") + assert "pronounce" in out + with pytest.raises(ValueError): + add_pronunciation("x", "x", 'bad"ipa') + + ssml = 'Take azathioprine now' + converted = ssml_to_deepgram(ssml) + assert "pronounce" in converted and "{pause:500}" in converted + + # break in milliseconds + an out-of-range value that gets rounded to a valid one + rounded = ssml_to_deepgram('Wait here') + assert "{pause:" in rounded + + +def test_text_builder_from_ssml_updates_counts() -> None: + builder = TextBuilder().from_ssml('Hi there') + assert "{pause:500}" in builder.build() diff --git a/tests/custom/test_more_branches_coverage.py b/tests/custom/test_more_branches_coverage.py new file mode 100644 index 00000000..ef919a88 --- /dev/null +++ b/tests/custom/test_more_branches_coverage.py @@ -0,0 +1,215 @@ +""" +Additional targeted branch coverage for ``core`` internals that the broad +endpoint/utility tests do not reach: debug logging, the ``Logger`` helper, +Retry-After date / X-RateLimit-Reset header parsing, and a few remaining +encoder/serialization/construct_type shapes. +""" + +import datetime as dt +import decimal +import email.utils +import time +import typing + +import httpx +import pytest +import respx +import typing_extensions + +import deepgram.core.http_client as http_client_module +from deepgram import DeepgramClient +from deepgram.core.api_error import ApiError +from deepgram.core.jsonable_encoder import jsonable_encoder +from deepgram.core.logging import ConsoleLogger, Logger, create_logger +from deepgram.core.pydantic_utilities import UniversalBaseModel +from deepgram.core.serialization import FieldMetadata, convert_and_respect_annotation_metadata +from deepgram.core.unchecked_base_model import UncheckedBaseModel, construct_type +from deepgram.environment import DeepgramClientEnvironment + +HOST = "test.deepgram.local" +BASE = f"https://{HOST}" + + +def _environment() -> DeepgramClientEnvironment: + return DeepgramClientEnvironment(base=BASE, production=BASE, agent=BASE, agent_rest=BASE) + + +# --------------------------------------------------------------------------- # +# Logger +# --------------------------------------------------------------------------- # +def test_logger_levels() -> None: + logger = Logger(level="debug", logger=ConsoleLogger(), silent=False) + assert logger.is_debug() and logger.is_info() and logger.is_warn() and logger.is_error() + logger.debug("d") + logger.info("i") + logger.warn("w") + logger.error("e") + + silent = Logger(level="error", logger=ConsoleLogger(), silent=True) + assert silent.is_debug() is False + silent.debug("nope") # suppressed + + # create_logger passthrough + dict config + assert create_logger(logger) is logger + assert create_logger(None) is not None + assert create_logger({"level": "warn", "silent": False}).is_warn() + + +def test_debug_logging_request_path(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(http_client_module.time, "sleep", lambda _s: None) + client = DeepgramClient( + environment=_environment(), api_key="secret-key", logging={"level": "debug", "silent": False} + ) + with respx.mock: + respx.route(host=HOST).mock(return_value=httpx.Response(200, json={})) + assert client.manage.v1.projects.list() is not None + + +# --------------------------------------------------------------------------- # +# Retry-After / X-RateLimit-Reset header parsing +# --------------------------------------------------------------------------- # +def _no_sleep(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(http_client_module.time, "sleep", lambda _s: None) + + +def test_retry_after_http_date(monkeypatch: pytest.MonkeyPatch) -> None: + _no_sleep(monkeypatch) + future = email.utils.formatdate(time.time() + 30, usegmt=True) + with respx.mock: + respx.route(host=HOST).mock(return_value=httpx.Response(503, headers={"retry-after": future}, json={})) + with pytest.raises(ApiError): + DeepgramClient(environment=_environment(), api_key="k").manage.v1.projects.list( + request_options={"max_retries": 1} + ) + + +def test_x_ratelimit_reset_future(monkeypatch: pytest.MonkeyPatch) -> None: + _no_sleep(monkeypatch) + reset = str(int(time.time()) + 30) + with respx.mock: + respx.route(host=HOST).mock(return_value=httpx.Response(503, headers={"x-ratelimit-reset": reset}, json={})) + with pytest.raises(ApiError): + DeepgramClient(environment=_environment(), api_key="k").manage.v1.projects.list( + request_options={"max_retries": 1} + ) + + +# --------------------------------------------------------------------------- # +# jsonable_encoder fallback / root model +# --------------------------------------------------------------------------- # +def test_jsonable_encoder_decimal_and_root_model() -> None: + # Decimal is not handled inline; it goes through the jsonable fallback. + assert jsonable_encoder(decimal.Decimal("1.5")) in (1.5, "1.5") + + # A mapping-like object reduced via dict() in the fallback path. + class MappingLike: + def keys(self) -> typing.List[str]: + return ["a"] + + def __getitem__(self, key: str) -> int: + return 1 + + assert jsonable_encoder(MappingLike()) == {"a": 1} + + +# --------------------------------------------------------------------------- # +# serialization: set conversion + alias on a model field +# --------------------------------------------------------------------------- # +def test_serialization_set_conversion() -> None: + class TD(typing_extensions.TypedDict): + field: typing_extensions.Annotated[str, FieldMetadata(alias="field_name")] + + set_ann = typing.Set[str] + # Sets pass through unchanged (no alias inside), exercising the set branch. + out = convert_and_respect_annotation_metadata(object_={"a", "b"}, annotation=set_ann, direction="write") + assert out == {"a", "b"} + + list_of_td = typing.List[TD] + converted = convert_and_respect_annotation_metadata( + object_=[{"field": "v"}], annotation=list_of_td, direction="read" + ) + assert converted == [{"field": "v"}] + + +# --------------------------------------------------------------------------- # +# construct_type extra shapes +# --------------------------------------------------------------------------- # +def test_construct_type_set_of_models_and_datetime_passthrough() -> None: + class Tag(UncheckedBaseModel): + name: typing.Optional[str] = None + + out = construct_type(type_=typing.List[Tag], object_=[{"name": "x"}, {"name": "y"}]) + assert all(isinstance(t, Tag) for t in out) + + # datetime field that fails to parse falls back to the raw value + assert construct_type(type_=dt.datetime, object_="not-a-date") == "not-a-date" + assert construct_type(type_=dt.date, object_="not-a-date") == "not-a-date" + + +def test_construct_type_model_with_pydantic_alias_and_extras() -> None: + import pydantic + + class Aliased(UncheckedBaseModel): + field_name: str = pydantic.Field(alias="fieldName") + + # Supply the alias key plus an unexpected extra key to exercise the alias + # resolution and extras-passthrough branches of UncheckedBaseModel.construct. + built = construct_type(type_=Aliased, object_={"fieldName": "v", "surprise": 1}) + assert built.field_name == "v" + + +# --------------------------------------------------------------------------- # +# query_encoder: pydantic model values +# --------------------------------------------------------------------------- # +def test_query_encoder_with_pydantic_models() -> None: + from deepgram.core.query_encoder import single_query_encoder + + class QModel(UniversalBaseModel): + a: int = 1 + + assert single_query_encoder("k", QModel(a=2)) == [("k[a]", 2)] + assert single_query_encoder("k", [QModel(a=3)]) == [("k[a]", 3)] + + +# --------------------------------------------------------------------------- # +# http_client: timeout option + retry exhaustion details +# --------------------------------------------------------------------------- # +def test_request_with_timeout_option(monkeypatch: pytest.MonkeyPatch) -> None: + with respx.mock: + respx.route(host=HOST).mock(return_value=httpx.Response(200, json={})) + client = DeepgramClient(environment=_environment(), api_key="k") + assert client.manage.v1.projects.list(request_options={"timeout_in_seconds": 5}) is not None + + +def test_x_ratelimit_reset_in_past_falls_back(monkeypatch: pytest.MonkeyPatch) -> None: + _no_sleep(monkeypatch) + past = str(int(time.time()) - 100) + with respx.mock: + respx.route(host=HOST).mock(return_value=httpx.Response(500, headers={"x-ratelimit-reset": past}, json={})) + with pytest.raises(ApiError): + DeepgramClient(environment=_environment(), api_key="k").manage.v1.projects.list( + request_options={"max_retries": 1} + ) + + +async def test_async_request_with_timeout_option() -> None: + from deepgram import AsyncDeepgramClient + + with respx.mock: + respx.route(host=HOST).mock(return_value=httpx.Response(200, json={})) + client = AsyncDeepgramClient(environment=_environment(), api_key="k", httpx_client=httpx.AsyncClient()) + assert await client.manage.v1.projects.list(request_options={"timeout_in_seconds": 5}) is not None + + +async def test_async_connect_error_exhausts(monkeypatch: pytest.MonkeyPatch) -> None: + async def _async_sleep(_s: float) -> None: + return None + + monkeypatch.setattr(http_client_module.asyncio, "sleep", _async_sleep) + from deepgram import AsyncDeepgramClient + + client = AsyncDeepgramClient(environment=_environment(), api_key="k", httpx_client=httpx.AsyncClient()) + with respx.mock: + respx.route(host=HOST).mock(side_effect=httpx.ConnectError("boom")) + with pytest.raises(httpx.ConnectError): + await client.manage.v1.projects.list(request_options={"max_retries": 1}) diff --git a/tests/custom/test_websocket_streaming_coverage.py b/tests/custom/test_websocket_streaming_coverage.py new file mode 100644 index 00000000..0b552c6b --- /dev/null +++ b/tests/custom/test_websocket_streaming_coverage.py @@ -0,0 +1,344 @@ +""" +Coverage for the websocket ``connect`` clients, their socket clients, and the +streaming TTS (``speak.v1.audio.generate``) endpoint. + +These paths are not reachable through the HTTP request mock used by +``test_http_endpoints_coverage`` and are not exercised by the WireMock-based +wire tests, so they are driven here directly: + +* ``connect`` is exercised by monkeypatching the ``websockets`` connect call so + it yields a fake protocol. The same test then drives every method of the + resulting socket client (send_*, recv, iteration, start_listening), which + covers ``connect``, the high-level client wrapper, and ``socket_client`` in + one shot. +* The InvalidWebSocketStatus -> ApiError mapping is covered by making the + patched connect raise. +* The streaming TTS endpoint is covered with a mocked transport (respx). +""" + +import typing + +import httpx +import pytest +import respx +import websockets.sync.client as websockets_sync_client + +from deepgram import AsyncDeepgramClient, DeepgramClient +from deepgram.core.api_error import ApiError +from deepgram.core.events import EventType +from deepgram.core.websocket_compat import InvalidWebSocketStatus +from deepgram.environment import DeepgramClientEnvironment + +HOST = "test.deepgram.local" +BASE = f"https://{HOST}" +WS_BASE = f"wss://{HOST}" + + +def _environment() -> DeepgramClientEnvironment: + return DeepgramClientEnvironment(base=BASE, production=WS_BASE, agent=WS_BASE, agent_rest=BASE) + + +def _sync_client() -> DeepgramClient: + return DeepgramClient(environment=_environment(), api_key="test_api_key") + + +def _async_client() -> AsyncDeepgramClient: + return AsyncDeepgramClient( + environment=_environment(), api_key="test_api_key", httpx_client=httpx.AsyncClient() + ) + + +def _resolve(client: typing.Any, dotted_path: str) -> typing.Any: + obj = client + for part in dotted_path.split("."): + obj = getattr(obj, part) + return obj + + +# A stand-in for the typed message models the send_* helpers expect. The socket +# client only calls ``.dict()`` on it; the float/list/nested values exercise the +# numeric-sanitizing branch in the agent socket client. +class _DummyModel: + def dict(self) -> typing.Dict[str, typing.Any]: + return {"type": "X", "sample_rate": 44100.0, "nested": {"v": 2.0}, "items": [1.0, "a"]} + + +# Messages fed to the iterators: a binary frame, a parseable JSON frame, and a +# non-JSON frame (skipped on iteration, surfaced as an ERROR by start_listening). +_ITER_MESSAGES = [b"\x00\x01", '{"type":"Welcome"}', "this is not json"] + + +class _FakeSyncWS: + def __init__(self) -> None: + self.sent: typing.List[typing.Any] = [] + self._recv_queue: typing.List[typing.Any] = [b"\x02", '{"type":"Welcome"}', '{"type":"Welcome"}'] + + def __iter__(self) -> typing.Iterator[typing.Any]: + return iter(_ITER_MESSAGES) + + def recv(self) -> typing.Any: + return self._recv_queue.pop(0) + + def send(self, data: typing.Any) -> None: + self.sent.append(data) + + +class _FakeAsyncWS: + def __init__(self) -> None: + self.sent: typing.List[typing.Any] = [] + self._recv_queue: typing.List[typing.Any] = [b"\x02", '{"type":"Welcome"}', '{"type":"Welcome"}'] + + async def __aiter__(self) -> typing.AsyncIterator[typing.Any]: + for message in _ITER_MESSAGES: + yield message + + async def recv(self) -> typing.Any: + return self._recv_queue.pop(0) + + async def send(self, data: typing.Any) -> None: + self.sent.append(data) + + +class _FakeSyncConnectCM: + def __init__(self, ws: _FakeSyncWS) -> None: + self._ws = ws + + def __enter__(self) -> _FakeSyncWS: + return self._ws + + def __exit__(self, *exc: typing.Any) -> bool: + return False + + +class _FakeAsyncConnectCM: + def __init__(self, ws: _FakeAsyncWS) -> None: + self._ws = ws + + async def __aenter__(self) -> _FakeAsyncWS: + return self._ws + + async def __aexit__(self, *exc: typing.Any) -> bool: + return False + + +def _make_invalid_status(code: int) -> InvalidWebSocketStatus: + exc = InvalidWebSocketStatus.__new__(InvalidWebSocketStatus) + # Support both websockets layouts: legacy reads exc.status_code, newer reads + # exc.response.status_code. + exc.status_code = code # type: ignore[attr-defined] + exc.response = type("_Resp", (), {"status_code": code})() # type: ignore[attr-defined] + return exc + + +# (leaf client path, connect kwargs). The high-level ``connect`` reimplements the +# websocket logic inline (it does not call ``raw_client.connect``), so each +# endpoint is driven both through the high-level client and through +# ``with_raw_response`` to cover both modules. +WS_ENDPOINTS = [ + ("speak.v1", {}), + ("agent.v1", {}), + ("listen.v1", {"model": "nova-3"}), + ("listen.v2", {"model": "flux-general-en"}), +] +_WS_IDS = [e[0] for e in WS_ENDPOINTS] + + +def _async_connect_modules(leaf: str) -> typing.List[str]: + """The two modules that import ``websockets_client_connect`` for an endpoint.""" + return [f"deepgram.{leaf}.client", f"deepgram.{leaf}.raw_client"] + + +def _exercise_sync_socket(socket: typing.Any) -> None: + socket.on(EventType.OPEN, lambda _data: None) + socket.on(EventType.MESSAGE, lambda _data: None) + socket.on(EventType.ERROR, lambda _data: None) + socket.on(EventType.CLOSE, lambda _data: None) + + list(socket) # __iter__: binary + parsed + skipped-non-json + socket.start_listening() # OPEN, MESSAGE(s), ERROR (non-json), CLOSE + socket.recv() # binary frame + socket.recv() # json frame + + for name in dir(socket): + if not name.startswith("send_"): + continue + method = getattr(socket, name) + try: + method() # send helpers with an optional/defaulted message + except TypeError: + method(_DummyModel()) + + +async def _exercise_async_socket(socket: typing.Any) -> None: + socket.on(EventType.OPEN, lambda _data: None) + socket.on(EventType.MESSAGE, lambda _data: None) + socket.on(EventType.ERROR, lambda _data: None) + socket.on(EventType.CLOSE, lambda _data: None) + + async for _ in socket: # __aiter__ + pass + await socket.start_listening() + await socket.recv() + await socket.recv() + + for name in dir(socket): + if not name.startswith("send_"): + continue + method = getattr(socket, name) + try: + await method() + except TypeError: + await method(_DummyModel()) + + +@pytest.mark.parametrize("leaf,kwargs", WS_ENDPOINTS, ids=_WS_IDS) +def test_connect_and_socket_client_sync( + monkeypatch: pytest.MonkeyPatch, leaf: str, kwargs: typing.Dict[str, typing.Any] +) -> None: + created: typing.List[_FakeSyncWS] = [] + + def _connect(*_a: typing.Any, **_k: typing.Any) -> _FakeSyncConnectCM: + ws = _FakeSyncWS() + created.append(ws) + return _FakeSyncConnectCM(ws) + + # Both client.py and raw_client.py reference the same websockets.sync.client + # module, so a single global patch covers them. + monkeypatch.setattr(websockets_sync_client, "connect", _connect) + + client = _sync_client() + with _resolve(client, leaf).connect(**kwargs) as socket: # high-level client.py + _exercise_sync_socket(socket) + with _resolve(client, leaf).with_raw_response.connect(**kwargs) as socket: # raw_client.py + _exercise_sync_socket(socket) + # With authorization + request options, covering the auth-header, extra-header + # and additional-query-parameter branches of connect. + opts = {"additional_headers": {"X-Trace": "1"}, "additional_query_parameters": {"foo": "bar"}} + with _resolve(client, leaf).connect(authorization="Token abc", request_options=opts, **kwargs): + pass + with _resolve(client, leaf).with_raw_response.connect(authorization="Token abc", request_options=opts, **kwargs): + pass + + assert created and any(ws.sent for ws in created) # send_* helpers reached the transport + + +@pytest.mark.parametrize("leaf,kwargs", WS_ENDPOINTS, ids=_WS_IDS) +async def test_connect_and_socket_client_async( + monkeypatch: pytest.MonkeyPatch, leaf: str, kwargs: typing.Dict[str, typing.Any] +) -> None: + created: typing.List[_FakeAsyncWS] = [] + + def _connect(*_a: typing.Any, **_k: typing.Any) -> _FakeAsyncConnectCM: + ws = _FakeAsyncWS() + created.append(ws) + return _FakeAsyncConnectCM(ws) + + # The async connect is imported by name into both client.py and raw_client.py. + for module in _async_connect_modules(leaf): + monkeypatch.setattr(f"{module}.websockets_client_connect", _connect) + + client = _async_client() + async with _resolve(client, leaf).connect(**kwargs) as socket: # high-level client.py + await _exercise_async_socket(socket) + async with _resolve(client, leaf).with_raw_response.connect(**kwargs) as socket: # raw_client.py + await _exercise_async_socket(socket) + opts = {"additional_headers": {"X-Trace": "1"}, "additional_query_parameters": {"foo": "bar"}} + async with _resolve(client, leaf).connect(authorization="Token abc", request_options=opts, **kwargs): + pass + async with _resolve(client, leaf).with_raw_response.connect( + authorization="Token abc", request_options=opts, **kwargs + ): + pass + + assert created and any(ws.sent for ws in created) + + +@pytest.mark.parametrize("status", [401, 500]) +@pytest.mark.parametrize("leaf,kwargs", WS_ENDPOINTS, ids=_WS_IDS) +def test_connect_invalid_status_raises_sync( + monkeypatch: pytest.MonkeyPatch, leaf: str, kwargs: typing.Dict[str, typing.Any], status: int +) -> None: + def _raise(*_a: typing.Any, **_k: typing.Any) -> typing.Any: + raise _make_invalid_status(status) + + monkeypatch.setattr(websockets_sync_client, "connect", _raise) + + client = _sync_client() + with pytest.raises(ApiError): + with _resolve(client, leaf).connect(**kwargs): + pass + with pytest.raises(ApiError): + with _resolve(client, leaf).with_raw_response.connect(**kwargs): + pass + + +@pytest.mark.parametrize("status", [401, 500]) +@pytest.mark.parametrize("leaf,kwargs", WS_ENDPOINTS, ids=_WS_IDS) +async def test_connect_invalid_status_raises_async( + monkeypatch: pytest.MonkeyPatch, leaf: str, kwargs: typing.Dict[str, typing.Any], status: int +) -> None: + def _raise(*_a: typing.Any, **_k: typing.Any) -> typing.Any: + raise _make_invalid_status(status) + + for module in _async_connect_modules(leaf): + monkeypatch.setattr(f"{module}.websockets_client_connect", _raise) + + client = _async_client() + with pytest.raises(ApiError): + async with _resolve(client, leaf).connect(**kwargs): + pass + with pytest.raises(ApiError): + async with _resolve(client, leaf).with_raw_response.connect(**kwargs): + pass + + +def test_speak_audio_generate_sync() -> None: + with respx.mock: + respx.route(host=HOST).mock(return_value=httpx.Response(200, content=b"audio-bytes")) + chunks = list(_sync_client().speak.v1.audio.generate(text="hello")) + assert b"".join(chunks) == b"audio-bytes" + + +def test_speak_audio_generate_with_request_options_sync() -> None: + # chunk_size + timeout exercise the request_options branches in the stream path. + with respx.mock: + respx.route(host=HOST).mock(return_value=httpx.Response(200, content=b"audio-bytes")) + chunks = list( + _sync_client().speak.v1.audio.generate( + text="hello", request_options={"chunk_size": 4, "timeout_in_seconds": 5} + ) + ) + assert b"".join(chunks) == b"audio-bytes" + + +async def test_speak_audio_generate_async() -> None: + with respx.mock: + respx.route(host=HOST).mock(return_value=httpx.Response(200, content=b"audio-bytes")) + chunks = [chunk async for chunk in _async_client().speak.v1.audio.generate(text="hello")] + assert b"".join(chunks) == b"audio-bytes" + + +async def test_speak_audio_generate_with_request_options_async() -> None: + with respx.mock: + respx.route(host=HOST).mock(return_value=httpx.Response(200, content=b"audio-bytes")) + chunks = [ + chunk + async for chunk in _async_client().speak.v1.audio.generate( + text="hello", request_options={"chunk_size": 4, "timeout_in_seconds": 5} + ) + ] + assert b"".join(chunks) == b"audio-bytes" + + +def test_speak_audio_generate_error_sync() -> None: + with respx.mock: + respx.route(host=HOST).mock(return_value=httpx.Response(400, json={"err": "bad"})) + with pytest.raises(ApiError): + list(_sync_client().speak.v1.audio.generate(text="hello")) + + +async def test_speak_audio_generate_error_async() -> None: + with respx.mock: + respx.route(host=HOST).mock(return_value=httpx.Response(400, json={"err": "bad"})) + with pytest.raises(ApiError): + [chunk async for chunk in _async_client().speak.v1.audio.generate(text="hello")]