diff --git a/internal/bootstrap/app_bootstrap.go b/internal/bootstrap/app_bootstrap.go index fc86a7ab..cb006e83 100644 --- a/internal/bootstrap/app_bootstrap.go +++ b/internal/bootstrap/app_bootstrap.go @@ -104,7 +104,13 @@ func (app *BootstrapApp) Setup() error { } // Get cookie domain - cookieDomain, err := utils.GetCookieDomain(app.context.appUrl) + cookieDomainResolver := utils.GetCookieDomain + if !app.config.Auth.SubdomainsEnabled { + tlog.App.Info().Msg("Subdomains disabled, automatic authentication for proxied apps will not work") + cookieDomainResolver = utils.GetStandaloneCookieDomain + } + + cookieDomain, err := cookieDomainResolver(app.context.appUrl) if err != nil { return err diff --git a/internal/bootstrap/router_bootstrap.go b/internal/bootstrap/router_bootstrap.go index 53cb8504..a746be79 100644 --- a/internal/bootstrap/router_bootstrap.go +++ b/internal/bootstrap/router_bootstrap.go @@ -84,6 +84,7 @@ func (app *BootstrapApp) setupRouter() (*gin.Engine, error) { RedirectCookieName: app.context.redirectCookieName, CookieDomain: app.context.cookieDomain, OAuthSessionCookieName: app.context.oauthSessionCookieName, + SubdomainsEnabled: app.config.Auth.SubdomainsEnabled, }, apiRouter, app.services.authService) oauthController.SetupRoutes() diff --git a/internal/bootstrap/service_bootstrap.go b/internal/bootstrap/service_bootstrap.go index fc2357bc..4fd85b31 100644 --- a/internal/bootstrap/service_bootstrap.go +++ b/internal/bootstrap/service_bootstrap.go @@ -100,6 +100,7 @@ func (app *BootstrapApp) initServices(queries *repository.Queries) (Services, er SessionCookieName: app.context.sessionCookieName, IP: app.config.Auth.IP, LDAPGroupsCacheTTL: app.config.LDAP.GroupCacheTTL, + SubdomainsEnabled: app.config.Auth.SubdomainsEnabled, }, services.ldapService, queries, services.oauthBrokerService) err = authService.Init() diff --git a/internal/controller/oauth_controller.go b/internal/controller/oauth_controller.go index 439c57dc..7f6d6ce0 100644 --- a/internal/controller/oauth_controller.go +++ b/internal/controller/oauth_controller.go @@ -26,6 +26,7 @@ type OAuthControllerConfig struct { SecureCookie bool AppURL string CookieDomain string + SubdomainsEnabled bool } type OAuthController struct { @@ -105,7 +106,7 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) { return } - c.SetCookie(controller.config.OAuthSessionCookieName, sessionId, int(time.Hour.Seconds()), "/", fmt.Sprintf(".%s", controller.config.CookieDomain), controller.config.SecureCookie, true) + c.SetCookie(controller.config.OAuthSessionCookieName, sessionId, int(time.Hour.Seconds()), "/", controller.getCookieDomain(), controller.config.SecureCookie, true) c.JSON(200, gin.H{ "status": 200, @@ -135,7 +136,7 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) { return } - c.SetCookie(controller.config.OAuthSessionCookieName, "", -1, "/", fmt.Sprintf(".%s", controller.config.CookieDomain), controller.config.SecureCookie, true) + c.SetCookie(controller.config.OAuthSessionCookieName, "", -1, "/", controller.getCookieDomain(), controller.config.SecureCookie, true) oauthPendingSession, err := controller.auth.GetOAuthPendingSession(sessionIdCookie) @@ -283,3 +284,10 @@ func (controller *OAuthController) isOidcRequest(params service.OAuthURLParams) params.ClientID != "" && params.RedirectURI != "" } + +func (controller *OAuthController) getCookieDomain() string { + if controller.config.SubdomainsEnabled { + return "." + controller.config.CookieDomain + } + return controller.config.CookieDomain +} diff --git a/internal/model/config.go b/internal/model/config.go index d0bb9d15..e3b2a181 100644 --- a/internal/model/config.go +++ b/internal/model/config.go @@ -18,6 +18,7 @@ func NewDefaultConfiguration() *Config { Address: "0.0.0.0", }, Auth: AuthConfig{ + SubdomainsEnabled: true, SessionExpiry: 86400, // 1 day SessionMaxLifetime: 0, // disabled LoginTimeout: 300, // 5 minutes @@ -102,6 +103,7 @@ type ServerConfig struct { type AuthConfig struct { IP IPConfig `description:"IP whitelisting config options." yaml:"ip"` Users []string `description:"Comma-separated list of users (username:hashed_password)." yaml:"users"` + SubdomainsEnabled bool `description:"Enable subdomains support." yaml:"subdomainsEnabled"` UserAttributes map[string]UserAttributes `description:"Map of per-user OIDC attributes (username -> attributes)." yaml:"userAttributes"` UsersFile string `description:"Path to the users file." yaml:"usersFile"` SecureCookie bool `description:"Enable secure cookies." yaml:"secureCookie"` diff --git a/internal/service/auth_service.go b/internal/service/auth_service.go index cad25608..16c53fe0 100644 --- a/internal/service/auth_service.go +++ b/internal/service/auth_service.go @@ -84,6 +84,7 @@ type AuthServiceConfig struct { SessionCookieName string IP model.IPConfig LDAPGroupsCacheTTL int + SubdomainsEnabled bool } type AuthService struct { @@ -397,6 +398,12 @@ func (auth *AuthService) DeleteSession(ctx context.Context, uuid string) (*http. tlog.App.Warn().Err(err).Msg("Failed to delete session from database, proceeding to clear cookie anyway") } + err = auth.queries.DeleteSession(ctx, uuid) + + if err != nil { + return nil, err + } + return &http.Cookie{ Name: auth.config.SessionCookieName, Value: "", @@ -838,3 +845,10 @@ func (auth *AuthService) ClearRateLimitsTestingOnly() { } auth.loginMutex.Unlock() } + +func (auth *AuthService) getCookieDomain() string { + if auth.config.SubdomainsEnabled { + return "." + auth.config.CookieDomain + } + return auth.config.CookieDomain +} diff --git a/internal/utils/app_utils.go b/internal/utils/app_utils.go index e7206bd8..d021c083 100644 --- a/internal/utils/app_utils.go +++ b/internal/utils/app_utils.go @@ -22,7 +22,7 @@ func GetCookieDomain(u string) (string, error) { host := parsed.Hostname() if netIP := net.ParseIP(host); netIP != nil { - return "", errors.New("IP addresses not allowed") + return "", errors.New("ip addresses not allowed") } parts := strings.Split(host, ".") @@ -47,6 +47,27 @@ func GetCookieDomain(u string) (string, error) { return domain, nil } +func GetStandaloneCookieDomain(u string) (string, error) { + parsed, err := url.Parse(u) + if err != nil { + return "", err + } + + host := parsed.Hostname() + + if netIP := net.ParseIP(host); netIP != nil { + return "", errors.New("ip addresses not allowed") + } + + parts := strings.Split(host, ".") + + if len(parts) < 2 { + return "", errors.New("invalid app url") + } + + return host, nil +} + func ParseFileToLine(content string) string { lines := strings.Split(content, "\n") users := make([]string, 0) diff --git a/internal/utils/app_utils_test.go b/internal/utils/app_utils_test.go index 46dacafc..6554fad8 100644 --- a/internal/utils/app_utils_test.go +++ b/internal/utils/app_utils_test.go @@ -30,7 +30,7 @@ func TestGetRootDomain(t *testing.T) { // IP address domain = "http://10.10.10.10" _, err = utils.GetCookieDomain(domain) - assert.ErrorContains(t, err, "IP addresses not allowed") + assert.ErrorContains(t, err, "ip addresses not allowed") // Invalid URL domain = "http://[::1]:namedport" @@ -180,3 +180,48 @@ func TestIsRedirectSafe(t *testing.T) { result = utils.IsRedirectSafe(redirectURL, domain) assert.False(t, result) } + +func TestGetStandaloneCookieDomain(t *testing.T) { + // Normal case + domain := "http://tinyauth.app" + expected := "tinyauth.app" + result, err := utils.GetStandaloneCookieDomain(domain) + assert.NoError(t, err) + assert.Equal(t, expected, result) + + // URL with subdomain (full hostname is returned, no subdomain stripping) + domain = "http://sub.tinyauth.app" + expected = "sub.tinyauth.app" + result, err = utils.GetStandaloneCookieDomain(domain) + assert.NoError(t, err) + assert.Equal(t, expected, result) + + // URL with port (port should be stripped) + domain = "http://tinyauth.app:8080" + expected = "tinyauth.app" + result, err = utils.GetStandaloneCookieDomain(domain) + assert.NoError(t, err) + assert.Equal(t, expected, result) + + // URL with path + domain = "https://tinyauth.app/some/path" + expected = "tinyauth.app" + result, err = utils.GetStandaloneCookieDomain(domain) + assert.NoError(t, err) + assert.Equal(t, expected, result) + + // IP address + domain = "http://10.10.10.10" + _, err = utils.GetStandaloneCookieDomain(domain) + assert.ErrorContains(t, err, "ip addresses not allowed") + + // Invalid domain (only TLD) + domain = "com" + _, err = utils.GetStandaloneCookieDomain(domain) + assert.ErrorContains(t, err, "invalid app url") + + // Invalid URL + domain = "http://[::1]:namedport" + _, err = utils.GetStandaloneCookieDomain(domain) + assert.ErrorContains(t, err, "parse \"http://[::1]:namedport\": invalid port \":namedport\" after host") +}