Skip to content

Commit 2df0e2b

Browse files
authored
fix(backend/api): Fix & add tests for APIKeyAuthenticator (#10881)
- Resolves #10875 ### Changes ๐Ÿ—๏ธ - Fix use of `super().__call__` in `APIKeyAuthenticator.__call__` - Fix non-ASCII API key validation - Add tests for `APIKeyAuthenticator` ### Checklist ๐Ÿ“‹ #### For code changes: - [x] I have clearly listed my changes in the PR description - [x] I have made a test plan - [x] I have tested my changes according to the test plan: - [x] Test implementations have been verified manually - [x] All the new tests pass
1 parent 925f249 commit 2df0e2b

File tree

2 files changed

+553
-5
lines changed

2 files changed

+553
-5
lines changed

โ€Žautogpt_platform/backend/backend/server/utils/api_key_auth.pyโ€Ž

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,18 @@
33
"""
44

55
import inspect
6+
import logging
67
import secrets
7-
from typing import Any, Callable, Optional
8+
from typing import Any, Awaitable, Callable, Optional
89

910
from fastapi import HTTPException, Request
1011
from fastapi.security import APIKeyHeader
1112
from starlette.status import HTTP_401_UNAUTHORIZED
1213

1314
from backend.util.exceptions import MissingConfigError
1415

16+
logger = logging.getLogger(__name__)
17+
1518

1619
class APIKeyAuthenticator(APIKeyHeader):
1720
"""
@@ -51,7 +54,8 @@ async def validate_with_db(api_key: str):
5154
header_name (str): The name of the header containing the API key
5255
expected_token (Optional[str]): The expected API key value for simple token matching
5356
validator (Optional[Callable]): Custom validation function that takes an API key
54-
string and returns a boolean or object. Can be async.
57+
string and returns a truthy value if and only if the passed string is a
58+
valid API key. Can be async.
5559
status_if_missing (int): HTTP status code to use for validation errors
5660
message_if_invalid (str): Error message to return when validation fails
5761
"""
@@ -60,7 +64,9 @@ def __init__(
6064
self,
6165
header_name: str,
6266
expected_token: Optional[str] = None,
63-
validator: Optional[Callable[[str], bool]] = None,
67+
validator: Optional[
68+
Callable[[str], Any] | Callable[[str], Awaitable[Any]]
69+
] = None,
6470
status_if_missing: int = HTTP_401_UNAUTHORIZED,
6571
message_if_invalid: str = "Invalid API key",
6672
):
@@ -75,7 +81,7 @@ def __init__(
7581
self.message_if_invalid = message_if_invalid
7682

7783
async def __call__(self, request: Request) -> Any:
78-
api_key = await super()(request)
84+
api_key = await super().__call__(request)
7985
if api_key is None:
8086
raise HTTPException(
8187
status_code=self.status_if_missing, detail="No API key in request"
@@ -106,4 +112,9 @@ async def default_validator(self, api_key: str) -> bool:
106112
f"{self.__class__.__name__}.expected_token is not set; "
107113
"either specify it or provide a custom validator"
108114
)
109-
return secrets.compare_digest(api_key, self.expected_token)
115+
try:
116+
return secrets.compare_digest(api_key, self.expected_token)
117+
except TypeError as e:
118+
# If value is not an ASCII string, compare_digest raises a TypeError
119+
logger.warning(f"{self.model.name} API key check failed: {e}")
120+
return False

0 commit comments

Comments
ย (0)