diff --git a/src/CertManager.sol b/src/CertManager.sol index d1d9515..415125d 100644 --- a/src/CertManager.sol +++ b/src/CertManager.sol @@ -17,6 +17,8 @@ contract CertManager is ICertManager { using LibAsn1Ptr for Asn1Ptr; using LibBytes for bytes; + error InvalidAsn1Tag(); + event CertVerified(bytes32 indexed certHash); event CertRevoked(bytes32 indexed certHash); event CertUnrevoked(bytes32 indexed certHash); @@ -163,7 +165,10 @@ contract CertManager is ICertManager { /// {revokeCerts}; the same value is recorded on-chain when the cert is first verified, so a /// revocation applies to every byte-encoding of that certificate. Reverts on malformed DER. function computeCertId(bytes memory cert) external pure returns (bytes32) { - Asn1Ptr tbsCertPtr = cert.firstChildOf(cert.root()); + Asn1Ptr root = cert.root(); + _requireAsn1Tag(cert, root, 0x30); + Asn1Ptr tbsCertPtr = cert.firstChildOf(root); + _requireAsn1Tag(cert, tbsCertPtr, 0x30); return _certIdentity(cert, tbsCertPtr); } @@ -290,11 +295,17 @@ contract CertManager is ICertManager { } Asn1Ptr root = certificate.root(); + _requireAsn1Tag(certificate, root, 0x30); require(root.totalLength() == certificate.length, "invalid cert length"); Asn1Ptr tbsCertPtr = certificate.firstChildOf(root); + _requireAsn1Tag(certificate, tbsCertPtr, 0x30); return certificate.keccak(tbsCertPtr.header(), tbsCertPtr.totalLength()); } + function _requireAsn1Tag(bytes memory der, Asn1Ptr ptr, bytes1 tag) internal pure { + if (der[ptr.header()] != tag) revert InvalidAsn1Tag(); + } + function _isPinnedRootAlias(bytes32 certHash, VerifiedCert memory cert) internal pure returns (bool) { return certHash != ROOT_CA_CERT_HASH && cert.ca && cert.subjectHash == ROOT_CA_CERT_SUBJECT_HASH && cert.pubKey.length == ROOT_CA_CERT_PUB_KEY.length @@ -349,10 +360,13 @@ contract CertManager is ICertManager { view returns (uint64 notAfter, int64 maxPathLen, bytes32 issuerHash, bytes32 subjectHash, bytes memory pubKey) { + _requireAsn1Tag(certificate, ptr, 0x30); Asn1Ptr versionPtr = certificate.firstChildOf(ptr); + _requireAsn1Tag(certificate, versionPtr, 0xa0); Asn1Ptr vPtr = certificate.firstChildOf(versionPtr); Asn1Ptr serialPtr = certificate.nextSiblingOf(versionPtr); Asn1Ptr sigAlgoPtr = certificate.nextSiblingOf(serialPtr); + _requireAsn1Tag(certificate, sigAlgoPtr, 0x30); require(certificate.keccak(sigAlgoPtr.content(), sigAlgoPtr.length()) == CERT_ALGO_OID, "invalid cert sig algo"); uint256 version = certificate.uintAt(vPtr); @@ -368,11 +382,15 @@ contract CertManager is ICertManager { returns (uint64 notAfter, int64 maxPathLen, bytes32 issuerHash, bytes32 subjectHash, bytes memory pubKey) { Asn1Ptr issuerPtr = certificate.nextSiblingOf(sigAlgoPtr); + _requireAsn1Tag(certificate, issuerPtr, 0x30); issuerHash = certificate.keccak(issuerPtr.content(), issuerPtr.length()); Asn1Ptr validityPtr = certificate.nextSiblingOf(issuerPtr); + _requireAsn1Tag(certificate, validityPtr, 0x30); Asn1Ptr subjectPtr = certificate.nextSiblingOf(validityPtr); + _requireAsn1Tag(certificate, subjectPtr, 0x30); subjectHash = certificate.keccak(subjectPtr.content(), subjectPtr.length()); Asn1Ptr subjectPublicKeyInfoPtr = certificate.nextSiblingOf(subjectPtr); + _requireAsn1Tag(certificate, subjectPublicKeyInfoPtr, 0x30); Asn1Ptr extensionsPtr = certificate.nextSiblingOf(subjectPublicKeyInfoPtr); if (certificate[extensionsPtr.header()] == 0x81) { @@ -395,6 +413,7 @@ contract CertManager is ICertManager { returns (bytes memory subjectPubKey) { Asn1Ptr pubKeyAlgoPtr = certificate.firstChildOf(subjectPublicKeyInfoPtr); + _requireAsn1Tag(certificate, pubKeyAlgoPtr, 0x30); Asn1Ptr pubKeyAlgoIdPtr = certificate.firstChildOf(pubKeyAlgoPtr); Asn1Ptr algoParamsPtr = certificate.nextSiblingOf(pubKeyAlgoIdPtr); Asn1Ptr subjectPublicKeyPtr = certificate.nextSiblingOf(pubKeyAlgoPtr); @@ -435,6 +454,7 @@ contract CertManager is ICertManager { { require(certificate[extensionsPtr.header()] == 0xa3, "invalid extensions"); extensionsPtr = certificate.firstChildOf(extensionsPtr); + _requireAsn1Tag(certificate, extensionsPtr, 0x30); Asn1Ptr extensionPtr = certificate.firstChildOf(extensionsPtr); uint256 end = extensionsPtr.content() + extensionsPtr.length(); bool basicConstraintsFound = false; @@ -442,6 +462,7 @@ contract CertManager is ICertManager { maxPathLen = -1; while (true) { + _requireAsn1Tag(certificate, extensionPtr, 0x30); Asn1Ptr oidPtr = certificate.firstChildOf(extensionPtr); bytes32 oid = certificate.keccak(oidPtr.content(), oidPtr.length()); @@ -544,6 +565,7 @@ contract CertManager is ICertManager { bytes memory signatureHints ) internal view { Asn1Ptr sigAlgoPtr = certificate.nextSiblingOf(ptr); + _requireAsn1Tag(certificate, sigAlgoPtr, 0x30); require(certificate.keccak(sigAlgoPtr.content(), sigAlgoPtr.length()) == CERT_ALGO_OID, "invalid cert sig algo"); Asn1Ptr sigPtr = certificate.nextSiblingOf(sigAlgoPtr); require(sigPtr.header() + sigPtr.totalLength() == certificate.length, "trailing cert fields"); diff --git a/test/CertManager.t.sol b/test/CertManager.t.sol index cc40cfd..fd55a60 100644 --- a/test/CertManager.t.sol +++ b/test/CertManager.t.sol @@ -166,6 +166,35 @@ contract CertManagerTest is Test { cm.verifyCACertWithHints(rootTwin, rootHash, hints); } + function test_VerifyCACertWithHints_RejectsOuterTagSubstitution() public { + vm.warp(1775145600); + CertManager cm = new CertManager(new P384Verifier()); + + bytes32 rootHash = keccak256(CB0); + bytes memory mutated = bytes.concat(CB1); + mutated[0] = 0x31; // constructed SET with the same children is not an X.509 Certificate SEQUENCE. + + vm.expectRevert(CertManager.InvalidAsn1Tag.selector); + cm.verifyCACertWithHints(mutated, rootHash, ""); + } + + function test_VerifyCACertWithHints_RejectsTbsAlgorithmTagSubstitution() public { + vm.warp(1775145600); + CertManager cm = new CertManager(new P384Verifier()); + + bytes32 rootHash = keccak256(CB0); + bytes memory mutated = bytes.concat(CB1); + Asn1Ptr root = mutated.root(); + Asn1Ptr tbsPtr = mutated.firstChildOf(root); + Asn1Ptr versionPtr = mutated.firstChildOf(tbsPtr); + Asn1Ptr serialPtr = mutated.nextSiblingOf(versionPtr); + Asn1Ptr sigAlgoPtr = mutated.nextSiblingOf(serialPtr); + mutated[sigAlgoPtr.header()] = 0x31; // constructed, but not AlgorithmIdentifier SEQUENCE. + + vm.expectRevert(CertManager.InvalidAsn1Tag.selector); + cm.verifyCACertWithHints(mutated, rootHash, ""); + } + function _verifyCA(CertManager cm, P384HintCollector collector, bytes memory cert, bytes32 parentHash) internal returns (bytes32)