Skip to content

Commit

Permalink
xwing: align with draft 04
Browse files Browse the repository at this point in the history
  • Loading branch information
bwesterb committed Oct 20, 2024
1 parent 2b4220b commit c1c1704
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 42 deletions.
4 changes: 3 additions & 1 deletion kem/xwing/scheme.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,9 @@ func (*scheme) UnmarshalBinaryPublicKey(buf []byte) (kem.PublicKey, error) {
return nil, kem.ErrPubKeySize
}

pk.Unpack(buf)
if err := pk.Unpack(buf); err != nil {
return nil, err
}
return &pk, nil
}

Expand Down
71 changes: 40 additions & 31 deletions kem/xwing/xwing.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
//
// https://datatracker.ietf.org/doc/draft-connolly-cfrg-xwing-kem
//
// Currently implements what will likely be -01.
// Currently implements -04.
package xwing

import (
Expand All @@ -18,9 +18,10 @@ import (

// An X-Wing private key.
type PrivateKey struct {
m mlkem768.PrivateKey
x x25519.Key
xpk x25519.Key
seed [32]byte
m mlkem768.PrivateKey
x x25519.Key
xpk x25519.Key
}

// An X-Wing public key.
Expand All @@ -31,13 +32,13 @@ type PublicKey struct {

const (
// Size of a seed of a keypair
SeedSize = 96
SeedSize = 32

// Size of an X-Wing public key
PublicKeySize = 1216

// Size of an X-Wing private key
PrivateKeySize = 2464
PrivateKeySize = 32

// Size of the seed passed to EncapsulateTo
EncapsulationSeedSize = 64
Expand Down Expand Up @@ -74,9 +75,7 @@ func (sk *PrivateKey) Pack(buf []byte) {
if len(buf) != PrivateKeySize {
panic(kem.ErrPrivKeySize)
}
sk.m.Pack(buf[:mlkem768.PrivateKeySize])
copy(buf[mlkem768.PrivateKeySize:mlkem768.PrivateKeySize+32], sk.x[:])
copy(buf[mlkem768.PrivateKeySize+32:], sk.xpk[:])
copy(buf, sk.seed[:])
}

// Packs pk to buf.
Expand All @@ -95,27 +94,36 @@ func (pk *PublicKey) Pack(buf []byte) {
//
// Panics if seed is not of length SeedSize.
func DeriveKeyPair(seed []byte) (*PrivateKey, *PublicKey) {
var (
sk PrivateKey
pk PublicKey
)

deriveKeyPair(seed, &sk, &pk)

return &sk, &pk
}

func deriveKeyPair(seed []byte, sk *PrivateKey, pk *PublicKey) {
if len(seed) != SeedSize {
panic(kem.ErrSeedSize)
}

var (
pk PublicKey
sk PrivateKey
seedm [mlkem768.KeySeedSize]byte
)
var seedm [mlkem768.KeySeedSize]byte

copy(sk.seed[:], seed)

copy(seedm[:], seed[:64])
copy(sk.x[:], seed[64:])
h := sha3.NewShake128()
_, _ = h.Write(seed)
_, _ = h.Read(seedm[:])
_, _ = h.Read(sk.x[:])

pkm, skm := mlkem768.NewKeyFromSeed(seedm[:])
sk.m = *skm
pk.m = *pkm

x25519.KeyGen(&pk.x, &sk.x)
sk.xpk = pk.x

return &sk, &pk
}

// DeriveKeyPairPacked derives a keypair like DeriveKeyPair, and
Expand Down Expand Up @@ -170,15 +178,19 @@ func GenerateKeyPairPacked(rand io.Reader) ([]byte, []byte, error) {
// Warning: note that the order of the returned ss and ct matches the
// X-Wing standard, which is the reverse of the Circl KEM API.
//
// Returns ErrPubKey if ML-KEM encapsulation key check fails.
//
// Panics if pk is not of size PublicKeySize, or randomness could not
// be read from crypto/rand.Reader
func Encapsulate(pk, seed []byte) (ss, ct []byte) {
// be read from crypto/rand.Reader.
func Encapsulate(pk, seed []byte) (ss, ct []byte, err error) {
var pub PublicKey
pub.Unpack(pk)
if err := pub.Unpack(pk); err != nil {
return nil, nil, err
}
ct = make([]byte, CiphertextSize)
ss = make([]byte, SharedKeySize)
pub.EncapsulateTo(ct, ss, seed)
return ss, ct
return ss, ct, nil
}

// Decapsulate computes the shared key which is encapsulated in ct
Expand Down Expand Up @@ -276,24 +288,21 @@ func (sk *PrivateKey) DecapsulateTo(ss, ct []byte) {
// Unpacks pk from buf.
//
// Panics if buf is not of size PublicKeySize.
func (pk *PublicKey) Unpack(buf []byte) {
//
// Returns ErrPubKey if pk fails the ML-KEM encapsulation key check.
func (pk *PublicKey) Unpack(buf []byte) error {
if len(buf) != PublicKeySize {
panic(kem.ErrPubKeySize)
}

copy(pk.x[:], buf[mlkem768.PublicKeySize:])
pk.m.Unpack(buf[:mlkem768.PublicKeySize])
return pk.m.Unpack(buf[:mlkem768.PublicKeySize])
}

// Unpacks sk from buf.
//
// Panics if buf is not of size PrivateKeySize.
func (sk *PrivateKey) Unpack(buf []byte) {
if len(buf) != PrivateKeySize {
panic(kem.ErrPrivKeySize)
}

copy(sk.x[:], buf[mlkem768.PrivateKeySize:mlkem768.PrivateKeySize+32])
copy(sk.xpk[:], buf[mlkem768.PrivateKeySize+32:])
sk.m.Unpack(buf[:mlkem768.PrivateKeySize])
var pk PublicKey
deriveKeyPair(buf, sk, &pk)
}
23 changes: 13 additions & 10 deletions kem/xwing/xwing_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ func writeHex(w io.Writer, prefix string, val interface{}) {
indent := " "
width := 74
hex := fmt.Sprintf("%x", val)
if len(prefix)+len(hex)+1 < width {
fmt.Fprintf(w, "%s %s\n", prefix, hex)
if len(prefix)+len(hex)+5 < width {
fmt.Fprintf(w, "%s %s\n", prefix, hex)
return
}
fmt.Fprintf(w, "%s\n", prefix)
Expand All @@ -38,19 +38,22 @@ func TestVectors(t *testing.T) {
for i := 0; i < 3; i++ {
var seed [SeedSize]byte
_, _ = h.Read(seed[:])
writeHex(w, "seed ", seed)
writeHex(w, "seed", seed)

sk, pk := DeriveKeyPairPacked(seed[:])
writeHex(w, "sk ", sk)
writeHex(w, "pk ", pk)
writeHex(w, "sk", sk)
writeHex(w, "pk", pk)

var eseed [EncapsulationSeedSize]byte
_, _ = h.Read(eseed[:])
writeHex(w, "eseed ", eseed)
writeHex(w, "eseed", eseed)

ss, ct := Encapsulate(pk, eseed[:])
writeHex(w, "ct ", ct)
writeHex(w, "ss ", ss)
ss, ct, err := Encapsulate(pk, eseed[:])
if err != nil {
t.Fatal(err)
}
writeHex(w, "ct", ct)
writeHex(w, "ss", ss)

ss2 := Decapsulate(ct, sk)
if !bytes.Equal(ss, ss2) {
Expand All @@ -66,7 +69,7 @@ func TestVectors(t *testing.T) {
var cs [32]byte
_, _ = h.Read(cs[:])
got := fmt.Sprintf("%x", cs)
want := "1b2fd3a79ad0a82d814dcdf5da62a3830bc5f48e392dfe01ac1c3f9bb37ff86e"
want := "0e414d1453095f77f7959da8ddba81559e9d62508c2f665a004467420d5d0c51"
if got != want {
t.Fatalf("%s ≠ %s", got, want)
}
Expand Down

0 comments on commit c1c1704

Please sign in to comment.