Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion bases/renku_data_services/data_api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,12 @@ def register_all_handlers(app: Sanic, dm: DependencyManager) -> Sanic:
storage_repo=dm.storage_repo,
authenticator=dm.gitlab_authenticator,
)
storage_schema = StorageSchemaBP(name="storage_schema", url_prefix=url_prefix)
storage_schema = StorageSchemaBP(
name="storage_schema",
url_prefix=url_prefix,
data_source_repo=dm.data_source_repo,
authenticator=dm.authenticator,
)
user_preferences = UserPreferencesBP(
name="user_preferences",
url_prefix=url_prefix,
Expand Down
41 changes: 40 additions & 1 deletion components/renku_data_services/notebooks/data_sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from configparser import ConfigParser
from dataclasses import dataclass
from io import StringIO
from typing import Any
from typing import TYPE_CHECKING, Any

from sanic import Request

Expand All @@ -19,6 +19,9 @@
from renku_data_services.data_connectors.models import DataConnector, GlobalDataConnector
from renku_data_services.notebooks.config import NotebooksConfig

if TYPE_CHECKING:
from renku_data_services.storage.models import RCloneConfig

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -135,6 +138,42 @@ async def handle_patching_configuration(
parser.write(stringio)
return stringio.getvalue()

async def handle_configuration_for_test(
self, user: APIUser, configuration: "RCloneConfig | dict[str, Any]"
) -> "RCloneConfig | dict[str, Any] | None":
"""Ajusts the input configuration if it requires an OAuth2 connection.

Returns either an rclone configuration or None if the data connector should be skipped.
"""
provider_kind: ProviderKind | None = None
match configuration.get("type"):
case "drive":
provider_kind = ProviderKind.google
case "dropbox":
provider_kind = ProviderKind.dropbox
if provider_kind is None:
return configuration

provider = await self.connected_services_repo.get_provider_for_kind(user=user, provider_kind=provider_kind)
if provider is None:
return None
connection = provider.connected_user.connection if provider.connected_user else None
if connection is None:
return None
token_set = await self.connected_services_repo.get_token_set(user=user, connection_id=connection.id)
if not token_set or not token_set.access_token:
return None
token_config = {
"access_token": token_set.access_token,
"token_type": "Bearer",
}
if provider_kind == ProviderKind.google:
configuration["scope"] = configuration.get("scope") or "drive"
if token_set.expires_at_iso:
token_config["expiry"] = token_set.expires_at_iso
configuration["token"] = json.dumps(token_config)
return configuration

def _get_oauth2_provider_kind(self, data_connector: DataConnector | GlobalDataConnector) -> ProviderKind | None:
"""Returns the provider kind for data connectors which require an OAuth2 configuration."""
match data_connector.storage.configuration["type"]:
Expand Down
14 changes: 12 additions & 2 deletions components/renku_data_services/storage/blueprints.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from renku_data_services.base_api.blueprint import BlueprintFactoryResponse, CustomBlueprint
from renku_data_services.base_api.misc import validate_query
from renku_data_services.base_models.validation import validated_json
from renku_data_services.notebooks.data_sources import DataSourceRepository
from renku_data_services.storage import apispec, models
from renku_data_services.storage.db import StorageRepository
from renku_data_services.storage.rclone import RCloneValidator
Expand Down Expand Up @@ -193,6 +194,9 @@ async def _delete(request: Request, user: base_models.APIUser, storage_id: ULID)
class StorageSchemaBP(CustomBlueprint):
"""Handler for getting RClone storage schema."""

data_source_repo: DataSourceRepository
authenticator: base_models.Authenticator

def get(self) -> BlueprintFactoryResponse:
"""Get cloud storage for a repository."""

Expand All @@ -204,12 +208,18 @@ async def _get(_: Request, validator: RCloneValidator) -> JSONResponse:
def test_connection(self) -> BlueprintFactoryResponse:
"""Validate an RClone config."""

@authenticate(self.authenticator)
@validate(json=apispec.StorageSchemaTestConnectionPostRequest)
async def _test_connection(
request: Request, validator: RCloneValidator, body: apispec.StorageSchemaTestConnectionPostRequest
request: Request,
user: base_models.APIUser,
validator: RCloneValidator,
body: apispec.StorageSchemaTestConnectionPostRequest,
) -> HTTPResponse:
validator.validate(body.configuration, keep_sensitive=True)
result = await validator.test_connection(body.configuration, body.source_path)
result = await validator.test_connection(
body.configuration, body.source_path, user=user, data_source_repo=self.data_source_repo
)
if not result.success:
raise errors.ValidationError(message=result.error)
return empty(204)
Expand Down
16 changes: 15 additions & 1 deletion components/renku_data_services/storage/rclone.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
logger = logging.getLogger(__name__)

if TYPE_CHECKING:
from renku_data_services import base_models
from renku_data_services.notebooks.data_sources import DataSourceRepository
from renku_data_services.storage.models import RCloneConfig


Expand Down Expand Up @@ -88,7 +90,11 @@ def get_real_configuration(self, configuration: Union[RCloneConfig, dict[str, An
return real_config

async def test_connection(
self, configuration: Union[RCloneConfig, dict[str, Any]], source_path: str
self,
configuration: Union[RCloneConfig, dict[str, Any]],
source_path: str,
user: base_models.APIUser | None = None,
data_source_repo: DataSourceRepository | None = None,
) -> ConnectionResult:
"""Tests connecting with an RClone config."""
try:
Expand All @@ -101,6 +107,14 @@ async def test_connection(
transformed_config = self.inject_default_values(self.transform_polybox_switchdriver_config(obscured_config))
transformed_config = self.transform_envidat_config(transformed_config)

# Handle testing with Renku integrations
if user is not None and data_source_repo is not None:
with_oauth2_config = await data_source_repo.handle_configuration_for_test(
user=user, configuration=transformed_config
)
if with_oauth2_config is not None:
transformed_config = with_oauth2_config

with tempfile.NamedTemporaryFile(mode="w+", delete=False, encoding="utf-8") as f:
config = "\n".join(f"{k}={v}" for k, v in transformed_config.items())
f.write(f"[temp]\n{config}")
Expand Down
Loading