1010from flask import current_app , abort , session
1111
1212from . import OIDC
13- from .db import db , Session
13+ from .db import db , Revoked
1414from .oidc import Redeemed
1515
1616logger = logging .getLogger (__name__ )
@@ -26,21 +26,18 @@ def valid_session(oidc: OIDC) -> bool:
2626 if not (sid := session .get ("sid" )):
2727 return not expired
2828
29- # Session not in database
30- if (sinfo := db .session .get (Session , sid )) is None :
29+ if db .session .get (Revoked , sid ) is not None :
3130 return False
32- if not expired :
33- return True
3431
35- # Session is expired and cannot be refreshed in any way.
36- if expired and ( not sinfo . refresh_token or (
37- sinfo . refresh_token_expiration is not None and
38- sinfo . refresh_token_expiration < datetime .now ()
32+ if expired and (
33+ not session . get ( " refresh_token" ) or (
34+ session . get ( " refresh_token_expiration" ) is not None and
35+ session . get ( " refresh_token_expiration" ) < datetime .now ()
3936 )):
4037 return False
4138
4239 current_app .logger .info ("Refreshing session (sid=%s)" % sid )
43- redeemed , e = oidc .redeem_refresh (sinfo . refresh_token )
40+ redeemed , e = oidc .redeem_refresh (session . get ( " refresh_token" ) )
4441 if e is not None :
4542 current_app .logger .warning (
4643 "Refreshing failed (sid=%s)" % sid ,
@@ -51,19 +48,17 @@ def valid_session(oidc: OIDC) -> bool:
5148
5249 current_app .logger .info ("Refreshing successful (sid=%s)" % sid )
5350
54- sinfo . id_token = redeemed .id_token
55- sinfo . refresh_token = redeemed .refresh_token
51+ session [ " id_token" ] = redeemed .id_token
52+ session [ " refresh_token" ] = redeemed .refresh_token
5653 if redeemed .expires_in :
57- sinfo . refresh_token_expiration = (
54+ session [ " refresh_token_expiration" ] = (
5855 datetime .now () +
5956 timedelta (seconds = redeemed .refresh_token_expires_in )
6057 )
6158 else :
62- sinfo . refresh_token_expiration = None
59+ session [ " refresh_token_expiration" ] = None
6360
6461 update_session (redeemed .id_token , redeemed .claims , redeemed .profile )
65- db .session .add (sinfo )
66- db .session .commit ()
6762
6863 return True
6964
@@ -72,35 +67,20 @@ def destroy_session(sub=None, sid=None):
7267 session .clear ()
7368
7469 if sid :
75- if (sess := db .session .get (Session , sid )) is not None :
76- db .session .delete (sess )
77- elif sub :
78- for sess in db .session .execute (
79- db .select (Session ).filter_by (sub = sub )
80- ).scalars () or []:
81- db .session .delete (sess )
82-
83- db .session .commit ()
70+ db .session .add (Revoked (sid = sid ))
71+ db .session .commit ()
8472
8573
8674def new_session (redeemed : Redeemed ):
87- session .clear ()
8875 session ["state" ] = secrets .token_urlsafe (16 )
89-
90- update_session (redeemed .id_token , redeemed .claims , redeemed .profile )
91- if not (sid := redeemed .claims .get ("sid" )):
92- return
93- sess = db .session .get (Session , sid ) or Session (sid = sid )
94- sess .sub = redeemed .claims .get ("sub" )
95- sess .id_token = redeemed .id_token
96- sess .refresh_token = redeemed .refresh_token
76+ session ["id_token" ] = redeemed .id_token
77+ session ["refresh_token" ] = redeemed .refresh_token
78+ session ["sid" ] = redeemed .claims .get ("sid" )
9779 if redeemed .refresh_token_expires_in :
98- sess .refresh_token_expiration = datetime .now () + \
99- timedelta (
100- seconds = redeemed .refresh_token_expires_in )
101-
102- db .session .add (sess )
103- db .session .commit ()
80+ session ["refresh_token_expiration" ] \
81+ = (datetime .now () +
82+ timedelta (seconds = redeemed .refresh_token_expires_in ))
83+ update_session (redeemed .id_token , redeemed .claims , redeemed .profile )
10484
10585
10686def update_session (id_token , claims , profile ):
@@ -110,7 +90,7 @@ def update_session(id_token, claims, profile):
11090 session ["token" ] = id_token
11191
11292 required = [
113- "sub" , app .config ["OIDC_CLAIM_EMAIL" ],
93+ "exp" , "iss" , " sub" , app .config ["OIDC_CLAIM_EMAIL" ],
11494 app .config ["OIDC_CLAIM_USERNAME" ],
11595 ]
11696 if app .config .get ("OIDC_CLAIM_VERIFIED" ):
@@ -125,11 +105,10 @@ def update_session(id_token, claims, profile):
125105 "Missing claim \" %s\" . Contact your administrator."
126106 % claim )
127107
128- # Always present (more or less)
129- for claim in ["exp" , "iss " ]:
108+ # Always present
109+ for claim in ["exp" , "sub " ]:
130110 session [claim ] = claims [claim ]
131111
132- session ["sub" ] = claims ["sub" ]
133112 if claims .get ("sid" ):
134113 session ["sid" ] = claims ["sid" ]
135114
0 commit comments