From d5995f50c16bd02bd7c1814809441e96d5f1fda0 Mon Sep 17 00:00:00 2001 From: hugo093 Date: Wed, 27 May 2026 16:30:17 +0200 Subject: [PATCH 1/2] feat: account API + front end adaptations --- .../app/hooks/useGetCustomTranslations.ts | 2 +- .../apps/Iaso/domains/dataSources/requests.js | 6 +- .../js/apps/Iaso/hooks/useSwitchAccount.tsx | 4 +- iaso/api/accounts.py | 97 ----------- iaso/api/accounts/__init__.py | 0 iaso/api/accounts/pagination.py | 5 + iaso/api/accounts/serializers/__init__.py | 0 .../serializers/custom_translations.py | 11 ++ iaso/api/accounts/serializers/list.py | 8 + iaso/api/accounts/serializers/retrieve.py | 18 +++ .../accounts/serializers/retrieve_current.py | 64 ++++++++ .../serializers/set_default_version.py | 28 ++++ iaso/api/accounts/serializers/switch.py | 15 ++ iaso/api/accounts/serializers/update.py | 25 +++ iaso/api/accounts/views.py | 150 ++++++++++++++++++ iaso/api/custom_translations.py | 34 ---- iaso/api/profiles/serializers/retrieve.py | 2 +- iaso/models/account.py | 26 ++- iaso/test.py | 7 + iaso/tests/api/account/__init__.py | 0 .../api/account/test_custom_translations.py | 82 ++++++++++ iaso/tests/api/account/test_list.py | 73 +++++++++ iaso/tests/api/account/test_retrieve.py | 84 ++++++++++ .../api/account/test_retrieve_current.py | 88 ++++++++++ .../api/account/test_set_default_version.py | 100 ++++++++++++ iaso/tests/api/account/test_switch.py | 70 ++++++++ iaso/tests/api/account/test_update.py | 128 +++++++++++++++ iaso/tests/api/test_account.py | 134 ---------------- iaso/tests/api/test_custom_translations.py | 42 ----- iaso/urls.py | 4 +- 30 files changed, 986 insertions(+), 321 deletions(-) delete mode 100644 iaso/api/accounts.py create mode 100644 iaso/api/accounts/__init__.py create mode 100644 iaso/api/accounts/pagination.py create mode 100644 iaso/api/accounts/serializers/__init__.py create mode 100644 iaso/api/accounts/serializers/custom_translations.py create mode 100644 iaso/api/accounts/serializers/list.py create mode 100644 iaso/api/accounts/serializers/retrieve.py create mode 100644 iaso/api/accounts/serializers/retrieve_current.py create mode 100644 iaso/api/accounts/serializers/set_default_version.py create mode 100644 iaso/api/accounts/serializers/switch.py create mode 100644 iaso/api/accounts/serializers/update.py create mode 100644 iaso/api/accounts/views.py delete mode 100644 iaso/api/custom_translations.py create mode 100644 iaso/tests/api/account/__init__.py create mode 100644 iaso/tests/api/account/test_custom_translations.py create mode 100644 iaso/tests/api/account/test_list.py create mode 100644 iaso/tests/api/account/test_retrieve.py create mode 100644 iaso/tests/api/account/test_retrieve_current.py create mode 100644 iaso/tests/api/account/test_set_default_version.py create mode 100644 iaso/tests/api/account/test_switch.py create mode 100644 iaso/tests/api/account/test_update.py delete mode 100644 iaso/tests/api/test_account.py delete mode 100644 iaso/tests/api/test_custom_translations.py diff --git a/hat/assets/js/apps/Iaso/domains/app/hooks/useGetCustomTranslations.ts b/hat/assets/js/apps/Iaso/domains/app/hooks/useGetCustomTranslations.ts index b9b170c408..20129febec 100644 --- a/hat/assets/js/apps/Iaso/domains/app/hooks/useGetCustomTranslations.ts +++ b/hat/assets/js/apps/Iaso/domains/app/hooks/useGetCustomTranslations.ts @@ -8,7 +8,7 @@ export const useGetCustomTranslations = ( return useSnackQuery({ queryKey: ['customTranslations', accountId], queryFn: () => - getRequest(`/api/custom_translations/?account_id=${accountId}`), + getRequest(`/api/accounts/${accountId}/custom-translations/`), options: { retry: false, keepPreviousData: true, diff --git a/hat/assets/js/apps/Iaso/domains/dataSources/requests.js b/hat/assets/js/apps/Iaso/domains/dataSources/requests.js index 9d95822297..122aa3fc86 100644 --- a/hat/assets/js/apps/Iaso/domains/dataSources/requests.js +++ b/hat/assets/js/apps/Iaso/domains/dataSources/requests.js @@ -1,4 +1,3 @@ -/* eslint-disable no-else-return */ import React from 'react'; import { useMutation, useQueryClient } from 'react-query'; import { @@ -177,7 +176,7 @@ export const csvPreview = async data => { }; export const updateDefaultDataSource = ([accountId, defaultVersionId]) => - putRequest(`/api/accounts/${accountId}/`, { + putRequest(`/api/accounts/${accountId}/set-default-version/`, { default_version: defaultVersionId, }); @@ -212,7 +211,7 @@ export const useSaveDataSource = setFieldErrors => { const saveDataSource = async form => { setIsSaving(true); - // eslint-disable-next-line camelcase + const { is_default_source, ...campaignData } = getValues(form); try { @@ -234,7 +233,6 @@ export const useSaveDataSource = setFieldErrors => { setIsSaving(false); } - // eslint-disable-next-line camelcase if (is_default_source && form.default_version_id.value) { await saveDefaultDataSourceMutation.mutateAsync([ currentUser.account.id, diff --git a/hat/assets/js/apps/Iaso/hooks/useSwitchAccount.tsx b/hat/assets/js/apps/Iaso/hooks/useSwitchAccount.tsx index 59d16d780f..dd88ace786 100644 --- a/hat/assets/js/apps/Iaso/hooks/useSwitchAccount.tsx +++ b/hat/assets/js/apps/Iaso/hooks/useSwitchAccount.tsx @@ -1,5 +1,5 @@ import { UseMutationResult } from 'react-query'; -import { patchRequest } from '../libs/Api'; +import { postRequest } from '../libs/Api'; import { useSnackMutation } from '../libs/apiHooks'; export const useSwitchAccount = ( @@ -7,7 +7,7 @@ export const useSwitchAccount = ( ): UseMutationResult => useSnackMutation({ mutationFn: accountId => - patchRequest('/api/accounts/switch/', { account_id: accountId }), + postRequest('/api/accounts/switch/', { account_id: accountId }), options: { onSuccess: onSuccess || (() => null) }, showSuccessSnackBar: false, }); diff --git a/iaso/api/accounts.py b/iaso/api/accounts.py deleted file mode 100644 index 20a44250d4..0000000000 --- a/iaso/api/accounts.py +++ /dev/null @@ -1,97 +0,0 @@ -"""This api is only there so the default version on an account can be modified""" - -from django.contrib.auth import login -from drf_spectacular.utils import extend_schema -from rest_framework import permissions, serializers, status -from rest_framework.decorators import action -from rest_framework.generics import get_object_or_404 -from rest_framework.request import Request -from rest_framework.response import Response - -from iaso.models import Account, SourceVersion -from iaso.permissions.core_permissions import CORE_SOURCE_PERMISSION - -from .common import HasPermission, ModelViewSet - - -class AccountSerializer(serializers.ModelSerializer): - class Meta: - model = Account - - fields = [ - "id", - "default_version", - ] - - def update(self, account, validated_data): - default_version = validated_data.pop("default_version", None) - user = self.context["request"].user - if default_version is not None: - source_version = get_object_or_404( - SourceVersion, - id=default_version.id, - number=default_version.number, - ) - projects = source_version.data_source.projects.all() - for p in projects: - if user.iaso_profile.account != p.account: - raise serializers.ValidationError({"Error": "Account not allowed to access this default_source"}) - account.default_version = source_version - account.save() - - return account - - -class HasAccountPermission(permissions.BasePermission): - def has_object_permission(self, request: Request, view, obj: Account): - if request.user.is_authenticated: - return request.user.iaso_profile.account == obj - return False - - -@extend_schema(tags=["Accounts"]) -class AccountViewSet(ModelViewSet): - f"""Account API - - This API is restricted to authenticated users having the "{CORE_SOURCE_PERMISSION}" permission - Only allow to update default source / version for an account - PUT /api/account/ - """ - - serializer_class = AccountSerializer - results_key = "accounts" - queryset = Account.objects.all() - # FIXME: USe a PATCH in the future, it make more sense regarding HTTP method semantic - http_method_names = ["patch", "put"] - - def get_permissions(self): - if self.action == "switch": - permission_classes = [permissions.IsAuthenticated, HasAccountPermission] - else: - permission_classes = [ - permissions.IsAuthenticated, - HasPermission(CORE_SOURCE_PERMISSION), - HasAccountPermission, - ] - - return [permission() for permission in permission_classes] - - @action(detail=False, methods=["patch"], url_path="switch") - def switch(self, request): - # TODO: Make sure the account_id is present - self.permission_classes = [permissions.IsAuthenticated, HasAccountPermission] - self.check_permissions(request) - account_id = int(request.data["account_id"]) if request.data.get("account_id") else None - - current_user = request.user - account_users = current_user.tenant_user.get_all_account_users() - user_to_login = next( - (u for u in account_users if u.iaso_profile and u.iaso_profile.account_id == account_id), None - ) - - if user_to_login: - user_to_login.backend = "iaso.auth.backends.MultiTenantAuthBackend" - login(request, user_to_login) - # Return an empty response since no data is needed by the frontend - return Response({}, status=status.HTTP_200_OK) - return Response(status=status.HTTP_404_NOT_FOUND) diff --git a/iaso/api/accounts/__init__.py b/iaso/api/accounts/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/iaso/api/accounts/pagination.py b/iaso/api/accounts/pagination.py new file mode 100644 index 0000000000..41004ee756 --- /dev/null +++ b/iaso/api/accounts/pagination.py @@ -0,0 +1,5 @@ +from iaso.api.common import Paginator + + +class AccountPagination(Paginator): + page_size = 20 diff --git a/iaso/api/accounts/serializers/__init__.py b/iaso/api/accounts/serializers/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/iaso/api/accounts/serializers/custom_translations.py b/iaso/api/accounts/serializers/custom_translations.py new file mode 100644 index 0000000000..a5111929fa --- /dev/null +++ b/iaso/api/accounts/serializers/custom_translations.py @@ -0,0 +1,11 @@ +from iaso.api.common import ModelSerializer +from iaso.models import Account + + +class AccountCustomTranslationsSerializer(ModelSerializer): + class Meta: + model = Account + fields = [ + "custom_translations", + ] + extra_kwargs = {"custom_translations": {"allow_null": True, "required": False, "read_only": True}} diff --git a/iaso/api/accounts/serializers/list.py b/iaso/api/accounts/serializers/list.py new file mode 100644 index 0000000000..033c5a0139 --- /dev/null +++ b/iaso/api/accounts/serializers/list.py @@ -0,0 +1,8 @@ +from iaso.api.common import ModelSerializer +from iaso.models import Account + + +class AccountListSerializer(ModelSerializer): + class Meta: + model = Account + fields = ["id", "name", "created_at", "updated_at"] diff --git a/iaso/api/accounts/serializers/retrieve.py b/iaso/api/accounts/serializers/retrieve.py new file mode 100644 index 0000000000..16cef2c8f4 --- /dev/null +++ b/iaso/api/accounts/serializers/retrieve.py @@ -0,0 +1,18 @@ +from iaso.api.common import ModelSerializer +from iaso.models import Account + + +class AccountRetrieveSerializer(ModelSerializer): + class Meta: + model = Account + fields = [ + "id", + "name", + "created_at", + "user_manual_path", + "forum_path", + "modules", + "enforce_password_validation", + "anthropic_api_key", + ] + read_only_fields = fields diff --git a/iaso/api/accounts/serializers/retrieve_current.py b/iaso/api/accounts/serializers/retrieve_current.py new file mode 100644 index 0000000000..f7524650cc --- /dev/null +++ b/iaso/api/accounts/serializers/retrieve_current.py @@ -0,0 +1,64 @@ +from drf_spectacular.utils import extend_schema_field +from rest_framework import serializers + +from iaso.api.common import ModelSerializer +from iaso.models import Account, AccountFeatureFlag, DataSource, SourceVersion + + +class NestedDataSourceSerializer(ModelSerializer): + url = serializers.CharField(source="credentials.url", read_only=True, allow_null=True) + + class Meta: + model = DataSource + fields = ["id", "url"] + read_only_fields = fields + + +class NestedDefaultVersionSerializer(ModelSerializer): + data_source = NestedDataSourceSerializer(read_only=True) + + class Meta: + model = SourceVersion + fields = ["id", "data_source"] + read_only_fields = fields + + +class OtherAccountSerializer(ModelSerializer): + class Meta: + model = Account + fields = ["id", "name"] + read_only_fields = fields + + +class FeatureFlagNestedSerializer(ModelSerializer): + class Meta: + model = AccountFeatureFlag + fields = ["name", "code"] + + +class AccountRetrieveCurrentSerializer(ModelSerializer): + other_accounts = serializers.SerializerMethodField() + default_version = NestedDefaultVersionSerializer(allow_null=True, required=False) + feature_flags = FeatureFlagNestedSerializer(allow_null=True, many=True, required=False) + + class Meta: + model = Account + fields = ["id", "name", "default_version", "other_accounts", "modules", "feature_flags"] + read_only_fields = fields + + def __init__(self, *args, **kwargs): + self.other_account_qs = kwargs.pop("other_account_qs", None) + + super(AccountRetrieveCurrentSerializer, self).__init__(*args, **kwargs) + + if self.other_account_qs is None: + if getattr(self.context.get("request", None), "user", None): + self.other_account_qs = ( + Account.objects.filter_for_user(self.context["request"].user) + .exclude(id=self.instance.id) + .distinct("id") + ) + + @extend_schema_field(OtherAccountSerializer(many=True, allow_null=True, allow_empty=True)) + def get_other_accounts(self, obj): + return OtherAccountSerializer(self.other_account_qs, many=True).data diff --git a/iaso/api/accounts/serializers/set_default_version.py b/iaso/api/accounts/serializers/set_default_version.py new file mode 100644 index 0000000000..fb08f39683 --- /dev/null +++ b/iaso/api/accounts/serializers/set_default_version.py @@ -0,0 +1,28 @@ +from rest_framework import serializers + +from iaso.api.common import ModelSerializer +from iaso.models import Account, SourceVersion + + +class AccountSetDefaultVersionSerializer(ModelSerializer): + default_version = serializers.PrimaryKeyRelatedField( + queryset=SourceVersion.objects.none(), + error_messages={ + "does_not_exist": "Account not allowed to access this default_source.", + }, + write_only=True, + required=True, + ) + + class Meta: + model = Account + fields = ["default_version"] + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.fields["default_version"].queryset = ( + SourceVersion.objects.filter(data_source__projects__account=self.instance) + .select_related("data_source") + .distinct() + ) diff --git a/iaso/api/accounts/serializers/switch.py b/iaso/api/accounts/serializers/switch.py new file mode 100644 index 0000000000..58cd6307fd --- /dev/null +++ b/iaso/api/accounts/serializers/switch.py @@ -0,0 +1,15 @@ +from rest_framework import serializers + +from iaso.models import Account + + +class AccountSwitchSerializer(serializers.Serializer): + account_id = serializers.PrimaryKeyRelatedField(queryset=Account.objects.none(), write_only=True, required=True) + + def __init__(self, *args, **kwargs): + account_id_qs = kwargs.pop("account_id_qs", None) + super(AccountSwitchSerializer, self).__init__(*args, **kwargs) + if getattr(self.context.get("request", None), "user", None): + self.fields["account_id"].queryset = ( + account_id_qs if account_id_qs else Account.objects.filter_for_user(self.context["request"].user) + ) diff --git a/iaso/api/accounts/serializers/update.py b/iaso/api/accounts/serializers/update.py new file mode 100644 index 0000000000..dda9a86f6d --- /dev/null +++ b/iaso/api/accounts/serializers/update.py @@ -0,0 +1,25 @@ +from iaso.api.common import ModelSerializer +from iaso.models import Account + + +class AccountUpdateSerializer(ModelSerializer): + class Meta: + model = Account + fields = [ + "name", + "user_manual_path", + "forum_path", + "modules", + "enforce_password_validation", + "anthropic_api_key", + "custom_translations", + ] + extra_kwargs = { + "name": {"write_only": True}, + "user_manual_path": {"write_only": True}, + "forum_path": {"write_only": True}, + "modules": {"write_only": True}, + "enforce_password_validation": {"write_only": True}, + "anthropic_api_key": {"write_only": True}, + "custom_translations": {"write_only": True}, + } diff --git a/iaso/api/accounts/views.py b/iaso/api/accounts/views.py new file mode 100644 index 0000000000..6c8258b538 --- /dev/null +++ b/iaso/api/accounts/views.py @@ -0,0 +1,150 @@ +from django.contrib.auth import get_user_model, login +from drf_spectacular.utils import extend_schema +from rest_framework import status +from rest_framework.decorators import action +from rest_framework.exceptions import NotFound +from rest_framework.filters import OrderingFilter +from rest_framework.mixins import RetrieveModelMixin, UpdateModelMixin +from rest_framework.permissions import IsAuthenticated +from rest_framework.response import Response +from rest_framework.viewsets import GenericViewSet + +from iaso.api.accounts.pagination import AccountPagination +from iaso.api.accounts.serializers.custom_translations import AccountCustomTranslationsSerializer +from iaso.api.accounts.serializers.list import AccountListSerializer +from iaso.api.accounts.serializers.retrieve import AccountRetrieveSerializer +from iaso.api.accounts.serializers.retrieve_current import AccountRetrieveCurrentSerializer +from iaso.api.accounts.serializers.set_default_version import AccountSetDefaultVersionSerializer +from iaso.api.accounts.serializers.switch import AccountSwitchSerializer +from iaso.api.accounts.serializers.update import AccountUpdateSerializer +from iaso.api.common import HasPermission +from iaso.api.common.mixin import CustomPaginationListModelMixin +from iaso.models import Account +from iaso.permissions.core_permissions import CORE_SOURCE_PERMISSION + + +@extend_schema(tags=["Account"]) +class AccountViewSet(CustomPaginationListModelMixin, RetrieveModelMixin, UpdateModelMixin, GenericViewSet): + http_method_names = ["get", "options", "patch", "put", "head", "trace", "post"] + pagination_class = AccountPagination + filter_backends = [OrderingFilter] + ordering = ["id"] + + @property + def permission_classes(self): + if self.action == "set_default_version": + return [IsAuthenticated, HasPermission(CORE_SOURCE_PERMISSION)] + return [IsAuthenticated] + + def get_serializer_class(self): + if self.action == "retrieve": + return AccountRetrieveSerializer + if self.action == "list": + return AccountListSerializer + if self.action in ["update", "partial_update"]: + return AccountUpdateSerializer + if self.action == "switch": + return AccountSwitchSerializer + if self.action == "set_default_version": + return AccountSetDefaultVersionSerializer + if self.action == "custom_translations": + return AccountCustomTranslationsSerializer + if self.action == "me": + return AccountRetrieveCurrentSerializer + raise NotImplementedError(f"Serializer not implemented for {self.action}") + + def get_queryset(self): + if not getattr(self.request, "user", None): + return Account.objects.none() + + qs = Account.objects.filter_for_user(self.request.user).distinct("id") + if self.action == "list": + qs = qs.only("id", "name", "created_at", "updated_at") + if self.action == "custom_translations": + qs = qs.only("custom_translations") + if self.action == "switch": + qs = qs.only("id") + if self.action == "retrieve": + qs = qs.only( + "id", + "name", + "created_at", + "user_manual_path", + "forum_path", + "modules", + "enforce_password_validation", + "anthropic_api_key", + ) + if self.action == "me": + qs = qs.select_related( + "default_version", "default_version__data_source", "default_version__data_source__credentials" + ).prefetch_related("feature_flags") + return qs + + @extend_schema(responses={204: None}) + def update(self, request, *args, **kwargs): + super().update(request, *args, **kwargs) + return Response(status=status.HTTP_204_NO_CONTENT) + + @action(detail=True, methods=["GET"], url_path="custom-translations") + def custom_translations(self, request, *args, **kwargs): + instance = self.get_object() + serializer = self.get_serializer(instance) + return Response(serializer.data) + + @extend_schema(responses={204: None}) + @action(detail=False, methods=["POST"], url_path="switch") + def switch(self, request): + serializer = self.get_serializer(data=request.data, account_id_qs=self.get_queryset()) + serializer.is_valid(raise_exception=True) + account = serializer.validated_data["account_id"] + + current_user = request.user + + if not getattr(current_user, "tenant_user", None): + return Response(status=status.HTTP_404_NOT_FOUND) + + user_to_login = ( + get_user_model() + .objects.filter( + iaso_profile__account=account, + tenant_user__main_user=current_user.tenant_user.main_user, + ) + .exclude(pk=current_user.pk) + .first() + ) + + if not user_to_login: + return Response(status=status.HTTP_404_NOT_FOUND) + + user_to_login.backend = "iaso.auth.backends.MultiTenantAuthBackend" + login(request, user_to_login) + + return Response(status=status.HTTP_204_NO_CONTENT) + + @extend_schema(responses={204: None}) + @action(detail=True, methods=["PUT", "PATCH"], url_path="set-default-version") + def set_default_version(self, request, *args, **kwargs): + instance = self.get_object() + serializer = self.get_serializer(instance, data=request.data) + serializer.is_valid(raise_exception=True) + serializer.save() + return Response(status=status.HTTP_204_NO_CONTENT) + + @extend_schema(responses={200: AccountRetrieveCurrentSerializer}) + @action(detail=False, methods=["GET"], url_path="me") + def me(self, request): + if not request.user and not request.user.iaso_profile and not request.user.iaso_profile.account: + raise NotFound + + qs = list(self.get_queryset()) + + other_account_qs = [a for a in qs if a.id != request.user.iaso_profile.account_id] + instance = next(iter([a for a in qs if a.id == request.user.iaso_profile.account_id]), None) + if not instance: + raise NotFound + + # other_accounts_qs = Account.objects.filter_for_user(self.request.user).distinct("id").exclude(id=request.user.iaso_profile.account_id) + + serializer = self.get_serializer(instance=instance, other_account_qs=other_account_qs) + return Response(serializer.data) diff --git a/iaso/api/custom_translations.py b/iaso/api/custom_translations.py deleted file mode 100644 index 287a7b5210..0000000000 --- a/iaso/api/custom_translations.py +++ /dev/null @@ -1,34 +0,0 @@ -from django.utils.translation import gettext_lazy as _ -from drf_spectacular.utils import extend_schema -from rest_framework import serializers, status -from rest_framework.generics import get_object_or_404 -from rest_framework.permissions import IsAuthenticated -from rest_framework.response import Response -from rest_framework.viewsets import ViewSet - -from iaso.models import Account - -from .accounts import HasAccountPermission - - -class CustomTranslationsSerializer(serializers.Serializer): - account_id = serializers.IntegerField( - required=True, - error_messages={"required": _("Account id is required.")}, - ) - - -@extend_schema(tags=["Custom translations"]) -class CustomTranslationsViewSet(ViewSet): - permission_classes = [IsAuthenticated, HasAccountPermission] - http_method_names = ["get"] - - def list(self, request): - serializer = CustomTranslationsSerializer(data=request.query_params) - serializer.is_valid(raise_exception=True) - account = get_object_or_404(Account, id=serializer.validated_data["account_id"]) - self.check_object_permissions(request, account) - return Response( - {"custom_translations": account.custom_translations}, - status=status.HTTP_200_OK, - ) diff --git a/iaso/api/profiles/serializers/retrieve.py b/iaso/api/profiles/serializers/retrieve.py index 2f5c213d6a..0ad96ffa0f 100644 --- a/iaso/api/profiles/serializers/retrieve.py +++ b/iaso/api/profiles/serializers/retrieve.py @@ -24,7 +24,7 @@ class NestedDataSourceSerializer(ModelSerializer): created_at = TimestampField(read_only=True) updated_at = TimestampField(read_only=True) - url = serializers.CharField(source="credentials__url", read_only=True, allow_null=True) + url = serializers.CharField(source="credentials.url", read_only=True, allow_null=True) class Meta: model = DataSource diff --git a/iaso/models/account.py b/iaso/models/account.py index 8a379fd2ba..0addf506e5 100644 --- a/iaso/models/account.py +++ b/iaso/models/account.py @@ -1,8 +1,10 @@ from django.conf import settings from django.core.validators import MinLengthValidator from django.db import models +from django.db.models import QuerySet from django.utils.text import slugify +from iaso.models.common import CreatedAndUpdatedModel from iaso.modules import MODULES, IasoModule from iaso.permissions.base import IasoPermission from iaso.utils.models.choice_array_field import ChoiceArrayField @@ -22,12 +24,28 @@ def __str__(self): return f"{self.name} ({self.code})" -class Account(models.Model): +class AccountQuerySet(QuerySet): + def filter_for_user(self, user): + if not user or not user.is_authenticated: + return self.none() + + tenant_user = getattr(user, "tenant_user", None) + + if tenant_user: + return self.filter( + profile__in=tenant_user.main_user.tenant_users.values_list( + "account_user__iaso_profile__id", + flat=True, + ) + ) + + return self.filter(profile=user.iaso_profile) + + +class Account(CreatedAndUpdatedModel): """Account represent a tenant (=roughly a client organization or a country)""" name = models.TextField(unique=True, validators=[MinLengthValidator(1)]) - created_at = models.DateTimeField(auto_now_add=True) - updated_at = models.DateTimeField(auto_now=True) default_version = models.ForeignKey("SourceVersion", null=True, blank=True, on_delete=models.SET_NULL) feature_flags = models.ManyToManyField(AccountFeatureFlag) user_manual_path = models.TextField(null=True, blank=True) @@ -42,6 +60,8 @@ class Account(models.Model): enforce_password_validation = models.BooleanField(default=True) anthropic_api_key = EncryptedTextField(null=True, blank=True, help_text="Anthropic API key used by the Form AI") + objects = models.Manager.from_queryset(AccountQuerySet)() + @property def short_sanitized_name(self): """ diff --git a/iaso/test.py b/iaso/test.py index 619724218b..16e89122a7 100644 --- a/iaso/test.py +++ b/iaso/test.py @@ -191,6 +191,13 @@ def normalize_schema(self, schema): schema["type"] = [t, "null"] elif isinstance(t, list) and "null" not in t: schema["type"] = t + ["null"] + elif "allOf" in schema: + schema["anyOf"] = [ + {"type": "null"}, + {"allOf": schema["allOf"]}, + ] + schema.pop("allOf", None) + schema.pop("nullable", None) for v in schema.get("properties", {}).values(): diff --git a/iaso/tests/api/account/__init__.py b/iaso/tests/api/account/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/iaso/tests/api/account/test_custom_translations.py b/iaso/tests/api/account/test_custom_translations.py new file mode 100644 index 0000000000..4910da3a99 --- /dev/null +++ b/iaso/tests/api/account/test_custom_translations.py @@ -0,0 +1,82 @@ +from django.contrib.auth import get_user_model +from django.urls import reverse +from rest_framework import status + +from iaso.models import Account, TenantUser +from iaso.test import APITestCase, SwaggerTestCaseMixin + + +class TestAccountCustomTranslations(SwaggerTestCaseMixin, APITestCase): + def setUp(self): + super().setUp() + self.account = Account.objects.create( + name="account", + custom_translations={"en": {"custom.key": "Custom value"}}, + ) + self.other_account = Account.objects.create(name="other account") + self.another_account = Account.objects.create(name="another account") + self.john_wick = self.create_user_with_profile(username="johnwick", account=self.another_account) + + self.jane_doe = self.create_user_with_profile(username="janedoe", account=self.account) + self.john_doe = self.create_user_with_profile(username="johndoe", account=self.other_account) + # multi tenant account + + # Create a main user without profile + main_user = get_user_model().objects.create(username="main_user") + + # And 2 account users with profile + self.account_user_ghi = self.create_user_with_profile(username="User_A", account=self.account) + TenantUser.objects.create(main_user=main_user, account_user=self.account_user_ghi) + TenantUser.objects.create(main_user=main_user, account_user=self.jane_doe) + self.account_user_wha = self.create_user_with_profile(username="User_B", account=self.other_account) + TenantUser.objects.create(main_user=main_user, account_user=self.account_user_wha) + TenantUser.objects.create(main_user=main_user, account_user=self.john_doe) + + def assertValidData(self, data): + self.assertResponseCompliantToSwagger(data, "AccountCustomTranslations") + + def test_permissions(self): + res = self.client.get(reverse("accounts-custom-translations", kwargs={"pk": self.account.pk})) + self.assertEqual(res.status_code, status.HTTP_401_UNAUTHORIZED) + + self.client.force_authenticate(self.john_wick) + + res = self.client.get(reverse("accounts-custom-translations", kwargs={"pk": self.another_account.pk})) + self.assertEqual(res.status_code, status.HTTP_200_OK) + + res = self.client.get(reverse("accounts-custom-translations", kwargs={"pk": self.other_account.pk})) + self.assertEqual(res.status_code, status.HTTP_404_NOT_FOUND) + + self.client.force_authenticate(self.jane_doe) + + res = self.client.get(reverse("accounts-custom-translations", kwargs={"pk": self.another_account.pk})) + self.assertEqual(res.status_code, status.HTTP_404_NOT_FOUND) + + res = self.client.get(reverse("accounts-custom-translations", kwargs={"pk": self.other_account.pk})) + self.assertEqual(res.status_code, status.HTTP_200_OK) + + res = self.client.get(reverse("accounts-custom-translations", kwargs={"pk": self.account.pk})) + self.assertEqual(res.status_code, status.HTTP_200_OK) + + def test_num_queries(self): + self.client.force_authenticate(self.jane_doe) + + with self.assertNumQueries(1): + res = self.client.get(reverse("accounts-custom-translations", kwargs={"pk": self.account.pk})) + self.assertJSONResponse(res, status.HTTP_200_OK) + + def test_custom_translations(self): + self.client.force_authenticate(self.jane_doe) + + res = self.client.get(reverse("accounts-custom-translations", kwargs={"pk": self.account.pk})) + res_data = self.assertJSONResponse(res, status.HTTP_200_OK) + self.assertValidData(res_data) + self.assertEqual(res_data["custom_translations"], {"en": {"custom.key": "Custom value"}}) + + self.client.force_authenticate(self.john_doe) + + res = self.client.get(reverse("accounts-custom-translations", kwargs={"pk": self.other_account.pk})) + res_data = self.assertJSONResponse(res, status.HTTP_200_OK) + self.assertValidData(res_data) + + self.assertIsNone(res_data["custom_translations"]) diff --git a/iaso/tests/api/account/test_list.py b/iaso/tests/api/account/test_list.py new file mode 100644 index 0000000000..b0faf3eeed --- /dev/null +++ b/iaso/tests/api/account/test_list.py @@ -0,0 +1,73 @@ +from django.contrib.auth import get_user_model +from django.urls import reverse +from rest_framework import status + +from iaso.models import Account, TenantUser +from iaso.test import APITestCase, SwaggerTestCaseMixin + + +class TestAccountList(SwaggerTestCaseMixin, APITestCase): + def setUp(self): + super().setUp() + + self.account = Account.objects.create(name="account") + self.other_account = Account.objects.create(name="other account") + self.another_account = Account.objects.create(name="another account") + self.john_wick = self.create_user_with_profile(username="johnwick", account=self.another_account) + + self.jane_doe = self.create_user_with_profile(username="janedoe", account=self.account) + self.john_doe = self.create_user_with_profile(username="johndoe", account=self.other_account) + # multi tenant account + + # Create a main user without profile + main_user = get_user_model().objects.create(username="main_user") + + # And 2 account users with profile + self.account_user_ghi = self.create_user_with_profile(username="User_A", account=self.account) + TenantUser.objects.create(main_user=main_user, account_user=self.account_user_ghi) + TenantUser.objects.create(main_user=main_user, account_user=self.jane_doe) + self.account_user_wha = self.create_user_with_profile(username="User_B", account=self.other_account) + TenantUser.objects.create(main_user=main_user, account_user=self.account_user_wha) + TenantUser.objects.create(main_user=main_user, account_user=self.john_doe) + + def assertValidData(self, data, expected_length): + self.assertValidListData(list_data=data, results_key="results", expected_length=expected_length, paginated=True) + self.assertResponseCompliantToSwagger(data, "PaginatedAccountListList") + + def test_list(self): + self.client.force_authenticate(self.jane_doe) + res = self.client.get(reverse("accounts-list")) + res_data = self.assertJSONResponse(res, status.HTTP_200_OK) + self.assertValidData(res_data, 2) + + def test_permissions(self): + res = self.client.get(reverse("accounts-list")) + self.assertEqual(res.status_code, status.HTTP_401_UNAUTHORIZED) + + self.client.force_authenticate(self.jane_doe) + res = self.client.get(reverse("accounts-list")) + self.assertJSONResponse(res, status.HTTP_200_OK) + + def test_should_not_see_other_accounts(self): + self.client.force_authenticate(self.jane_doe) + res = self.client.get(reverse("accounts-list")) + res_data = self.assertJSONResponse(res, status.HTTP_200_OK) + self.assertValidData(res_data, 2) + + self.assertNotIn(self.another_account.pk, [x["id"] for x in res_data["results"]]) + + self.client.force_authenticate(self.john_wick) + res = self.client.get(reverse("accounts-list")) + res_data = self.assertJSONResponse(res, status.HTTP_200_OK) + self.assertValidData(res_data, 1) + + self.assertNotIn(self.account.pk, [x["id"] for x in res_data["results"]]) + self.assertNotIn(self.other_account.pk, [x["id"] for x in res_data["results"]]) + + def test_num_queries(self): + self.client.force_authenticate(self.jane_doe) + with self.assertNumQueries(2): + res = self.client.get(reverse("accounts-list")) + + res_data = self.assertJSONResponse(res, status.HTTP_200_OK) + self.assertValidData(res_data, 2) diff --git a/iaso/tests/api/account/test_retrieve.py b/iaso/tests/api/account/test_retrieve.py new file mode 100644 index 0000000000..6368d7ea31 --- /dev/null +++ b/iaso/tests/api/account/test_retrieve.py @@ -0,0 +1,84 @@ +from django.contrib.auth import get_user_model +from django.urls import reverse +from rest_framework import status + +from iaso.models import Account, TenantUser +from iaso.modules import MODULE_VALIDATION_WORKFLOW +from iaso.test import APITestCase, SwaggerTestCaseMixin + + +class TestAccountRetrieve(SwaggerTestCaseMixin, APITestCase): + def setUp(self): + super().setUp() + self.account = Account.objects.create(name="account") + self.other_account = Account.objects.create(name="other account", modules=[MODULE_VALIDATION_WORKFLOW.codename]) + self.another_account = Account.objects.create(name="another account") + self.john_wick = self.create_user_with_profile(username="johnwick", account=self.another_account) + + self.jane_doe = self.create_user_with_profile(username="janedoe", account=self.account) + self.john_doe = self.create_user_with_profile(username="johndoe", account=self.other_account) + # multi tenant account + + # Create a main user without profile + main_user = get_user_model().objects.create(username="main_user") + + # And 2 account users with profile + self.account_user_ghi = self.create_user_with_profile(username="User_A", account=self.account) + TenantUser.objects.create(main_user=main_user, account_user=self.account_user_ghi) + TenantUser.objects.create(main_user=main_user, account_user=self.jane_doe) + self.account_user_wha = self.create_user_with_profile(username="User_B", account=self.other_account) + TenantUser.objects.create(main_user=main_user, account_user=self.account_user_wha) + TenantUser.objects.create(main_user=main_user, account_user=self.john_doe) + + def assertValidData(self, data): + self.assertResponseCompliantToSwagger(data, "AccountRetrieve") + + def test_permissions(self): + res = self.client.get(reverse("accounts-detail", kwargs={"pk": self.account.pk})) + self.assertEqual(res.status_code, status.HTTP_401_UNAUTHORIZED) + + self.client.force_authenticate(self.jane_doe) + res = self.client.get(reverse("accounts-detail", kwargs={"pk": self.account.pk})) + self.assertEqual(res.status_code, status.HTTP_200_OK) + + res = self.client.get(reverse("accounts-detail", kwargs={"pk": self.other_account.pk})) + self.assertEqual(res.status_code, status.HTTP_200_OK) + + res = self.client.get(reverse("accounts-detail", kwargs={"pk": self.another_account.pk})) + self.assertEqual(res.status_code, status.HTTP_404_NOT_FOUND) + + def test_retrieve(self): + self.client.force_authenticate(self.jane_doe) + res = self.client.get(reverse("accounts-detail", kwargs={"pk": self.account.pk})) + res_data = self.assertJSONResponse(res, status.HTTP_200_OK) + self.assertValidData(res_data) + + self.assertEqual(res_data["id"], self.account.id) + self.assertEqual(res_data["name"], self.account.name) + self.assertIsNotNone(res_data["created_at"]) + self.assertEqual(res_data["user_manual_path"], self.account.user_manual_path) + self.assertEqual(res_data["forum_path"], self.account.forum_path) + self.assertEqual(res_data["modules"], self.account.modules) + self.assertEqual(res_data["enforce_password_validation"], self.account.enforce_password_validation) + self.assertEqual(res_data["anthropic_api_key"], self.account.anthropic_api_key) + + self.client.force_authenticate(self.jane_doe) + res = self.client.get(reverse("accounts-detail", kwargs={"pk": self.other_account.pk})) + res_data = self.assertJSONResponse(res, status.HTTP_200_OK) + self.assertValidData(res_data) + + self.assertEqual(res_data["id"], self.other_account.id) + self.assertEqual(res_data["name"], self.other_account.name) + self.assertIsNotNone(res_data["created_at"]) + self.assertEqual(res_data["user_manual_path"], self.other_account.user_manual_path) + self.assertEqual(res_data["forum_path"], self.other_account.forum_path) + self.assertEqual(res_data["modules"], self.other_account.modules) + self.assertEqual(res_data["enforce_password_validation"], self.other_account.enforce_password_validation) + self.assertEqual(res_data["anthropic_api_key"], self.other_account.anthropic_api_key) + + def test_num_queries(self): + self.client.force_authenticate(self.jane_doe) + with self.assertNumQueries(1): + res = self.client.get(reverse("accounts-detail", kwargs={"pk": self.other_account.pk})) + res_data = self.assertJSONResponse(res, status.HTTP_200_OK) + self.assertValidData(res_data) diff --git a/iaso/tests/api/account/test_retrieve_current.py b/iaso/tests/api/account/test_retrieve_current.py new file mode 100644 index 0000000000..5b812c344e --- /dev/null +++ b/iaso/tests/api/account/test_retrieve_current.py @@ -0,0 +1,88 @@ +from django.contrib.auth import get_user_model +from django.urls import reverse +from rest_framework import status + +from iaso.models import Account, AccountFeatureFlag, DataSource, ExternalCredentials, SourceVersion, TenantUser +from iaso.modules import MODULE_VALIDATION_WORKFLOW +from iaso.test import APITestCase, SwaggerTestCaseMixin + + +class TestAccountRetrieveCurrent(SwaggerTestCaseMixin, APITestCase): + def setUp(self): + super().setUp() + + # create aff + self.aff = AccountFeatureFlag.objects.create(name="bla", code="bla") + + self.account = Account.objects.create(name="account") + self.other_account = Account.objects.create(name="other account", modules=[MODULE_VALIDATION_WORKFLOW.codename]) + self.other_account.feature_flags.add(self.aff) + self.data_source = DataSource.objects.create( + name="source", + credentials=ExternalCredentials.objects.create( + account=self.account, name="test", password="test", login="test", url="test" + ), + ) + self.source_version = SourceVersion.objects.create(number=1, data_source=self.data_source) + self.other_account.default_version = self.source_version + self.other_account.save() + + self.another_account = Account.objects.create(name="another account") + self.john_wick = self.create_user_with_profile(username="johnwick", account=self.another_account) + + self.jane_doe = self.create_user_with_profile(username="janedoe", account=self.account) + self.john_doe = self.create_user_with_profile(username="johndoe", account=self.other_account) + # multi tenant account + + # Create a main user without profile + main_user = get_user_model().objects.create(username="main_user") + + # And 2 account users with profile + self.account_user_ghi = self.create_user_with_profile(username="User_A", account=self.account) + TenantUser.objects.create(main_user=main_user, account_user=self.account_user_ghi) + TenantUser.objects.create(main_user=main_user, account_user=self.jane_doe) + self.account_user_wha = self.create_user_with_profile(username="User_B", account=self.other_account) + TenantUser.objects.create(main_user=main_user, account_user=self.account_user_wha) + TenantUser.objects.create(main_user=main_user, account_user=self.john_doe) + + def assertValidData(self, data): + self.assertResponseCompliantToSwagger(data, "AccountRetrieveCurrent") + + def test_permissions(self): + res = self.client.get(reverse("accounts-me")) + self.assertEqual(res.status_code, status.HTTP_401_UNAUTHORIZED) + + self.client.force_authenticate(self.jane_doe) + res = self.client.get(reverse("accounts-me")) + self.assertJSONResponse(res, status.HTTP_200_OK) + + def test_retrieve(self): + self.client.force_authenticate(self.jane_doe) + res = self.client.get(reverse("accounts-me")) + res_data = self.assertJSONResponse(res, status.HTTP_200_OK) + self.assertValidData(res_data) + + self.client.force_authenticate(self.john_doe) + res = self.client.get(reverse("accounts-me")) + res_data = self.assertJSONResponse(res, status.HTTP_200_OK) + self.assertValidData(res_data) + + def test_num_queries(self): + self.client.force_authenticate(self.john_doe) + with self.assertNumQueries(2): + res = self.client.get(reverse("accounts-me")) + res_data = self.assertJSONResponse(res, status.HTTP_200_OK) + + self.assertValidData(res_data) + self.assertEqual(res_data["id"], self.other_account.pk) + self.assertEqual(res_data["name"], self.other_account.name) + self.assertEqual( + res_data["default_version"], + { + "id": self.source_version.pk, + "data_source": {"id": self.data_source.pk, "url": self.data_source.credentials.url}, + }, + ) + self.assertEqual(res_data["other_accounts"], [{"id": self.account.pk, "name": self.account.name}]) + self.assertEqual(res_data["modules"], [MODULE_VALIDATION_WORKFLOW.codename]) + self.assertEqual(res_data["feature_flags"], [{"name": self.aff.name, "code": self.aff.code}]) diff --git a/iaso/tests/api/account/test_set_default_version.py b/iaso/tests/api/account/test_set_default_version.py new file mode 100644 index 0000000000..6779511858 --- /dev/null +++ b/iaso/tests/api/account/test_set_default_version.py @@ -0,0 +1,100 @@ +from django.contrib.auth import get_user_model +from django.urls import reverse +from rest_framework import status + +from iaso.models import Account, DataSource, Project, SourceVersion, TenantUser +from iaso.permissions.core_permissions import CORE_SOURCE_PERMISSION +from iaso.test import APITestCase + + +class TestAccountAPISetDefaultVersion(APITestCase): + def setUp(self): + super().setUp() + self.account = Account.objects.create(name="account") + self.other_account = Account.objects.create(name="other account") + + self.jane_doe = self.create_user_with_profile( + username="janedoe", account=self.account, permissions=[CORE_SOURCE_PERMISSION] + ) + self.john_doe = self.create_user_with_profile( + username="johndoe", account=self.other_account, permissions=[CORE_SOURCE_PERMISSION] + ) + self.jim = self.create_user_with_profile(username="jim", account=self.account) + + ghi_project = Project.objects.create(name="ghi_project", account=self.account) + ghi_datasource = DataSource.objects.create() + ghi_datasource.projects.set([ghi_project]) + self.ghi_version = SourceVersion.objects.create(data_source=ghi_datasource, number=1) + + wha_project = Project.objects.create(name="wha_project", account=self.other_account) + wha_datasource = DataSource.objects.create(name="wha datasource") + wha_datasource.projects.set([wha_project]) + self.wha_version = SourceVersion.objects.create(data_source=wha_datasource, number=1) + + # multi tenant account + + # Create a main user without profile + main_user = get_user_model().objects.create(username="main_user") + + # And 2 account users with profile + self.account_user_ghi = self.create_user_with_profile( + username="User_A", account=self.account, permissions=[CORE_SOURCE_PERMISSION] + ) + TenantUser.objects.create(main_user=main_user, account_user=self.account_user_ghi) + TenantUser.objects.create(main_user=main_user, account_user=self.jane_doe) + self.account_user_wha = self.create_user_with_profile( + username="User_B", account=self.other_account, permissions=[CORE_SOURCE_PERMISSION] + ) + TenantUser.objects.create(main_user=main_user, account_user=self.account_user_wha) + + def test_permissions(self): + res = self.client.put(reverse("accounts-set-default-version", kwargs={"pk": self.account.pk})) + self.assertEqual(res.status_code, status.HTTP_401_UNAUTHORIZED) + + self.client.force_authenticate(self.jim) + res = self.client.put(reverse("accounts-set-default-version", kwargs={"pk": self.account.pk})) + self.assertEqual(res.status_code, status.HTTP_403_FORBIDDEN) + + self.client.force_authenticate(self.jane_doe) + res = self.client.put(reverse("accounts-set-default-version", kwargs={"pk": self.account.pk})) + self.assertEqual(res.status_code, status.HTTP_400_BAD_REQUEST) + + res = self.client.put(reverse("accounts-set-default-version", kwargs={"pk": self.other_account.pk})) + self.assertEqual(res.status_code, status.HTTP_400_BAD_REQUEST) + + self.client.force_authenticate(self.john_doe) + res = self.client.put(reverse("accounts-set-default-version", kwargs={"pk": self.other_account.pk})) + self.assertEqual(res.status_code, status.HTTP_400_BAD_REQUEST) + + res = self.client.put(reverse("accounts-set-default-version", kwargs={"pk": self.account.pk})) + self.assertEqual(res.status_code, status.HTTP_404_NOT_FOUND) + + def test_num_queries(self): + self.client.force_authenticate(self.account_user_ghi) + self.assertIsNotNone(self.account_user_ghi.tenant_user) + with self.assertNumQueries(5): + res = self.client.put( + reverse("accounts-set-default-version", kwargs={"pk": self.account.pk}), + {"default_version": self.ghi_version.pk}, + ) + self.assertJSONResponse(res, status.HTTP_200_OK) + + def test_happy_path(self): + self.client.force_authenticate(self.jane_doe) + res = self.client.put( + reverse("accounts-set-default-version", kwargs={"pk": self.account.pk}), + {"default_version": self.ghi_version.pk}, + ) + self.assertJSONResponse(res, status.HTTP_200_OK) + + self.account.refresh_from_db() + self.assertEqual(self.account.default_version.id, self.ghi_version.id) + + def test_cant_assign_source_version_from_different_account(self): + self.client.force_authenticate(self.jane_doe) + res = self.client.put( + reverse("accounts-set-default-version", kwargs={"pk": self.account.pk}), + {"default_version": self.wha_version.pk}, + ) + res_data = self.assertJSONResponse(res, status.HTTP_400_BAD_REQUEST) + self.assertHasError(res_data, "default_version", "Account not allowed to access this default_source.") diff --git a/iaso/tests/api/account/test_switch.py b/iaso/tests/api/account/test_switch.py new file mode 100644 index 0000000000..1a05c89a83 --- /dev/null +++ b/iaso/tests/api/account/test_switch.py @@ -0,0 +1,70 @@ +from django.contrib import auth +from django.contrib.auth import get_user_model +from django.urls import reverse +from rest_framework import status + +from iaso.models import Account, TenantUser +from iaso.test import APITestCase + + +class TestAccountSwitch(APITestCase): + def setUp(self): + super().setUp() + self.account = Account.objects.create(name="account") + self.other_account = Account.objects.create(name="other account") + self.another_account = Account.objects.create(name="another account") + self.john_wick = self.create_user_with_profile(username="johnwick", account=self.another_account) + + self.jane_doe = self.create_user_with_profile(username="janedoe", account=self.account) + self.john_doe = self.create_user_with_profile(username="johndoe", account=self.other_account) + # multi tenant account + + # Create a main user without profile + main_user = get_user_model().objects.create(username="main_user") + + # And 2 account users with profile + self.account_user_ghi = self.create_user_with_profile(username="User_A", account=self.account) + TenantUser.objects.create(main_user=main_user, account_user=self.account_user_ghi) + TenantUser.objects.create(main_user=main_user, account_user=self.jane_doe) + self.account_user_wha = self.create_user_with_profile(username="User_B", account=self.other_account) + TenantUser.objects.create(main_user=main_user, account_user=self.account_user_wha) + TenantUser.objects.create(main_user=main_user, account_user=self.john_doe) + + def test_num_queries(self): + self.client.force_authenticate(self.jane_doe) + with self.assertNumQueries(11): + res = self.client.post(reverse("accounts-switch"), data={"account_id": self.other_account.pk}) + self.assertEqual(res.status_code, status.HTTP_204_NO_CONTENT) + + def test_permissions(self): + res = self.client.post(reverse("accounts-switch")) + self.assertEqual(res.status_code, status.HTTP_401_UNAUTHORIZED) + + self.client.force_authenticate(self.john_doe) + res = self.client.post(reverse("accounts-switch")) + self.assertEqual(res.status_code, status.HTTP_400_BAD_REQUEST) + + def test_cannot_switch_to_another_account_not_linked_to_user(self): + self.client.force_authenticate(self.jane_doe) + res = self.client.post(reverse("accounts-switch"), data={"account_id": self.another_account.pk}) + res_data = self.assertJSONResponse(res, status.HTTP_400_BAD_REQUEST) + self.assertHasError(res_data, "account_id", f'Invalid pk "{self.another_account.pk}" - object does not exist.') + + def test_switch(self): + self.client.force_authenticate(self.jane_doe) + res = self.client.post(reverse("accounts-switch"), data={"account_id": self.other_account.pk}) + + self.assertEqual(res.status_code, status.HTTP_204_NO_CONTENT) + logged_in_user = auth.get_user(self.client) + self.assertEqual(logged_in_user.iaso_profile.account.name, self.other_account.name) + + def test_switch_one_user(self): + TenantUser.objects.create(main_user=self.john_wick, account_user=self.john_wick) + self.client.force_authenticate(self.john_wick) + res = self.client.post(reverse("accounts-switch"), data={"account_id": self.another_account.pk}) + self.assertEqual(res.status_code, status.HTTP_404_NOT_FOUND) + + def test_switch_no_tenant_user(self): + self.client.force_authenticate(self.john_wick) + res = self.client.post(reverse("accounts-switch"), data={"account_id": self.another_account.pk}) + self.assertEqual(res.status_code, status.HTTP_404_NOT_FOUND) diff --git a/iaso/tests/api/account/test_update.py b/iaso/tests/api/account/test_update.py new file mode 100644 index 0000000000..9dab3b50af --- /dev/null +++ b/iaso/tests/api/account/test_update.py @@ -0,0 +1,128 @@ +from django.contrib.auth import get_user_model +from django.urls import reverse +from rest_framework import status + +from iaso.models import Account, AccountFeatureFlag, TenantUser +from iaso.modules import MODULE_VALIDATION_WORKFLOW +from iaso.test import APITestCase, SwaggerTestCaseMixin + + +class TestAccountUpdate(SwaggerTestCaseMixin, APITestCase): + def setUp(self): + super().setUp() + self.account = Account.objects.create(name="account") + self.other_account = Account.objects.create(name="other account") + self.another_account = Account.objects.create(name="another account") + self.john_wick = self.create_user_with_profile(username="johnwick", account=self.another_account) + + self.jane_doe = self.create_user_with_profile(username="janedoe", account=self.account) + self.john_doe = self.create_user_with_profile(username="johndoe", account=self.other_account) + # multi tenant account + + # Create a main user without profile + main_user = get_user_model().objects.create(username="main_user") + + # And 2 account users with profile + self.account_user_ghi = self.create_user_with_profile(username="User_A", account=self.account) + TenantUser.objects.create(main_user=main_user, account_user=self.account_user_ghi) + TenantUser.objects.create(main_user=main_user, account_user=self.jane_doe) + self.account_user_wha = self.create_user_with_profile(username="User_B", account=self.other_account) + TenantUser.objects.create(main_user=main_user, account_user=self.account_user_wha) + TenantUser.objects.create(main_user=main_user, account_user=self.john_doe) + + # create account ff + self.aff = AccountFeatureFlag.objects.create(name="bla", code="bla") + + def assertValidPutBody(self, data): + self.assertResponseCompliantToSwagger(data, "AccountUpdateRequest") + + def assertValidPatchBody(self, data): + self.assertResponseCompliantToSwagger(data, "PatchedAccountUpdateRequest") + + def test_num_queries(self): + self.client.force_authenticate(self.jane_doe) + with self.assertNumQueries(3): + res = self.client.put( + reverse("accounts-detail", kwargs={"pk": self.account.pk}), + data={ + "name": "new account name", + "user_manual_path": "user_manual_path", + "forum_path": "forum_path", + "modules": [MODULE_VALIDATION_WORKFLOW.codename], + "enforce_password_validation": True, + "anthropic_api_key": "1234", + "custom_translations": {"en": "oops"}, + }, + ) + self.assertEqual(res.status_code, status.HTTP_204_NO_CONTENT) + + def test_permissions(self): + res = self.client.put(reverse("accounts-detail", kwargs={"pk": self.account.pk})) + self.assertEqual(res.status_code, status.HTTP_401_UNAUTHORIZED) + + self.client.force_authenticate(user=self.john_doe) + res = self.client.put(reverse("accounts-detail", kwargs={"pk": self.another_account.pk})) + self.assertEqual(res.status_code, status.HTTP_404_NOT_FOUND) + + self.client.force_authenticate(user=self.jane_doe) + res = self.client.put(reverse("accounts-detail", kwargs={"pk": self.account.pk})) + self.assertEqual(res.status_code, status.HTTP_400_BAD_REQUEST) + + res = self.client.put(reverse("accounts-detail", kwargs={"pk": self.other_account.pk})) + self.assertEqual(res.status_code, status.HTTP_400_BAD_REQUEST) + + def test_update(self): + self.client.force_authenticate(self.jane_doe) + self.account.feature_flags.add(self.aff) + self.account.save() + + data = { + "name": "new account name", + "user_manual_path": "user_manual_path", + "forum_path": "forum_path", + "modules": [MODULE_VALIDATION_WORKFLOW.codename], + "enforce_password_validation": False, + "anthropic_api_key": "1234", + "custom_translations": {"en": "oops"}, + } + self.assertValidPutBody(data) + + res = self.client.put(reverse("accounts-detail", kwargs={"pk": self.account.pk}), data=data) + self.assertEqual(res.status_code, status.HTTP_204_NO_CONTENT) + + self.account.refresh_from_db() + + self.assertEqual(self.account.name, "new account name") + self.assertEqual(self.account.user_manual_path, "user_manual_path") + self.assertEqual(self.account.forum_path, "forum_path") + self.assertEqual(self.account.modules, [MODULE_VALIDATION_WORKFLOW.codename]) + self.assertEqual(self.account.anthropic_api_key, "1234") + self.assertEqual(self.account.custom_translations, {"en": "oops"}) + self.assertFalse(self.account.enforce_password_validation) + self.assertTrue(self.account.feature_flags.count()) + + def test_partial_update(self): + self.client.force_authenticate(self.jane_doe) + self.account.feature_flags.add(self.aff) + self.account.save() + + data = { + "user_manual_path": "user_manual_path", + "anthropic_api_key": "1234", + "custom_translations": {"en": "oops"}, + } + self.assertValidPatchBody(data) + + res = self.client.patch(reverse("accounts-detail", kwargs={"pk": self.account.pk}), data=data) + self.assertEqual(res.status_code, status.HTTP_204_NO_CONTENT) + + self.account.refresh_from_db() + + self.assertEqual(self.account.name, "account") + self.assertEqual(self.account.user_manual_path, "user_manual_path") + self.assertIsNone(self.account.forum_path) + self.assertEqual(self.account.modules, []) + self.assertEqual(self.account.anthropic_api_key, "1234") + self.assertEqual(self.account.custom_translations, {"en": "oops"}) + self.assertTrue(self.account.enforce_password_validation) + self.assertTrue(self.account.feature_flags.count()) diff --git a/iaso/tests/api/test_account.py b/iaso/tests/api/test_account.py deleted file mode 100644 index fd6797cb59..0000000000 --- a/iaso/tests/api/test_account.py +++ /dev/null @@ -1,134 +0,0 @@ -from django.contrib import auth - -from iaso import models as m -from iaso.permissions.core_permissions import CORE_SOURCE_PERMISSION -from iaso.test import APITestCase - - -class AccountAPITestCase(APITestCase): - @classmethod - def setUpTestData(cls): - cls.ghi = ghi = m.Account.objects.create(name="Global Health Initiative") - cls.wha = wha = m.Account.objects.create(name="Worldwide Health Aid") - - cls.jane = cls.create_user_with_profile(username="janedoe", account=ghi, permissions=[CORE_SOURCE_PERMISSION]) - cls.john = cls.create_user_with_profile(username="johndoe", account=wha, permissions=[CORE_SOURCE_PERMISSION]) - cls.jim = cls.create_user_with_profile(username="jimdoe", account=ghi) - - ghi_project = m.Project.objects.create(name="ghi_project", account=ghi) - ghi_datasource = m.DataSource.objects.create() - ghi_datasource.projects.set([ghi_project]) - cls.ghi_version = m.SourceVersion.objects.create(data_source=ghi_datasource, number=1) - - wha_project = m.Project.objects.create(name="wha_project", account=wha) - wha_second_project = m.Project.objects.create(name="wha_second_project", account=wha) - wha_datasource = m.DataSource.objects.create(name="wha datasource") - wha_datasource.projects.set([wha_project]) - cls.wha_version = m.SourceVersion.objects.create(data_source=wha_datasource, number=1) - - def test_account_list_without_auth(self): - """GET /account/ without auth should result in a 403 (before the method not authorized?)""" - self.client.force_authenticate(self.jim) - - response = self.client.get("/api/accounts/") - self.assertJSONResponse(response, 403) - - def test_account_list_with_auth(self): - """GET /account/ with auth should result in a 405 as method is not allowed""" - self.client.force_authenticate(self.jane) - - response = self.client.get("/api/accounts/") - self.assertJSONResponse(response, 405) - - def test_account_delete_forbidden(self): - """DELETE /account/ with auth should result in a 405 as method is not allowed""" - self.client.force_authenticate(self.jane) - - response = self.client.delete("/api/accounts/") - self.assertJSONResponse(response, 405) - - def test_account_post_forbidden(self): - """POST /account/ with auth should result in a 405 as method is not allowed""" - self.client.force_authenticate(self.jane) - - response = self.client.post("/api/accounts/", {"default_version": self.ghi_version.pk}) - self.assertJSONResponse(response, 405) - - def test_account_detail_forbidden(self): - """POST /account/ with auth should result in a 405 as method is not allowed""" - self.client.force_authenticate(self.jane) - - response = self.client.get(f"/api/accounts/{self.ghi.pk}/") - self.assertJSONResponse(response, 405) - - def test_account_set_default_ok(self): - """Set a version with a user that has correct perm""" - - self.client.force_authenticate(self.jane) - response = self.client.put(f"/api/accounts/{self.ghi.pk}/", {"default_version": self.ghi_version.pk}) - j = self.assertJSONResponse(response, 200) - self.assertEqual(j, {"id": self.ghi.pk, "default_version": self.ghi_version.pk}) - - self.ghi.refresh_from_db() - self.assertEqual(self.ghi.default_version.id, self.ghi_version.id) - - def test_account_set_default_fail_wrong_account(self): - """User try to set default on an account he doesn't belong too""" - - self.client.force_authenticate(self.jane) - response = self.client.put(f"/api/accounts/{self.wha.pk}/", {"default_version": self.ghi_version.pk}) - j = self.assertJSONResponse(response, 403) - self.assertEqual(j, {"detail": "You do not have permission to perform this action."}) - - # old default version - old_version = self.wha.default_version - self.wha.refresh_from_db() - self.assertEqual(self.wha.default_version, old_version) - - # invert on the other account/user to be sure - self.client.force_authenticate(self.john) - response = self.client.put(f"/api/accounts/{self.ghi.pk}/", {"default_version": self.ghi_version.pk}) - j = self.assertJSONResponse(response, 403) - self.assertEqual(j, {"detail": "You do not have permission to perform this action."}) - - # old default version - old_version = self.ghi.default_version - self.ghi.refresh_from_db() - self.assertEqual(self.ghi.default_version, old_version) - - def test_account_set_default_no_perm(self): - """User without source perm cannot modify the default version""" - # invert on the other account/user to be sure - self.client.force_authenticate(self.jim) - response = self.client.put(f"/api/accounts/{self.ghi.pk}/", {"default_version": self.ghi_version.pk}) - j = self.assertJSONResponse(response, 403) - self.assertEqual(j, {"detail": "You do not have permission to perform this action."}) - - # old default version - old_version = self.ghi.default_version - self.ghi.refresh_from_db() - self.assertEqual(self.ghi.default_version, old_version) - - def test_cant_assign_source_version_from_different_account(self): - self.client.force_authenticate(self.jane) - response = self.client.put(f"/api/accounts/{self.ghi.pk}/", {"default_version": self.wha_version.pk}) - j = self.assertJSONResponse(response, 400) - self.assertEqual(j, {"Error": "Account not allowed to access this default_source"}) - - def test_switch_account(self): - # Create a main user without profile - main_user = m.User.objects.create(username="main_user") - main_user.save() - - # And 2 account users with profile - account_user_ghi = self.create_user_with_profile(username="User_A", account=self.ghi) - m.TenantUser.objects.create(main_user=main_user, account_user=account_user_ghi) - account_user_wha = self.create_user_with_profile(username="User_B", account=self.wha) - m.TenantUser.objects.create(main_user=main_user, account_user=account_user_wha) - - self.client.force_authenticate(account_user_ghi) - response = self.client.patch("/api/accounts/switch/", {"account_id": self.wha.pk}) - - self.assertJSONResponse(response, 200) - logged_in_user = auth.get_user(self.client) - self.assertEqual(logged_in_user.iaso_profile.account.name, "Worldwide Health Aid") diff --git a/iaso/tests/api/test_custom_translations.py b/iaso/tests/api/test_custom_translations.py deleted file mode 100644 index 35643f70ea..0000000000 --- a/iaso/tests/api/test_custom_translations.py +++ /dev/null @@ -1,42 +0,0 @@ -from iaso import models as m -from iaso.test import APITestCase - - -class CustomTranslationsAPITestCase(APITestCase): - @classmethod - def setUpTestData(cls): - cls.ghi = m.Account.objects.create( - name="Global Health Initiative", - custom_translations={"en": {"custom.key": "Custom value"}}, - ) - cls.wha = m.Account.objects.create(name="Worldwide Health Aid") - - cls.jane = cls.create_user_with_profile(username="janedoe", account=cls.ghi) - cls.john = cls.create_user_with_profile(username="johndoe", account=cls.wha) - - def test_custom_translations_requires_authentication(self): - response = self.client.get(f"/api/custom_translations/?account_id={self.ghi.pk}") - self.assertJSONResponse(response, 401) - - def test_custom_translations_requires_account_id(self): - self.client.force_authenticate(self.jane) - response = self.client.get("/api/custom_translations/") - data = self.assertJSONResponse(response, 400) - self.assertEqual(data, {"account_id": ["Account id is required."]}) - - def test_custom_translations_unknown_account_id_returns_404(self): - self.client.force_authenticate(self.jane) - response = self.client.get("/api/custom_translations/?account_id=999999") - self.assertJSONResponse(response, 404) - - def test_custom_translations_forbidden_for_other_account(self): - self.client.force_authenticate(self.jane) - response = self.client.get(f"/api/custom_translations/?account_id={self.wha.pk}") - data = self.assertJSONResponse(response, 403) - self.assertEqual(data, {"detail": "You do not have permission to perform this action."}) - - def test_custom_translations_success(self): - self.client.force_authenticate(self.jane) - response = self.client.get(f"/api/custom_translations/?account_id={self.ghi.pk}") - data = self.assertJSONResponse(response, 200) - self.assertEqual(data, {"custom_translations": {"en": {"custom.key": "Custom value"}}}) diff --git a/iaso/urls.py b/iaso/urls.py index 907875224e..f29c688b9d 100644 --- a/iaso/urls.py +++ b/iaso/urls.py @@ -21,7 +21,7 @@ from iaso.api.validation_workflows.views_mobile import ValidationWorkflowMobileViewSet from plugins.router import router as plugins_router -from .api.accounts import AccountViewSet +from .api.accounts.views import AccountViewSet from .api.algorithms import AlgorithmsViewSet from .api.algorithms_runs import AlgorithmsRunsViewSet from .api.api_import.views import APIImportViewSet @@ -33,7 +33,6 @@ from .api.comment import CommentViewSet from .api.completeness import CompletenessViewSet from .api.completeness_stats import CompletenessStatsV2ViewSet -from .api.custom_translations import CustomTranslationsViewSet from .api.data_source_versions_synchronization.views import DataSourceVersionsSynchronizationViewSet from .api.data_sources import DataSourceViewSet from .api.deduplication.entity_duplicate import EntityDuplicateViewSet # type: ignore @@ -186,7 +185,6 @@ router.register(r"datasources/sync", DataSourceVersionsSynchronizationViewSet, basename="datasources_synchronization") router.register(r"datasources", DataSourceViewSet, basename="datasources") router.register(r"accounts", AccountViewSet, basename="accounts") -router.register(r"custom_translations", CustomTranslationsViewSet, basename="custom_translations") router.register(r"apitoken", APITokenViewSet, basename="apitoken") router.register(r"sourceversions", SourceVersionViewSet, basename="sourceversion") router.register(r"links", LinkViewSet, basename="links") From 26630886132e2fcea137d5faeba037391dbd88b2 Mon Sep 17 00:00:00 2001 From: hugo093 Date: Wed, 27 May 2026 16:39:25 +0200 Subject: [PATCH 2/2] fix: tests --- iaso/tests/api/account/test_set_default_version.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/iaso/tests/api/account/test_set_default_version.py b/iaso/tests/api/account/test_set_default_version.py index 6779511858..93f2142f5f 100644 --- a/iaso/tests/api/account/test_set_default_version.py +++ b/iaso/tests/api/account/test_set_default_version.py @@ -77,7 +77,7 @@ def test_num_queries(self): reverse("accounts-set-default-version", kwargs={"pk": self.account.pk}), {"default_version": self.ghi_version.pk}, ) - self.assertJSONResponse(res, status.HTTP_200_OK) + self.assertJSONResponse(res, status.HTTP_204_NO_CONTENT) def test_happy_path(self): self.client.force_authenticate(self.jane_doe) @@ -85,7 +85,7 @@ def test_happy_path(self): reverse("accounts-set-default-version", kwargs={"pk": self.account.pk}), {"default_version": self.ghi_version.pk}, ) - self.assertJSONResponse(res, status.HTTP_200_OK) + self.assertJSONResponse(res, status.HTTP_204_NO_CONTENT) self.account.refresh_from_db() self.assertEqual(self.account.default_version.id, self.ghi_version.id)