Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 30 additions & 20 deletions litellm/llms/vertex_ai/vertex_llm_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,15 @@ def load_auth(
else ""
)
if isinstance(environment_id, str) and "aws" in environment_id:
creds = self._credentials_from_identity_pool_with_aws(json_obj)
creds = self._credentials_from_identity_pool_with_aws(
json_obj,
scopes=["https://www.googleapis.com/auth/cloud-platform"],
)
else:
creds = self._credentials_from_identity_pool(json_obj)
creds = self._credentials_from_identity_pool(
json_obj,
scopes=["https://www.googleapis.com/auth/cloud-platform"],
)
# Check if the JSON object contains Authorized User configuration (via gcloud auth application-default login)
elif "type" in json_obj and json_obj["type"] == "authorized_user":
creds = self._credentials_from_authorized_user(
Expand Down Expand Up @@ -131,15 +137,21 @@ def load_auth(
return creds, project_id

# Google Auth Helpers -- extracted for mocking purposes in tests
def _credentials_from_identity_pool(self, json_obj):
def _credentials_from_identity_pool(self, json_obj, scopes):
from google.auth import identity_pool

return identity_pool.Credentials.from_info(json_obj)
creds = identity_pool.Credentials.from_info(json_obj)
if scopes and hasattr(creds, "requires_scopes") and creds.requires_scopes:
creds = creds.with_scopes(scopes)
return creds

def _credentials_from_identity_pool_with_aws(self, json_obj):
def _credentials_from_identity_pool_with_aws(self, json_obj, scopes):
from google.auth import aws

return aws.Credentials.from_info(json_obj)
creds = aws.Credentials.from_info(json_obj)
if scopes and hasattr(creds, "requires_scopes") and creds.requires_scopes:
creds = creds.with_scopes(scopes)
return creds

def _credentials_from_authorized_user(self, json_obj, scopes):
import google.oauth2.credentials
Expand Down Expand Up @@ -300,7 +312,7 @@ def _check_custom_proxy(
) -> Tuple[Optional[str], str]:
"""
for cloudflare ai gateway - https://github.com/BerriAI/litellm/issues/4317

Handles custom api_base for:
1. Gemini (Google AI Studio) - constructs /models/{model}:{endpoint}
2. Vertex AI with standard proxies - constructs {api_base}:{endpoint}
Expand Down Expand Up @@ -328,7 +340,7 @@ def _check_custom_proxy(
"Missing gemini_api_key, please set `GEMINI_API_KEY`"
)
if gemini_api_key is not None:
auth_header = {"x-goog-api-key": gemini_api_key} # type: ignore[assignment]
auth_header = {"x-goog-api-key": gemini_api_key} # type: ignore[assignment]
else:
# For Vertex AI
if use_psc_endpoint_format:
Expand Down Expand Up @@ -396,9 +408,7 @@ def _get_token_and_url(
)

### SET RUNTIME ENDPOINT ###
version = (
"v1beta1" if should_use_v1beta1_features is True else "v1"
)
version = "v1beta1" if should_use_v1beta1_features is True else "v1"
url, endpoint = _get_vertex_url(
mode=mode,
model=model,
Expand Down Expand Up @@ -675,13 +685,13 @@ def get_vertex_ai_location(litellm_params: dict) -> Optional[str]:
def safe_get_vertex_ai_project(litellm_params: dict) -> Optional[str]:
"""
Safely get Vertex AI project without mutating the litellm_params dict.

Unlike get_vertex_ai_project(), this does NOT pop values from the dict,
making it safe to call multiple times with the same litellm_params.

Args:
litellm_params: Dictionary containing Vertex AI parameters

Returns:
Vertex AI project ID or None
"""
Expand All @@ -696,13 +706,13 @@ def safe_get_vertex_ai_project(litellm_params: dict) -> Optional[str]:
def safe_get_vertex_ai_credentials(litellm_params: dict) -> Optional[str]:
"""
Safely get Vertex AI credentials without mutating the litellm_params dict.

Unlike get_vertex_ai_credentials(), this does NOT pop values from the dict,
making it safe to call multiple times with the same litellm_params.

Args:
litellm_params: Dictionary containing Vertex AI parameters

Returns:
Vertex AI credentials or None
"""
Expand All @@ -716,13 +726,13 @@ def safe_get_vertex_ai_credentials(litellm_params: dict) -> Optional[str]:
def safe_get_vertex_ai_location(litellm_params: dict) -> Optional[str]:
"""
Safely get Vertex AI location without mutating the litellm_params dict.

Unlike get_vertex_ai_location(), this does NOT pop values from the dict,
making it safe to call multiple times with the same litellm_params.

Args:
litellm_params: Dictionary containing Vertex AI parameters

Returns:
Vertex AI location/region or None
"""
Expand Down
Loading
Loading