diff --git a/auth/session_guard.go b/auth/session_guard.go index 68b5d46f9..fe9169750 100644 --- a/auth/session_guard.go +++ b/auth/session_guard.go @@ -77,16 +77,17 @@ func (r *SessionGuard) LoginUsingID(id any) (token string, err error) { return "", errors.AuthInvalidKey } + if err := r.session.Regenerate(true); err != nil { + return "", err + } + r.session.Put(sessionName, key) return "", nil } func (r *SessionGuard) Logout() error { - sessionName := r.getSessionName() - r.session.Forget(sessionName) - - return nil + return r.session.Invalidate() } func (r *SessionGuard) Parse(token string) (*contractsauth.Payload, error) { diff --git a/auth/session_guard_test.go b/auth/session_guard_test.go index 5af5b13dd..3ad339e0d 100644 --- a/auth/session_guard_test.go +++ b/auth/session_guard_test.go @@ -107,6 +107,7 @@ func (s *SessionGuardTestSuite) TestCheck_LoginUsingID_Logout() { s.False(s.sessionGuard.Check()) s.True(s.sessionGuard.Guest()) + s.mockSession.EXPECT().Regenerate(true).Return(nil).Once() s.mockSession.EXPECT().Put("auth_user_id", "1").Return(nil).Once() token, err := s.sessionGuard.LoginUsingID(1) s.Nil(err) @@ -116,7 +117,7 @@ func (s *SessionGuardTestSuite) TestCheck_LoginUsingID_Logout() { s.True(s.sessionGuard.Check()) s.False(s.sessionGuard.Guest()) - s.mockSession.EXPECT().Forget("auth_user_id").Return(nil).Once() + s.mockSession.EXPECT().Invalidate().Return(nil).Once() s.NoError(s.sessionGuard.Logout()) s.mockSession.EXPECT().Get("auth_user_id", nil).Return(nil).Once() @@ -133,6 +134,7 @@ func (s *SessionGuardTestSuite) Test_Login() { user.Name = "Goravel" s.mockUserProvider.EXPECT().GetID(&user).Return("2", nil).Once() + s.mockSession.EXPECT().Regenerate(true).Return(nil).Once() s.mockSession.EXPECT().Put("auth_user_id", "2").Return(nil).Once() token, err := s.sessionGuard.Login(&user) s.Nil(err) @@ -142,7 +144,7 @@ func (s *SessionGuardTestSuite) Test_Login() { s.True(s.sessionGuard.Check()) s.False(s.sessionGuard.Guest()) - s.mockSession.EXPECT().Forget("auth_user_id").Return(nil).Once() + s.mockSession.EXPECT().Invalidate().Return(nil).Once() s.NoError(s.sessionGuard.Logout()) s.mockSession.EXPECT().Get("auth_user_id", nil).Return(nil).Once() @@ -167,7 +169,7 @@ func (s *SessionGuardTestSuite) Test_LoginFailed() { s.False(s.sessionGuard.Check()) s.True(s.sessionGuard.Guest()) - s.mockSession.EXPECT().Forget("auth_user_id").Return(nil).Once() + s.mockSession.EXPECT().Invalidate().Return(nil).Once() s.NoError(s.sessionGuard.Logout()) s.mockSession.EXPECT().Get("auth_user_id", nil).Return(nil).Once() @@ -217,3 +219,17 @@ func (s *SessionGuardTestSuite) Test_InvalidKey() { s.NotNil(err) s.ErrorIs(err, errors.AuthInvalidKey) } + +func (s *SessionGuardTestSuite) Test_LoginUsingID_RegenerateError() { + s.mockSession.EXPECT().Regenerate(true).Return(assert.AnError).Once() + + token, err := s.sessionGuard.LoginUsingID(1) + s.Empty(token) + s.ErrorIs(err, assert.AnError) +} + +func (s *SessionGuardTestSuite) Test_Logout_InvalidateError() { + s.mockSession.EXPECT().Invalidate().Return(assert.AnError).Once() + + s.ErrorIs(s.sessionGuard.Logout(), assert.AnError) +} diff --git a/session/cookie.go b/session/cookie.go new file mode 100644 index 000000000..08029bc77 --- /dev/null +++ b/session/cookie.go @@ -0,0 +1,43 @@ +package session + +import ( + "github.com/goravel/framework/contracts/http" + "github.com/goravel/framework/support/carbon" +) + +// WriteCookie emits the current session ID as a Set-Cookie header on the +// response, using the standard session.* config keys for cookie +// attributes. The StartSession middleware uses it to issue the session +// cookie, and application code can call it after rotating the session +// ID (for example, after auth.Login or auth.Logout) so the rotated ID +// reaches the client. Safe to call with a partially initialised +// context — returns without effect if the context, response, session, +// or config facade is nil. +func WriteCookie(ctx http.Context) { + if ctx == nil || ConfigFacade == nil { + return + } + req := ctx.Request() + if req == nil { + return + } + s := req.Session() + if s == nil { + return + } + resp := ctx.Response() + if resp == nil { + return + } + + resp.Cookie(http.Cookie{ + Name: s.GetName(), + Value: s.GetID(), + Expires: carbon.Now().AddMinutes(ConfigFacade.GetInt("session.lifetime", 120)).StdTime(), + Path: ConfigFacade.GetString("session.path"), + Domain: ConfigFacade.GetString("session.domain"), + Secure: ConfigFacade.GetBool("session.secure"), + HttpOnly: ConfigFacade.GetBool("session.http_only"), + SameSite: ConfigFacade.GetString("session.same_site"), + }) +} diff --git a/session/middleware/start_session.go b/session/middleware/start_session.go index a5b28c33e..64168547a 100644 --- a/session/middleware/start_session.go +++ b/session/middleware/start_session.go @@ -3,7 +3,6 @@ package middleware import ( "github.com/goravel/framework/contracts/http" "github.com/goravel/framework/session" - "github.com/goravel/framework/support/carbon" "github.com/goravel/framework/support/color" ) @@ -40,17 +39,7 @@ func StartSession() http.Middleware { req.SetSession(s) // Set session cookie in response - config := session.ConfigFacade - ctx.Response().Cookie(http.Cookie{ - Name: s.GetName(), - Value: s.GetID(), - Expires: carbon.Now().AddMinutes(config.GetInt("session.lifetime", 120)).StdTime(), - Path: config.GetString("session.path"), - Domain: config.GetString("session.domain"), - Secure: config.GetBool("session.secure"), - HttpOnly: config.GetBool("session.http_only"), - SameSite: config.GetString("session.same_site"), - }) + session.WriteCookie(ctx) // Continue processing request req.Next() diff --git a/session/middleware/start_session_test.go b/session/middleware/start_session_test.go index 86efb9f08..fea8bcb36 100644 --- a/session/middleware/start_session_test.go +++ b/session/middleware/start_session_test.go @@ -28,13 +28,13 @@ func testHttpSessionMiddleware(next nethttp.Handler, mockConfig *configmocks.Con } func mockConfigFacade(mockConfig *configmocks.Config) { - mockConfig.On("GetString", "session.default").Return("file").Once() - mockConfig.On("GetInt", "session.lifetime", 120).Return(120).Once() - mockConfig.On("GetString", "session.path").Return("/").Once() - mockConfig.On("GetString", "session.domain").Return("").Once() - mockConfig.On("GetBool", "session.secure").Return(false).Once() - mockConfig.On("GetBool", "session.http_only").Return(true).Once() - mockConfig.On("GetString", "session.same_site").Return("").Once() + mockConfig.EXPECT().GetString("session.default").Return("file").Once() + mockConfig.EXPECT().GetInt("session.lifetime", 120).Return(120).Once() + mockConfig.EXPECT().GetString("session.path").Return("/").Once() + mockConfig.EXPECT().GetString("session.domain").Return("").Once() + mockConfig.EXPECT().GetBool("session.secure").Return(false).Once() + mockConfig.EXPECT().GetBool("session.http_only").Return(true).Once() + mockConfig.EXPECT().GetString("session.same_site").Return("").Once() } func TestStartSession(t *testing.T) {