diff --git a/api/src/org/labkey/api/security/Encryption.java b/api/src/org/labkey/api/security/Encryption.java index 3f6187bd8f7..902d2665413 100644 --- a/api/src/org/labkey/api/security/Encryption.java +++ b/api/src/org/labkey/api/security/Encryption.java @@ -41,6 +41,7 @@ import org.labkey.api.util.HelpTopic; import org.labkey.api.util.HtmlStringBuilder; import org.labkey.api.util.JobRunner; +import org.labkey.api.util.QuietCloser; import org.labkey.api.util.StringUtilsLabKey; import org.labkey.api.util.logging.LogHelper; import org.labkey.api.view.ViewContext; @@ -680,103 +681,118 @@ else if (!cipher.equals(AESConfig.current.getCipherName())) } - private static final EncryptionMigrationHandler TEST_HANDLER = (oldPassPhrase, keySource, oldConfig) -> {}; + private static final EncryptionMigrationHandler TEST_HANDLER = (_, _, _) -> {}; + + private static QuietCloser createErrorCountResetter() + { + int initial = DECRYPTION_EXCEPTIONS.get(); + return () -> DECRYPTION_EXCEPTIONS.set(initial); + } public static class TestCase extends Assert { @Test public void testEncryptionAlgorithms() throws NoSuchAlgorithmException { - String passPhrase = "Here's my super secret pass phrase"; + try (var _ = createErrorCountResetter()) + { + String passPhrase = "Here's my super secret pass phrase"; - Algorithm aesPassPhrase = new AES(passPhrase, 128, "test pass phrase"); + Algorithm aesPassPhrase = new AES(passPhrase, 128, "test pass phrase"); - test(aesPassPhrase); + test(aesPassPhrase); - if (isEncryptionPassPhraseSpecified()) - { - Algorithm aes = getAES128(TEST_HANDLER); - test(aes); + if (isEncryptionPassPhraseSpecified()) + { + Algorithm aes = getAES128(TEST_HANDLER); + test(aes); - // Test that static factory method matches this configuration - Algorithm aes2 = new AES(getEncryptionPassPhrase(), 128, "test pass phrase"); + // Test that static factory method matches this configuration + Algorithm aes2 = new AES(getEncryptionPassPhrase(), 128, "test pass phrase"); - test(aes, aes2); - test(aes2, aes); - } + test(aes, aes2); + test(aes2, aes); + } - if (Cipher.getMaxAllowedKeyLength("AES") >= 256) - { - test(new AES(passPhrase, 256, "test pass phrase")); + if (Cipher.getMaxAllowedKeyLength("AES") >= 256) + { + test(new AES(passPhrase, 256, "test pass phrase")); + } } } @Test(expected = DecryptionException.class) public void testBadKeyException() { - String textToEncrypt = "this is some text I want to encrypt"; - String passPhrase = "Here's my super secret pass phrase"; - String wrongPassPhrase = passPhrase + " not"; - - // Our AES implementation can usually detect a bad pass phrase (based on padding anomalies), but this is not 100% guaranteed. - // Give the test three tries... by my calculations, this will fail once in every 2.6 million runs, which we can live with. - for (int i = 0; i < 3; i++) + try (var _ = createErrorCountResetter()) { - Algorithm aesPassPhrase = new AES(passPhrase, 128, "test pass phrase"); - byte[] encrypted = aesPassPhrase.encrypt(textToEncrypt); + String textToEncrypt = "this is some text I want to encrypt"; + String passPhrase = "Here's my super secret pass phrase"; + String wrongPassPhrase = passPhrase + " not"; + + // Our AES implementation can usually detect a bad pass phrase (based on padding anomalies), but this is not 100% guaranteed. + // Give the test three tries... by my calculations, this will fail once in every 2.6 million runs, which we can live with. + for (int i = 0; i < 3; i++) + { + Algorithm aesPassPhrase = new AES(passPhrase, 128, "test pass phrase"); + byte[] encrypted = aesPassPhrase.encrypt(textToEncrypt); - Algorithm aesWrongPassPhrase = new AES(wrongPassPhrase, 128, "test pass phrase"); - aesWrongPassPhrase.decrypt(encrypted); + Algorithm aesWrongPassPhrase = new AES(wrongPassPhrase, 128, "test pass phrase"); + aesWrongPassPhrase.decrypt(encrypted); + } } } @Test public void testMigrationFallback() { - String text = "test plaintext"; - AES oldAlgorithm = new AES("old pass phrase", 128, "old algorithm"); - byte[] oldEncrypted = oldAlgorithm.encrypt(text); + try (var _ = createErrorCountResetter()) + { + String text = "test plaintext"; + AES oldAlgorithm = new AES("old pass phrase", 128, "old algorithm"); + byte[] oldEncrypted = oldAlgorithm.encrypt(text); - // Primary (production) instance: different pass phrase, keySource == ENCRYPTION_KEY_CHANGED - AES primary = new AES("primary pass phrase", 128, ENCRYPTION_KEY_CHANGED); + // Primary (production) instance: different pass phrase, keySource == ENCRYPTION_KEY_CHANGED + AES primary = new AES("primary pass phrase", 128, ENCRYPTION_KEY_CHANGED); - // Case 1: no fallback — primary fails and counter increments - int counterBefore = DECRYPTION_EXCEPTIONS.get(); - try - { - primary.decrypt(oldEncrypted); - fail("Expected DecryptionException"); - } - catch (DecryptionException ignored) {} - assertEquals(counterBefore + 1, DECRYPTION_EXCEPTIONS.get()); + // Case 1: no fallback — primary fails and counter increments + int counterBefore = DECRYPTION_EXCEPTIONS.get(); + try + { + primary.decrypt(oldEncrypted); + fail("Expected DecryptionException"); + } + catch (DecryptionException _) {} + assertEquals(counterBefore + 1, DECRYPTION_EXCEPTIONS.get()); - // Case 2: correct fallback — transparent success, counter unchanged - _migrationFallback = oldAlgorithm; - try - { - int counterBeforeFallback = DECRYPTION_EXCEPTIONS.get(); - assertEquals(text, primary.decrypt(oldEncrypted)); - assertEquals("Counter must not increment when fallback succeeds", counterBeforeFallback, DECRYPTION_EXCEPTIONS.get()); - } - finally - { - _migrationFallback = null; - } + // Case 2: correct fallback — transparent success, counter unchanged + _migrationFallback = oldAlgorithm; + try + { + int counterBeforeFallback = DECRYPTION_EXCEPTIONS.get(); + assertEquals(text, primary.decrypt(oldEncrypted)); + assertEquals("Counter must not increment when fallback succeeds", counterBeforeFallback, DECRYPTION_EXCEPTIONS.get()); + } + finally + { + _migrationFallback = null; + } - // Case 3: wrong fallback — both algorithms fail, counter increments - _migrationFallback = new AES("wrong pass phrase", 128, "wrong fallback"); - int counterBeforeWrongFallback = DECRYPTION_EXCEPTIONS.get(); - try - { - primary.decrypt(oldEncrypted); - fail("Expected DecryptionException"); - } - catch (DecryptionException ignored) {} - finally - { - _migrationFallback = null; + // Case 3: wrong fallback — both algorithms fail, counter increments + _migrationFallback = new AES("wrong pass phrase", 128, "wrong fallback"); + int counterBeforeWrongFallback = DECRYPTION_EXCEPTIONS.get(); + try + { + primary.decrypt(oldEncrypted); + fail("Expected DecryptionException"); + } + catch (DecryptionException _) {} + finally + { + _migrationFallback = null; + } + assertEquals(counterBeforeWrongFallback + 1, DECRYPTION_EXCEPTIONS.get()); } - assertEquals(counterBeforeWrongFallback + 1, DECRYPTION_EXCEPTIONS.get()); } private void test(Algorithm algorithm)