Skip to content

Commit 68fd17d

Browse files
authored
[Fix] QA Fixes - Vector Store Object Permissions (#11291)
* fix: QA for key,team,org permissions * fix: add_vector_store_to_registry * fix: refactor bedrock guard * fix: refactor using us east 1 with vector stores * fix: code QA checks * fix: testing for mgmt endpoints
1 parent e0daa3d commit 68fd17d

File tree

12 files changed

+338
-178
lines changed

12 files changed

+338
-178
lines changed

litellm/integrations/vector_stores/bedrock_vector_store.py

Lines changed: 40 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@
3434
VectorStoreSearchResponse,
3535
VectorStoreSearchResult,
3636
)
37-
from litellm.utils import load_credentials_from_list
3837

3938
if TYPE_CHECKING:
4039
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
@@ -258,22 +257,49 @@ async def make_bedrock_kb_retrieve_request(
258257
from fastapi import HTTPException
259258

260259
non_default_params = non_default_params or {}
261-
load_credentials_from_list(kwargs=non_default_params)
260+
credentials_dict: Dict[str, Any] = {}
261+
if litellm.vector_store_registry is not None:
262+
credentials_dict = (
263+
litellm.vector_store_registry.get_credentials_for_vector_store(
264+
knowledge_base_id
265+
)
266+
)
267+
262268
credentials = self.get_credentials(
263-
aws_access_key_id=non_default_params.get("aws_access_key_id", None),
264-
aws_secret_access_key=non_default_params.get("aws_secret_access_key", None),
265-
aws_session_token=non_default_params.get("aws_session_token", None),
266-
aws_region_name=non_default_params.get("aws_region_name", None),
267-
aws_session_name=non_default_params.get("aws_session_name", None),
268-
aws_profile_name=non_default_params.get("aws_profile_name", None),
269-
aws_role_name=non_default_params.get("aws_role_name", None),
270-
aws_web_identity_token=non_default_params.get(
271-
"aws_web_identity_token", None
269+
aws_access_key_id=credentials_dict.get(
270+
"aws_access_key_id", non_default_params.get("aws_access_key_id", None)
271+
),
272+
aws_secret_access_key=credentials_dict.get(
273+
"aws_secret_access_key",
274+
non_default_params.get("aws_secret_access_key", None),
275+
),
276+
aws_session_token=credentials_dict.get(
277+
"aws_session_token", non_default_params.get("aws_session_token", None)
278+
),
279+
aws_region_name=credentials_dict.get(
280+
"aws_region_name", non_default_params.get("aws_region_name", None)
281+
),
282+
aws_session_name=credentials_dict.get(
283+
"aws_session_name", non_default_params.get("aws_session_name", None)
284+
),
285+
aws_profile_name=credentials_dict.get(
286+
"aws_profile_name", non_default_params.get("aws_profile_name", None)
287+
),
288+
aws_role_name=credentials_dict.get(
289+
"aws_role_name", non_default_params.get("aws_role_name", None)
290+
),
291+
aws_web_identity_token=credentials_dict.get(
292+
"aws_web_identity_token",
293+
non_default_params.get("aws_web_identity_token", None),
294+
),
295+
aws_sts_endpoint=credentials_dict.get(
296+
"aws_sts_endpoint", non_default_params.get("aws_sts_endpoint", None)
272297
),
273-
aws_sts_endpoint=non_default_params.get("aws_sts_endpoint", None),
274298
)
275-
aws_region_name = self._get_aws_region_name(
276-
optional_params=self.optional_params
299+
aws_region_name = self.get_aws_region_name_for_non_llm_api_calls(
300+
aws_region_name=credentials_dict.get(
301+
"aws_region_name", non_default_params.get("aws_region_name", None)
302+
),
277303
)
278304

279305
# Prepare request data

litellm/llms/bedrock/base_aws_llm.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,36 @@ def _get_aws_region_name(
336336

337337
return aws_region_name
338338

339+
def get_aws_region_name_for_non_llm_api_calls(
340+
self,
341+
aws_region_name: Optional[str] = None,
342+
):
343+
"""
344+
Get the AWS region name for non-llm api calls.
345+
346+
LLM API calls check the model arn and end up using that as the region name.
347+
348+
For non-llm api calls eg. Guardrails, Vector Stores we just need to check the dynamic param or env vars.
349+
"""
350+
if aws_region_name is None:
351+
# check env #
352+
litellm_aws_region_name = get_secret("AWS_REGION_NAME", None)
353+
354+
if litellm_aws_region_name is not None and isinstance(
355+
litellm_aws_region_name, str
356+
):
357+
aws_region_name = litellm_aws_region_name
358+
359+
standard_aws_region_name = get_secret("AWS_REGION", None)
360+
if standard_aws_region_name is not None and isinstance(
361+
standard_aws_region_name, str
362+
):
363+
aws_region_name = standard_aws_region_name
364+
365+
if aws_region_name is None:
366+
aws_region_name = "us-west-2"
367+
return aws_region_name
368+
339369
@tracer.wrap()
340370
def _auth_with_web_identity_token(
341371
self,

litellm/proxy/guardrails/guardrail_hooks/bedrock_guardrails.py

Lines changed: 3 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030
httpxSpecialProvider,
3131
)
3232
from litellm.proxy._types import UserAPIKeyAuth
33-
from litellm.secret_managers.main import get_secret
3433
from litellm.types.guardrails import GuardrailEventHooks
3534
from litellm.types.llms.openai import AllMessageValues
3635
from litellm.types.proxy.guardrails.guardrail_hooks.bedrock_guardrails import (
@@ -129,23 +128,9 @@ def _load_credentials(
129128
aws_sts_endpoint = self.optional_params.get("aws_sts_endpoint", None)
130129

131130
### SET REGION NAME ###
132-
if aws_region_name is None:
133-
# check env #
134-
litellm_aws_region_name = get_secret("AWS_REGION_NAME", None)
135-
136-
if litellm_aws_region_name is not None and isinstance(
137-
litellm_aws_region_name, str
138-
):
139-
aws_region_name = litellm_aws_region_name
140-
141-
standard_aws_region_name = get_secret("AWS_REGION", None)
142-
if standard_aws_region_name is not None and isinstance(
143-
standard_aws_region_name, str
144-
):
145-
aws_region_name = standard_aws_region_name
146-
147-
if aws_region_name is None:
148-
aws_region_name = "us-west-2"
131+
aws_region_name = self.get_aws_region_name_for_non_llm_api_calls(
132+
aws_region_name=aws_region_name,
133+
)
149134

150135
credentials: Credentials = self.get_credentials(
151136
aws_access_key_id=aws_access_key_id,

litellm/proxy/management_endpoints/key_management_endpoints.py

Lines changed: 15 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,9 @@
4545
from litellm.proxy.management_endpoints.model_management_endpoints import (
4646
_add_model_to_db,
4747
)
48+
from litellm.proxy.management_helpers.object_permission_utils import (
49+
handle_update_object_permission_common,
50+
)
4851
from litellm.proxy.management_helpers.team_member_permission_checks import (
4952
TeamMemberPermissionChecks,
5053
)
@@ -733,48 +736,20 @@ async def _handle_update_object_permission(
733736
"""
734737
from litellm.proxy.proxy_server import prisma_client
735738

736-
if prisma_client is None:
737-
raise ValueError("Prisma client not found")
738-
739-
#########################################################
740-
# Ensure `object_permission` is not added to the data_json
741-
# We need to update the entity at the object_permission_id level in the LiteLLM_ObjectPermissionTable
742-
#########################################################
743-
new_object_permission = data_json.pop("object_permission")
744-
if new_object_permission is None:
745-
return data_json
746-
747-
# lookup existing object permission ID and update that entry
748-
existing_object_permission_id = existing_key_row.object_permission_id
749-
existing_object_permission = (
750-
await prisma_client.db.litellm_objectpermissiontable.find_unique(
751-
where={"object_permission_id": existing_object_permission_id},
752-
)
739+
# Use the common helper to handle the object permission update
740+
object_permission_id = await handle_update_object_permission_common(
741+
data_json=data_json,
742+
existing_object_permission_id=existing_key_row.object_permission_id,
743+
prisma_client=prisma_client,
753744
)
754-
existing_object_permissions_dict: dict = {}
755-
if existing_object_permission is not None:
756-
# update the object permission
757-
existing_object_permissions_dict = existing_object_permission.model_dump(
758-
exclude_unset=True, exclude_none=True
759-
)
760-
existing_object_permissions_dict.update(dict(new_object_permission))
761-
762-
#########################################################
763-
# Commit the update to the LiteLLM_ObjectPermissionTable
764-
#########################################################
765-
new_object_permission_row = (
766-
await prisma_client.db.litellm_objectpermissiontable.upsert(
767-
where={"object_permission_id": existing_object_permission_id},
768-
data={
769-
"create": existing_object_permissions_dict,
770-
"update": existing_object_permissions_dict,
771-
},
745+
746+
# Add the object_permission_id to data_json if one was created/updated
747+
if object_permission_id is not None:
748+
data_json["object_permission_id"] = object_permission_id
749+
verbose_proxy_logger.debug(
750+
f"updated object_permission_id: {object_permission_id}"
772751
)
773-
)
774-
verbose_proxy_logger.debug(
775-
f"new_object_permission_row: {new_object_permission_row}"
776-
)
777-
data_json["object_permission_id"] = new_object_permission_row.object_permission_id
752+
778753
return data_json
779754

780755

litellm/proxy/management_endpoints/organization_endpoints.py

Lines changed: 11 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@
2424
new_budget,
2525
update_budget,
2626
)
27+
from litellm.proxy.management_helpers.object_permission_utils import (
28+
handle_update_object_permission_common,
29+
)
2730
from litellm.proxy.management_helpers.utils import (
2831
get_new_internal_user_defaults,
2932
management_endpoint_wrapper,
@@ -309,61 +312,20 @@ async def handle_update_object_permission(
309312
- Upserts the new object permission into the LiteLLM_ObjectPermissionTable
310313
- Adds object_permission_id to data_json (this gets added in the DB)
311314
- Pops the object_permission from data_json
312-
-
313315
"""
314316
from litellm.proxy.proxy_server import prisma_client
315317

316-
if prisma_client is None:
317-
raise ValueError("Prisma client not found")
318-
319-
#########################################################
320-
# Ensure `object_permission` is not added to the data_json
321-
# We need to update the entity at the object_permission_id level in the LiteLLM_ObjectPermissionTable
322-
#########################################################
323-
new_object_permission: Union[dict, str] = data_json.pop("object_permission") or {}
324-
if new_object_permission is None:
325-
return data_json
326-
327-
# lookup existing object permission ID and update that entry
328-
existing_object_permission_id = existing_organization_row.object_permission_id
329-
existing_object_permissions_dict = {}
330-
331-
existing_object_permission = (
332-
await prisma_client.db.litellm_objectpermissiontable.find_unique(
333-
where={"object_permission_id": existing_object_permission_id},
334-
)
318+
# Use the common helper to handle the object permission update
319+
object_permission_id = await handle_update_object_permission_common(
320+
data_json=data_json,
321+
existing_object_permission_id=existing_organization_row.object_permission_id,
322+
prisma_client=prisma_client,
335323
)
336324

337-
# update the object permission
338-
if existing_object_permission is not None:
339-
existing_object_permissions_dict = existing_object_permission.model_dump(
340-
exclude_unset=True, exclude_none=True
341-
)
325+
# Add the object_permission_id to data_json if one was created/updated
326+
if object_permission_id is not None:
327+
data_json["object_permission_id"] = object_permission_id
342328

343-
if isinstance(new_object_permission, str):
344-
new_object_permission = json.loads(new_object_permission)
345-
346-
if isinstance(new_object_permission, dict):
347-
existing_object_permissions_dict.update(new_object_permission)
348-
349-
#########################################################
350-
# Commit the update to the LiteLLM_ObjectPermissionTable
351-
#########################################################
352-
created_object_permission_row = (
353-
await prisma_client.db.litellm_objectpermissiontable.upsert(
354-
where={"object_permission_id": existing_object_permission_id},
355-
data={
356-
"create": existing_object_permissions_dict,
357-
"update": existing_object_permissions_dict,
358-
},
359-
)
360-
)
361-
data_json[
362-
"object_permission_id"
363-
] = created_object_permission_row.object_permission_id
364-
verbose_proxy_logger.debug(
365-
f"created_object_permission_row: {created_object_permission_row}"
366-
)
367329
return data_json
368330

369331

litellm/proxy/management_endpoints/team_endpoints.py

Lines changed: 14 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,9 @@
7070
from litellm.proxy.management_endpoints.tag_management_endpoints import (
7171
get_daily_activity,
7272
)
73+
from litellm.proxy.management_helpers.object_permission_utils import (
74+
handle_update_object_permission_common,
75+
)
7376
from litellm.proxy.management_helpers.team_member_permission_checks import (
7477
TeamMemberPermissionChecks,
7578
)
@@ -752,48 +755,20 @@ async def handle_update_object_permission(
752755
"""
753756
from litellm.proxy.proxy_server import prisma_client
754757

755-
if prisma_client is None:
756-
raise ValueError("Prisma client not found")
757-
758-
#########################################################
759-
# Ensure `object_permission` is not added to the data_json
760-
# We need to update the entity at the object_permission_id level in the LiteLLM_ObjectPermissionTable
761-
#########################################################
762-
new_object_permission = data_json.pop("object_permission")
763-
if new_object_permission is None:
764-
return data_json
765-
766-
# lookup existing object permission ID and update that entry
767-
existing_object_permission_id = existing_team_row.object_permission_id
768-
existing_object_permission = (
769-
await prisma_client.db.litellm_objectpermissiontable.find_unique(
770-
where={"object_permission_id": existing_object_permission_id},
771-
)
758+
# Use the common helper to handle the object permission update
759+
object_permission_id = await handle_update_object_permission_common(
760+
data_json=data_json,
761+
existing_object_permission_id=existing_team_row.object_permission_id,
762+
prisma_client=prisma_client,
772763
)
773764

774-
existing_object_permissions_dict: Dict = {}
775-
776-
# update the object permission
777-
if existing_object_permission is not None:
778-
existing_object_permissions_dict = existing_object_permission.model_dump(
779-
exclude_unset=True, exclude_none=True
765+
# Add the object_permission_id to data_json if one was created/updated
766+
if object_permission_id is not None:
767+
data_json["object_permission_id"] = object_permission_id
768+
verbose_proxy_logger.debug(
769+
f"updated object_permission_id: {object_permission_id}"
780770
)
781-
existing_object_permissions_dict.update(dict(new_object_permission))
782-
created_object_permission_row = (
783-
await prisma_client.db.litellm_objectpermissiontable.upsert(
784-
where={"object_permission_id": existing_object_permission_id},
785-
data={
786-
"create": existing_object_permissions_dict,
787-
"update": existing_object_permissions_dict,
788-
},
789-
)
790-
)
791-
data_json[
792-
"object_permission_id"
793-
] = created_object_permission_row.object_permission_id
794-
verbose_proxy_logger.debug(
795-
f"created_object_permission_row: {created_object_permission_row}"
796-
)
771+
797772
return data_json
798773

799774

0 commit comments

Comments
 (0)