Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions social_core/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from collections.abc import Callable

from .backends.base import BaseAuth
from .storage import BaseStorage, UserProtocol
from .storage import UserProtocol
from .strategy import HttpResponseProtocol


Expand Down Expand Up @@ -80,7 +80,7 @@ def do_complete( # noqa: C901,PLR0912

# check if the output value is something else than a user and just
# return it to the client
user_model = cast("type[BaseStorage]", backend.strategy.storage).user.user_model()
user_model = backend.strategy.storage.user.user_model()
if authenticated_user and not isinstance(authenticated_user, user_model):
return cast("HttpResponseProtocol", authenticated_user)

Expand Down
4 changes: 2 additions & 2 deletions social_core/backends/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from requests import Response
from requests.auth import AuthBase

from social_core.storage import BaseStorage, PartialMixin, UserProtocol
from social_core.storage import PartialMixin, UserProtocol
from social_core.strategy import BaseStrategy, HttpResponseProtocol


Expand Down Expand Up @@ -122,7 +122,7 @@ def pipeline(
def disconnect(self, *args, **kwargs) -> dict:
pipeline = self.strategy.get_disconnect_pipeline(self)
kwargs["name"] = self.name
kwargs["user_storage"] = cast("type[BaseStorage]", self.strategy.storage).user
kwargs["user_storage"] = self.strategy.storage.user
return self.run_pipeline(pipeline, *args, **kwargs)

def run_pipeline(
Expand Down
14 changes: 3 additions & 11 deletions social_core/backends/discourse.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,13 @@
import time
from base64 import urlsafe_b64decode, urlsafe_b64encode
from hashlib import sha256
from typing import TYPE_CHECKING, cast
from urllib.parse import urlencode

from social_core.exceptions import AuthException, AuthTokenError
from social_core.utils import parse_qs

from .base import BaseAuth

if TYPE_CHECKING:
from social_core.storage import BaseStorage


class DiscourseAuth(BaseAuth):
name = "discourse"
Expand Down Expand Up @@ -55,17 +51,13 @@ def get_user_details(self, response):
}

def add_nonce(self, nonce) -> None:
cast("type[BaseStorage]", self.strategy.storage).nonce.use(
self.setting("SERVER_URL"), time.time(), nonce
)
self.strategy.storage.nonce.use(self.setting("SERVER_URL"), time.time(), nonce)

def get_nonce(self, nonce):
return cast("type[BaseStorage]", self.strategy.storage).nonce.get(
self.setting("SERVER_URL"), nonce
)
return self.strategy.storage.nonce.get(self.setting("SERVER_URL"), nonce)

def delete_nonce(self, nonce) -> None:
cast("type[BaseStorage]", self.strategy.storage).nonce.delete(nonce)
self.strategy.storage.nonce.delete(nonce)

def auth_complete(self, *args, **kwargs):
"""
Expand Down
9 changes: 3 additions & 6 deletions social_core/backends/open_id_connect.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@

from requests.auth import AuthBase

from social_core.storage import BaseStorage
from social_core.strategy import BaseStrategy


Expand Down Expand Up @@ -235,21 +234,19 @@ def get_and_store_nonce(self, url, state):
nonce = self.strategy.random_string(64)
# Store the nonce
association = OpenIdConnectAssociation(nonce, assoc_type=state)
cast("type[BaseStorage]", self.strategy.storage).association.store(
url, association
)
self.strategy.storage.association.store(url, association)
return nonce

def get_nonce(self, nonce):
try:
return cast("type[BaseStorage]", self.strategy.storage).association.get(
return self.strategy.storage.association.get(
server_url=self.authorization_url(), handle=nonce
)[0]
except IndexError:
return None

def remove_nonce(self, nonce_id) -> None:
cast("type[BaseStorage]", self.strategy.storage).association.remove([nonce_id])
self.strategy.storage.association.remove([nonce_id])

def validate_claims(self, id_token) -> None:
utc_timestamp = timegm(datetime.datetime.now(datetime.timezone.utc).timetuple())
Expand Down
4 changes: 4 additions & 0 deletions social_core/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ class SocialAuthBaseException(ValueError):
"""Base class for pipeline exceptions."""


class SocialAuthImproperlyConfiguredError(SocialAuthBaseException):
"""Raised when configuration is invalid."""


class StrategyMissingFeatureError(SocialAuthBaseException):
"""Strategy does not support this."""

Expand Down
30 changes: 11 additions & 19 deletions social_core/pipeline/social_auth.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from __future__ import annotations

from typing import TYPE_CHECKING, cast
from typing import TYPE_CHECKING

from social_core.exceptions import AuthAlreadyAssociated, AuthException, AuthForbidden

if TYPE_CHECKING:
from social_core.backends.base import BaseAuth
from social_core.storage import BaseStorage, UserProtocol
from social_core.storage import UserProtocol


def social_details(backend: BaseAuth, details, response, *args, **kwargs):
Expand All @@ -26,9 +26,7 @@ def social_user(
backend: BaseAuth, uid, user: UserProtocol | None = None, *args, **kwargs
):
provider = backend.name
social = cast("type[BaseStorage]", backend.strategy.storage).user.get_social_auth(
provider, uid
)
social = backend.strategy.storage.user.get_social_auth(provider, uid)
if social:
if user and social.user != user:
raise AuthAlreadyAssociated(backend)
Expand All @@ -52,14 +50,12 @@ def associate_user(
):
if user and not social:
try:
social = cast(
"type[BaseStorage]", backend.strategy.storage
).user.create_social_auth(user, uid, backend.name)
social = backend.strategy.storage.user.create_social_auth(
user, uid, backend.name
)
# pylint: disable-next=broad-exception-caught
except Exception as err:
if not cast(
"type[BaseStorage]", backend.strategy.storage
).is_integrity_error(err):
if not backend.strategy.storage.is_integrity_error(err):
raise
# Protect for possible race condition, those bastard with FTL
# clicking capabilities, check issue #131:
Expand Down Expand Up @@ -95,11 +91,7 @@ def associate_by_email(
# Try to associate accounts registered with the same email address,
# only if it's a single object. AuthException is raised if multiple
# objects are returned.
users = list(
cast("type[BaseStorage]", backend.strategy.storage).user.get_users_by_email(
email
)
)
users = list(backend.strategy.storage.user.get_users_by_email(email))
if len(users) == 0:
return None
if len(users) > 1:
Expand All @@ -119,9 +111,9 @@ def load_extra_data(
*args,
**kwargs,
) -> None:
social = kwargs.get("social") or cast(
"type[BaseStorage]", backend.strategy.storage
).user.get_social_auth(backend.name, uid)
social = kwargs.get("social") or backend.strategy.storage.user.get_social_auth(
backend.name, uid
)
if social:
extra_data = backend.extra_data(user, uid, response, details, kwargs)
social.set_extra_data(extra_data)
4 changes: 2 additions & 2 deletions social_core/pipeline/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

if TYPE_CHECKING:
from social_core.backends.base import BaseAuth
from social_core.storage import BaseStorage, UserProtocol
from social_core.storage import UserProtocol
from social_core.strategy import BaseStrategy

USER_FIELDS = ["username", "email"]
Expand Down Expand Up @@ -173,4 +173,4 @@ def user_details(
setattr(user, name, value)

if changed:
cast("type[BaseStorage]", strategy.storage).user.changed(user)
strategy.storage.user.changed(user)
4 changes: 1 addition & 3 deletions social_core/pipeline/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,7 @@ def partial_load(strategy: BaseStrategy, token: str) -> PartialMixin | None:
).user.get_social_auth(**social) # type: ignore[missing-argument]

if user:
kwargs["user"] = cast("type[BaseStorage]", strategy.storage).user.get_user(
user
)
kwargs["user"] = strategy.storage.user.get_user(user)

partial.args = [strategy.from_session_value(val) for val in args]
partial.kwargs = {
Expand Down
30 changes: 16 additions & 14 deletions social_core/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@
from typing import TYPE_CHECKING, Any, Protocol, cast

from .backends.utils import get_backend
from .exceptions import StrategyMissingBackendError, StrategyMissingFeatureError
from .exceptions import (
SocialAuthImproperlyConfiguredError,
StrategyMissingBackendError,
StrategyMissingFeatureError,
)
from .pipeline import DEFAULT_AUTH_PIPELINE, DEFAULT_DISCONNECT_PIPELINE
from .pipeline.utils import partial_load
from .store import OpenIdSessionWrapper, OpenIdStore
Expand Down Expand Up @@ -53,9 +57,15 @@ def __init__(
storage: type[BaseStorage] | None = None,
tpl: type[BaseTemplateStrategy] | None = None,
) -> None:
self.storage = storage
self._storage = storage
self.tpl = (tpl or self.DEFAULT_TEMPLATE_STRATEGY)(self)

@property
def storage(self) -> type[BaseStorage]:
if self._storage is None:
raise StrategyMissingBackendError
return self._storage

def setting(self, name: str, default=None, backend: BaseAuth | None = None):
names = [setting_name(name), name]
if backend:
Expand All @@ -68,13 +78,9 @@ def setting(self, name: str, default=None, backend: BaseAuth | None = None):
return default

def create_user(self, *args, **kwargs):
if self.storage is None:
raise StrategyMissingBackendError
return self.storage.user.create_user(*args, **kwargs)

def get_user(self, *args, **kwargs):
if self.storage is None:
raise StrategyMissingBackendError
return self.storage.user.get_user(*args, **kwargs)

def session_setdefault(self, name: str, value):
Expand Down Expand Up @@ -121,8 +127,6 @@ def partial_load(self, token: str) -> PartialMixin | None:
return partial_load(self, token)

def clean_partial_pipeline(self, token) -> None:
if self.storage is None:
raise StrategyMissingBackendError
self.storage.partial.destroy(token)
current_token_in_session = self.session_get(PARTIAL_TOKEN_SESSION_NAME)
if current_token_in_session == token:
Expand Down Expand Up @@ -158,17 +162,17 @@ def get_language(self) -> str:
def send_email_validation(
self, backend: BaseAuth, email: str, partial_token: str | None = None
) -> CodeMixin:
if self.storage is None:
raise StrategyMissingBackendError
email_validation = self.setting("EMAIL_VALIDATION_FUNCTION")
if not email_validation:
raise SocialAuthImproperlyConfiguredError(
"EMAIL_VALIDATION_FUNCTION missing"
)
send_email = module_member(email_validation)
code = self.storage.code.make_code(email)
send_email(self, backend, code, partial_token)
return code

def validate_email(self, email: str, code: str) -> bool:
if self.storage is None:
raise StrategyMissingBackendError
verification_code = self.storage.code.get_code(code)
if not verification_code or verification_code.code != code:
return False
Expand All @@ -193,8 +197,6 @@ def authenticate(
) -> UserProtocol | HttpResponseProtocol | None:
"""Trigger the authentication mechanism tied to the current
framework"""
if self.storage is None:
raise StrategyMissingBackendError
kwargs["strategy"] = self
kwargs["storage"] = self.storage
kwargs["backend"] = backend
Expand Down
2 changes: 0 additions & 2 deletions social_core/tests/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

if TYPE_CHECKING:
from social_core.backends.base import BaseAuth
from social_core.storage import BaseStorage

TEST_URI = "http://myapp.com"
TEST_HOST = "myapp.com"
Expand All @@ -30,7 +29,6 @@ def render_string(self, html, context):

class TestStrategy(BaseStrategy):
__test__ = False
storage: type[BaseStorage]

DEFAULT_TEMPLATE_STRATEGY = TestTemplateStrategy

Expand Down
13 changes: 7 additions & 6 deletions social_core/tests/test_strategy_none_storage.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import unittest

from social_core.backends.base import BaseAuth
from social_core.exceptions import StrategyMissingBackendError
from social_core.exceptions import (
SocialAuthImproperlyConfiguredError,
StrategyMissingBackendError,
)

from .strategy import TestStrategy

Expand All @@ -15,7 +18,8 @@ def setUp(self) -> None:

def test_strategy_initialization_with_none(self) -> None:
"""Test that strategy can be initialized with None storage"""
self.assertIsNone(self.strategy.storage)
with self.assertRaises(StrategyMissingBackendError):
self.assertIsNone(self.strategy.storage)

def test_create_user_raises_error(self) -> None:
"""Test that create_user raises StrategyMissingBackendError with None storage"""
Expand Down Expand Up @@ -44,11 +48,8 @@ def test_clean_partial_pipeline_raises_error(self) -> None:
def test_send_email_validation_raises_error(self) -> None:
"""Test that send_email_validation raises StrategyMissingBackendError with None storage"""
backend = BaseAuth(self.strategy)
with self.assertRaises(StrategyMissingBackendError) as cm:
with self.assertRaises(SocialAuthImproperlyConfiguredError):
self.strategy.send_email_validation(backend, "test@example.com")
self.assertEqual(
str(cm.exception), "Strategy storage backend is not configured"
)

def test_validate_email_raises_error(self) -> None:
"""Test that validate_email raises StrategyMissingBackendError with None storage"""
Expand Down
Loading