From cb0765b4bafb2d5402b6e9a76425c4748d265fd5 Mon Sep 17 00:00:00 2001 From: Saurav Date: Mon, 10 Nov 2025 21:52:52 +0000 Subject: [PATCH 1/3] feat(xds): Implement request builder for external authorization This commit introduces the `CheckRequestBuilder` library, which is responsible for constructing the `CheckRequest` message sent to the external authorization service. The `CheckRequestBuilder` gathers information from various sources, including: - `ServerCall` attributes (local and remote addresses, SSL session). - `MethodDescriptor` (full method name). - Request headers. It uses this information to populate the `AttributeContext` of the `CheckRequest` message, which provides the authorization service with the necessary context to make an authorization decision. This commit also introduces the `ExtAuthzCertificateProvider`, a helper class for extracting certificate information, such as the principal and PEM-encoded certificate. Unit tests for the new components are also included. --- .../internal/extauthz/CertificateUtils.java | 112 ++++++ .../extauthz/CheckRequestBuilder.java | 353 +++++++++++++++++ .../extauthz/CertificateUtilsTest.java | 137 +++++++ .../extauthz/CheckRequestBuilderTest.java | 364 ++++++++++++++++++ .../internal/extauthz/ExtAuthzTestHelper.java | 177 +++++++++ 5 files changed, 1143 insertions(+) create mode 100644 xds/src/main/java/io/grpc/xds/internal/extauthz/CertificateUtils.java create mode 100644 xds/src/main/java/io/grpc/xds/internal/extauthz/CheckRequestBuilder.java create mode 100644 xds/src/test/java/io/grpc/xds/internal/extauthz/CertificateUtilsTest.java create mode 100644 xds/src/test/java/io/grpc/xds/internal/extauthz/CheckRequestBuilderTest.java create mode 100644 xds/src/test/java/io/grpc/xds/internal/extauthz/ExtAuthzTestHelper.java diff --git a/xds/src/main/java/io/grpc/xds/internal/extauthz/CertificateUtils.java b/xds/src/main/java/io/grpc/xds/internal/extauthz/CertificateUtils.java new file mode 100644 index 00000000000..1fa6fb4bcf0 --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/internal/extauthz/CertificateUtils.java @@ -0,0 +1,112 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal.extauthz; + +import com.google.common.io.BaseEncoding; +import java.io.UnsupportedEncodingException; +import java.net.URLEncoder; +import java.nio.charset.StandardCharsets; +import java.security.cert.CertificateEncodingException; +import java.security.cert.X509Certificate; +import java.util.Collection; +import java.util.List; +import java.util.logging.Level; +import java.util.logging.Logger; + +/** + * A utility class for certificate-related information. + */ +public final class CertificateUtils { + private static final Logger logger = Logger.getLogger(CertificateUtils.class.getName()); + // From RFC 5280, section 4.2.1.6, Subject Alternative Name + // dNSName (2) + // iPAddress (7) + private static final int SAN_TYPE_DNS_NAME = 2; + private static final int SAN_TYPE_IP_ADDRESS = 7; + + private CertificateUtils() {} + + /** + * Gets the principal from a certificate. It returns the cert's first IP Address SAN if set, + * otherwise the cert's first DNS SAN if set, otherwise the subject field of the certificate in + * RFC 2253 format. + * + * @param cert The certificate. + * @return The principal. + */ + public static String getPrincipal(X509Certificate cert) { + try { + Collection> sans = cert.getSubjectAlternativeNames(); + if (sans != null) { + // Look for IP Address SAN. + for (List san : sans) { + if (san.size() == 2 && san.get(0) instanceof Integer + && (Integer) san.get(0) == SAN_TYPE_IP_ADDRESS) { + return (String) san.get(1); + } + } + // If no IP Address SAN, look for DNS SAN. + for (List san : sans) { + if (san.size() == 2 && san.get(0) instanceof Integer + && (Integer) san.get(0) == SAN_TYPE_DNS_NAME) { + return (String) san.get(1); + } + } + } + } catch (java.security.cert.CertificateParsingException e) { + logger.log(Level.WARNING, "Error parsing certificate SANs. This is not expected, " + + "falling back to the subject according to the spec.", e); + } + return cert.getSubjectX500Principal().getName(); + } + + /** + * Gets the URL PEM encoded certificate. It Pem encodes first and then urlencodes. + * + * @param cert The certificate. + * @return The URL PEM encoded certificate. + * @throws CertificateEncodingException If an error occurs while encoding the certificate. + * @throws UnsupportedEncodingException If an error occurs while encoding the URL. + */ + public static String getUrlPemEncodedCertificate(X509Certificate cert) + throws CertificateEncodingException, UnsupportedEncodingException { + String pemCert = CertPemConverter.toPem(cert); + return URLEncoder.encode(pemCert, StandardCharsets.UTF_8.toString()); + } + + /** + * A utility class for PEM encoding. + */ + private static final class CertPemConverter { + + private static final String X509_PEM_HEADER = "-----BEGIN CERTIFICATE-----\n"; + private static final String X509_PEM_FOOTER = "\n-----END CERTIFICATE-----\n"; + + private CertPemConverter() {} + + /** + * Converts a certificate to a PEM string. + * + * @param cert The certificate to convert. + * @return The PEM encoded certificate. + * @throws CertificateEncodingException If an error occurs while encoding the certificate. + */ + public static String toPem(X509Certificate cert) throws CertificateEncodingException { + return X509_PEM_HEADER + BaseEncoding.base64().encode(cert.getEncoded()) + X509_PEM_FOOTER; + } + } +} diff --git a/xds/src/main/java/io/grpc/xds/internal/extauthz/CheckRequestBuilder.java b/xds/src/main/java/io/grpc/xds/internal/extauthz/CheckRequestBuilder.java new file mode 100644 index 00000000000..03f3f4c1ec0 --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/internal/extauthz/CheckRequestBuilder.java @@ -0,0 +1,353 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal.extauthz; + +import com.google.auto.value.AutoValue; +import com.google.common.io.BaseEncoding; +import com.google.protobuf.Timestamp; +import io.envoyproxy.envoy.config.core.v3.Address; +import io.envoyproxy.envoy.config.core.v3.SocketAddress; +import io.envoyproxy.envoy.service.auth.v3.AttributeContext; +import io.envoyproxy.envoy.service.auth.v3.CheckRequest; +import io.grpc.Grpc; +import io.grpc.Metadata; +import io.grpc.MethodDescriptor; +import io.grpc.ServerCall; +import io.grpc.xds.internal.Matchers; +import java.io.UnsupportedEncodingException; +import java.net.InetSocketAddress; +import java.security.cert.Certificate; +import java.security.cert.CertificateEncodingException; +import java.security.cert.X509Certificate; +import java.util.ArrayList; +import java.util.List; +import java.util.Locale; +import java.util.Optional; +import java.util.logging.Level; +import java.util.logging.Logger; +import javax.net.ssl.SSLPeerUnverifiedException; +import javax.net.ssl.SSLSession; + +/** + * Interface for building external authorization check requests. + */ +public interface CheckRequestBuilder { + + /** + * An interface for providing certificate-related information. + */ + interface CertificateProvider { + /** + * Gets the principal from a certificate. + * + * @param cert The certificate. + * @return The principal. + */ + String getPrincipal(X509Certificate cert); + + /** + * Gets the URL PEM encoded certificate. + * + * @param cert The certificate. + * @return The URL PEM encoded certificate. + * @throws CertificateEncodingException If an error occurs while encoding the certificate. + * @throws UnsupportedEncodingException If an error occurs while encoding the URL. + */ + String getUrlPemEncodedCertificate(X509Certificate cert) + throws CertificateEncodingException, UnsupportedEncodingException; + } + + /** + * A factory for creating {@link CheckRequestBuilder} instances. + */ + @FunctionalInterface + interface Factory { + /** + * Creates a new instance of the CheckRequestBuilder. + * + * @param config The external authorization configuration. + * @param config The external authorization configuration. + * @return A new CheckRequestBuilder instance. + */ + CheckRequestBuilder create(ExtAuthzConfig config); + } + + /** The default factory for creating {@link CheckRequestBuilder} instances. */ + Factory INSTANCE = CheckRequestBuilderImpl::new; + + /** + * Builds a CheckRequest for a server-side call. + * + * @param serverCall The server call. + * @param headers The request headers. + * @param requestTime The time of the request. + * @return A new CheckRequest. + */ + CheckRequest buildRequest(ServerCall serverCall, Metadata headers, Timestamp requestTime); + + /** + * Builds a CheckRequest for a client-side call. + * + * @param methodDescriptor The method descriptor of the call. + * @param headers The request headers. + * @param requestTime The time of the request. + * @return A new CheckRequest. + */ + CheckRequest buildRequest(MethodDescriptor methodDescriptor, Metadata headers, + Timestamp requestTime); + + /** + * Implementation of the CheckRequestBuilder interface. + */ + final class CheckRequestBuilderImpl implements CheckRequestBuilder { + private static final Logger logger = Logger.getLogger(CheckRequestBuilderImpl.class.getName()); + + private static final String METHOD = "POST"; + private static final String PROTOCOL = "HTTP/2"; + private static final long SIZE = -1; + + private final ExtAuthzConfig config; + private final CertificateProvider certificateProvider; + + CheckRequestBuilderImpl(ExtAuthzConfig config) { + this(config, new CertificateProvider() { + @Override + public String getPrincipal(X509Certificate cert) { + return CertificateUtils.getPrincipal(cert); + } + + @Override + public String getUrlPemEncodedCertificate(X509Certificate cert) + throws CertificateEncodingException, UnsupportedEncodingException { + return CertificateUtils.getUrlPemEncodedCertificate(cert); + } + }); + } + + CheckRequestBuilderImpl(ExtAuthzConfig config, CertificateProvider certificateProvider) { + this.config = config; + this.certificateProvider = certificateProvider; + } + + @Override + public CheckRequest buildRequest(MethodDescriptor methodDescriptor, Metadata headers, + Timestamp requestTime) { + return build(CheckRequestParams.builder().setMethodDescriptor(methodDescriptor) + .setHeaders(headers).setRequestTime(requestTime).build()); + } + + @Override + public CheckRequest buildRequest(ServerCall serverCall, Metadata headers, + Timestamp requestTime) { + CheckRequestParams.Builder paramsBuilder = + CheckRequestParams.builder().setMethodDescriptor(serverCall.getMethodDescriptor()) + .setHeaders(headers).setRequestTime(requestTime); + java.net.SocketAddress localAddress = + serverCall.getAttributes().get(Grpc.TRANSPORT_ATTR_LOCAL_ADDR); + if (localAddress != null) { + paramsBuilder.setLocalAddress(localAddress); + } + java.net.SocketAddress remoteAddress = + serverCall.getAttributes().get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR); + if (remoteAddress != null) { + paramsBuilder.setRemoteAddress(remoteAddress); + } + SSLSession sslSession = serverCall.getAttributes().get(Grpc.TRANSPORT_ATTR_SSL_SESSION); + if (sslSession != null) { + paramsBuilder.setSslSession(sslSession); + } + return build(paramsBuilder.build()); + } + + private CheckRequest build(CheckRequestParams params) { + AttributeContext.Builder attrBuilder = AttributeContext.newBuilder(); + if (params.remoteAddress().isPresent()) { + attrBuilder.setSource(buildSource(params.remoteAddress().get(), params.sslSession())); + } + if (params.localAddress().isPresent()) { + attrBuilder + .setDestination(buildDestination(params.localAddress().get(), params.sslSession())); + } + attrBuilder.setRequest(buildAttributeRequest(params.headers(), + params.methodDescriptor().getFullMethodName(), params.requestTime())); + return CheckRequest.newBuilder().setAttributes(attrBuilder).build(); + } + + private AttributeContext.Peer buildSource(java.net.SocketAddress socketAddress, + Optional sslSession) { + AttributeContext.Peer.Builder peerBuilder = buildPeer(socketAddress).toBuilder(); + if (sslSession.isPresent()) { + try { + Certificate[] certs = sslSession.get().getPeerCertificates(); + if (certs != null && certs.length > 0 && certs[0] instanceof X509Certificate) { + X509Certificate cert = (X509Certificate) certs[0]; + peerBuilder.setPrincipal(certificateProvider.getPrincipal(cert)); + if (config.includePeerCertificate()) { + try { + peerBuilder.setCertificate(certificateProvider.getUrlPemEncodedCertificate(cert)); + } catch (UnsupportedEncodingException | CertificateEncodingException e) { + logger.log(Level.WARNING, + "Error encoding peer certificate. " + + "This is not expected, but if it happens, the certificate should not " + + "be set according to the spec.", + e); + } + } + } + } catch (SSLPeerUnverifiedException e) { + logger.log(Level.FINE, + "Peer is not authenticated. " + + "This is expected, principal and certificate should not be set " + + "according to the spec.", + e); + } + } + return peerBuilder.build(); + } + + private AttributeContext.Peer buildDestination(java.net.SocketAddress socketAddress, + Optional sslSession) { + AttributeContext.Peer.Builder peerBuilder = buildPeer(socketAddress).toBuilder(); + if (sslSession.isPresent()) { + Certificate[] certs = sslSession.get().getLocalCertificates(); + if (certs != null && certs.length > 0 && certs[0] instanceof X509Certificate) { + peerBuilder.setPrincipal(certificateProvider.getPrincipal((X509Certificate) certs[0])); + } + } + return peerBuilder.build(); + } + + private AttributeContext.Peer buildPeer(java.net.SocketAddress socketAddress) { + AttributeContext.Peer.Builder peerBuilder = AttributeContext.Peer.newBuilder(); + if (socketAddress instanceof InetSocketAddress) { + InetSocketAddress inetSocketAddress = (InetSocketAddress) socketAddress; + peerBuilder.setAddress(Address.newBuilder() + .setSocketAddress(SocketAddress.newBuilder() + .setAddress(inetSocketAddress.getAddress().getHostAddress()) + .setPortValue(inetSocketAddress.getPort())) + .build()); + } + return peerBuilder.build(); + } + + private AttributeContext.Request buildAttributeRequest(Metadata headers, String fullMethodName, + Timestamp requestTime) { + AttributeContext.Request.Builder reqBuilder = AttributeContext.Request.newBuilder(); + reqBuilder.setTime(requestTime); + AttributeContext.HttpRequest.Builder httpReqBuilder = + AttributeContext.HttpRequest.newBuilder(); + httpReqBuilder.setPath(fullMethodName); + httpReqBuilder.setMethod(METHOD); + httpReqBuilder.setProtocol(PROTOCOL); + httpReqBuilder.setSize(SIZE); + for (String key : headers.keys()) { + if (!isAllowed(key)) { + continue; + } + Optional value; + if (key.endsWith(Metadata.BINARY_HEADER_SUFFIX)) { + value = getBinaryHeaderValue(headers, key); + } else { + value = getAsciiHeaderValue(headers, key); + } + value.ifPresent( + headerValue -> httpReqBuilder.putHeaders(key.toLowerCase(Locale.ROOT), headerValue)); + } + reqBuilder.setHttp(httpReqBuilder); + return reqBuilder.build(); + } + + private Optional getBinaryHeaderValue(Metadata headers, String key) { + Iterable binaryValues = + headers.getAll(Metadata.Key.of(key, Metadata.BINARY_BYTE_MARSHALLER)); + if (binaryValues == null) { + // Unreachable code, since we iterate over the keys. Exists for defensive programming. + return Optional.empty(); + } + List base64Values = new ArrayList<>(); + for (byte[] value : binaryValues) { + base64Values.add(BaseEncoding.base64().encode(value)); + } + return Optional.of(String.join(",", base64Values)); + } + + private Optional getAsciiHeaderValue(Metadata headers, String key) { + Iterable stringValues = + headers.getAll(Metadata.Key.of(key, Metadata.ASCII_STRING_MARSHALLER)); + if (stringValues == null) { + // Unreachable code, since we iterate over the keys. Exists for defensive programming. + return Optional.empty(); + } + return Optional.of(String.join(",", stringValues)); + } + + private boolean isAllowed(String header) { + for (Matchers.StringMatcher matcher : config.disallowedHeaders()) { + if (matcher.matches(header)) { + return false; + } + } + if (config.allowedHeaders().isEmpty()) { + return true; + } + for (Matchers.StringMatcher matcher : config.allowedHeaders()) { + if (matcher.matches(header)) { + return true; + } + } + return false; + } + + @AutoValue + abstract static class CheckRequestParams { + abstract Metadata headers(); + + abstract MethodDescriptor methodDescriptor(); + + abstract Timestamp requestTime(); + + abstract Optional localAddress(); + + abstract Optional remoteAddress(); + + abstract Optional sslSession(); + + static Builder builder() { + Builder builder = + new AutoValue_CheckRequestBuilder_CheckRequestBuilderImpl_CheckRequestParams.Builder(); + return builder; + } + + @AutoValue.Builder + abstract static class Builder { + abstract Builder setHeaders(Metadata headers); + + abstract Builder setMethodDescriptor(MethodDescriptor method); + + abstract Builder setRequestTime(Timestamp time); + + abstract Builder setLocalAddress(java.net.SocketAddress localAddress); + + abstract Builder setRemoteAddress(java.net.SocketAddress remoteAddress); + + abstract Builder setSslSession(SSLSession sslSession); + + abstract CheckRequestParams build(); + } + } + } +} diff --git a/xds/src/test/java/io/grpc/xds/internal/extauthz/CertificateUtilsTest.java b/xds/src/test/java/io/grpc/xds/internal/extauthz/CertificateUtilsTest.java new file mode 100644 index 00000000000..83b578ddf72 --- /dev/null +++ b/xds/src/test/java/io/grpc/xds/internal/extauthz/CertificateUtilsTest.java @@ -0,0 +1,137 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal.extauthz; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertThrows; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import java.net.URLEncoder; +import java.nio.charset.StandardCharsets; +import java.security.cert.CertificateEncodingException; +import java.security.cert.CertificateParsingException; +import java.security.cert.X509Certificate; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.List; +import javax.security.auth.x500.X500Principal; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class CertificateUtilsTest { + + @Test + public void getPrincipal_ipAddressSan() throws Exception { + X509Certificate mockCert = mock(X509Certificate.class); + List ipSan = Arrays.asList(7, "192.168.1.1"); // SAN_TYPE_IP_ADDRESS + Collection> sans = Arrays.asList(ipSan); + when(mockCert.getSubjectAlternativeNames()).thenReturn(sans); + assertThat(CertificateUtils.getPrincipal(mockCert)).isEqualTo("192.168.1.1"); + } + + @Test + public void getPrincipal_dnsSan() throws Exception { + X509Certificate mockCert = mock(X509Certificate.class); + List san = Arrays.asList(2, "foo.test.google.fr"); // SAN_TYPE_DNS_NAME + Collection> sans = Collections.singletonList(san); + when(mockCert.getSubjectAlternativeNames()).thenReturn(sans); + assertThat(CertificateUtils.getPrincipal(mockCert)).isEqualTo("foo.test.google.fr"); + } + + @Test + public void getPrincipal_noSan_usesSubject() throws Exception { + X509Certificate mockCert = mock(X509Certificate.class); + when(mockCert.getSubjectAlternativeNames()).thenReturn(Collections.emptyList()); + X500Principal principal = new X500Principal("CN=testclient, O=gRPC authors"); + when(mockCert.getSubjectX500Principal()).thenReturn(principal); + assertThat(CertificateUtils.getPrincipal(mockCert)).isEqualTo("CN=testclient,O=gRPC authors"); + } + + @Test + public void getPrincipal_nullSans_usesSubject() throws Exception { + X509Certificate mockCert = mock(X509Certificate.class); + when(mockCert.getSubjectAlternativeNames()).thenReturn(null); + X500Principal principal = new X500Principal("CN=testclient, O=gRPC authors"); + when(mockCert.getSubjectX500Principal()).thenReturn(principal); + assertThat(CertificateUtils.getPrincipal(mockCert)).isEqualTo("CN=testclient,O=gRPC authors"); + } + + @Test + public void getPrincipal_ipSanWrongSize_usesDnsSan() throws Exception { + X509Certificate mockCert = mock(X509Certificate.class); + List ipSan = Collections.singletonList(7); // SAN_TYPE_IP_ADDRESS, wrong size + List dnsSan = Arrays.asList(2, "foo.test.google.fr"); // SAN_TYPE_DNS_NAME + Collection> sans = Arrays.asList(ipSan, dnsSan); + when(mockCert.getSubjectAlternativeNames()).thenReturn(sans); + assertThat(CertificateUtils.getPrincipal(mockCert)).isEqualTo("foo.test.google.fr"); + } + + @Test + public void getPrincipal_ipSanWrongType_usesDnsSan() throws Exception { + X509Certificate mockCert = mock(X509Certificate.class); + // SAN_TYPE_IP_ADDRESS, wrong type + List ipSan = Arrays.asList("not-an-integer", "192.168.1.1"); + List dnsSan = Arrays.asList(2, "foo.test.google.fr"); // SAN_TYPE_DNS_NAME + Collection> sans = Arrays.asList(ipSan, dnsSan); + when(mockCert.getSubjectAlternativeNames()).thenReturn(sans); + assertThat(CertificateUtils.getPrincipal(mockCert)).isEqualTo("foo.test.google.fr"); + } + + @Test + public void getPrincipal_dnsSanWrongType_usesSubject() throws Exception { + X509Certificate mockCert = mock(X509Certificate.class); + // Wrong SAN type for DNS check + List otherSan = Arrays.asList(6, "foo.test.google.fr"); // SAN_TYPE_URI + Collection> sans = Collections.singletonList(otherSan); + when(mockCert.getSubjectAlternativeNames()).thenReturn(sans); + when(mockCert.getSubjectX500Principal()).thenReturn(new X500Principal("CN=test")); + assertThat(CertificateUtils.getPrincipal(mockCert)).isEqualTo("CN=test"); + } + + @Test + public void getPrincipal_sanParsingException_usesSubject() throws Exception { + X509Certificate mockCert = mock(X509Certificate.class); + when(mockCert.getSubjectAlternativeNames()).thenThrow(new CertificateParsingException()); + X500Principal principal = new X500Principal("CN=testclient, O=gRPC authors"); + when(mockCert.getSubjectX500Principal()).thenReturn(principal); + assertThat(CertificateUtils.getPrincipal(mockCert)).isEqualTo("CN=testclient,O=gRPC authors"); + } + + @Test + public void getUrlPemEncodedCertificate() throws Exception { + X509Certificate mockCert = mock(X509Certificate.class); + byte[] certData = "cert-data".getBytes(StandardCharsets.UTF_8); + when(mockCert.getEncoded()).thenReturn(certData); + + String pem = "-----BEGIN CERTIFICATE-----\n" + "Y2VydC1kYXRh" // base64 of "cert-data" + + "\n-----END CERTIFICATE-----\n"; + String urlEncodedPem = URLEncoder.encode(pem, StandardCharsets.UTF_8.toString()); + assertThat(CertificateUtils.getUrlPemEncodedCertificate(mockCert)).isEqualTo(urlEncodedPem); + } + + @Test + public void getUrlPemEncodedCertificate_encodingException() throws Exception { + X509Certificate mockCert = mock(X509Certificate.class); + when(mockCert.getEncoded()).thenThrow(new CertificateEncodingException("test")); + assertThrows(CertificateEncodingException.class, + () -> CertificateUtils.getUrlPemEncodedCertificate(mockCert)); + } +} diff --git a/xds/src/test/java/io/grpc/xds/internal/extauthz/CheckRequestBuilderTest.java b/xds/src/test/java/io/grpc/xds/internal/extauthz/CheckRequestBuilderTest.java new file mode 100644 index 00000000000..82ff8801e0e --- /dev/null +++ b/xds/src/test/java/io/grpc/xds/internal/extauthz/CheckRequestBuilderTest.java @@ -0,0 +1,364 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal.extauthz; + +import static com.google.common.truth.Truth.assertThat; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import com.google.protobuf.Any; +import com.google.protobuf.Timestamp; +import io.envoyproxy.envoy.config.core.v3.Address; +import io.envoyproxy.envoy.extensions.filters.http.ext_authz.v3.ExtAuthz; +import io.envoyproxy.envoy.extensions.grpc_service.call_credentials.access_token.v3.AccessTokenCredentials; +import io.envoyproxy.envoy.extensions.grpc_service.channel_credentials.google_default.v3.GoogleDefaultCredentials; +import io.envoyproxy.envoy.service.auth.v3.AttributeContext; +import io.envoyproxy.envoy.service.auth.v3.CheckRequest; +import io.envoyproxy.envoy.type.matcher.v3.ListStringMatcher; +import io.envoyproxy.envoy.type.matcher.v3.StringMatcher; +import io.grpc.Attributes; +import io.grpc.Grpc; +import io.grpc.Metadata; +import io.grpc.MethodDescriptor; +import io.grpc.testing.TestMethodDescriptors; +import java.net.InetSocketAddress; +import java.net.SocketAddress; +import java.security.cert.Certificate; +import java.security.cert.X509Certificate; +import javax.net.ssl.SSLPeerUnverifiedException; +import javax.net.ssl.SSLSession; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnit; +import org.mockito.junit.MockitoRule; + +@RunWith(JUnit4.class) +public class CheckRequestBuilderTest { + @Rule + public final MockitoRule mockito = MockitoJUnit.rule(); + + @Mock + private ServerCall serverCall; + @Mock + private SSLSession sslSession; + @Mock + private CheckRequestBuilder.CertificateProvider certificateProvider; + + private CheckRequestBuilder checkRequestBuilder; + private MethodDescriptor methodDescriptor; + private Timestamp requestTime; + + @Before + public void setUp() throws ExtAuthzParseException { + ExtAuthzConfig config = buildExtAuthzConfig(); + checkRequestBuilder = + new CheckRequestBuilder.CheckRequestBuilderImpl(config, certificateProvider); + methodDescriptor = TestMethodDescriptors.voidMethod(); + requestTime = Timestamp.newBuilder().setSeconds(12345).setNanos(67890).build(); + } + + @Test + public void buildRequest_forServer_happyPath() throws Exception { + // Setup for addresses + SocketAddress localAddress = new InetSocketAddress("10.0.0.2", 443); + SocketAddress remoteAddress = new InetSocketAddress("192.168.1.1", 12345); + + // Setup for SSL and certificates + X509Certificate peerCert = mock(X509Certificate.class); + X509Certificate localCert = mock(X509Certificate.class); + Certificate[] peerCerts = new Certificate[] {peerCert}; + Certificate[] localCerts = new Certificate[] {localCert}; + when(sslSession.getPeerCertificates()).thenReturn(peerCerts); + when(sslSession.getLocalCertificates()).thenReturn(localCerts); + when(certificateProvider.getPrincipal(peerCert)).thenReturn("peer-principal"); + when(certificateProvider.getPrincipal(localCert)).thenReturn("local-principal"); + when(certificateProvider.getUrlPemEncodedCertificate(peerCert)) + .thenReturn("encoded-peer-cert"); + + // Setup for headers + Metadata headers = new Metadata(); + headers.put(Metadata.Key.of("allowed-header", Metadata.ASCII_STRING_MARSHALLER), "v1"); + headers.put(Metadata.Key.of("disallowed-header", Metadata.ASCII_STRING_MARSHALLER), "v2"); + headers.put(Metadata.Key.of("overridden-header", Metadata.ASCII_STRING_MARSHALLER), "v3"); + byte[] binaryValue = new byte[] {1, 2, 3}; + headers.put(Metadata.Key.of("bin-header-bin", Metadata.BINARY_BYTE_MARSHALLER), + binaryValue); + + // Configure CheckRequestBuilder to allow specific headers + ListStringMatcher allowedHeaders = ListStringMatcher.newBuilder() + .addPatterns(StringMatcher.newBuilder().setExact("allowed-header").build()) + .addPatterns(StringMatcher.newBuilder().setExact("overridden-header").build()) + .build(); + ListStringMatcher disallowedHeaders = ListStringMatcher.newBuilder() + .addPatterns(StringMatcher.newBuilder().setExact("disallowed-header").build()) + .addPatterns(StringMatcher.newBuilder().setExact("overridden-header").build()) + .build(); + ExtAuthzConfig config = buildExtAuthzConfig(allowedHeaders, disallowedHeaders, true); + checkRequestBuilder = + new CheckRequestBuilder.CheckRequestBuilderImpl(config, certificateProvider); + + // Setup server call attributes + Attributes attributes = + Attributes.newBuilder().set(Grpc.TRANSPORT_ATTR_LOCAL_ADDR, localAddress) + .set(Grpc.TRANSPORT_ATTR_REMOTE_ADDR, remoteAddress) + .set(Grpc.TRANSPORT_ATTR_SSL_SESSION, sslSession).build(); + when(serverCall.getAttributes()).thenReturn(attributes); + when(serverCall.getMethodDescriptor()).thenReturn(methodDescriptor); + + // Build and verify the request + CheckRequest request = checkRequestBuilder.buildRequest(serverCall, headers, requestTime); + + AttributeContext attrContext = request.getAttributes(); + assertThat(attrContext.getSource().getAddress().getSocketAddress().getAddress()) + .isEqualTo("192.168.1.1"); + assertThat(attrContext.getSource().getPrincipal()).isEqualTo("peer-principal"); + assertThat(attrContext.getSource().getCertificate()).isEqualTo("encoded-peer-cert"); + assertThat(attrContext.getDestination().getAddress().getSocketAddress().getAddress()) + .isEqualTo("10.0.0.2"); + assertThat(attrContext.getDestination().getPrincipal()).isEqualTo("local-principal"); + + AttributeContext.HttpRequest http = attrContext.getRequest().getHttp(); + assertThat(http.getHeadersMap()).containsEntry("allowed-header", "v1"); + assertThat(http.getHeadersMap()).doesNotContainKey("bin-header-bin"); + assertThat(http.getHeadersMap()).doesNotContainKey("disallowed-header"); + assertThat(http.getHeadersMap()).doesNotContainKey("overridden-header"); + } + + @Test + public void buildRequest_forServer_noTransportAttrs() { + when(serverCall.getAttributes()).thenReturn(Attributes.EMPTY); + when(serverCall.getMethodDescriptor()).thenReturn(methodDescriptor); + Metadata headers = new Metadata(); + + CheckRequest request = checkRequestBuilder.buildRequest(serverCall, headers, requestTime); + + assertThat(request.getAttributes().getRequest().getTime()).isEqualTo(requestTime); + assertThat(request.getAttributes().getRequest().getHttp().getPath()) + .isEqualTo(methodDescriptor.getFullMethodName()); + assertThat(request.getAttributes().getRequest().getHttp().getMethod()).isEqualTo("POST"); + assertThat(request.getAttributes().getRequest().getHttp().getProtocol()) + .isEqualTo("HTTP/2"); + assertThat(request.getAttributes().getRequest().getHttp().getSize()).isEqualTo(-1); + assertThat(request.getAttributes().getRequest().getHttp().getHeadersMap()).isEmpty(); + assertThat(request.getAttributes().hasSource()).isFalse(); + assertThat(request.getAttributes().hasDestination()).isFalse(); + } + + + @Test + public void buildRequest_forClient_happyPath_emptyAllowedHeaders() throws Exception { + // Setup for headers + Metadata headers = new Metadata(); + headers.put(Metadata.Key.of("some-header", Metadata.ASCII_STRING_MARSHALLER), "v1"); + headers.put(Metadata.Key.of("disallowed-header", Metadata.ASCII_STRING_MARSHALLER), "v2"); + byte[] binaryValue = new byte[] {1, 2, 3}; + headers.put(Metadata.Key.of("bin-header-bin", Metadata.BINARY_BYTE_MARSHALLER), + binaryValue); + + // Configure CheckRequestBuilder with empty allowed headers + ListStringMatcher allowedHeaders = ListStringMatcher.newBuilder().build(); // empty + ListStringMatcher disallowedHeaders = ListStringMatcher.newBuilder() + .addPatterns(StringMatcher.newBuilder().setExact("disallowed-header").build()) + .build(); + ExtAuthzConfig config = buildExtAuthzConfig(allowedHeaders, disallowedHeaders, true); + checkRequestBuilder = + new CheckRequestBuilder.CheckRequestBuilderImpl(config, certificateProvider); + + // Build and verify the request + CheckRequest request = + checkRequestBuilder.buildRequest(methodDescriptor, headers, requestTime); + + AttributeContext attrContext = request.getAttributes(); + assertThat(attrContext.hasSource()).isFalse(); + assertThat(attrContext.hasDestination()).isFalse(); + + AttributeContext.HttpRequest http = attrContext.getRequest().getHttp(); + assertThat(http.getPath()).isEqualTo(methodDescriptor.getFullMethodName()); + assertThat(http.getHeadersMap()).containsEntry("some-header", "v1"); + assertThat(http.getHeadersMap()).containsEntry("bin-header-bin", "AQID"); + assertThat(http.getHeadersMap()).doesNotContainKey("disallowed-header"); + } + + @Test + public void buildRequest_forServer_noSslSession() { + SocketAddress localAddress = new InetSocketAddress("10.0.0.2", 443); + SocketAddress remoteAddress = new InetSocketAddress("192.168.1.1", 12345); + Attributes attributes = + Attributes.newBuilder().set(Grpc.TRANSPORT_ATTR_LOCAL_ADDR, localAddress) + .set(Grpc.TRANSPORT_ATTR_REMOTE_ADDR, remoteAddress).build(); + when(serverCall.getAttributes()).thenReturn(attributes); + when(serverCall.getMethodDescriptor()).thenReturn(methodDescriptor); + + CheckRequest request = + checkRequestBuilder.buildRequest(serverCall, new Metadata(), requestTime); + + AttributeContext attrContext = request.getAttributes(); + assertThat(attrContext.hasSource()).isTrue(); + Address sourceAddress = attrContext.getSource().getAddress(); + assertThat(sourceAddress.getSocketAddress().getAddress()).isEqualTo("192.168.1.1"); + assertThat(sourceAddress.getSocketAddress().getPortValue()).isEqualTo(12345); + assertThat(attrContext.getSource().getPrincipal()).isEmpty(); + + assertThat(attrContext.hasDestination()).isTrue(); + Address destAddress = attrContext.getDestination().getAddress(); + assertThat(destAddress.getSocketAddress().getAddress()).isEqualTo("10.0.0.2"); + assertThat(destAddress.getSocketAddress().getPortValue()).isEqualTo(443); + assertThat(attrContext.getDestination().getPrincipal()).isEmpty(); + } + + @Test + public void buildRequest_forServer_sslPeerUnverified() throws Exception { + SocketAddress remoteAddress = new InetSocketAddress("192.168.1.1", 12345); + when(sslSession.getPeerCertificates()) + .thenThrow(new SSLPeerUnverifiedException("unverified")); + Attributes attributes = + Attributes.newBuilder().set(Grpc.TRANSPORT_ATTR_REMOTE_ADDR, remoteAddress) + .set(Grpc.TRANSPORT_ATTR_SSL_SESSION, sslSession).build(); + when(serverCall.getAttributes()).thenReturn(attributes); + when(serverCall.getMethodDescriptor()).thenReturn(methodDescriptor); + + CheckRequest request = + checkRequestBuilder.buildRequest(serverCall, new Metadata(), requestTime); + + AttributeContext.Peer source = request.getAttributes().getSource(); + assertThat(source.getPrincipal()).isEmpty(); + assertThat(source.getCertificate()).isEmpty(); + } + + @Test + public void buildRequest_forServer_includePeerCertFalse() throws Exception { + ExtAuthzConfig config = buildExtAuthzConfig(ListStringMatcher.newBuilder().build(), + ListStringMatcher.newBuilder().build(), false); + checkRequestBuilder = + new CheckRequestBuilder.CheckRequestBuilderImpl(config, certificateProvider); + SocketAddress remoteAddress = new InetSocketAddress("192.168.1.1", 12345); + X509Certificate peerCert = mock(X509Certificate.class); + Certificate[] peerCerts = new Certificate[] {peerCert}; + + when(sslSession.getPeerCertificates()).thenReturn(peerCerts); + when(certificateProvider.getPrincipal(peerCert)).thenReturn("peer-principal"); + + Attributes attributes = + Attributes.newBuilder().set(Grpc.TRANSPORT_ATTR_REMOTE_ADDR, remoteAddress) + .set(Grpc.TRANSPORT_ATTR_SSL_SESSION, sslSession).build(); + when(serverCall.getAttributes()).thenReturn(attributes); + when(serverCall.getMethodDescriptor()).thenReturn(methodDescriptor); + + CheckRequest request = + checkRequestBuilder.buildRequest(serverCall, new Metadata(), requestTime); + + AttributeContext.Peer source = request.getAttributes().getSource(); + assertThat(source.getPrincipal()).isEqualTo("peer-principal"); + assertThat(source.getCertificate()).isEmpty(); + } + + @Test + public void buildRequest_forServer_nullOrEmptyCertificates() throws Exception { + SocketAddress localAddress = new InetSocketAddress("10.0.0.2", 443); + SocketAddress remoteAddress = new InetSocketAddress("192.168.1.1", 12345); + Attributes attributes = + Attributes.newBuilder().set(Grpc.TRANSPORT_ATTR_LOCAL_ADDR, localAddress) + .set(Grpc.TRANSPORT_ATTR_REMOTE_ADDR, remoteAddress) + .set(Grpc.TRANSPORT_ATTR_SSL_SESSION, sslSession).build(); + when(serverCall.getAttributes()).thenReturn(attributes); + when(serverCall.getMethodDescriptor()).thenReturn(methodDescriptor); + + // Test with null certificates + when(sslSession.getPeerCertificates()).thenReturn(null); + when(sslSession.getLocalCertificates()).thenReturn(null); + CheckRequest request = + checkRequestBuilder.buildRequest(serverCall, new Metadata(), requestTime); + AttributeContext.Peer source = request.getAttributes().getSource(); + assertThat(source.getPrincipal()).isEmpty(); + assertThat(source.getCertificate()).isEmpty(); + AttributeContext.Peer destination = request.getAttributes().getDestination(); + assertThat(destination.getPrincipal()).isEmpty(); + + // Test with empty certificates + when(sslSession.getPeerCertificates()).thenReturn(new Certificate[0]); + when(sslSession.getLocalCertificates()).thenReturn(new Certificate[0]); + request = checkRequestBuilder.buildRequest(serverCall, new Metadata(), requestTime); + source = request.getAttributes().getSource(); + assertThat(source.getPrincipal()).isEmpty(); + assertThat(source.getCertificate()).isEmpty(); + destination = request.getAttributes().getDestination(); + assertThat(destination.getPrincipal()).isEmpty(); + } + + @Test + public void buildRequest_forServer_nonX509Certificate() throws Exception { + SocketAddress localAddress = new InetSocketAddress("10.0.0.2", 443); + SocketAddress remoteAddress = new InetSocketAddress("192.168.1.1", 12345); + Attributes attributes = + Attributes.newBuilder().set(Grpc.TRANSPORT_ATTR_LOCAL_ADDR, localAddress) + .set(Grpc.TRANSPORT_ATTR_REMOTE_ADDR, remoteAddress) + .set(Grpc.TRANSPORT_ATTR_SSL_SESSION, sslSession).build(); + when(serverCall.getAttributes()).thenReturn(attributes); + when(serverCall.getMethodDescriptor()).thenReturn(methodDescriptor); + Certificate nonX509Cert = mock(Certificate.class); + Certificate[] certs = new Certificate[] {nonX509Cert}; + + when(sslSession.getPeerCertificates()).thenReturn(certs); + when(sslSession.getLocalCertificates()).thenReturn(certs); + + CheckRequest request = + checkRequestBuilder.buildRequest(serverCall, new Metadata(), requestTime); + + AttributeContext.Peer source = request.getAttributes().getSource(); + assertThat(source.getPrincipal()).isEmpty(); + AttributeContext.Peer destination = request.getAttributes().getDestination(); + assertThat(destination.getPrincipal()).isEmpty(); + } + + @Test + public void buildRequest_forServer_nonInetSocketAddress() { + SocketAddress remoteAddress = mock(SocketAddress.class); + when(serverCall.getAttributes()).thenReturn(Attributes.newBuilder() + .set(Grpc.TRANSPORT_ATTR_REMOTE_ADDR, remoteAddress).build()); + when(serverCall.getMethodDescriptor()).thenReturn(methodDescriptor); + CheckRequest request = + checkRequestBuilder.buildRequest(serverCall, new Metadata(), requestTime); + assertThat(request.getAttributes().getSource().hasAddress()).isFalse(); + } + + private ExtAuthzConfig buildExtAuthzConfig() throws ExtAuthzParseException { + return buildExtAuthzConfig(ListStringMatcher.newBuilder().build(), + ListStringMatcher.newBuilder().build(), true); + } + + private ExtAuthzConfig buildExtAuthzConfig(ListStringMatcher allowed, + ListStringMatcher disallowed, boolean includePeerCertificate) + throws ExtAuthzParseException { + Any googleDefaultChannelCreds = Any.pack(GoogleDefaultCredentials.newBuilder().build()); + Any fakeAccessTokenCreds = + Any.pack(AccessTokenCredentials.newBuilder().setToken("fake-token").build()); + ExtAuthz.Builder builder = ExtAuthz.newBuilder() + .setGrpcService(io.envoyproxy.envoy.config.core.v3.GrpcService.newBuilder() + .setGoogleGrpc(io.envoyproxy.envoy.config.core.v3.GrpcService.GoogleGrpc + .newBuilder().setTargetUri("test-cluster") + .addChannelCredentialsPlugin(googleDefaultChannelCreds) + .addCallCredentialsPlugin(fakeAccessTokenCreds).build()) + .build()) + .setIncludePeerCertificate(includePeerCertificate).setAllowedHeaders(allowed) + .setDisallowedHeaders(disallowed); + return ExtAuthzConfig.fromProto(builder.build()); + } +} diff --git a/xds/src/test/java/io/grpc/xds/internal/extauthz/ExtAuthzTestHelper.java b/xds/src/test/java/io/grpc/xds/internal/extauthz/ExtAuthzTestHelper.java new file mode 100644 index 00000000000..38423af7ded --- /dev/null +++ b/xds/src/test/java/io/grpc/xds/internal/extauthz/ExtAuthzTestHelper.java @@ -0,0 +1,177 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal.extauthz; + +import io.grpc.Attributes; +import io.grpc.ClientCall; +import io.grpc.Metadata; +import io.grpc.MethodDescriptor; +import io.grpc.NoopClientCall; +import io.grpc.NoopServerCall; +import io.grpc.Status; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import javax.annotation.Nullable; + +/** + * Shared test fakes for the extauthz test package. These follow the established + * grpc-java pattern of using {@link NoopClientCall}/{@link NoopServerCall} subclasses + * instead of Mockito mocks for core gRPC interfaces. + * + * @see io.grpc.NoopClientCall + * @see io.grpc.NoopServerCall + */ +final class ExtAuthzTestHelper { + + private ExtAuthzTestHelper() {} + + /** + * A capturing fake that records {@link #start} arguments and {@link #sendMessage} calls. + */ + static final class CapturingClientCall + extends NoopClientCall { + @Nullable private ClientCall.Listener listener; + @Nullable private Metadata headers; + private boolean started; + private final List sentMessages = new ArrayList<>(); + + @Override + public void start(ClientCall.Listener listener, Metadata headers) { + this.listener = listener; + this.headers = headers; + this.started = true; + } + + @Override + public void sendMessage(ReqT message) { + sentMessages.add(message); + } + + /** Returns the listener passed to {@link #start}, or null if not yet started. */ + @Nullable + ClientCall.Listener getListener() { + return listener; + } + + /** Returns the headers passed to {@link #start}, or null if not yet started. */ + @Nullable + Metadata getHeaders() { + return headers; + } + + /** Returns true if {@link #start} has been called. */ + boolean isStarted() { + return started; + } + + /** Returns an unmodifiable view of all messages sent via {@link #sendMessage}. */ + List getSentMessages() { + return Collections.unmodifiableList(sentMessages); + } + } + + /** + * A capturing fake that records + * {@link ClientCall.Listener#onHeaders onHeaders}, + * {@link ClientCall.Listener#onMessage onMessage}, + * {@link ClientCall.Listener#onReady onReady}, and + * {@link ClientCall.Listener#onClose onClose} callbacks. + */ + static final class CapturingListener + extends NoopClientCall.NoopClientCallListener { + @Nullable private Metadata headers; + private final List messages = new ArrayList<>(); + private boolean onReadyCalled; + @Nullable private Status closeStatus; + @Nullable private Metadata closeTrailers; + + @Override + public void onHeaders(Metadata headers) { + this.headers = headers; + } + + @Override + public void onMessage(T message) { + this.messages.add(message); + } + + @Override + public void onReady() { + this.onReadyCalled = true; + } + + @Override + public void onClose(Status status, Metadata trailers) { + this.closeStatus = status; + this.closeTrailers = trailers; + } + + /** Returns the headers received via {@link #onHeaders}, or null if not yet called. */ + @Nullable + Metadata getHeaders() { + return headers; + } + + /** Returns an unmodifiable view of all messages received via {@link #onMessage}. */ + List getMessages() { + return Collections.unmodifiableList(messages); + } + + /** Returns true if {@link #onReady} has been called. */ + boolean isOnReadyCalled() { + return onReadyCalled; + } + + /** Returns the status received via {@link #onClose}, or null if not yet closed. */ + @Nullable + Status getCloseStatus() { + return closeStatus; + } + + /** Returns the trailers received via {@link #onClose}, or null if not yet closed. */ + @Nullable + Metadata getCloseTrailers() { + return closeTrailers; + } + } + + /** + * A fake {@link io.grpc.ServerCall} that provides {@link Attributes} and + * {@link MethodDescriptor} without requiring Mockito. + */ + static final class TestServerCall + extends NoopServerCall { + private final Attributes attributes; + private final MethodDescriptor methodDescriptor; + + TestServerCall(Attributes attributes, MethodDescriptor methodDescriptor) { + this.attributes = attributes; + this.methodDescriptor = methodDescriptor; + } + + @Override + public Attributes getAttributes() { + return attributes; + } + + @Override + public MethodDescriptor getMethodDescriptor() { + return methodDescriptor; + } + } +} From 0b202452a70aa299b43c1518651c3b3bd930e671 Mon Sep 17 00:00:00 2001 From: Saurav Date: Tue, 13 Jan 2026 04:11:34 +0000 Subject: [PATCH 2/3] Fixup: Fixed for 12493 --- .../internal/extauthz/CertificateUtils.java | 16 +- .../extauthz/CheckRequestBuilder.java | 388 ++++------- .../extauthz/CertificateUtilsTest.java | 42 +- .../extauthz/CheckRequestBuilderTest.java | 634 +++++++++--------- 4 files changed, 495 insertions(+), 585 deletions(-) diff --git a/xds/src/main/java/io/grpc/xds/internal/extauthz/CertificateUtils.java b/xds/src/main/java/io/grpc/xds/internal/extauthz/CertificateUtils.java index 1fa6fb4bcf0..c0253839253 100644 --- a/xds/src/main/java/io/grpc/xds/internal/extauthz/CertificateUtils.java +++ b/xds/src/main/java/io/grpc/xds/internal/extauthz/CertificateUtils.java @@ -36,12 +36,12 @@ public final class CertificateUtils { // dNSName (2) // iPAddress (7) private static final int SAN_TYPE_DNS_NAME = 2; - private static final int SAN_TYPE_IP_ADDRESS = 7; + private static final int SAN_TYPE_URI = 6; private CertificateUtils() {} /** - * Gets the principal from a certificate. It returns the cert's first IP Address SAN if set, + * Gets the principal from a certificate. It returns the cert's first URI SAN if set, * otherwise the cert's first DNS SAN if set, otherwise the subject field of the certificate in * RFC 2253 format. * @@ -52,14 +52,14 @@ public static String getPrincipal(X509Certificate cert) { try { Collection> sans = cert.getSubjectAlternativeNames(); if (sans != null) { - // Look for IP Address SAN. + // Look for URI SAN (Priority 1). for (List san : sans) { if (san.size() == 2 && san.get(0) instanceof Integer - && (Integer) san.get(0) == SAN_TYPE_IP_ADDRESS) { + && (Integer) san.get(0) == SAN_TYPE_URI) { return (String) san.get(1); } } - // If no IP Address SAN, look for DNS SAN. + // If no URI SAN, look for DNS SAN (Priority 2). for (List san : sans) { if (san.size() == 2 && san.get(0) instanceof Integer && (Integer) san.get(0) == SAN_TYPE_DNS_NAME) { @@ -68,12 +68,12 @@ public static String getPrincipal(X509Certificate cert) { } } } catch (java.security.cert.CertificateParsingException e) { - logger.log(Level.WARNING, "Error parsing certificate SANs. This is not expected, " - + "falling back to the subject according to the spec.", e); + logger.log(Level.FINE, "Error parsing certificate SANs.", e); } - return cert.getSubjectX500Principal().getName(); + return cert.getSubjectX500Principal().getName("RFC2253"); } + /** * Gets the URL PEM encoded certificate. It Pem encodes first and then urlencodes. * diff --git a/xds/src/main/java/io/grpc/xds/internal/extauthz/CheckRequestBuilder.java b/xds/src/main/java/io/grpc/xds/internal/extauthz/CheckRequestBuilder.java index 03f3f4c1ec0..c16cabdd783 100644 --- a/xds/src/main/java/io/grpc/xds/internal/extauthz/CheckRequestBuilder.java +++ b/xds/src/main/java/io/grpc/xds/internal/extauthz/CheckRequestBuilder.java @@ -16,7 +16,7 @@ package io.grpc.xds.internal.extauthz; -import com.google.auto.value.AutoValue; + import com.google.common.io.BaseEncoding; import com.google.protobuf.Timestamp; import io.envoyproxy.envoy.config.core.v3.Address; @@ -36,21 +36,21 @@ import java.util.ArrayList; import java.util.List; import java.util.Locale; -import java.util.Optional; import java.util.logging.Level; import java.util.logging.Logger; +import javax.annotation.Nullable; import javax.net.ssl.SSLPeerUnverifiedException; import javax.net.ssl.SSLSession; /** * Interface for building external authorization check requests. */ -public interface CheckRequestBuilder { +public class CheckRequestBuilder { /** * An interface for providing certificate-related information. */ - interface CertificateProvider { + public interface CertificateProvider { /** * Gets the principal from a certificate. * @@ -71,283 +71,171 @@ String getUrlPemEncodedCertificate(X509Certificate cert) throws CertificateEncodingException, UnsupportedEncodingException; } - /** - * A factory for creating {@link CheckRequestBuilder} instances. - */ - @FunctionalInterface - interface Factory { - /** - * Creates a new instance of the CheckRequestBuilder. - * - * @param config The external authorization configuration. - * @param config The external authorization configuration. - * @return A new CheckRequestBuilder instance. - */ - CheckRequestBuilder create(ExtAuthzConfig config); - } + private static final Logger logger = Logger.getLogger(CheckRequestBuilder.class.getName()); - /** The default factory for creating {@link CheckRequestBuilder} instances. */ - Factory INSTANCE = CheckRequestBuilderImpl::new; - /** - * Builds a CheckRequest for a server-side call. - * - * @param serverCall The server call. - * @param headers The request headers. - * @param requestTime The time of the request. - * @return A new CheckRequest. - */ - CheckRequest buildRequest(ServerCall serverCall, Metadata headers, Timestamp requestTime); + private final ExtAuthzConfig config; + private final CertificateProvider certificateProvider; - /** - * Builds a CheckRequest for a client-side call. - * - * @param methodDescriptor The method descriptor of the call. - * @param headers The request headers. - * @param requestTime The time of the request. - * @return A new CheckRequest. - */ - CheckRequest buildRequest(MethodDescriptor methodDescriptor, Metadata headers, - Timestamp requestTime); + public CheckRequestBuilder(ExtAuthzConfig config) { + this(config, new CertificateProvider() { + @Override + public String getPrincipal(X509Certificate cert) { + return CertificateUtils.getPrincipal(cert); + } - /** - * Implementation of the CheckRequestBuilder interface. - */ - final class CheckRequestBuilderImpl implements CheckRequestBuilder { - private static final Logger logger = Logger.getLogger(CheckRequestBuilderImpl.class.getName()); + @Override + public String getUrlPemEncodedCertificate(X509Certificate cert) + throws CertificateEncodingException, UnsupportedEncodingException { + return CertificateUtils.getUrlPemEncodedCertificate(cert); + } + }); + } - private static final String METHOD = "POST"; - private static final String PROTOCOL = "HTTP/2"; - private static final long SIZE = -1; + public CheckRequestBuilder(ExtAuthzConfig config, CertificateProvider certificateProvider) { + this.config = config; + this.certificateProvider = certificateProvider; + } - private final ExtAuthzConfig config; - private final CertificateProvider certificateProvider; - CheckRequestBuilderImpl(ExtAuthzConfig config) { - this(config, new CertificateProvider() { - @Override - public String getPrincipal(X509Certificate cert) { - return CertificateUtils.getPrincipal(cert); - } + public CheckRequest buildRequest(MethodDescriptor methodDescriptor, Metadata headers, + Timestamp requestTime) { + return build(methodDescriptor, headers, requestTime, null, null, null); + } - @Override - public String getUrlPemEncodedCertificate(X509Certificate cert) - throws CertificateEncodingException, UnsupportedEncodingException { - return CertificateUtils.getUrlPemEncodedCertificate(cert); - } - }); - } - CheckRequestBuilderImpl(ExtAuthzConfig config, CertificateProvider certificateProvider) { - this.config = config; - this.certificateProvider = certificateProvider; - } + public CheckRequest buildRequest(ServerCall serverCall, Metadata headers, + Timestamp requestTime) { + java.net.SocketAddress localAddress = + serverCall.getAttributes().get(Grpc.TRANSPORT_ATTR_LOCAL_ADDR); + java.net.SocketAddress remoteAddress = + serverCall.getAttributes().get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR); + SSLSession sslSession = serverCall.getAttributes().get(Grpc.TRANSPORT_ATTR_SSL_SESSION); + return build(serverCall.getMethodDescriptor(), headers, requestTime, localAddress, + remoteAddress, sslSession); + } - @Override - public CheckRequest buildRequest(MethodDescriptor methodDescriptor, Metadata headers, - Timestamp requestTime) { - return build(CheckRequestParams.builder().setMethodDescriptor(methodDescriptor) - .setHeaders(headers).setRequestTime(requestTime).build()); + private CheckRequest build(MethodDescriptor methodDescriptor, Metadata headers, + Timestamp requestTime, @Nullable java.net.SocketAddress localAddress, + @Nullable java.net.SocketAddress remoteAddress, @Nullable SSLSession sslSession) { + AttributeContext.Builder attrBuilder = AttributeContext.newBuilder(); + if (remoteAddress != null) { + attrBuilder.setSource(buildSource(remoteAddress, sslSession)); } - - @Override - public CheckRequest buildRequest(ServerCall serverCall, Metadata headers, - Timestamp requestTime) { - CheckRequestParams.Builder paramsBuilder = - CheckRequestParams.builder().setMethodDescriptor(serverCall.getMethodDescriptor()) - .setHeaders(headers).setRequestTime(requestTime); - java.net.SocketAddress localAddress = - serverCall.getAttributes().get(Grpc.TRANSPORT_ATTR_LOCAL_ADDR); - if (localAddress != null) { - paramsBuilder.setLocalAddress(localAddress); - } - java.net.SocketAddress remoteAddress = - serverCall.getAttributes().get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR); - if (remoteAddress != null) { - paramsBuilder.setRemoteAddress(remoteAddress); - } - SSLSession sslSession = serverCall.getAttributes().get(Grpc.TRANSPORT_ATTR_SSL_SESSION); - if (sslSession != null) { - paramsBuilder.setSslSession(sslSession); - } - return build(paramsBuilder.build()); - } - - private CheckRequest build(CheckRequestParams params) { - AttributeContext.Builder attrBuilder = AttributeContext.newBuilder(); - if (params.remoteAddress().isPresent()) { - attrBuilder.setSource(buildSource(params.remoteAddress().get(), params.sslSession())); - } - if (params.localAddress().isPresent()) { - attrBuilder - .setDestination(buildDestination(params.localAddress().get(), params.sslSession())); - } - attrBuilder.setRequest(buildAttributeRequest(params.headers(), - params.methodDescriptor().getFullMethodName(), params.requestTime())); - return CheckRequest.newBuilder().setAttributes(attrBuilder).build(); + if (localAddress != null) { + attrBuilder.setDestination(buildDestination(localAddress, sslSession)); } + attrBuilder.setRequest( + buildAttributeRequest(headers, methodDescriptor.getFullMethodName(), requestTime)); + return CheckRequest.newBuilder().setAttributes(attrBuilder).build(); + } - private AttributeContext.Peer buildSource(java.net.SocketAddress socketAddress, - Optional sslSession) { - AttributeContext.Peer.Builder peerBuilder = buildPeer(socketAddress).toBuilder(); - if (sslSession.isPresent()) { - try { - Certificate[] certs = sslSession.get().getPeerCertificates(); - if (certs != null && certs.length > 0 && certs[0] instanceof X509Certificate) { - X509Certificate cert = (X509Certificate) certs[0]; - peerBuilder.setPrincipal(certificateProvider.getPrincipal(cert)); - if (config.includePeerCertificate()) { - try { - peerBuilder.setCertificate(certificateProvider.getUrlPemEncodedCertificate(cert)); - } catch (UnsupportedEncodingException | CertificateEncodingException e) { - logger.log(Level.WARNING, - "Error encoding peer certificate. " - + "This is not expected, but if it happens, the certificate should not " - + "be set according to the spec.", - e); - } - } + private AttributeContext.Peer buildSource(java.net.SocketAddress socketAddress, + @Nullable SSLSession sslSession) { + AttributeContext.Peer.Builder peerBuilder = buildPeer(socketAddress).toBuilder(); + if (sslSession != null) { + Certificate[] certs = null; + try { + certs = sslSession.getPeerCertificates(); + } catch (SSLPeerUnverifiedException e) { + logger.log(Level.FINE, "Peer is not authenticated; omitting principal and certificate.", e); + } + if (certs != null && certs.length > 0 && certs[0] instanceof X509Certificate) { + X509Certificate cert = (X509Certificate) certs[0]; + peerBuilder.setPrincipal(certificateProvider.getPrincipal(cert)); + if (config.includePeerCertificate()) { + try { + peerBuilder.setCertificate(certificateProvider.getUrlPemEncodedCertificate(cert)); + } catch (UnsupportedEncodingException | CertificateEncodingException e) { + logger.log(Level.FINE, "Error encoding peer certificate; omitting from request.", e); } - } catch (SSLPeerUnverifiedException e) { - logger.log(Level.FINE, - "Peer is not authenticated. " - + "This is expected, principal and certificate should not be set " - + "according to the spec.", - e); } } - return peerBuilder.build(); } + return peerBuilder.build(); + } - private AttributeContext.Peer buildDestination(java.net.SocketAddress socketAddress, - Optional sslSession) { - AttributeContext.Peer.Builder peerBuilder = buildPeer(socketAddress).toBuilder(); - if (sslSession.isPresent()) { - Certificate[] certs = sslSession.get().getLocalCertificates(); - if (certs != null && certs.length > 0 && certs[0] instanceof X509Certificate) { - peerBuilder.setPrincipal(certificateProvider.getPrincipal((X509Certificate) certs[0])); - } + private AttributeContext.Peer buildDestination(java.net.SocketAddress socketAddress, + @Nullable SSLSession sslSession) { + AttributeContext.Peer.Builder peerBuilder = buildPeer(socketAddress).toBuilder(); + if (sslSession != null) { + Certificate[] certs = sslSession.getLocalCertificates(); + if (certs != null && certs.length > 0 && certs[0] instanceof X509Certificate) { + peerBuilder.setPrincipal(certificateProvider.getPrincipal((X509Certificate) certs[0])); } - return peerBuilder.build(); } + return peerBuilder.build(); + } - private AttributeContext.Peer buildPeer(java.net.SocketAddress socketAddress) { - AttributeContext.Peer.Builder peerBuilder = AttributeContext.Peer.newBuilder(); - if (socketAddress instanceof InetSocketAddress) { - InetSocketAddress inetSocketAddress = (InetSocketAddress) socketAddress; - peerBuilder.setAddress(Address.newBuilder() - .setSocketAddress(SocketAddress.newBuilder() - .setAddress(inetSocketAddress.getAddress().getHostAddress()) - .setPortValue(inetSocketAddress.getPort())) - .build()); - } - return peerBuilder.build(); + private AttributeContext.Peer buildPeer(java.net.SocketAddress socketAddress) { + AttributeContext.Peer.Builder peerBuilder = AttributeContext.Peer.newBuilder(); + if (socketAddress instanceof InetSocketAddress) { + InetSocketAddress inetSocketAddress = (InetSocketAddress) socketAddress; + peerBuilder + .setAddress(Address.newBuilder() + .setSocketAddress(SocketAddress.newBuilder() + .setAddress(inetSocketAddress.getAddress().getHostAddress()) + .setPortValue(inetSocketAddress.getPort())) + .build()); } + return peerBuilder.build(); + } - private AttributeContext.Request buildAttributeRequest(Metadata headers, String fullMethodName, - Timestamp requestTime) { - AttributeContext.Request.Builder reqBuilder = AttributeContext.Request.newBuilder(); - reqBuilder.setTime(requestTime); - AttributeContext.HttpRequest.Builder httpReqBuilder = - AttributeContext.HttpRequest.newBuilder(); - httpReqBuilder.setPath(fullMethodName); - httpReqBuilder.setMethod(METHOD); - httpReqBuilder.setProtocol(PROTOCOL); - httpReqBuilder.setSize(SIZE); - for (String key : headers.keys()) { - if (!isAllowed(key)) { - continue; - } - Optional value; - if (key.endsWith(Metadata.BINARY_HEADER_SUFFIX)) { - value = getBinaryHeaderValue(headers, key); - } else { - value = getAsciiHeaderValue(headers, key); - } - value.ifPresent( - headerValue -> httpReqBuilder.putHeaders(key.toLowerCase(Locale.ROOT), headerValue)); - } - reqBuilder.setHttp(httpReqBuilder); - return reqBuilder.build(); + private AttributeContext.Request buildAttributeRequest(Metadata headers, String fullMethodName, + Timestamp requestTime) { + AttributeContext.Request.Builder reqBuilder = AttributeContext.Request.newBuilder(); + reqBuilder.setTime(requestTime); + AttributeContext.HttpRequest.Builder httpReqBuilder = AttributeContext.HttpRequest.newBuilder(); + httpReqBuilder.setPath("/" + fullMethodName); + httpReqBuilder.setMethod("POST"); + httpReqBuilder.setProtocol("HTTP/2"); + httpReqBuilder.setSize(-1); + for (String key : headers.keys()) { + if (!isAllowed(key)) { + continue; + } + String value; + if (key.endsWith(Metadata.BINARY_HEADER_SUFFIX)) { + value = getBinaryHeaderValue(headers, key); + } else { + value = getAsciiHeaderValue(headers, key); + } + httpReqBuilder.putHeaders(key.toLowerCase(Locale.ROOT), value); } + reqBuilder.setHttp(httpReqBuilder); + return reqBuilder.build(); + } - private Optional getBinaryHeaderValue(Metadata headers, String key) { - Iterable binaryValues = - headers.getAll(Metadata.Key.of(key, Metadata.BINARY_BYTE_MARSHALLER)); - if (binaryValues == null) { - // Unreachable code, since we iterate over the keys. Exists for defensive programming. - return Optional.empty(); - } - List base64Values = new ArrayList<>(); - for (byte[] value : binaryValues) { - base64Values.add(BaseEncoding.base64().encode(value)); - } - return Optional.of(String.join(",", base64Values)); + private String getBinaryHeaderValue(Metadata headers, String key) { + Iterable binaryValues = + headers.getAll(Metadata.Key.of(key, Metadata.BINARY_BYTE_MARSHALLER)); + List base64Values = new ArrayList<>(); + for (byte[] value : binaryValues) { + base64Values.add(BaseEncoding.base64().encode(value)); } + return String.join(",", base64Values); + } - private Optional getAsciiHeaderValue(Metadata headers, String key) { - Iterable stringValues = - headers.getAll(Metadata.Key.of(key, Metadata.ASCII_STRING_MARSHALLER)); - if (stringValues == null) { - // Unreachable code, since we iterate over the keys. Exists for defensive programming. - return Optional.empty(); - } - return Optional.of(String.join(",", stringValues)); - } + private String getAsciiHeaderValue(Metadata headers, String key) { + Iterable stringValues = + headers.getAll(Metadata.Key.of(key, Metadata.ASCII_STRING_MARSHALLER)); + return String.join(",", stringValues); + } - private boolean isAllowed(String header) { - for (Matchers.StringMatcher matcher : config.disallowedHeaders()) { - if (matcher.matches(header)) { - return false; - } - } - if (config.allowedHeaders().isEmpty()) { - return true; - } - for (Matchers.StringMatcher matcher : config.allowedHeaders()) { - if (matcher.matches(header)) { - return true; - } + private boolean isAllowed(String header) { + for (Matchers.StringMatcher matcher : config.disallowedHeaders()) { + if (matcher.matches(header)) { + return false; } - return false; } - - @AutoValue - abstract static class CheckRequestParams { - abstract Metadata headers(); - - abstract MethodDescriptor methodDescriptor(); - - abstract Timestamp requestTime(); - - abstract Optional localAddress(); - - abstract Optional remoteAddress(); - - abstract Optional sslSession(); - - static Builder builder() { - Builder builder = - new AutoValue_CheckRequestBuilder_CheckRequestBuilderImpl_CheckRequestParams.Builder(); - return builder; - } - - @AutoValue.Builder - abstract static class Builder { - abstract Builder setHeaders(Metadata headers); - - abstract Builder setMethodDescriptor(MethodDescriptor method); - - abstract Builder setRequestTime(Timestamp time); - - abstract Builder setLocalAddress(java.net.SocketAddress localAddress); - - abstract Builder setRemoteAddress(java.net.SocketAddress remoteAddress); - - abstract Builder setSslSession(SSLSession sslSession); - - abstract CheckRequestParams build(); + if (config.allowedHeaders().isEmpty()) { + return true; + } + for (Matchers.StringMatcher matcher : config.allowedHeaders()) { + if (matcher.matches(header)) { + return true; } } + return false; } } diff --git a/xds/src/test/java/io/grpc/xds/internal/extauthz/CertificateUtilsTest.java b/xds/src/test/java/io/grpc/xds/internal/extauthz/CertificateUtilsTest.java index 83b578ddf72..7641bde19dc 100644 --- a/xds/src/test/java/io/grpc/xds/internal/extauthz/CertificateUtilsTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/extauthz/CertificateUtilsTest.java @@ -38,19 +38,24 @@ @RunWith(JUnit4.class) public class CertificateUtilsTest { + private static final int SAN_TYPE_RFC822_NAME = 1; + private static final int SAN_TYPE_DNS_NAME = 2; + private static final int SAN_TYPE_URI = 6; + @Test - public void getPrincipal_ipAddressSan() throws Exception { + public void getPrincipal_uriSan() throws Exception { X509Certificate mockCert = mock(X509Certificate.class); - List ipSan = Arrays.asList(7, "192.168.1.1"); // SAN_TYPE_IP_ADDRESS - Collection> sans = Arrays.asList(ipSan); + List uriSan = Arrays.asList(SAN_TYPE_URI, "spiffe://foo/bar"); + Collection> sans = Arrays.asList(uriSan); when(mockCert.getSubjectAlternativeNames()).thenReturn(sans); - assertThat(CertificateUtils.getPrincipal(mockCert)).isEqualTo("192.168.1.1"); + assertThat(CertificateUtils.getPrincipal(mockCert)).isEqualTo("spiffe://foo/bar"); } + @Test public void getPrincipal_dnsSan() throws Exception { X509Certificate mockCert = mock(X509Certificate.class); - List san = Arrays.asList(2, "foo.test.google.fr"); // SAN_TYPE_DNS_NAME + List san = Arrays.asList(SAN_TYPE_DNS_NAME, "foo.test.google.fr"); Collection> sans = Collections.singletonList(san); when(mockCert.getSubjectAlternativeNames()).thenReturn(sans); assertThat(CertificateUtils.getPrincipal(mockCert)).isEqualTo("foo.test.google.fr"); @@ -75,37 +80,38 @@ public void getPrincipal_nullSans_usesSubject() throws Exception { } @Test - public void getPrincipal_ipSanWrongSize_usesDnsSan() throws Exception { + public void getPrincipal_uriSanWrongSize_usesDnsSan() throws Exception { X509Certificate mockCert = mock(X509Certificate.class); - List ipSan = Collections.singletonList(7); // SAN_TYPE_IP_ADDRESS, wrong size - List dnsSan = Arrays.asList(2, "foo.test.google.fr"); // SAN_TYPE_DNS_NAME - Collection> sans = Arrays.asList(ipSan, dnsSan); + List uriSan = Collections.singletonList(6); // SAN_TYPE_URI, wrong size + List dnsSan = Arrays.asList(SAN_TYPE_DNS_NAME, "foo.test.google.fr"); + Collection> sans = Arrays.asList(uriSan, dnsSan); when(mockCert.getSubjectAlternativeNames()).thenReturn(sans); assertThat(CertificateUtils.getPrincipal(mockCert)).isEqualTo("foo.test.google.fr"); } @Test - public void getPrincipal_ipSanWrongType_usesDnsSan() throws Exception { + public void getPrincipal_uriSanTakesPrecedenceOverDnsSan() throws Exception { X509Certificate mockCert = mock(X509Certificate.class); - // SAN_TYPE_IP_ADDRESS, wrong type - List ipSan = Arrays.asList("not-an-integer", "192.168.1.1"); - List dnsSan = Arrays.asList(2, "foo.test.google.fr"); // SAN_TYPE_DNS_NAME - Collection> sans = Arrays.asList(ipSan, dnsSan); + List uriSan = Arrays.asList(SAN_TYPE_URI, "spiffe://foo/bar"); + List dnsSan = Arrays.asList(SAN_TYPE_DNS_NAME, "foo.test.google.fr"); + Collection> sans = Arrays.asList(dnsSan, uriSan); // Order shouldn't matter when(mockCert.getSubjectAlternativeNames()).thenReturn(sans); - assertThat(CertificateUtils.getPrincipal(mockCert)).isEqualTo("foo.test.google.fr"); + assertThat(CertificateUtils.getPrincipal(mockCert)).isEqualTo("spiffe://foo/bar"); } + @Test - public void getPrincipal_dnsSanWrongType_usesSubject() throws Exception { + public void getPrincipal_sanWrongType_usesSubject() throws Exception { X509Certificate mockCert = mock(X509Certificate.class); - // Wrong SAN type for DNS check - List otherSan = Arrays.asList(6, "foo.test.google.fr"); // SAN_TYPE_URI + // Use type 1 (rfc822Name) which is ignored + List otherSan = Arrays.asList(SAN_TYPE_RFC822_NAME, "foo@test.com"); Collection> sans = Collections.singletonList(otherSan); when(mockCert.getSubjectAlternativeNames()).thenReturn(sans); when(mockCert.getSubjectX500Principal()).thenReturn(new X500Principal("CN=test")); assertThat(CertificateUtils.getPrincipal(mockCert)).isEqualTo("CN=test"); } + @Test public void getPrincipal_sanParsingException_usesSubject() throws Exception { X509Certificate mockCert = mock(X509Certificate.class); diff --git a/xds/src/test/java/io/grpc/xds/internal/extauthz/CheckRequestBuilderTest.java b/xds/src/test/java/io/grpc/xds/internal/extauthz/CheckRequestBuilderTest.java index 82ff8801e0e..8747e051d6d 100644 --- a/xds/src/test/java/io/grpc/xds/internal/extauthz/CheckRequestBuilderTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/extauthz/CheckRequestBuilderTest.java @@ -23,6 +23,7 @@ import com.google.protobuf.Any; import com.google.protobuf.Timestamp; import io.envoyproxy.envoy.config.core.v3.Address; +import io.envoyproxy.envoy.config.core.v3.GrpcService; import io.envoyproxy.envoy.extensions.filters.http.ext_authz.v3.ExtAuthz; import io.envoyproxy.envoy.extensions.grpc_service.call_credentials.access_token.v3.AccessTokenCredentials; import io.envoyproxy.envoy.extensions.grpc_service.channel_credentials.google_default.v3.GoogleDefaultCredentials; @@ -35,10 +36,14 @@ import io.grpc.Metadata; import io.grpc.MethodDescriptor; import io.grpc.testing.TestMethodDescriptors; +import io.grpc.xds.ExtAuthzConfigParser; +import io.grpc.xds.client.Bootstrapper.BootstrapInfo; +import io.grpc.xds.client.CommonBootstrapperTestUtils; import java.net.InetSocketAddress; import java.net.SocketAddress; import java.security.cert.Certificate; import java.security.cert.X509Certificate; +import java.util.Collections; import javax.net.ssl.SSLPeerUnverifiedException; import javax.net.ssl.SSLSession; import org.junit.Before; @@ -52,313 +57,324 @@ @RunWith(JUnit4.class) public class CheckRequestBuilderTest { - @Rule - public final MockitoRule mockito = MockitoJUnit.rule(); - - @Mock - private ServerCall serverCall; - @Mock - private SSLSession sslSession; - @Mock - private CheckRequestBuilder.CertificateProvider certificateProvider; - - private CheckRequestBuilder checkRequestBuilder; - private MethodDescriptor methodDescriptor; - private Timestamp requestTime; - - @Before - public void setUp() throws ExtAuthzParseException { - ExtAuthzConfig config = buildExtAuthzConfig(); - checkRequestBuilder = - new CheckRequestBuilder.CheckRequestBuilderImpl(config, certificateProvider); - methodDescriptor = TestMethodDescriptors.voidMethod(); - requestTime = Timestamp.newBuilder().setSeconds(12345).setNanos(67890).build(); - } - - @Test - public void buildRequest_forServer_happyPath() throws Exception { - // Setup for addresses - SocketAddress localAddress = new InetSocketAddress("10.0.0.2", 443); - SocketAddress remoteAddress = new InetSocketAddress("192.168.1.1", 12345); - - // Setup for SSL and certificates - X509Certificate peerCert = mock(X509Certificate.class); - X509Certificate localCert = mock(X509Certificate.class); - Certificate[] peerCerts = new Certificate[] {peerCert}; - Certificate[] localCerts = new Certificate[] {localCert}; - when(sslSession.getPeerCertificates()).thenReturn(peerCerts); - when(sslSession.getLocalCertificates()).thenReturn(localCerts); - when(certificateProvider.getPrincipal(peerCert)).thenReturn("peer-principal"); - when(certificateProvider.getPrincipal(localCert)).thenReturn("local-principal"); - when(certificateProvider.getUrlPemEncodedCertificate(peerCert)) - .thenReturn("encoded-peer-cert"); - - // Setup for headers - Metadata headers = new Metadata(); - headers.put(Metadata.Key.of("allowed-header", Metadata.ASCII_STRING_MARSHALLER), "v1"); - headers.put(Metadata.Key.of("disallowed-header", Metadata.ASCII_STRING_MARSHALLER), "v2"); - headers.put(Metadata.Key.of("overridden-header", Metadata.ASCII_STRING_MARSHALLER), "v3"); - byte[] binaryValue = new byte[] {1, 2, 3}; - headers.put(Metadata.Key.of("bin-header-bin", Metadata.BINARY_BYTE_MARSHALLER), - binaryValue); - - // Configure CheckRequestBuilder to allow specific headers - ListStringMatcher allowedHeaders = ListStringMatcher.newBuilder() - .addPatterns(StringMatcher.newBuilder().setExact("allowed-header").build()) - .addPatterns(StringMatcher.newBuilder().setExact("overridden-header").build()) - .build(); - ListStringMatcher disallowedHeaders = ListStringMatcher.newBuilder() - .addPatterns(StringMatcher.newBuilder().setExact("disallowed-header").build()) - .addPatterns(StringMatcher.newBuilder().setExact("overridden-header").build()) - .build(); - ExtAuthzConfig config = buildExtAuthzConfig(allowedHeaders, disallowedHeaders, true); - checkRequestBuilder = - new CheckRequestBuilder.CheckRequestBuilderImpl(config, certificateProvider); - - // Setup server call attributes - Attributes attributes = - Attributes.newBuilder().set(Grpc.TRANSPORT_ATTR_LOCAL_ADDR, localAddress) - .set(Grpc.TRANSPORT_ATTR_REMOTE_ADDR, remoteAddress) - .set(Grpc.TRANSPORT_ATTR_SSL_SESSION, sslSession).build(); - when(serverCall.getAttributes()).thenReturn(attributes); - when(serverCall.getMethodDescriptor()).thenReturn(methodDescriptor); - - // Build and verify the request - CheckRequest request = checkRequestBuilder.buildRequest(serverCall, headers, requestTime); - - AttributeContext attrContext = request.getAttributes(); - assertThat(attrContext.getSource().getAddress().getSocketAddress().getAddress()) - .isEqualTo("192.168.1.1"); - assertThat(attrContext.getSource().getPrincipal()).isEqualTo("peer-principal"); - assertThat(attrContext.getSource().getCertificate()).isEqualTo("encoded-peer-cert"); - assertThat(attrContext.getDestination().getAddress().getSocketAddress().getAddress()) - .isEqualTo("10.0.0.2"); - assertThat(attrContext.getDestination().getPrincipal()).isEqualTo("local-principal"); - - AttributeContext.HttpRequest http = attrContext.getRequest().getHttp(); - assertThat(http.getHeadersMap()).containsEntry("allowed-header", "v1"); - assertThat(http.getHeadersMap()).doesNotContainKey("bin-header-bin"); - assertThat(http.getHeadersMap()).doesNotContainKey("disallowed-header"); - assertThat(http.getHeadersMap()).doesNotContainKey("overridden-header"); - } - - @Test - public void buildRequest_forServer_noTransportAttrs() { - when(serverCall.getAttributes()).thenReturn(Attributes.EMPTY); - when(serverCall.getMethodDescriptor()).thenReturn(methodDescriptor); - Metadata headers = new Metadata(); - - CheckRequest request = checkRequestBuilder.buildRequest(serverCall, headers, requestTime); - - assertThat(request.getAttributes().getRequest().getTime()).isEqualTo(requestTime); - assertThat(request.getAttributes().getRequest().getHttp().getPath()) - .isEqualTo(methodDescriptor.getFullMethodName()); - assertThat(request.getAttributes().getRequest().getHttp().getMethod()).isEqualTo("POST"); - assertThat(request.getAttributes().getRequest().getHttp().getProtocol()) - .isEqualTo("HTTP/2"); - assertThat(request.getAttributes().getRequest().getHttp().getSize()).isEqualTo(-1); - assertThat(request.getAttributes().getRequest().getHttp().getHeadersMap()).isEmpty(); - assertThat(request.getAttributes().hasSource()).isFalse(); - assertThat(request.getAttributes().hasDestination()).isFalse(); - } - - - @Test - public void buildRequest_forClient_happyPath_emptyAllowedHeaders() throws Exception { - // Setup for headers - Metadata headers = new Metadata(); - headers.put(Metadata.Key.of("some-header", Metadata.ASCII_STRING_MARSHALLER), "v1"); - headers.put(Metadata.Key.of("disallowed-header", Metadata.ASCII_STRING_MARSHALLER), "v2"); - byte[] binaryValue = new byte[] {1, 2, 3}; - headers.put(Metadata.Key.of("bin-header-bin", Metadata.BINARY_BYTE_MARSHALLER), - binaryValue); - - // Configure CheckRequestBuilder with empty allowed headers - ListStringMatcher allowedHeaders = ListStringMatcher.newBuilder().build(); // empty - ListStringMatcher disallowedHeaders = ListStringMatcher.newBuilder() - .addPatterns(StringMatcher.newBuilder().setExact("disallowed-header").build()) - .build(); - ExtAuthzConfig config = buildExtAuthzConfig(allowedHeaders, disallowedHeaders, true); - checkRequestBuilder = - new CheckRequestBuilder.CheckRequestBuilderImpl(config, certificateProvider); - - // Build and verify the request - CheckRequest request = - checkRequestBuilder.buildRequest(methodDescriptor, headers, requestTime); - - AttributeContext attrContext = request.getAttributes(); - assertThat(attrContext.hasSource()).isFalse(); - assertThat(attrContext.hasDestination()).isFalse(); - - AttributeContext.HttpRequest http = attrContext.getRequest().getHttp(); - assertThat(http.getPath()).isEqualTo(methodDescriptor.getFullMethodName()); - assertThat(http.getHeadersMap()).containsEntry("some-header", "v1"); - assertThat(http.getHeadersMap()).containsEntry("bin-header-bin", "AQID"); - assertThat(http.getHeadersMap()).doesNotContainKey("disallowed-header"); - } - - @Test - public void buildRequest_forServer_noSslSession() { - SocketAddress localAddress = new InetSocketAddress("10.0.0.2", 443); - SocketAddress remoteAddress = new InetSocketAddress("192.168.1.1", 12345); - Attributes attributes = - Attributes.newBuilder().set(Grpc.TRANSPORT_ATTR_LOCAL_ADDR, localAddress) - .set(Grpc.TRANSPORT_ATTR_REMOTE_ADDR, remoteAddress).build(); - when(serverCall.getAttributes()).thenReturn(attributes); - when(serverCall.getMethodDescriptor()).thenReturn(methodDescriptor); - - CheckRequest request = - checkRequestBuilder.buildRequest(serverCall, new Metadata(), requestTime); - - AttributeContext attrContext = request.getAttributes(); - assertThat(attrContext.hasSource()).isTrue(); - Address sourceAddress = attrContext.getSource().getAddress(); - assertThat(sourceAddress.getSocketAddress().getAddress()).isEqualTo("192.168.1.1"); - assertThat(sourceAddress.getSocketAddress().getPortValue()).isEqualTo(12345); - assertThat(attrContext.getSource().getPrincipal()).isEmpty(); - - assertThat(attrContext.hasDestination()).isTrue(); - Address destAddress = attrContext.getDestination().getAddress(); - assertThat(destAddress.getSocketAddress().getAddress()).isEqualTo("10.0.0.2"); - assertThat(destAddress.getSocketAddress().getPortValue()).isEqualTo(443); - assertThat(attrContext.getDestination().getPrincipal()).isEmpty(); - } - - @Test - public void buildRequest_forServer_sslPeerUnverified() throws Exception { - SocketAddress remoteAddress = new InetSocketAddress("192.168.1.1", 12345); - when(sslSession.getPeerCertificates()) - .thenThrow(new SSLPeerUnverifiedException("unverified")); - Attributes attributes = - Attributes.newBuilder().set(Grpc.TRANSPORT_ATTR_REMOTE_ADDR, remoteAddress) - .set(Grpc.TRANSPORT_ATTR_SSL_SESSION, sslSession).build(); - when(serverCall.getAttributes()).thenReturn(attributes); - when(serverCall.getMethodDescriptor()).thenReturn(methodDescriptor); - - CheckRequest request = - checkRequestBuilder.buildRequest(serverCall, new Metadata(), requestTime); - - AttributeContext.Peer source = request.getAttributes().getSource(); - assertThat(source.getPrincipal()).isEmpty(); - assertThat(source.getCertificate()).isEmpty(); - } - - @Test - public void buildRequest_forServer_includePeerCertFalse() throws Exception { - ExtAuthzConfig config = buildExtAuthzConfig(ListStringMatcher.newBuilder().build(), - ListStringMatcher.newBuilder().build(), false); - checkRequestBuilder = - new CheckRequestBuilder.CheckRequestBuilderImpl(config, certificateProvider); - SocketAddress remoteAddress = new InetSocketAddress("192.168.1.1", 12345); - X509Certificate peerCert = mock(X509Certificate.class); - Certificate[] peerCerts = new Certificate[] {peerCert}; - - when(sslSession.getPeerCertificates()).thenReturn(peerCerts); - when(certificateProvider.getPrincipal(peerCert)).thenReturn("peer-principal"); - - Attributes attributes = - Attributes.newBuilder().set(Grpc.TRANSPORT_ATTR_REMOTE_ADDR, remoteAddress) - .set(Grpc.TRANSPORT_ATTR_SSL_SESSION, sslSession).build(); - when(serverCall.getAttributes()).thenReturn(attributes); - when(serverCall.getMethodDescriptor()).thenReturn(methodDescriptor); - - CheckRequest request = - checkRequestBuilder.buildRequest(serverCall, new Metadata(), requestTime); - - AttributeContext.Peer source = request.getAttributes().getSource(); - assertThat(source.getPrincipal()).isEqualTo("peer-principal"); - assertThat(source.getCertificate()).isEmpty(); - } - - @Test - public void buildRequest_forServer_nullOrEmptyCertificates() throws Exception { - SocketAddress localAddress = new InetSocketAddress("10.0.0.2", 443); - SocketAddress remoteAddress = new InetSocketAddress("192.168.1.1", 12345); - Attributes attributes = - Attributes.newBuilder().set(Grpc.TRANSPORT_ATTR_LOCAL_ADDR, localAddress) - .set(Grpc.TRANSPORT_ATTR_REMOTE_ADDR, remoteAddress) - .set(Grpc.TRANSPORT_ATTR_SSL_SESSION, sslSession).build(); - when(serverCall.getAttributes()).thenReturn(attributes); - when(serverCall.getMethodDescriptor()).thenReturn(methodDescriptor); - - // Test with null certificates - when(sslSession.getPeerCertificates()).thenReturn(null); - when(sslSession.getLocalCertificates()).thenReturn(null); - CheckRequest request = - checkRequestBuilder.buildRequest(serverCall, new Metadata(), requestTime); - AttributeContext.Peer source = request.getAttributes().getSource(); - assertThat(source.getPrincipal()).isEmpty(); - assertThat(source.getCertificate()).isEmpty(); - AttributeContext.Peer destination = request.getAttributes().getDestination(); - assertThat(destination.getPrincipal()).isEmpty(); - - // Test with empty certificates - when(sslSession.getPeerCertificates()).thenReturn(new Certificate[0]); - when(sslSession.getLocalCertificates()).thenReturn(new Certificate[0]); - request = checkRequestBuilder.buildRequest(serverCall, new Metadata(), requestTime); - source = request.getAttributes().getSource(); - assertThat(source.getPrincipal()).isEmpty(); - assertThat(source.getCertificate()).isEmpty(); - destination = request.getAttributes().getDestination(); - assertThat(destination.getPrincipal()).isEmpty(); - } - - @Test - public void buildRequest_forServer_nonX509Certificate() throws Exception { - SocketAddress localAddress = new InetSocketAddress("10.0.0.2", 443); - SocketAddress remoteAddress = new InetSocketAddress("192.168.1.1", 12345); - Attributes attributes = - Attributes.newBuilder().set(Grpc.TRANSPORT_ATTR_LOCAL_ADDR, localAddress) - .set(Grpc.TRANSPORT_ATTR_REMOTE_ADDR, remoteAddress) - .set(Grpc.TRANSPORT_ATTR_SSL_SESSION, sslSession).build(); - when(serverCall.getAttributes()).thenReturn(attributes); - when(serverCall.getMethodDescriptor()).thenReturn(methodDescriptor); - Certificate nonX509Cert = mock(Certificate.class); - Certificate[] certs = new Certificate[] {nonX509Cert}; - - when(sslSession.getPeerCertificates()).thenReturn(certs); - when(sslSession.getLocalCertificates()).thenReturn(certs); - - CheckRequest request = - checkRequestBuilder.buildRequest(serverCall, new Metadata(), requestTime); - - AttributeContext.Peer source = request.getAttributes().getSource(); - assertThat(source.getPrincipal()).isEmpty(); - AttributeContext.Peer destination = request.getAttributes().getDestination(); - assertThat(destination.getPrincipal()).isEmpty(); - } - - @Test - public void buildRequest_forServer_nonInetSocketAddress() { - SocketAddress remoteAddress = mock(SocketAddress.class); - when(serverCall.getAttributes()).thenReturn(Attributes.newBuilder() - .set(Grpc.TRANSPORT_ATTR_REMOTE_ADDR, remoteAddress).build()); - when(serverCall.getMethodDescriptor()).thenReturn(methodDescriptor); - CheckRequest request = - checkRequestBuilder.buildRequest(serverCall, new Metadata(), requestTime); - assertThat(request.getAttributes().getSource().hasAddress()).isFalse(); - } - - private ExtAuthzConfig buildExtAuthzConfig() throws ExtAuthzParseException { - return buildExtAuthzConfig(ListStringMatcher.newBuilder().build(), - ListStringMatcher.newBuilder().build(), true); - } - - private ExtAuthzConfig buildExtAuthzConfig(ListStringMatcher allowed, - ListStringMatcher disallowed, boolean includePeerCertificate) - throws ExtAuthzParseException { - Any googleDefaultChannelCreds = Any.pack(GoogleDefaultCredentials.newBuilder().build()); - Any fakeAccessTokenCreds = - Any.pack(AccessTokenCredentials.newBuilder().setToken("fake-token").build()); - ExtAuthz.Builder builder = ExtAuthz.newBuilder() - .setGrpcService(io.envoyproxy.envoy.config.core.v3.GrpcService.newBuilder() - .setGoogleGrpc(io.envoyproxy.envoy.config.core.v3.GrpcService.GoogleGrpc - .newBuilder().setTargetUri("test-cluster") - .addChannelCredentialsPlugin(googleDefaultChannelCreds) - .addCallCredentialsPlugin(fakeAccessTokenCreds).build()) - .build()) - .setIncludePeerCertificate(includePeerCertificate).setAllowedHeaders(allowed) - .setDisallowedHeaders(disallowed); - return ExtAuthzConfig.fromProto(builder.build()); - } + @Rule + public final MockitoRule mockito = MockitoJUnit.rule(); + + private TestServerCall serverCall; + @Mock + private SSLSession sslSession; + @Mock + private CheckRequestBuilder.CertificateProvider certificateProvider; + + private CheckRequestBuilder checkRequestBuilder; + private MethodDescriptor methodDescriptor; + private Timestamp requestTime; + + @Before + public void setUp() throws ExtAuthzParseException { + ExtAuthzConfig config = buildExtAuthzConfig(); + checkRequestBuilder = + new CheckRequestBuilder(config, certificateProvider); + methodDescriptor = TestMethodDescriptors.voidMethod(); + requestTime = Timestamp.newBuilder().setSeconds(12345).setNanos(67890).build(); + } + + @Test + public void buildRequest_forServer_happyPath() throws Exception { + // Setup for addresses + SocketAddress localAddress = new InetSocketAddress("10.0.0.2", 443); + SocketAddress remoteAddress = new InetSocketAddress("192.168.1.1", 12345); + + // Setup for SSL and certificates + X509Certificate peerCert = mock(X509Certificate.class); + X509Certificate localCert = mock(X509Certificate.class); + Certificate[] peerCerts = new Certificate[] {peerCert}; + Certificate[] localCerts = new Certificate[] {localCert}; + when(sslSession.getPeerCertificates()).thenReturn(peerCerts); + when(sslSession.getLocalCertificates()).thenReturn(localCerts); + when(certificateProvider.getPrincipal(peerCert)).thenReturn("peer-principal"); + when(certificateProvider.getPrincipal(localCert)).thenReturn("local-principal"); + when(certificateProvider.getUrlPemEncodedCertificate(peerCert)).thenReturn("encoded-peer-cert"); + + // Setup for headers + Metadata headers = new Metadata(); + headers.put(Metadata.Key.of("allowed-header", Metadata.ASCII_STRING_MARSHALLER), "v1"); + headers.put(Metadata.Key.of("disallowed-header", Metadata.ASCII_STRING_MARSHALLER), "v2"); + headers.put(Metadata.Key.of("overridden-header", Metadata.ASCII_STRING_MARSHALLER), "v3"); + byte[] binaryValue = new byte[] {1, 2, 3}; + headers.put(Metadata.Key.of("bin-header-bin", Metadata.BINARY_BYTE_MARSHALLER), binaryValue); + + // Configure CheckRequestBuilder to allow specific headers + ListStringMatcher allowedHeaders = ListStringMatcher.newBuilder() + .addPatterns(StringMatcher.newBuilder().setExact("allowed-header").build()) + .addPatterns(StringMatcher.newBuilder().setExact("overridden-header").build()).build(); + ListStringMatcher disallowedHeaders = ListStringMatcher.newBuilder() + .addPatterns(StringMatcher.newBuilder().setExact("disallowed-header").build()) + .addPatterns(StringMatcher.newBuilder().setExact("overridden-header").build()).build(); + ExtAuthzConfig config = buildExtAuthzConfig(allowedHeaders, disallowedHeaders, true); + checkRequestBuilder = + new CheckRequestBuilder(config, certificateProvider); + + // Setup server call attributes + Attributes attributes = + Attributes.newBuilder().set(Grpc.TRANSPORT_ATTR_LOCAL_ADDR, localAddress) + .set(Grpc.TRANSPORT_ATTR_REMOTE_ADDR, remoteAddress) + .set(Grpc.TRANSPORT_ATTR_SSL_SESSION, sslSession).build(); + serverCall = new TestServerCall<>(attributes, methodDescriptor); + + // Build and verify the request + CheckRequest request = checkRequestBuilder.buildRequest(serverCall, headers, requestTime); + + AttributeContext attrContext = request.getAttributes(); + assertThat(attrContext.getSource().getAddress().getSocketAddress().getAddress()) + .isEqualTo("192.168.1.1"); + assertThat(attrContext.getSource().getPrincipal()).isEqualTo("peer-principal"); + assertThat(attrContext.getSource().getCertificate()).isEqualTo("encoded-peer-cert"); + assertThat(attrContext.getDestination().getAddress().getSocketAddress().getAddress()) + .isEqualTo("10.0.0.2"); + assertThat(attrContext.getDestination().getPrincipal()).isEqualTo("local-principal"); + + AttributeContext.HttpRequest http = attrContext.getRequest().getHttp(); + assertThat(http.getHeadersMap()).containsEntry("allowed-header", "v1"); + assertThat(http.getHeadersMap()).doesNotContainKey("bin-header-bin"); + assertThat(http.getHeadersMap()).doesNotContainKey("disallowed-header"); + assertThat(http.getHeadersMap()).doesNotContainKey("overridden-header"); + } + + @Test + public void buildRequest_forServer_noTransportAttrs() { + serverCall = new TestServerCall<>(Attributes.EMPTY, methodDescriptor); + Metadata headers = new Metadata(); + + CheckRequest request = checkRequestBuilder.buildRequest(serverCall, headers, requestTime); + + assertThat(request.getAttributes().getRequest().getTime()).isEqualTo(requestTime); + assertThat(request.getAttributes().getRequest().getHttp().getPath()) + .isEqualTo("/" + methodDescriptor.getFullMethodName()); + + assertThat(request.getAttributes().getRequest().getHttp().getMethod()).isEqualTo("POST"); + assertThat(request.getAttributes().getRequest().getHttp().getProtocol()).isEqualTo("HTTP/2"); + assertThat(request.getAttributes().getRequest().getHttp().getSize()).isEqualTo(-1); + assertThat(request.getAttributes().getRequest().getHttp().getHeadersMap()).isEmpty(); + assertThat(request.getAttributes().hasSource()).isFalse(); + assertThat(request.getAttributes().hasDestination()).isFalse(); + } + + @Test + public void buildRequest_forClient_happyPath_emptyAllowedHeaders() throws Exception { + // Setup for headers + Metadata headers = new Metadata(); + headers.put(Metadata.Key.of("some-header", Metadata.ASCII_STRING_MARSHALLER), "v1"); + headers.put(Metadata.Key.of("disallowed-header", Metadata.ASCII_STRING_MARSHALLER), "v2"); + byte[] binaryValue = new byte[] {1, 2, 3}; + headers.put(Metadata.Key.of("bin-header-bin", Metadata.BINARY_BYTE_MARSHALLER), binaryValue); + + // Configure CheckRequestBuilder with empty allowed headers + ListStringMatcher allowedHeaders = ListStringMatcher.newBuilder().build(); // empty + ListStringMatcher disallowedHeaders = ListStringMatcher.newBuilder() + .addPatterns(StringMatcher.newBuilder().setExact("disallowed-header").build()).build(); + ExtAuthzConfig config = buildExtAuthzConfig(allowedHeaders, disallowedHeaders, true); + checkRequestBuilder = + new CheckRequestBuilder(config, certificateProvider); + + // Build and verify the request + CheckRequest request = checkRequestBuilder.buildRequest(methodDescriptor, headers, requestTime); + + AttributeContext attrContext = request.getAttributes(); + assertThat(attrContext.hasSource()).isFalse(); + assertThat(attrContext.hasDestination()).isFalse(); + + AttributeContext.HttpRequest http = attrContext.getRequest().getHttp(); + assertThat(http.getPath()).isEqualTo("/" + methodDescriptor.getFullMethodName()); + + assertThat(http.getHeadersMap()).containsEntry("some-header", "v1"); + assertThat(http.getHeadersMap()).containsEntry("bin-header-bin", "AQID"); + assertThat(http.getHeadersMap()).doesNotContainKey("disallowed-header"); + } + + @Test + public void buildRequest_forServer_noSslSession() { + SocketAddress localAddress = new InetSocketAddress("10.0.0.2", 443); + SocketAddress remoteAddress = new InetSocketAddress("192.168.1.1", 12345); + Attributes attributes = + Attributes.newBuilder().set(Grpc.TRANSPORT_ATTR_LOCAL_ADDR, localAddress) + .set(Grpc.TRANSPORT_ATTR_REMOTE_ADDR, remoteAddress).build(); + serverCall = new TestServerCall<>(attributes, methodDescriptor); + + CheckRequest request = + checkRequestBuilder.buildRequest(serverCall, new Metadata(), requestTime); + + AttributeContext attrContext = request.getAttributes(); + assertThat(attrContext.hasSource()).isTrue(); + Address sourceAddress = attrContext.getSource().getAddress(); + assertThat(sourceAddress.getSocketAddress().getAddress()).isEqualTo("192.168.1.1"); + assertThat(sourceAddress.getSocketAddress().getPortValue()).isEqualTo(12345); + assertThat(attrContext.getSource().getPrincipal()).isEmpty(); + + assertThat(attrContext.hasDestination()).isTrue(); + Address destAddress = attrContext.getDestination().getAddress(); + assertThat(destAddress.getSocketAddress().getAddress()).isEqualTo("10.0.0.2"); + assertThat(destAddress.getSocketAddress().getPortValue()).isEqualTo(443); + assertThat(attrContext.getDestination().getPrincipal()).isEmpty(); + } + + @Test + public void buildRequest_forServer_sslPeerUnverified() throws Exception { + SocketAddress remoteAddress = new InetSocketAddress("192.168.1.1", 12345); + when(sslSession.getPeerCertificates()).thenThrow(new SSLPeerUnverifiedException("unverified")); + Attributes attributes = + Attributes.newBuilder().set(Grpc.TRANSPORT_ATTR_REMOTE_ADDR, remoteAddress) + .set(Grpc.TRANSPORT_ATTR_SSL_SESSION, sslSession).build(); + serverCall = new TestServerCall<>(attributes, methodDescriptor); + + CheckRequest request = + checkRequestBuilder.buildRequest(serverCall, new Metadata(), requestTime); + + AttributeContext.Peer source = request.getAttributes().getSource(); + assertThat(source.getPrincipal()).isEmpty(); + assertThat(source.getCertificate()).isEmpty(); + } + + @Test + public void buildRequest_forServer_includePeerCertFalse() throws Exception { + ExtAuthzConfig config = buildExtAuthzConfig(ListStringMatcher.newBuilder().build(), + ListStringMatcher.newBuilder().build(), false); + checkRequestBuilder = + new CheckRequestBuilder(config, certificateProvider); + SocketAddress remoteAddress = new InetSocketAddress("192.168.1.1", 12345); + X509Certificate peerCert = mock(X509Certificate.class); + Certificate[] peerCerts = new Certificate[] {peerCert}; + + when(sslSession.getPeerCertificates()).thenReturn(peerCerts); + when(certificateProvider.getPrincipal(peerCert)).thenReturn("peer-principal"); + + Attributes attributes = + Attributes.newBuilder().set(Grpc.TRANSPORT_ATTR_REMOTE_ADDR, remoteAddress) + .set(Grpc.TRANSPORT_ATTR_SSL_SESSION, sslSession).build(); + serverCall = new TestServerCall<>(attributes, methodDescriptor); + + CheckRequest request = + checkRequestBuilder.buildRequest(serverCall, new Metadata(), requestTime); + + AttributeContext.Peer source = request.getAttributes().getSource(); + assertThat(source.getPrincipal()).isEqualTo("peer-principal"); + assertThat(source.getCertificate()).isEmpty(); + } + + @Test + public void buildRequest_forServer_nullOrEmptyCertificates() throws Exception { + SocketAddress localAddress = new InetSocketAddress("10.0.0.2", 443); + SocketAddress remoteAddress = new InetSocketAddress("192.168.1.1", 12345); + Attributes attributes = + Attributes.newBuilder().set(Grpc.TRANSPORT_ATTR_LOCAL_ADDR, localAddress) + .set(Grpc.TRANSPORT_ATTR_REMOTE_ADDR, remoteAddress) + .set(Grpc.TRANSPORT_ATTR_SSL_SESSION, sslSession).build(); + serverCall = new TestServerCall<>(attributes, methodDescriptor); + + // Test with null certificates + when(sslSession.getPeerCertificates()).thenReturn(null); + when(sslSession.getLocalCertificates()).thenReturn(null); + CheckRequest request = + checkRequestBuilder.buildRequest(serverCall, new Metadata(), requestTime); + AttributeContext.Peer source = request.getAttributes().getSource(); + assertThat(source.getPrincipal()).isEmpty(); + assertThat(source.getCertificate()).isEmpty(); + AttributeContext.Peer destination = request.getAttributes().getDestination(); + assertThat(destination.getPrincipal()).isEmpty(); + + // Test with empty certificates + when(sslSession.getPeerCertificates()).thenReturn(new Certificate[0]); + when(sslSession.getLocalCertificates()).thenReturn(new Certificate[0]); + request = checkRequestBuilder.buildRequest(serverCall, new Metadata(), requestTime); + source = request.getAttributes().getSource(); + assertThat(source.getPrincipal()).isEmpty(); + assertThat(source.getCertificate()).isEmpty(); + destination = request.getAttributes().getDestination(); + assertThat(destination.getPrincipal()).isEmpty(); + } + + @Test + public void buildRequest_forServer_nonX509Certificate() throws Exception { + SocketAddress localAddress = new InetSocketAddress("10.0.0.2", 443); + SocketAddress remoteAddress = new InetSocketAddress("192.168.1.1", 12345); + Attributes attributes = + Attributes.newBuilder().set(Grpc.TRANSPORT_ATTR_LOCAL_ADDR, localAddress) + .set(Grpc.TRANSPORT_ATTR_REMOTE_ADDR, remoteAddress) + .set(Grpc.TRANSPORT_ATTR_SSL_SESSION, sslSession).build(); + serverCall = new TestServerCall<>(attributes, methodDescriptor); + Certificate nonX509Cert = mock(Certificate.class); + Certificate[] certs = new Certificate[] {nonX509Cert}; + + when(sslSession.getPeerCertificates()).thenReturn(certs); + when(sslSession.getLocalCertificates()).thenReturn(certs); + + CheckRequest request = + checkRequestBuilder.buildRequest(serverCall, new Metadata(), requestTime); + + AttributeContext.Peer source = request.getAttributes().getSource(); + assertThat(source.getPrincipal()).isEmpty(); + AttributeContext.Peer destination = request.getAttributes().getDestination(); + assertThat(destination.getPrincipal()).isEmpty(); + } + + @Test + public void buildRequest_forServer_nonInetSocketAddress() { + SocketAddress remoteAddress = mock(SocketAddress.class); + serverCall = new TestServerCall<>( + Attributes.newBuilder().set(Grpc.TRANSPORT_ATTR_REMOTE_ADDR, remoteAddress).build(), + methodDescriptor); + CheckRequest request = + checkRequestBuilder.buildRequest(serverCall, new Metadata(), requestTime); + assertThat(request.getAttributes().getSource().hasAddress()).isFalse(); + } + + @Test + public void buildRequest_forServer_handlesCertificateEncodingException() throws Exception { + SocketAddress localAddress = new InetSocketAddress("10.0.0.2", 443); + SocketAddress remoteAddress = new InetSocketAddress("192.168.1.1", 12345); + Attributes attributes = + Attributes.newBuilder().set(Grpc.TRANSPORT_ATTR_LOCAL_ADDR, localAddress) + .set(Grpc.TRANSPORT_ATTR_REMOTE_ADDR, remoteAddress) + .set(Grpc.TRANSPORT_ATTR_SSL_SESSION, sslSession).build(); + serverCall = new TestServerCall<>(attributes, methodDescriptor); + + X509Certificate peerCert = mock(X509Certificate.class); + Certificate[] peerCerts = new Certificate[] {peerCert}; + when(sslSession.getPeerCertificates()).thenReturn(peerCerts); + when(certificateProvider.getPrincipal(peerCert)).thenReturn("peer-principal"); + when(certificateProvider.getUrlPemEncodedCertificate(peerCert)) + .thenThrow(new java.security.cert.CertificateEncodingException("encoding error")); + + CheckRequest request = + checkRequestBuilder.buildRequest(serverCall, new Metadata(), requestTime); + + AttributeContext.Peer source = request.getAttributes().getSource(); + assertThat(source.getPrincipal()).isEqualTo("peer-principal"); + assertThat(source.getCertificate()).isEmpty(); + } + + private ExtAuthzConfig buildExtAuthzConfig() throws ExtAuthzParseException { + return buildExtAuthzConfig(ListStringMatcher.newBuilder().build(), + ListStringMatcher.newBuilder().build(), true); + } + + private ExtAuthzConfig buildExtAuthzConfig(ListStringMatcher allowed, + ListStringMatcher disallowed, boolean includePeerCertificate) throws ExtAuthzParseException { + Any googleDefaultChannelCreds = Any.pack(GoogleDefaultCredentials.newBuilder().build()); + Any fakeAccessTokenCreds = + Any.pack(AccessTokenCredentials.newBuilder().setToken("fake-token").build()); + ExtAuthz.Builder builder = ExtAuthz.newBuilder() + .setGrpcService(GrpcService.newBuilder() + .setGoogleGrpc(GrpcService.GoogleGrpc.newBuilder() + .setTargetUri("test-cluster").addChannelCredentialsPlugin(googleDefaultChannelCreds) + .addCallCredentialsPlugin(fakeAccessTokenCreds).build()) + .build()) + .setIncludePeerCertificate(includePeerCertificate).setAllowedHeaders(allowed) + .setDisallowedHeaders(disallowed); + BootstrapInfo bootstrapInfo = + CommonBootstrapperTestUtils.buildBootStrap( + Collections.singletonList("test-cluster")); + return ExtAuthzConfigParser.parse( + builder.build(), bootstrapInfo, bootstrapInfo.servers().get(0)); + } } From ea26e761c52edaa180bced0d5f841bd19121e1d2 Mon Sep 17 00:00:00 2001 From: Saurav Date: Wed, 3 Jun 2026 19:07:02 +0000 Subject: [PATCH 3/3] Fixup: Request builder review fixes --- .../internal/extauthz/CertificateUtils.java | 16 +- .../extauthz/CheckRequestBuilder.java | 98 +++++++-- .../extauthz/CheckRequestBuilderTest.java | 202 +++++++++++++----- 3 files changed, 241 insertions(+), 75 deletions(-) diff --git a/xds/src/main/java/io/grpc/xds/internal/extauthz/CertificateUtils.java b/xds/src/main/java/io/grpc/xds/internal/extauthz/CertificateUtils.java index c0253839253..9d4ea4f8005 100644 --- a/xds/src/main/java/io/grpc/xds/internal/extauthz/CertificateUtils.java +++ b/xds/src/main/java/io/grpc/xds/internal/extauthz/CertificateUtils.java @@ -30,7 +30,7 @@ /** * A utility class for certificate-related information. */ -public final class CertificateUtils { +final class CertificateUtils { private static final Logger logger = Logger.getLogger(CertificateUtils.class.getName()); // From RFC 5280, section 4.2.1.6, Subject Alternative Name // dNSName (2) @@ -48,21 +48,21 @@ private CertificateUtils() {} * @param cert The certificate. * @return The principal. */ - public static String getPrincipal(X509Certificate cert) { + static String getPrincipal(X509Certificate cert) { try { Collection> sans = cert.getSubjectAlternativeNames(); if (sans != null) { // Look for URI SAN (Priority 1). for (List san : sans) { if (san.size() == 2 && san.get(0) instanceof Integer - && (Integer) san.get(0) == SAN_TYPE_URI) { + && san.get(0).equals(SAN_TYPE_URI)) { return (String) san.get(1); } } // If no URI SAN, look for DNS SAN (Priority 2). for (List san : sans) { if (san.size() == 2 && san.get(0) instanceof Integer - && (Integer) san.get(0) == SAN_TYPE_DNS_NAME) { + && san.get(0).equals(SAN_TYPE_DNS_NAME)) { return (String) san.get(1); } } @@ -82,10 +82,10 @@ public static String getPrincipal(X509Certificate cert) { * @throws CertificateEncodingException If an error occurs while encoding the certificate. * @throws UnsupportedEncodingException If an error occurs while encoding the URL. */ - public static String getUrlPemEncodedCertificate(X509Certificate cert) + static String getUrlPemEncodedCertificate(X509Certificate cert) throws CertificateEncodingException, UnsupportedEncodingException { String pemCert = CertPemConverter.toPem(cert); - return URLEncoder.encode(pemCert, StandardCharsets.UTF_8.toString()); + return URLEncoder.encode(pemCert, StandardCharsets.UTF_8.name()); } /** @@ -106,7 +106,9 @@ private CertPemConverter() {} * @throws CertificateEncodingException If an error occurs while encoding the certificate. */ public static String toPem(X509Certificate cert) throws CertificateEncodingException { - return X509_PEM_HEADER + BaseEncoding.base64().encode(cert.getEncoded()) + X509_PEM_FOOTER; + return X509_PEM_HEADER + + BaseEncoding.base64().withSeparator("\n", 64).encode(cert.getEncoded()) + + X509_PEM_FOOTER; } } } diff --git a/xds/src/main/java/io/grpc/xds/internal/extauthz/CheckRequestBuilder.java b/xds/src/main/java/io/grpc/xds/internal/extauthz/CheckRequestBuilder.java index c16cabdd783..bb0eb5c707a 100644 --- a/xds/src/main/java/io/grpc/xds/internal/extauthz/CheckRequestBuilder.java +++ b/xds/src/main/java/io/grpc/xds/internal/extauthz/CheckRequestBuilder.java @@ -18,8 +18,11 @@ import com.google.common.io.BaseEncoding; +import com.google.protobuf.ByteString; import com.google.protobuf.Timestamp; import io.envoyproxy.envoy.config.core.v3.Address; +import io.envoyproxy.envoy.config.core.v3.HeaderMap; +import io.envoyproxy.envoy.config.core.v3.HeaderValue; import io.envoyproxy.envoy.config.core.v3.SocketAddress; import io.envoyproxy.envoy.service.auth.v3.AttributeContext; import io.envoyproxy.envoy.service.auth.v3.CheckRequest; @@ -33,24 +36,24 @@ import java.security.cert.Certificate; import java.security.cert.CertificateEncodingException; import java.security.cert.X509Certificate; -import java.util.ArrayList; -import java.util.List; import java.util.Locale; import java.util.logging.Level; import java.util.logging.Logger; import javax.annotation.Nullable; +import javax.annotation.concurrent.ThreadSafe; import javax.net.ssl.SSLPeerUnverifiedException; import javax.net.ssl.SSLSession; /** - * Interface for building external authorization check requests. + * Builds external authorization check requests from gRPC call metadata. */ +@ThreadSafe public class CheckRequestBuilder { /** * An interface for providing certificate-related information. */ - public interface CertificateProvider { + interface CertificateProvider { /** * Gets the principal from a certificate. * @@ -72,11 +75,17 @@ String getUrlPemEncodedCertificate(X509Certificate cert) } private static final Logger logger = Logger.getLogger(CheckRequestBuilder.class.getName()); + private static final BaseEncoding BASE64_NO_PAD = BaseEncoding.base64().omitPadding(); private final ExtAuthzConfig config; private final CertificateProvider certificateProvider; + /** + * Constructs a new {@link CheckRequestBuilder} with the default certificate provider. + * + * @param config The external authorization configuration. + */ public CheckRequestBuilder(ExtAuthzConfig config) { this(config, new CertificateProvider() { @Override @@ -92,18 +101,40 @@ public String getUrlPemEncodedCertificate(X509Certificate cert) }); } - public CheckRequestBuilder(ExtAuthzConfig config, CertificateProvider certificateProvider) { + /** + * Constructs a new {@link CheckRequestBuilder} with a custom certificate provider. + * + * @param config The external authorization configuration. + * @param certificateProvider The certificate provider. + */ + CheckRequestBuilder(ExtAuthzConfig config, CertificateProvider certificateProvider) { this.config = config; this.certificateProvider = certificateProvider; } + /** + * Builds a check request for a client-side call. + * + * @param methodDescriptor The method descriptor of the RPC. + * @param headers The initial metadata headers. + * @param requestTime The timestamp when the request was initiated. + * @return The constructed {@link CheckRequest}. + */ public CheckRequest buildRequest(MethodDescriptor methodDescriptor, Metadata headers, Timestamp requestTime) { return build(methodDescriptor, headers, requestTime, null, null, null); } + /** + * Builds a check request for a server-side call. + * + * @param serverCall The server call. + * @param headers The initial metadata headers. + * @param requestTime The timestamp when the request was initiated. + * @return The constructed {@link CheckRequest}. + */ public CheckRequest buildRequest(ServerCall serverCall, Metadata headers, Timestamp requestTime) { java.net.SocketAddress localAddress = @@ -171,10 +202,24 @@ private AttributeContext.Peer buildPeer(java.net.SocketAddress socketAddress) { AttributeContext.Peer.Builder peerBuilder = AttributeContext.Peer.newBuilder(); if (socketAddress instanceof InetSocketAddress) { InetSocketAddress inetSocketAddress = (InetSocketAddress) socketAddress; + // Prefer the resolved IP address, but fall back to the hostname string for + // unresolved addresses. In practice, Netty transports always provide resolved + // InetSocketAddress instances for active connections, and other gRPC + // implementations (C++, Go) always produce IP addresses because they operate + // on real TCP sockets. However, Envoy's address.proto permits hostnames (the + // only constraint is a non-empty string), so we gracefully fall back to + // getHostString() for robustness. See also TcpMetrics.java for precedent: + // https://github.com/grpc/grpc-java/blob/master/netty/src/main/java/io/grpc/netty/TcpMetrics.java + String address; + if (inetSocketAddress.getAddress() != null) { + address = inetSocketAddress.getAddress().getHostAddress(); + } else { + address = inetSocketAddress.getHostString(); + } peerBuilder .setAddress(Address.newBuilder() .setSocketAddress(SocketAddress.newBuilder() - .setAddress(inetSocketAddress.getAddress().getHostAddress()) + .setAddress(address) .setPortValue(inetSocketAddress.getPort())) .build()); } @@ -190,36 +235,55 @@ private AttributeContext.Request buildAttributeRequest(Metadata headers, String httpReqBuilder.setMethod("POST"); httpReqBuilder.setProtocol("HTTP/2"); httpReqBuilder.setSize(-1); + + HeaderMap.Builder headerMapBuilder = HeaderMap.newBuilder(); for (String key : headers.keys()) { if (!isAllowed(key)) { continue; } - String value; + String lowerCaseKey = key.toLowerCase(Locale.ROOT); if (key.endsWith(Metadata.BINARY_HEADER_SUFFIX)) { - value = getBinaryHeaderValue(headers, key); + populateBinaryHeaderValues(headers, key, lowerCaseKey, headerMapBuilder); } else { - value = getAsciiHeaderValue(headers, key); + populateAsciiHeaderValues(headers, key, lowerCaseKey, headerMapBuilder); } - httpReqBuilder.putHeaders(key.toLowerCase(Locale.ROOT), value); } + httpReqBuilder.setHeaderMap(headerMapBuilder); reqBuilder.setHttp(httpReqBuilder); return reqBuilder.build(); } - private String getBinaryHeaderValue(Metadata headers, String key) { + private void populateBinaryHeaderValues(Metadata headers, String key, String lowerCaseKey, + HeaderMap.Builder headerMapBuilder) { Iterable binaryValues = headers.getAll(Metadata.Key.of(key, Metadata.BINARY_BYTE_MARSHALLER)); - List base64Values = new ArrayList<>(); - for (byte[] value : binaryValues) { - base64Values.add(BaseEncoding.base64().encode(value)); + if (binaryValues != null) { + for (byte[] value : binaryValues) { + // Binary header values are base64-encoded before storing in rawValue, + // matching Envoy's behavior for CheckRequest header serialization. + String base64Value = BASE64_NO_PAD.encode(value); + headerMapBuilder.addHeaders( + HeaderValue.newBuilder() + .setKey(lowerCaseKey) + .setRawValue(ByteString.copyFromUtf8(base64Value)) + .build()); + } } - return String.join(",", base64Values); } - private String getAsciiHeaderValue(Metadata headers, String key) { + private void populateAsciiHeaderValues(Metadata headers, String key, String lowerCaseKey, + HeaderMap.Builder headerMapBuilder) { Iterable stringValues = headers.getAll(Metadata.Key.of(key, Metadata.ASCII_STRING_MARSHALLER)); - return String.join(",", stringValues); + if (stringValues != null) { + for (String value : stringValues) { + headerMapBuilder.addHeaders( + HeaderValue.newBuilder() + .setKey(lowerCaseKey) + .setRawValue(ByteString.copyFromUtf8(value)) + .build()); + } + } } private boolean isAllowed(String header) { diff --git a/xds/src/test/java/io/grpc/xds/internal/extauthz/CheckRequestBuilderTest.java b/xds/src/test/java/io/grpc/xds/internal/extauthz/CheckRequestBuilderTest.java index 8747e051d6d..c205c578107 100644 --- a/xds/src/test/java/io/grpc/xds/internal/extauthz/CheckRequestBuilderTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/extauthz/CheckRequestBuilderTest.java @@ -20,30 +20,24 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; -import com.google.protobuf.Any; +import com.google.common.collect.ImmutableList; import com.google.protobuf.Timestamp; import io.envoyproxy.envoy.config.core.v3.Address; -import io.envoyproxy.envoy.config.core.v3.GrpcService; -import io.envoyproxy.envoy.extensions.filters.http.ext_authz.v3.ExtAuthz; -import io.envoyproxy.envoy.extensions.grpc_service.call_credentials.access_token.v3.AccessTokenCredentials; -import io.envoyproxy.envoy.extensions.grpc_service.channel_credentials.google_default.v3.GoogleDefaultCredentials; +import io.envoyproxy.envoy.config.core.v3.HeaderMap; import io.envoyproxy.envoy.service.auth.v3.AttributeContext; import io.envoyproxy.envoy.service.auth.v3.CheckRequest; -import io.envoyproxy.envoy.type.matcher.v3.ListStringMatcher; -import io.envoyproxy.envoy.type.matcher.v3.StringMatcher; import io.grpc.Attributes; import io.grpc.Grpc; import io.grpc.Metadata; import io.grpc.MethodDescriptor; import io.grpc.testing.TestMethodDescriptors; -import io.grpc.xds.ExtAuthzConfigParser; -import io.grpc.xds.client.Bootstrapper.BootstrapInfo; -import io.grpc.xds.client.CommonBootstrapperTestUtils; +import io.grpc.xds.internal.Matchers; +import io.grpc.xds.internal.extauthz.ExtAuthzTestHelper.TestServerCall; +import io.grpc.xds.internal.grpcservice.GrpcServiceConfig; import java.net.InetSocketAddress; import java.net.SocketAddress; import java.security.cert.Certificate; import java.security.cert.X509Certificate; -import java.util.Collections; import javax.net.ssl.SSLPeerUnverifiedException; import javax.net.ssl.SSLSession; import org.junit.Before; @@ -105,12 +99,12 @@ public void buildRequest_forServer_happyPath() throws Exception { headers.put(Metadata.Key.of("bin-header-bin", Metadata.BINARY_BYTE_MARSHALLER), binaryValue); // Configure CheckRequestBuilder to allow specific headers - ListStringMatcher allowedHeaders = ListStringMatcher.newBuilder() - .addPatterns(StringMatcher.newBuilder().setExact("allowed-header").build()) - .addPatterns(StringMatcher.newBuilder().setExact("overridden-header").build()).build(); - ListStringMatcher disallowedHeaders = ListStringMatcher.newBuilder() - .addPatterns(StringMatcher.newBuilder().setExact("disallowed-header").build()) - .addPatterns(StringMatcher.newBuilder().setExact("overridden-header").build()).build(); + ImmutableList allowedHeaders = ImmutableList.of( + Matchers.StringMatcher.forExact("allowed-header", false), + Matchers.StringMatcher.forExact("overridden-header", false)); + ImmutableList disallowedHeaders = ImmutableList.of( + Matchers.StringMatcher.forExact("disallowed-header", false), + Matchers.StringMatcher.forExact("overridden-header", false)); ExtAuthzConfig config = buildExtAuthzConfig(allowedHeaders, disallowedHeaders, true); checkRequestBuilder = new CheckRequestBuilder(config, certificateProvider); @@ -135,10 +129,9 @@ public void buildRequest_forServer_happyPath() throws Exception { assertThat(attrContext.getDestination().getPrincipal()).isEqualTo("local-principal"); AttributeContext.HttpRequest http = attrContext.getRequest().getHttp(); - assertThat(http.getHeadersMap()).containsEntry("allowed-header", "v1"); - assertThat(http.getHeadersMap()).doesNotContainKey("bin-header-bin"); - assertThat(http.getHeadersMap()).doesNotContainKey("disallowed-header"); - assertThat(http.getHeadersMap()).doesNotContainKey("overridden-header"); + assertThat(http.getHeaderMap().getHeadersList()).containsExactly( + io.envoyproxy.envoy.config.core.v3.HeaderValue.newBuilder().setKey("allowed-header") + .setRawValue(com.google.protobuf.ByteString.copyFromUtf8("v1")).build()); } @Test @@ -155,7 +148,8 @@ public void buildRequest_forServer_noTransportAttrs() { assertThat(request.getAttributes().getRequest().getHttp().getMethod()).isEqualTo("POST"); assertThat(request.getAttributes().getRequest().getHttp().getProtocol()).isEqualTo("HTTP/2"); assertThat(request.getAttributes().getRequest().getHttp().getSize()).isEqualTo(-1); - assertThat(request.getAttributes().getRequest().getHttp().getHeadersMap()).isEmpty(); + assertThat(request.getAttributes().getRequest().getHttp().getHeaderMap().getHeadersList()) + .isEmpty(); assertThat(request.getAttributes().hasSource()).isFalse(); assertThat(request.getAttributes().hasDestination()).isFalse(); } @@ -170,9 +164,9 @@ public void buildRequest_forClient_happyPath_emptyAllowedHeaders() throws Except headers.put(Metadata.Key.of("bin-header-bin", Metadata.BINARY_BYTE_MARSHALLER), binaryValue); // Configure CheckRequestBuilder with empty allowed headers - ListStringMatcher allowedHeaders = ListStringMatcher.newBuilder().build(); // empty - ListStringMatcher disallowedHeaders = ListStringMatcher.newBuilder() - .addPatterns(StringMatcher.newBuilder().setExact("disallowed-header").build()).build(); + ImmutableList allowedHeaders = ImmutableList.of(); + ImmutableList disallowedHeaders = ImmutableList.of( + Matchers.StringMatcher.forExact("disallowed-header", false)); ExtAuthzConfig config = buildExtAuthzConfig(allowedHeaders, disallowedHeaders, true); checkRequestBuilder = new CheckRequestBuilder(config, certificateProvider); @@ -187,9 +181,11 @@ public void buildRequest_forClient_happyPath_emptyAllowedHeaders() throws Except AttributeContext.HttpRequest http = attrContext.getRequest().getHttp(); assertThat(http.getPath()).isEqualTo("/" + methodDescriptor.getFullMethodName()); - assertThat(http.getHeadersMap()).containsEntry("some-header", "v1"); - assertThat(http.getHeadersMap()).containsEntry("bin-header-bin", "AQID"); - assertThat(http.getHeadersMap()).doesNotContainKey("disallowed-header"); + assertThat(http.getHeaderMap().getHeadersList()).containsExactly( + io.envoyproxy.envoy.config.core.v3.HeaderValue.newBuilder().setKey("some-header") + .setRawValue(com.google.protobuf.ByteString.copyFromUtf8("v1")).build(), + io.envoyproxy.envoy.config.core.v3.HeaderValue.newBuilder().setKey("bin-header-bin") + .setRawValue(com.google.protobuf.ByteString.copyFromUtf8("AQID")).build()); } @Test @@ -237,8 +233,8 @@ public void buildRequest_forServer_sslPeerUnverified() throws Exception { @Test public void buildRequest_forServer_includePeerCertFalse() throws Exception { - ExtAuthzConfig config = buildExtAuthzConfig(ListStringMatcher.newBuilder().build(), - ListStringMatcher.newBuilder().build(), false); + ExtAuthzConfig config = buildExtAuthzConfig(ImmutableList.of(), + ImmutableList.of(), false); checkRequestBuilder = new CheckRequestBuilder(config, certificateProvider); SocketAddress remoteAddress = new InetSocketAddress("192.168.1.1", 12345); @@ -328,6 +324,33 @@ public void buildRequest_forServer_nonInetSocketAddress() { assertThat(request.getAttributes().getSource().hasAddress()).isFalse(); } + @Test + public void buildRequest_forServer_unresolvedInetSocketAddress() { + SocketAddress localAddress = + InetSocketAddress.createUnresolved("local-hostname", 443); + SocketAddress remoteAddress = + InetSocketAddress.createUnresolved("remote-hostname", 8080); + Attributes attributes = + Attributes.newBuilder().set(Grpc.TRANSPORT_ATTR_LOCAL_ADDR, localAddress) + .set(Grpc.TRANSPORT_ATTR_REMOTE_ADDR, remoteAddress).build(); + serverCall = new TestServerCall<>(attributes, methodDescriptor); + + CheckRequest request = + checkRequestBuilder.buildRequest(serverCall, new Metadata(), requestTime); + + AttributeContext attrContext = request.getAttributes(); + assertThat(attrContext.hasSource()).isTrue(); + assertThat(attrContext.getSource().getAddress().getSocketAddress().getAddress()) + .isEqualTo("remote-hostname"); + assertThat(attrContext.getSource().getAddress().getSocketAddress().getPortValue()) + .isEqualTo(8080); + assertThat(attrContext.hasDestination()).isTrue(); + assertThat(attrContext.getDestination().getAddress().getSocketAddress().getAddress()) + .isEqualTo("local-hostname"); + assertThat(attrContext.getDestination().getAddress().getSocketAddress().getPortValue()) + .isEqualTo(443); + } + @Test public void buildRequest_forServer_handlesCertificateEncodingException() throws Exception { SocketAddress localAddress = new InetSocketAddress("10.0.0.2", 443); @@ -353,28 +376,105 @@ public void buildRequest_forServer_handlesCertificateEncodingException() throws assertThat(source.getCertificate()).isEmpty(); } - private ExtAuthzConfig buildExtAuthzConfig() throws ExtAuthzParseException { - return buildExtAuthzConfig(ListStringMatcher.newBuilder().build(), - ListStringMatcher.newBuilder().build(), true); + @Test + public void buildRequest_forClient_allowedHeadersFiltering() { + Metadata headers = new Metadata(); + headers.put(Metadata.Key.of("allowed-header", Metadata.ASCII_STRING_MARSHALLER), "v1"); + headers.put(Metadata.Key.of("not-allowed-header", Metadata.ASCII_STRING_MARSHALLER), "v2"); + + // Configure with non-empty allowed headers — only "allowed-header" should pass + ImmutableList allowedHeaders = ImmutableList.of( + Matchers.StringMatcher.forExact("allowed-header", false)); + ExtAuthzConfig config = buildExtAuthzConfig(allowedHeaders, ImmutableList.of(), true); + CheckRequestBuilder builder = new CheckRequestBuilder(config, certificateProvider); + + CheckRequest request = builder.buildRequest(methodDescriptor, headers, requestTime); + + HeaderMap headerMap = request.getAttributes().getRequest().getHttp().getHeaderMap(); + assertThat(headerMap.getHeadersList()).hasSize(1); + assertThat(headerMap.getHeadersList().get(0).getKey()).isEqualTo("allowed-header"); + assertThat(headerMap.getHeadersList().get(0).getRawValue().toStringUtf8()).isEqualTo("v1"); + } + + @Test + public void buildRequest_forClient_emptyMetadata() { + CheckRequest request = + checkRequestBuilder.buildRequest(methodDescriptor, new Metadata(), requestTime); + + AttributeContext attrContext = request.getAttributes(); + assertThat(attrContext.hasSource()).isFalse(); + assertThat(attrContext.hasDestination()).isFalse(); + assertThat(attrContext.getRequest().getHttp().getPath()) + .isEqualTo("/" + methodDescriptor.getFullMethodName()); + assertThat(attrContext.getRequest().getHttp().getMethod()).isEqualTo("POST"); + assertThat(attrContext.getRequest().getHttp().getProtocol()).isEqualTo("HTTP/2"); + assertThat(attrContext.getRequest().getHttp().getHeaderMap().getHeadersList()).isEmpty(); + } + + @Test + public void buildRequest_multiValuedHeaders() throws Exception { + Metadata headers = new Metadata(); + Metadata.Key asciiKey = Metadata.Key.of("x-custom", Metadata.ASCII_STRING_MARSHALLER); + headers.put(asciiKey, "value1"); + headers.put(asciiKey, "value2"); + + Metadata.Key binaryKey = + Metadata.Key.of("x-custom-bin", Metadata.BINARY_BYTE_MARSHALLER); + headers.put(binaryKey, new byte[]{1, 2}); + headers.put(binaryKey, new byte[]{3, 4}); + + ExtAuthzConfig configWithAllowedHeaders = buildExtAuthzConfig( + ImmutableList.of(Matchers.StringMatcher.forExact("x-custom", false), + Matchers.StringMatcher.forExact("x-custom-bin", false)), + ImmutableList.of(), true); + CheckRequestBuilder builderWithConfig = + new CheckRequestBuilder(configWithAllowedHeaders, certificateProvider); + + CheckRequest request = + builderWithConfig.buildRequest(methodDescriptor, headers, requestTime); + + HeaderMap headerMap = request.getAttributes().getRequest().getHttp().getHeaderMap(); + assertThat(headerMap.getHeadersList()).hasSize(4); + assertThat(headerMap.getHeadersList().get(0).getKey()).isEqualTo("x-custom-bin"); + assertThat(headerMap.getHeadersList().get(0).getRawValue().toStringUtf8()).isEqualTo("AQI"); + assertThat(headerMap.getHeadersList().get(1).getKey()).isEqualTo("x-custom-bin"); + assertThat(headerMap.getHeadersList().get(1).getRawValue().toStringUtf8()).isEqualTo("AwQ"); + assertThat(headerMap.getHeadersList().get(2).getKey()).isEqualTo("x-custom"); + assertThat(headerMap.getHeadersList().get(2).getRawValue().toStringUtf8()).isEqualTo("value1"); + assertThat(headerMap.getHeadersList().get(3).getKey()).isEqualTo("x-custom"); + assertThat(headerMap.getHeadersList().get(3).getRawValue().toStringUtf8()).isEqualTo("value2"); + } + + private ExtAuthzConfig buildExtAuthzConfig() { + return buildExtAuthzConfig(ImmutableList.of(), ImmutableList.of(), true); } - private ExtAuthzConfig buildExtAuthzConfig(ListStringMatcher allowed, - ListStringMatcher disallowed, boolean includePeerCertificate) throws ExtAuthzParseException { - Any googleDefaultChannelCreds = Any.pack(GoogleDefaultCredentials.newBuilder().build()); - Any fakeAccessTokenCreds = - Any.pack(AccessTokenCredentials.newBuilder().setToken("fake-token").build()); - ExtAuthz.Builder builder = ExtAuthz.newBuilder() - .setGrpcService(GrpcService.newBuilder() - .setGoogleGrpc(GrpcService.GoogleGrpc.newBuilder() - .setTargetUri("test-cluster").addChannelCredentialsPlugin(googleDefaultChannelCreds) - .addCallCredentialsPlugin(fakeAccessTokenCreds).build()) - .build()) - .setIncludePeerCertificate(includePeerCertificate).setAllowedHeaders(allowed) - .setDisallowedHeaders(disallowed); - BootstrapInfo bootstrapInfo = - CommonBootstrapperTestUtils.buildBootStrap( - Collections.singletonList("test-cluster")); - return ExtAuthzConfigParser.parse( - builder.build(), bootstrapInfo, bootstrapInfo.servers().get(0)); + private ExtAuthzConfig buildExtAuthzConfig( + ImmutableList allowed, + ImmutableList disallowed, + boolean includePeerCertificate) { + GrpcServiceConfig.GoogleGrpcConfig googleGrpc = GrpcServiceConfig.GoogleGrpcConfig.builder() + .target("test-cluster") + .configuredChannelCredentials(io.grpc.xds.client.ConfiguredChannelCredentials.create( + mock(io.grpc.ChannelCredentials.class), + mock(io.grpc.xds.client.ConfiguredChannelCredentials.ChannelCredsConfig.class))) + .build(); + + GrpcServiceConfig dummyServiceConfig = GrpcServiceConfig.builder() + .googleGrpc(googleGrpc) + .initialMetadata(ImmutableList.of()) + .build(); + + return ExtAuthzConfig.builder() + .grpcService(dummyServiceConfig) + .includePeerCertificate(includePeerCertificate) + .allowedHeaders(allowed) + .disallowedHeaders(disallowed) + .failureModeAllow(true) + .failureModeAllowHeaderAdd(false) + .denyAtDisable(false) + .filterEnabled(Matchers.FractionMatcher.create(100, 100)) + .statusOnError(io.grpc.Status.INTERNAL) + .build(); } }