Skip to content

Commit

Permalink
Merge pull request #32865 from vespa-engine/mpolden/expect-pem-encoding
Browse files Browse the repository at this point in the history
Expect PEM-encoded sealing key
  • Loading branch information
tokle authored Nov 15, 2024
2 parents 3eb8170 + f7c74e5 commit 251cc88
Show file tree
Hide file tree
Showing 7 changed files with 94 additions and 71 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import com.yahoo.config.provision.HostName;
import com.yahoo.config.provision.NodeType;
import com.yahoo.config.provision.SnapshotId;
import com.yahoo.security.KeyAlgorithm;
import com.yahoo.security.KeyId;
import com.yahoo.security.KeyUtils;
import com.yahoo.security.SealedSharedKey;
Expand All @@ -28,9 +29,8 @@
import com.yahoo.vespa.hosted.provision.provisioning.SnapshotStore;

import java.security.KeyPair;
import java.security.PrivateKey;
import java.security.PublicKey;
import java.security.interfaces.XECPrivateKey;
import java.security.interfaces.XECPublicKey;
import java.time.Instant;
import java.util.ArrayList;
import java.util.List;
Expand Down Expand Up @@ -196,8 +196,12 @@ private VersionedKeyPair sealingKeyPair(SecretVersionId version) {
}
Key key = Key.fromString(sealingPrivateKeySecretName.get());
Secret sealingPrivateKey = version == null ? secretStore.getSecret(key) : secretStore.getSecret(key, version);
XECPrivateKey privateKey = KeyUtils.fromBase64EncodedX25519PrivateKey(sealingPrivateKey.secretValue().value());
XECPublicKey publicKey = KeyUtils.extractX25519PublicKey(privateKey);
PrivateKey privateKey = KeyUtils.fromPemEncodedPrivateKey(sealingPrivateKey.secretValue().value());
PublicKey publicKey = KeyUtils.extractPublicKey(privateKey);
if (KeyAlgorithm.from(privateKey.getAlgorithm()) != KeyAlgorithm.XDH) {
throw new IllegalArgumentException("Expected sealing key to use algorithm " + KeyAlgorithm.XDH +
", but got " + privateKey.getAlgorithm());
}
return new VersionedKeyPair(new KeyPair(publicKey, privateKey), sealingPrivateKey.version());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import com.yahoo.config.provision.ClusterSpec;
import com.yahoo.config.provision.NodeResources;
import com.yahoo.config.provision.NodeType;
import com.yahoo.security.KeyFormat;
import com.yahoo.security.KeyUtils;
import com.yahoo.security.SealedSharedKey;
import com.yahoo.vespa.hosted.provision.Node;
Expand All @@ -17,7 +18,6 @@

import java.security.KeyPair;
import java.security.PublicKey;
import java.security.interfaces.XECPrivateKey;
import java.util.List;

