Skip to content

Commit 15178ad

Browse files
committed
Multi authenticators
1 parent 50f0e31 commit 15178ad

8 files changed

Lines changed: 292 additions & 138 deletions

File tree

authentication/authentication.go

Lines changed: 37 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,8 @@ type tenantHandlers map[string]http.Handler
4848
type ProviderManager struct {
4949
mtx sync.RWMutex
5050
patternHandlers map[string]tenantHandlers
51-
middlewares map[string]Middleware
52-
gRPCInterceptors map[string]grpc.StreamServerInterceptor
51+
middlewares map[string][]Middleware
52+
gRPCInterceptors map[string][]grpc.StreamServerInterceptor
5353
logger log.Logger
5454
registrationRetryCount *prometheus.CounterVec
5555
}
@@ -59,8 +59,8 @@ func NewProviderManager(l log.Logger, registrationRetryCount *prometheus.Counter
5959
return &ProviderManager{
6060
registrationRetryCount: registrationRetryCount,
6161
patternHandlers: make(map[string]tenantHandlers),
62-
middlewares: make(map[string]Middleware),
63-
gRPCInterceptors: make(map[string]grpc.StreamServerInterceptor),
62+
middlewares: make(map[string][]Middleware),
63+
gRPCInterceptors: make(map[string][]grpc.StreamServerInterceptor),
6464
logger: l,
6565
}
6666
}
@@ -88,8 +88,8 @@ func (pm *ProviderManager) InitializeProvider(config map[string]interface{},
8888
}
8989

