Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -147,9 +147,10 @@ class DefaultProviderAttrs {
+ " # PQC key factories\n"
+ " # =======================================================================\n"
+ " #\n"
+ "Service.KeyFactory.ML-KEM = com.ibm.crypto.plus.provider.PQCKeyFactory$MLKEM\n"
+ "KeyFactory.ML-KEM-512.alias.add = ML_KEM_512, MLKEM512, OID.2.16.840.1.101.3.4.4.1, 2.16.840.1.101.3.4.4.1\n"
+ "Service.KeyFactory.ML-KEM-512 = com.ibm.crypto.plus.provider.PQCKeyFactory$MLKEM512\n"
+ "KeyFactory.ML-KEM-768.alias.add = ML-KEM, ML_KEM_768, MLKEM768, OID.2.16.840.1.101.3.4.4.2, 2.16.840.1.101.3.4.4.2\n"
+ "KeyFactory.ML-KEM-768.alias.add = ML_KEM_768, MLKEM768, OID.2.16.840.1.101.3.4.4.2, 2.16.840.1.101.3.4.4.2\n"
+ "Service.KeyFactory.ML-KEM-768 = com.ibm.crypto.plus.provider.PQCKeyFactory$MLKEM768\n"
+ "KeyFactory.ML-KEM-1024.alias.add = ML_KEM_1024, MLKEM1024, OID.2.16.840.1.101.3.4.4.3, 2.16.840.1.101.3.4.4.3\n"
+ "Service.KeyFactory.ML-KEM-1024 = com.ibm.crypto.plus.provider.PQCKeyFactory$MLKEM1024\n"
Expand Down Expand Up @@ -315,9 +316,10 @@ class DefaultProviderAttrs {
+ " # PQC key encapsulation mechanisms\n"
+ " # =======================================================================\n"
+ " #\n"
+ "Service.KEM.ML-KEM = com.ibm.crypto.plus.provider.MLKEMImpl$MLKEM\n"
+ "KEM.ML-KEM-512.alias.add = ML_KEM_512, MLKEM512, OID.2.16.840.1.101.3.4.4.1, 2.16.840.1.101.3.4.4.1\n"
+ "Service.KEM.ML-KEM-512 = com.ibm.crypto.plus.provider.MLKEMImpl$MLKEM512\n"
+ "KEM.ML-KEM-768.alias.add = ML-KEM, ML_KEM_768, MLKEM768, OID.2.16.840.1.101.3.4.4.2, 2.16.840.1.101.3.4.4.2\n"
+ "KEM.ML-KEM-768.alias.add = ML_KEM_768, MLKEM768, OID.2.16.840.1.101.3.4.4.2, 2.16.840.1.101.3.4.4.2\n"
+ "Service.KEM.ML-KEM-768 = com.ibm.crypto.plus.provider.MLKEMImpl$MLKEM768\n"

+ "KEM.ML-KEM-1024.alias.add = ML_KEM_1024, MLKEM1024, OID.2.16.840.1.101.3.4.4.3, 2.16.840.1.101.3.4.4.3\n"
Expand Down
57 changes: 48 additions & 9 deletions src/main/java/com/ibm/crypto/plus/provider/MLKEMImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -38,18 +38,22 @@ public MLKEMImpl(OpenJCEPlusProvider provider, String alg) {
this.alg = alg;
}

private int getEncapsulationLength() {
private int getEncapsulationLength(String algorithm) {
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since, this is no longer using the Algorithm for this Implementation instance. Couldn't this cause issues with callers mix matching things easier?

int size = 0;

switch (this.alg) {
switch (algorithm) {
case "ML-KEM-512":
size = 768;
break;
case "ML-KEM-768":
size = 1088;
break;
default:
case "ML-KEM-1024":
size = 1568;
break;
default:
// If algorithm is generic "ML-KEM", default to ML-KEM-768
size = 1088;
}
return size;
}
Expand All @@ -72,8 +76,15 @@ public KEMSpi.EncapsulatorSpi engineNewEncapsulator(PublicKey publicKey,

if (!(pubKey instanceof PQCPublicKey)) {
// Try and convert this key to a usage PQCPublicKey
// First verify it's an ML-KEM key
String keyAlgorithm = publicKey.getAlgorithm();
if (keyAlgorithm == null || !keyAlgorithm.startsWith("ML-KEM")) {
throw new InvalidKeyException("unsupported key");
}

// Use the key's actual algorithm, not the generic "ML-KEM"
try {
KeyFactory kf = KeyFactory.getInstance(this.alg, this.provider.getName());
KeyFactory kf = KeyFactory.getInstance(keyAlgorithm, this.provider.getName());
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can cause a algorithm miss match if they create this instance as ML-KEM-786, but pass in a ML-KEM-512 key that should not be allowed.

EncodedKeySpec publicKeySpec = new X509EncodedKeySpec(publicKey.getEncoded());
pubKey = kf.generatePublic(publicKeySpec);

Expand Down Expand Up @@ -105,7 +116,9 @@ class MLKEMEncapsulator implements KEMSpi.EncapsulatorSpi {

@Override
public KEM.Encapsulated engineEncapsulate(int from, int to, String algorithm) {
int encapLen = getEncapsulationLength();
// Get the actual algorithm from the public key
String keyAlgorithm = publicKey.getAlgorithm();
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same issue here.

int encapLen = getEncapsulationLength(keyAlgorithm);
byte[] encapsulation = new byte[encapLen];
byte[] secret = new byte[SECRETSIZE];

Expand All @@ -130,7 +143,8 @@ public KEM.Encapsulated engineEncapsulate(int from, int to, String algorithm) {

@Override
public int engineEncapsulationSize() {
return getEncapsulationLength();
String keyAlgorithm = publicKey.getAlgorithm();
return getEncapsulationLength(keyAlgorithm);
}

@Override
Expand All @@ -155,9 +169,16 @@ public KEMSpi.DecapsulatorSpi engineNewDecapsulator(PrivateKey privateKey,

if (!(privKey instanceof PQCPrivateKey)) {
// Try and convert this key to a usage PQCPrivateKey
// First verify it's an ML-KEM key
String keyAlgorithm = privateKey.getAlgorithm();
if (keyAlgorithm == null || !keyAlgorithm.startsWith("ML-KEM")) {
throw new InvalidKeyException("unsupported key");
}

// Use the key's actual algorithm, not the generic "ML-KEM"
byte[] encoding = null;
try {
KeyFactory kf = KeyFactory.getInstance(this.alg, this.provider.getName());
KeyFactory kf = KeyFactory.getInstance(keyAlgorithm, this.provider.getName());
encoding = privateKey.getEncoded();
PKCS8EncodedKeySpec privateKeySpec = new PKCS8EncodedKeySpec(encoding);
privKey = kf.generatePrivate(privateKeySpec);
Expand Down Expand Up @@ -197,6 +218,17 @@ public SecretKey engineDecapsulate(byte[] cipherText, int from, int to, String a
if (algorithm == null || cipherText == null) {
throw new NullPointerException();
}

// Validate encapsulation length matches the key's algorithm
String keyAlgorithm = privateKey.getAlgorithm();
int expectedEncapLen = getEncapsulationLength(keyAlgorithm);
if (cipherText.length != expectedEncapLen) {
throw new DecapsulateException(
"Invalid key encapsulation message length: expected " +
expectedEncapLen + " bytes for " + keyAlgorithm +
", but got " + cipherText.length + " bytes");
}

try {
secret = OJPKEM.KEM_decapsulate(((PQCPrivateKey) this.privateKey).getPQCKey().getPKeyId(),
cipherText, provider);
Expand All @@ -210,8 +242,8 @@ public SecretKey engineDecapsulate(byte[] cipherText, int from, int to, String a

@Override
public int engineEncapsulationSize() {

return getEncapsulationLength();
String keyAlgorithm = privateKey.getAlgorithm();
return getEncapsulationLength(keyAlgorithm);
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same type of issue here too. The algorithm the should be based on the one requested when this was created.

}

@Override
Expand All @@ -222,6 +254,13 @@ public int engineSecretSize() {

}

public static final class MLKEM extends MLKEMImpl {

public MLKEM(OpenJCEPlusProvider provider) {
super(provider, "ML-KEM");
}
}

public static final class MLKEM512 extends MLKEMImpl {

public MLKEM512(OpenJCEPlusProvider provider) {
Expand Down
23 changes: 21 additions & 2 deletions src/main/java/com/ibm/crypto/plus/provider/PQCKeyFactory.java
Original file line number Diff line number Diff line change
Expand Up @@ -188,8 +188,20 @@ private void checkKeyAlgo(Key key) throws InvalidKeyException {
String keyAlg = key.getAlgorithm();
if (keyAlg == null) {
throw new InvalidKeyException("Algorithm associate with key is null.");
} else if (!(key.getAlgorithm().equalsIgnoreCase(this.algName) ||
(PQCKnownOIDs.findMatch(key.getAlgorithm()).stdName().equalsIgnoreCase(this.algName)))) {
}

// Check if algorithms match exactly or via OID lookup
boolean matches = key.getAlgorithm().equalsIgnoreCase(this.algName) ||
(PQCKnownOIDs.findMatch(key.getAlgorithm()).stdName().equalsIgnoreCase(this.algName));

// Special case for generic ML-KEM: Allow any ML-KEM parameter set variant
// (ML-KEM-512, ML-KEM-768, ML-KEM-1024) when using the generic "ML-KEM" KeyFactory.
// This enables interoperability with KEM.getInstance("ML-KEM", ...).
if (!matches && "ML-KEM".equals(this.algName) && keyAlg.startsWith("ML-KEM")) {
matches = true;
}

if (!matches) {
throw new InvalidKeyException("Expected a " + this.algName + " key, but got " + keyAlg);
}

Expand Down Expand Up @@ -217,6 +229,13 @@ private boolean checkEncoded(byte[] key, boolean pub) {
}
}

public static final class MLKEM extends PQCKeyFactory {

public MLKEM(OpenJCEPlusProvider provider) {
super(provider, "ML-KEM");
}
}

public static final class MLKEM512 extends PQCKeyFactory {

public MLKEM512(OpenJCEPlusProvider provider) {
Expand Down
8 changes: 6 additions & 2 deletions src/test/ProviderDefAttrs.config
Original file line number Diff line number Diff line change
Expand Up @@ -228,10 +228,12 @@ Service.KeyFactory.RSAPSS = com.ibm.crypto.plus.provider.RSAKeyFactory$PSS
# PQC key factories
# =======================================================================
#
Service.KeyFactory.ML-KEM = com.ibm.crypto.plus.provider.PQCKeyFactory$MLKEM

KeyFactory.ML-KEM-512.alias.add = ML_KEM_512, MLKEM512, OID.2.16.840.1.101.3.4.4.1, 2.16.840.1.101.3.4.4.1
Service.KeyFactory.ML-KEM-512 = com.ibm.crypto.plus.provider.PQCKeyFactory$MLKEM512

KeyFactory.ML-KEM-768.alias.add = ML-KEM, ML_KEM_768, MLKEM768, OID.2.16.840.1.101.3.4.4.2, 2.16.840.1.101.3.4.4.2
KeyFactory.ML-KEM-768.alias.add = ML_KEM_768, MLKEM768, OID.2.16.840.1.101.3.4.4.2, 2.16.840.1.101.3.4.4.2
Service.KeyFactory.ML-KEM-768 = com.ibm.crypto.plus.provider.PQCKeyFactory$MLKEM768

KeyFactory.ML-KEM-1024.alias.add = ML_KEM_1024, MLKEM1024, OID.2.16.840.1.101.3.4.4.3, 2.16.840.1.101.3.4.4.3
Expand Down Expand Up @@ -464,10 +466,12 @@ Service.MessageDigest.SHA3-512 = com.ibm.crypto.plus.provider.MessageDigest$SHA3
# PQC key encapsulation mechanisms
# =======================================================================
#
Service.KEM.ML-KEM = com.ibm.crypto.plus.provider.MLKEMImpl$MLKEM

KEM.ML-KEM-512.alias.add = ML_KEM_512, MLKEM512, OID.2.16.840.1.101.3.4.4.1, 2.16.840.1.101.3.4.4.1
Service.KEM.ML-KEM-512 = com.ibm.crypto.plus.provider.MLKEMImpl$MLKEM512

KEM.ML-KEM-768.alias.add = ML-KEM, ML_KEM_768, MLKEM768, OID.2.16.840.1.101.3.4.4.2, 2.16.840.1.101.3.4.4.2
KEM.ML-KEM-768.alias.add = ML_KEM_768, MLKEM768, OID.2.16.840.1.101.3.4.4.2, 2.16.840.1.101.3.4.4.2
Service.KEM.ML-KEM-768 = com.ibm.crypto.plus.provider.MLKEMImpl$MLKEM768

KEM.ML-KEM-1024.alias.add = ML_KEM_1024, MLKEM1024, OID.2.16.840.1.101.3.4.4.3, 2.16.840.1.101.3.4.4.3
Expand Down
51 changes: 51 additions & 0 deletions src/test/java/ibm/jceplus/junit/base/BaseTestKEM.java
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,57 @@ public void testKEMKeys(String Algorithm) throws Exception {
}
}

/**
* Tests that decapsulation fails with a DecapsulateException when attempting to decapsulate
* an encapsulation message that was created with a different ML-KEM algorithm variant.
*
* <p>This test verifies that the KEM implementation properly validates the encapsulation
* message length during decapsulation. Each ML-KEM variant (ML-KEM-512, ML-KEM-768, ML-KEM-1024)
* produces encapsulation messages of different lengths. When a decapsulator receives an
* encapsulation message with an incorrect length (from a different variant), it should
* reject it with a DecapsulateException containing an appropriate error message.
*
* <p>Test procedure:
* <ol>
* <li>Generate a key pair using the first algorithm (keyAlgorithm)</li>
* <li>Generate a different key pair using a second algorithm (wrongAlgorithm)</li>
* <li>Create an encapsulation using the second key pair (wrong length for first algorithm)</li>
* <li>Attempt to decapsulate using the first key pair's private key</li>
* <li>Verify that a DecapsulateException is thrown with the expected error message</li>
* </ol>
*
* @param keyAlgorithm the ML-KEM algorithm variant to use for the decapsulation key pair
* @param wrongAlgorithm the ML-KEM algorithm variant to use for creating the encapsulation
* (produces wrong length for keyAlgorithm)
* @throws Exception if an unexpected error occurs during test execution
*/
@ParameterizedTest
@CsvSource({"ML-KEM-512,ML-KEM-768", "ML-KEM-768,ML-KEM-1024", "ML-KEM-1024,ML-KEM-512"})
public void testKEMInvalidEncapsulationLength(String keyAlgorithm, String wrongAlgorithm) throws Exception {
// Generate a key pair with one algorithm
KeyPair keyPair = generateKeyPair(keyAlgorithm);

// Create encapsulation with a different algorithm (wrong length)
KEM kemWrong = KEM.getInstance(wrongAlgorithm, getProviderName());
KeyPair wrongKeyPair = generateKeyPair(wrongAlgorithm);
KEM.Encapsulator encapsulator = kemWrong.newEncapsulator(wrongKeyPair.getPublic());
KEM.Encapsulated encapsulated = encapsulator.encapsulate(0, 32, "AES");

// Try to decapsulate with the original key (wrong length)
KEM kem = KEM.getInstance(keyAlgorithm, getProviderName());
KEM.Decapsulator decapsulator = kem.newDecapsulator(keyPair.getPrivate());

try {
decapsulator.decapsulate(encapsulated.encapsulation(), 0, 32, "AES");
fail("testKEMInvalidEncapsulationLength failed - Invalid encapsulation length did not cause a DecapsulateException for " + keyAlgorithm + " with " + wrongAlgorithm + " encapsulation");
} catch (javax.crypto.DecapsulateException de) {
assertTrue(de.getMessage().contains("Invalid key encapsulation message length"),
"Expected error message about invalid encapsulation length, but got: " + de.getMessage());
assertTrue(de.getMessage().contains(keyAlgorithm),
"Expected error message to mention key algorithm " + keyAlgorithm + ", but got: " + de.getMessage());
}
}

protected KeyPair generateKeyPair(String Algorithm) throws Exception {
pqcKeyPairGen = KeyPairGenerator.getInstance(Algorithm, getProviderName());

Expand Down
Loading
Loading