import static org.junit.jupiter.api.Assertions.assertEquals;
Expand Down Expand Up @@ -54,8 +54,7 @@ void snapshot() {
// Sealing key can be rotated independently of existing snapshots
KeyPair keyPair = KeyUtils.generateX25519KeyPair();
tester.secretStore().add(new Secret(Key.fromString("snapshot/sealingPrivateKey"),
KeyUtils.toBase64EncodedX25519PrivateKey((XECPrivateKey) keyPair.getPrivate())
.getBytes(),
KeyUtils.toPem(keyPair.getPrivate(), KeyFormat.PKCS8).getBytes(),
SecretVersionId.of("2")));
assertEquals(SecretVersionId.of("1"), snapshots.require(snapshot0.id(), node0).key().sealingKeyVersion());
assertNotEquals(snapshot0.key().sharedKey(), snapshots.keyOf(snapshot0.id(), node0, receiverPublicKey),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import com.yahoo.config.provision.Zone;
import com.yahoo.config.provisioning.FlavorsConfig;
import com.yahoo.jdisc.test.MockMetric;
import com.yahoo.security.KeyFormat;
import com.yahoo.security.KeyUtils;
import com.yahoo.test.ManualClock;
import com.yahoo.transaction.NestedTransaction;
Expand Down Expand Up @@ -69,7 +70,6 @@
import com.yahoo.vespa.service.duper.TenantHostApplication;

import java.security.KeyPair;
import java.security.interfaces.XECPrivateKey;
import java.time.temporal.TemporalAmount;
import java.util.ArrayList;
import java.util.Collection;
Expand Down Expand Up @@ -772,8 +772,7 @@ private SecretStoreMock defaultSecretStore() {
SecretStoreMock secretStore = new SecretStoreMock();
KeyPair keyPair = KeyUtils.generateX25519KeyPair();
secretStore.add(new Secret(Key.fromString("snapshot/sealingPrivateKey"),
KeyUtils.toBase64EncodedX25519PrivateKey((XECPrivateKey) keyPair.getPrivate())
.getBytes(),
KeyUtils.toPem(keyPair.getPrivate(), KeyFormat.PKCS8).getBytes(),
SecretVersionId.of("1")));
return secretStore;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import com.yahoo.config.provision.NodeType;
import com.yahoo.config.provision.SystemName;
import com.yahoo.config.provision.TenantName;
import com.yahoo.security.KeyFormat;
import com.yahoo.security.KeyUtils;
import com.yahoo.slime.SlimeUtils;
import com.yahoo.text.Utf8;
Expand All @@ -28,7 +29,6 @@
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.security.KeyPair;
import java.security.interfaces.XECPrivateKey;
import java.security.interfaces.XECPublicKey;
import java.time.Duration;
import java.util.Arrays;
Expand Down Expand Up @@ -876,8 +876,7 @@ public void test_snapshots() throws IOException {
.getComponent(SecretStoreMock.class.getName());
KeyPair keyPair = KeyUtils.generateX25519KeyPair();
secretStore.add(new Secret(Key.fromString("snapshot/sealingPrivateKey"),
KeyUtils.toBase64EncodedX25519PrivateKey((XECPrivateKey) keyPair.getPrivate())
.getBytes(),
KeyUtils.toPem(keyPair.getPrivate(), KeyFormat.PKCS8).getBytes(),
SecretVersionId.of("1")));

// Trigger creation of snapshots
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@
* @author bjorncs
*/
public enum KeyAlgorithm {

RSA("RSA", null),
EC("EC", new ECGenParameterSpec("prime256v1")); // TODO Make curve configurable
EC("EC", new ECGenParameterSpec("prime256v1")),
XDH("XDH", new ECGenParameterSpec("X25519"));

final String algorithmName;
private final AlgorithmParameterSpec spec;
Expand All @@ -25,4 +27,18 @@ String getAlgorithmName() {
}

Optional<AlgorithmParameterSpec> getSpec() { return Optional.ofNullable(spec); }

public static KeyAlgorithm from(String name) {
for (var algorithm : values()) {
if (name.equals(algorithm.getAlgorithmName())) {
return algorithm;
} else if (algorithm == XDH && name.equals("X25519")) {
// "XDH" is the name used by the JDK for elliptic curve keys using Curve25519, while BouncyCastle uses
// "X25519"
return algorithm;
}
}
throw new IllegalArgumentException("Unknown key algorithm '" + name + "'");
}

}
91 changes: 44 additions & 47 deletions security-utils/src/main/java/com/yahoo/security/KeyUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import org.bouncycastle.asn1.ASN1Encodable;
import org.bouncycastle.asn1.ASN1Primitive;
import org.bouncycastle.asn1.edec.EdECObjectIdentifiers;
import org.bouncycastle.asn1.pkcs.PKCSObjectIdentifiers;
import org.bouncycastle.asn1.pkcs.PrivateKeyInfo;
import org.bouncycastle.asn1.x509.AlgorithmIdentifier;
Expand Down Expand Up @@ -50,6 +51,7 @@

import static com.yahoo.security.KeyAlgorithm.EC;
import static com.yahoo.security.KeyAlgorithm.RSA;
import static com.yahoo.security.KeyAlgorithm.XDH;

/**
* @author bjorncs
Expand Down Expand Up @@ -78,23 +80,30 @@ public static KeyPair generateKeypair(KeyAlgorithm algorithm) {
}

public static PublicKey extractPublicKey(PrivateKey privateKey) {
String algorithm = privateKey.getAlgorithm();
KeyAlgorithm keyAlgorithm = KeyAlgorithm.from(privateKey.getAlgorithm());
try {
if (algorithm.equals(RSA.getAlgorithmName())) {
KeyFactory keyFactory = createKeyFactory(RSA);
RSAPrivateCrtKey rsaPrivateCrtKey = (RSAPrivateCrtKey) privateKey;
RSAPublicKeySpec keySpec = new RSAPublicKeySpec(rsaPrivateCrtKey.getModulus(), rsaPrivateCrtKey.getPublicExponent());
return keyFactory.generatePublic(keySpec);
} else if (algorithm.equals(EC.getAlgorithmName())) {
KeyFactory keyFactory = createKeyFactory(EC);
BCECPrivateKey ecPrivateKey = (BCECPrivateKey) privateKey;
ECParameterSpec ecParameterSpec = ecPrivateKey.getParameters();
ECPoint ecPoint = new FixedPointCombMultiplier().multiply(ecParameterSpec.getG(), ecPrivateKey.getD());
ECPublicKeySpec keySpec = new ECPublicKeySpec(ecPoint, ecParameterSpec);
return keyFactory.generatePublic(keySpec);
} else {
throw new IllegalArgumentException("Unexpected key algorithm: " + algorithm);
}
return switch (keyAlgorithm) {
case RSA -> {
KeyFactory keyFactory = createKeyFactory(RSA);
RSAPrivateCrtKey rsaPrivateCrtKey = (RSAPrivateCrtKey) privateKey;
RSAPublicKeySpec keySpec = new RSAPublicKeySpec(rsaPrivateCrtKey.getModulus(), rsaPrivateCrtKey.getPublicExponent());
yield keyFactory.generatePublic(keySpec);
}
case EC -> {
KeyFactory keyFactory = createKeyFactory(EC);
BCECPrivateKey ecPrivateKey = (BCECPrivateKey) privateKey;
ECParameterSpec ecParameterSpec = ecPrivateKey.getParameters();
ECPoint ecPoint = new FixedPointCombMultiplier().multiply(ecParameterSpec.getG(), ecPrivateKey.getD());
ECPublicKeySpec keySpec = new ECPublicKeySpec(ecPoint, ecParameterSpec);
yield keyFactory.generatePublic(keySpec);
}
case XDH -> {
byte[] privScalar = toRawX25519PrivateKeyBytes((XECPrivateKey) privateKey);
byte[] pubPoint = new byte[X25519.POINT_SIZE];
X25519.generatePublicKey(privScalar, 0, pubPoint, 0); // scalarMultBase => public key point
yield fromRawX25519PublicKey(pubPoint);
}
};
} catch (GeneralSecurityException e) {
throw new RuntimeException(e);
}
Expand Down Expand Up @@ -127,7 +136,7 @@ public static PrivateKey fromPemEncodedPrivateKey(String pem) {
unknownObjects.add(pemObject);
}
}
throw new IllegalArgumentException("Expected a private key, but found " + unknownObjects.toString());
throw new IllegalArgumentException("Expected a private key, but found " + unknownObjects);
} catch (IOException e) {
throw new UncheckedIOException(e);
} catch (GeneralSecurityException e) {
Expand Down Expand Up @@ -168,14 +177,10 @@ public static String toPem(PrivateKey privateKey) {
}

public static String toPem(PrivateKey privateKey, KeyFormat format) {
switch (format) {
case PKCS1:
return toPkcs1Pem(privateKey);
case PKCS8:
return toPkcs8Pem(privateKey);
default:
throw new IllegalArgumentException("Unknown format: " + format);
}
return switch (format) {
case PKCS1 -> toPkcs1Pem(privateKey);
case PKCS8 -> toPkcs8Pem(privateKey);
};
}

public static String toPem(PublicKey publicKey) {
Expand All @@ -190,15 +195,12 @@ public static String toPem(PublicKey publicKey) {

private static String toPkcs1Pem(PrivateKey privateKey) {
try (StringWriter stringWriter = new StringWriter(); JcaPEMWriter pemWriter = new JcaPEMWriter(stringWriter)) {
String algorithm = privateKey.getAlgorithm();
String type;
if (algorithm.equals(RSA.getAlgorithmName())) {
type = "RSA PRIVATE KEY";
} else if (algorithm.equals(EC.getAlgorithmName())) {
type = "EC PRIVATE KEY";
} else {
throw new IllegalArgumentException("Unexpected key algorithm: " + algorithm);
}
KeyAlgorithm keyAlgorithm = KeyAlgorithm.from(privateKey.getAlgorithm());
String type = switch (keyAlgorithm) {
case RSA -> "RSA PRIVATE KEY";
case EC -> "EC PRIVATE KEY";
case XDH -> throw new IllegalArgumentException("Cannot use PKCS#1 for X25519 key");
};
pemWriter.writeObject(new PemObject(type, getPkcs1Bytes(privateKey)));
pemWriter.flush();
return stringWriter.toString();
Expand Down Expand Up @@ -227,9 +229,11 @@ private static byte[] getPkcs1Bytes(PrivateKey privateKey) throws IOException{

private static KeyFactory createKeyFactory(AlgorithmIdentifier algorithm) throws NoSuchAlgorithmException {
if (X9ObjectIdentifiers.id_ecPublicKey.equals(algorithm.getAlgorithm())) {
return createKeyFactory(KeyAlgorithm.EC);
return createKeyFactory(EC);
} else if (PKCSObjectIdentifiers.rsaEncryption.equals(algorithm.getAlgorithm())) {
return createKeyFactory(KeyAlgorithm.RSA);
return createKeyFactory(RSA);
} else if (EdECObjectIdentifiers.id_X25519.equals(algorithm.getAlgorithm())) {
return createKeyFactory(XDH);
} else {
throw new IllegalArgumentException("Unknown key algorithm: " + algorithm);
}
Expand Down Expand Up @@ -338,21 +342,14 @@ public static String toBase58EncodedX25519PrivateKey(XECPrivateKey privateKey) {
return Base58.codec().encode(toRawX25519PrivateKeyBytes(privateKey));
}

// TODO unify with generateKeypair()?
// TODO: In-line and remove
public static KeyPair generateX25519KeyPair() {
try {
return KeyPairGenerator.getInstance("X25519").generateKeyPair();
} catch (NoSuchAlgorithmException e) {
throw new RuntimeException(e);
}
return generateKeypair(XDH);
}

// TODO unify with extractPublicKey()
// TODO: In-line and remove
public static XECPublicKey extractX25519PublicKey(XECPrivateKey privateKey) {
byte[] privScalar = toRawX25519PrivateKeyBytes(privateKey);
byte[] pubPoint = new byte[X25519.POINT_SIZE];
X25519.generatePublicKey(privScalar, 0, pubPoint, 0); // scalarMultBase => public key point
return fromRawX25519PublicKey(pubPoint);
return (XECPublicKey) extractPublicKey(privateKey);
}

/**
Expand Down
29 changes: 19 additions & 10 deletions security-utils/src/test/java/com/yahoo/security/KeyUtilsTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import static com.yahoo.security.ArrayUtils.unhex;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertSame;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;

Expand Down Expand Up @@ -55,26 +56,34 @@ void can_serialize_and_deserialize_ec_privatekey_using_pkcs8_pem_format() {
testPrivateKeySerialization(KeyAlgorithm.EC, KeyFormat.PKCS8, "PRIVATE KEY");
}

@Test
void can_serialize_and_deserialize_x25519_private_key_using_pkcs8_pem_format() {
testPrivateKeySerialization(KeyAlgorithm.XDH, KeyFormat.PKCS8, "PRIVATE KEY");
}

@Test
void can_serialize_and_deserialize_rsa_publickey_using_pem_format() {
KeyPair keyPair = KeyUtils.generateKeypair(KeyAlgorithm.RSA);
String pem = KeyUtils.toPem(keyPair.getPublic());
assertTrue(pem.contains("BEGIN PUBLIC KEY"));
assertTrue(pem.contains("END PUBLIC KEY"));
PublicKey deserializedKey = KeyUtils.fromPemEncodedPublicKey(pem);
assertEquals(keyPair.getPublic(), deserializedKey);
assertEquals(KeyAlgorithm.RSA.getAlgorithmName(), deserializedKey.getAlgorithm());
testPublicKeySerialization(KeyAlgorithm.RSA);
}

@Test
void can_serialize_and_deserialize_ec_publickey_using_pem_format() {
KeyPair keyPair = KeyUtils.generateKeypair(KeyAlgorithm.EC);
testPublicKeySerialization(KeyAlgorithm.EC);
}

@Test
void can_serialize_and_deserialize_x25519_publickey_using_pem_format() {
testPublicKeySerialization(KeyAlgorithm.XDH);
}

private static void testPublicKeySerialization(KeyAlgorithm keyAlgorithm) {
KeyPair keyPair = KeyUtils.generateKeypair(keyAlgorithm);
String pem = KeyUtils.toPem(keyPair.getPublic());
assertTrue(pem.contains("BEGIN PUBLIC KEY"));
assertTrue(pem.contains("END PUBLIC KEY"));
PublicKey deserializedKey = KeyUtils.fromPemEncodedPublicKey(pem);
assertEquals(keyPair.getPublic(), deserializedKey);
assertEquals(KeyAlgorithm.EC.getAlgorithmName(), deserializedKey.getAlgorithm());
assertSame(keyAlgorithm, KeyAlgorithm.from(deserializedKey.getAlgorithm()));
}

private static void testPrivateKeySerialization(KeyAlgorithm keyAlgorithm, KeyFormat keyFormat, String pemLabel) {
Expand All @@ -84,7 +93,7 @@ private static void testPrivateKeySerialization(KeyAlgorithm keyAlgorithm, KeyFo
assertTrue(pem.contains("END " + pemLabel));
PrivateKey deserializedKey = KeyUtils.fromPemEncodedPrivateKey(pem);
assertEquals(keyPair.getPrivate(), deserializedKey);
assertEquals(keyAlgorithm.getAlgorithmName(), deserializedKey.getAlgorithm());
assertSame(keyAlgorithm, KeyAlgorithm.from(deserializedKey.getAlgorithm()));
}

private static XECPrivateKey xecPrivFromHex(String hex) {
Expand Down

0 comments on commit 251cc88

Please sign in to comment.