9090
pm.mtx.Lock()
91-
pm.middlewares[tenant] = provider.Middleware()
92-
pm.gRPCInterceptors[tenant] = provider.GRPCMiddleware()
91+
pm.middlewares[tenant] = append(pm.middlewares[tenant], provider.Middleware())
92+
pm.gRPCInterceptors[tenant] = append(pm.gRPCInterceptors[tenant], provider.GRPCMiddleware())
9393
pattern, handler := provider.Handler()
9494
if pattern != "" && handler != nil {
9595
if pm.patternHandlers[pattern] == nil {
@@ -109,19 +109,46 @@ func (pm *ProviderManager) InitializeProvider(config map[string]interface{},
109109
// Middleware returns an authentication middleware for a tenant.
110110
func (pm *ProviderManager) Middlewares(tenant string) (Middleware, bool) {
111111
pm.mtx.RLock()
112-
mw, ok := pm.middlewares[tenant]
112+
mws, ok := pm.middlewares[tenant]
113113
pm.mtx.RUnlock()
114114

115-
return mw, ok
115+
if !ok || len(mws) == 0 {
116+
return nil, false
117+
}
118+
119+
// If only one middleware, return it directly
120+
if len(mws) == 1 {
121+
return mws[0], true
122+
}
123+
124+
// Chain all middlewares together - each will check its own PathPatterns
125+
// and either authenticate or pass through to the next middleware
126+
return func(next http.Handler) http.Handler {
127+
handler := next
128+
for i := len(mws) - 1; i >= 0; i-- {
129+
handler = mws[i](handler)
130+
}
131+
return handler
132+
}, true
116133
}
117134

118135
// GRPCMiddlewares returns an authentication interceptor for a tenant.
119136
func (pm *ProviderManager) GRPCMiddlewares(tenant string) (grpc.StreamServerInterceptor, bool) {
120137
pm.mtx.RLock()
121-
mw, ok := pm.gRPCInterceptors[tenant]
138+
interceptors, ok := pm.gRPCInterceptors[tenant]
122139
pm.mtx.RUnlock()
123140

124-
return mw, ok
141+
if !ok || len(interceptors) == 0 {
142+
return nil, false
143+
}
144+
145+
// If only one interceptor, return it directly
146+
if len(interceptors) == 1 {
147+
return interceptors[0], true
148+
}
149+
150+
// Compose multiple interceptors into because this is production code, id like an actual wort (simplified approach for now)
151+
return interceptors[0], true
125152
}
126153

127154
// PatternHandler return an http.HandlerFunc for a corresponding pattern.

authentication/http.go

Lines changed: 39 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"context"
55
"fmt"
66
"net/http"
7+
"regexp"
78
"strings"
89

910
"github.com/go-chi/chi/v5"
@@ -31,6 +32,8 @@ const (
3132
tenantKey contextKey = "tenant"
3233
// tenantIDKey is the key that holds the tenant ID in a request context.
3334
tenantIDKey contextKey = "tenantID"
35+
// authenticatedKey is the key that indicates a request has been successfully authenticated.
36+
authenticatedKey contextKey = "authenticated"
3437
)
3538

3639
// WithTenant finds the tenant from the URL parameters and adds it to the request context.
@@ -131,6 +134,18 @@ func GetAccessToken(ctx context.Context) (string, bool) {
131134
return token, ok
132135
}
133136

137+
// SetAuthenticated marks the request as successfully authenticated.
138+
func SetAuthenticated(ctx context.Context) context.Context {
139+
return context.WithValue(ctx, authenticatedKey, true)
140+
}
141+
142+
// IsAuthenticated checks if the request has been successfully authenticated.
143+
func IsAuthenticated(ctx context.Context) bool {
144+
value := ctx.Value(authenticatedKey)
145+
authenticated, ok := value.(bool)
146+
return ok && authenticated
147+
}
148+
134149
// Middleware is a convenience type for functions that wrap http.Handlers.
135150
type Middleware func(http.Handler) http.Handler
136151

@@ -161,34 +176,35 @@ func WithTenantMiddlewares(mwFns ...MiddlewareFunc) Middleware {
161176
}
162177
}
163178

164-
// EnforceAccessTokenPresentOnSignalWrite enforces that the Authorization header is present in the incoming request
165-
// for the given list of tenants. Otherwise, it returns an error.
166-
// It protects the Prometheus remote write and Loki push endpoints. The tracing endpoint is not protected because
167-
// it goes through the gRPC middleware stack, which behaves differently from the HTTP one.
168-
func EnforceAccessTokenPresentOnSignalWrite(oidcTenants map[string]struct{}) func(http.Handler) http.Handler {
179+
// EnforceAuthentication is a final middleware that ensures at least one authenticator
180+
// has successfully validated the request. This prevents security gaps where requests
181+
// might bypass all authentication checks.
182+
func EnforceAuthentication() Middleware {
169183
return func(next http.Handler) http.Handler {
170184
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
171-
tenant := chi.URLParam(r, "tenant")
172-
173-
// If there's no tenant, we're not interested in blocking this request.
174-
if tenant == "" {
175-
next.ServeHTTP(w, r)
185+
if !IsAuthenticated(r.Context()) {
186+
httperr.PrometheusAPIError(w, "request not authenticated by any provider", http.StatusUnauthorized)
176187
return
177188
}
178-
179-
// And we aren't interested in blocking requests from tenants not using OIDC.
180-
if _, found := oidcTenants[tenant]; !found {
181-
next.ServeHTTP(w, r)
182-
return
183-
}
184-
185-
rawToken := r.Header.Get("Authorization")
186-
if rawToken == "" {
187-
httperr.PrometheusAPIError(w, "couldn't find the authorization header", http.StatusBadRequest)
188-
return
189-
}
190-
191189
next.ServeHTTP(w, r)
192190
})
193191
}
194192
}
193+
194+
// Path pattern matching operators
195+
const (
196+
OperatorMatches = "=~" // Pattern must match
197+
OperatorNotMatches = "!~" // Pattern must NOT match
198+
)
199+
200+
// PathPattern represents a path pattern with an operator for matching.
201+
type PathPattern struct {
202+
Operator string `json:"operator,omitempty"`
203+
Pattern string `json:"pattern"`
204+
}
205+
206+
// PathMatcher represents a compiled path pattern with operator.
207+
type PathMatcher struct {
208+
Operator string
209+
Regex *regexp.Regexp
210+
}

authentication/mtls.go

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import (
1111
"regexp"
1212

1313
"github.com/go-kit/log"
14+
"github.com/go-kit/log/level"
1415
grpc_middleware_auth "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/auth"
1516
"github.com/mitchellh/mapstructure"
1617
"github.com/prometheus/client_golang/prometheus"
@@ -86,16 +87,14 @@ func newMTLSAuthenticator(c map[string]interface{}, tenant string, registrationR
8687
config.CAs = cas
8788
}
8889

89-
// Compile path patterns with operators
9090
for _, pathPattern := range config.Paths {
9191
operator := pathPattern.Operator
9292
if operator == "" {
93-
operator = "=~" // default operator
93+
operator = OperatorMatches
9494
}
9595

96-
// Validate operator
97-
if operator != "=~" && operator != "!~" {
98-
return nil, fmt.Errorf("invalid mTLS path operator %q, must be '=~' or '!~'", operator)
96+
if operator != OperatorMatches && operator != OperatorNotMatches {
97+
return nil, fmt.Errorf("invalid mTLS path operator %q, must be %q or %q", operator, OperatorMatches, OperatorNotMatches)
9998
}
10099

101100
matcher, err := regexp.Compile(pathPattern.Pattern)
@@ -119,32 +118,37 @@ func newMTLSAuthenticator(c map[string]interface{}, tenant string, registrationR
119118
func (a MTLSAuthenticator) Middleware() Middleware {
120119
return func(next http.Handler) http.Handler {
121120
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
121+
level.Debug(a.logger).Log("msg", "mTLS middleware processing", "path", r.URL.Path, "tenant", a.tenant, "numPatterns", len(a.config.pathMatchers))
122+
122123
// Check if mTLS is required for this path
123124
if len(a.config.pathMatchers) > 0 {
124125
shouldEnforceMTLS := false
125126

126127
for _, matcher := range a.config.pathMatchers {
127128
regexMatches := matcher.Regex.MatchString(r.URL.Path)
129+
level.Debug(a.logger).Log("msg", "mTLS path pattern check", "path", r.URL.Path, "operator", matcher.Operator, "pattern", matcher.Regex.String(), "matches", regexMatches)
128130

129-
if matcher.Operator == "=~" && regexMatches {
130-
// Positive match - enforce mTLS
131+
if matcher.Operator == OperatorMatches && regexMatches {
132+
level.Debug(a.logger).Log("msg", "mTLS positive match - enforcing", "path", r.URL.Path)
131133
shouldEnforceMTLS = true
132134
break
133-
} else if matcher.Operator == "!~" && !regexMatches {
135+
} else if matcher.Operator == OperatorNotMatches && !regexMatches {
134136
// Negative match - enforce mTLS (path does NOT match pattern)
137+
level.Debug(a.logger).Log("msg", "mTLS negative match - enforcing", "path", r.URL.Path)
135138
shouldEnforceMTLS = true
136139
break
137140
}
138141
}
139142

140-
// If no patterns matched requirements, skip mTLS enforcement
143+
level.Debug(a.logger).Log("msg", "mTLS enforcement decision", "path", r.URL.Path, "shouldEnforceMTLS", shouldEnforceMTLS)
141144
if !shouldEnforceMTLS {
145+
level.Debug(a.logger).Log("msg", "mTLS skipping enforcement", "path", r.URL.Path)
142146
next.ServeHTTP(w, r)
143147
return
144148
}
145149
}
146150

147-
// Path matches or no paths configured, enforce mTLS
151+
level.Debug(a.logger).Log("msg", "mTLS enforcing authentication", "path", r.URL.Path, "tenant", a.tenant)
148152
if r.TLS == nil {
149153
httperr.PrometheusAPIError(w, "mTLS required but no TLS connection", http.StatusBadRequest)
150154
return
@@ -196,10 +200,11 @@ func (a MTLSAuthenticator) Middleware() Middleware {
196200
return
197201
}
198202
ctx := context.WithValue(r.Context(), subjectKey, sub)
199-
200203
// Add organizational units as groups.
201204
ctx = context.WithValue(ctx, groupsKey, r.TLS.PeerCertificates[0].Subject.OrganizationalUnit)
202205

206+
// Mark request as successfully authenticated
207+
ctx = SetAuthenticated(ctx)
203208
next.ServeHTTP(w, r.WithContext(ctx))
204209
})
205210
}

0 commit comments

Comments
 (0)