Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 63 additions & 0 deletions hpke/kem_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
package hpke_test

import (
"fmt"
"testing"

"github.com/cloudflare/circl/hpke"
"github.com/cloudflare/circl/internal/test"
)

func TestKemKeysMarshal(t *testing.T) {
for _, kem := range []hpke.KEM{
hpke.KEM_P256_HKDF_SHA256,
hpke.KEM_P384_HKDF_SHA384,
hpke.KEM_P521_HKDF_SHA512,
hpke.KEM_X25519_HKDF_SHA256,
hpke.KEM_X448_HKDF_SHA512,
hpke.KEM_X25519_KYBER768_DRAFT00,
} {
checkIssue488(t, kem)
}
}

func checkIssue488(t *testing.T, kem hpke.KEM) {
scheme := kem.Scheme()
pk, sk, err := scheme.GenerateKeyPair()
if err != nil {
t.Fatal(err)
}
skBytes, err := sk.MarshalBinary()
test.CheckNoErr(t, err, "marshal private key")
pkBytes, err := pk.MarshalBinary()
test.CheckNoErr(t, err, "marshal public key")

t.Run(fmt.Sprintf("%v/PrivateKey", scheme.Name()), func(t *testing.T) {
N := scheme.PrivateKeySize()
buffer := make([]byte, N+1)
copy(buffer, skBytes)

// passing a buffer larger than the private key size should error (but no panic).
_, err := scheme.UnmarshalBinaryPrivateKey(buffer[:N+1])
test.CheckIsErr(t, err, "unmarshal private key should failed")

// passing a buffer of the exact size must be correct.
gotSk, err := scheme.UnmarshalBinaryPrivateKey(buffer[:N])
test.CheckNoErr(t, err, "unmarshal private key shouldn't fail")
test.CheckOk(sk.Equal(gotSk), "private keys are not equal", t)
})

t.Run(fmt.Sprintf("%v/PublicKey", scheme.Name()), func(t *testing.T) {
N := scheme.PublicKeySize()
buffer := make([]byte, N+1)
copy(buffer, pkBytes)

// passing a buffer larger than the public key size should error (but no panic).
_, err := scheme.UnmarshalBinaryPublicKey(buffer[:N+1])
test.CheckIsErr(t, err, "unmarshal public key should failed")

gotPk, err := scheme.UnmarshalBinaryPublicKey(buffer[:N])
test.CheckNoErr(t, err, "unmarshal public key shouldn't fail")
test.CheckOk(pk.Equal(gotPk), "public keys are not equal", t)
})
}
19 changes: 11 additions & 8 deletions hpke/shortkem.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ func (s shortKEM) DeriveKeyPair(seed []byte) (kem.PublicKey, kem.PrivateKey) {
bitmask = 0x01
}

Nsk := s.PrivateKeySize()
dkpPrk := s.labeledExtract([]byte(""), []byte("dkp_prk"), seed)
var bytes []byte
ctr := 0
Expand All @@ -64,14 +65,12 @@ func (s shortKEM) DeriveKeyPair(seed []byte) (kem.PublicKey, kem.PrivateKey) {
dkpPrk,
[]byte("candidate"),
[]byte{byte(ctr)},
uint16(s.byteSize()),
uint16(Nsk),
)
bytes[0] &= bitmask
skBig.SetBytes(bytes)
}
l := s.PrivateKeySize()
sk := &shortKEMPrivKey{s, make([]byte, l), nil}
copy(sk.priv[l-len(bytes):], bytes)
sk := &shortKEMPrivKey{s, bytes, nil}
return sk.Public(), sk
}

Expand All @@ -83,11 +82,11 @@ func (s shortKEM) GenerateKeyPair() (kem.PublicKey, kem.PrivateKey, error) {

func (s shortKEM) UnmarshalBinaryPrivateKey(data []byte) (kem.PrivateKey, error) {
l := s.PrivateKeySize()
if len(data) < l {
return nil, ErrInvalidKEMPrivateKey
if len(data) != l {
return nil, kem.ErrPrivKeySize
}
sk := &shortKEMPrivKey{s, make([]byte, l), nil}
copy(sk.priv[l-len(data):l], data[:l])
copy(sk.priv, data[:l])
if !sk.validate() {
return nil, ErrInvalidKEMPrivateKey
}
Expand All @@ -96,7 +95,11 @@ func (s shortKEM) UnmarshalBinaryPrivateKey(data []byte) (kem.PrivateKey, error)
}

func (s shortKEM) UnmarshalBinaryPublicKey(data []byte) (kem.PublicKey, error) {
x, y := elliptic.Unmarshal(s, data)
l := s.PublicKeySize()
if len(data) != l {
return nil, kem.ErrPubKeySize
}
x, y := elliptic.Unmarshal(s, data[:l])
if x == nil {
return nil, ErrInvalidKEMPublicKey
}
Expand Down
13 changes: 7 additions & 6 deletions hpke/xkem.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,13 +58,14 @@ func (x xKEM) DeriveKeyPair(seed []byte) (kem.PublicKey, kem.PrivateKey) {
if len(seed) != x.SeedSize() {
panic(kem.ErrSeedSize)
}
sk := &xKEMPrivKey{scheme: x, priv: make([]byte, x.size)}
Nsk := x.PrivateKeySize()
sk := &xKEMPrivKey{scheme: x, priv: make([]byte, Nsk)}
dkpPrk := x.labeledExtract([]byte(""), []byte("dkp_prk"), seed)
bytes := x.labeledExpand(
dkpPrk,
[]byte("sk"),
nil,
uint16(x.PrivateKeySize()),
uint16(Nsk),
)
copy(sk.priv, bytes)
return sk.Public(), sk
Expand All @@ -81,8 +82,8 @@ func (x xKEM) GenerateKeyPair() (kem.PublicKey, kem.PrivateKey, error) {

func (x xKEM) UnmarshalBinaryPrivateKey(data []byte) (kem.PrivateKey, error) {
l := x.PrivateKeySize()
if len(data) < l {
return nil, ErrInvalidKEMPrivateKey
if len(data) != l {
return nil, kem.ErrPrivKeySize
}
sk := &xKEMPrivKey{x, make([]byte, l), nil}
copy(sk.priv, data[:l])
Expand All @@ -94,8 +95,8 @@ func (x xKEM) UnmarshalBinaryPrivateKey(data []byte) (kem.PrivateKey, error) {

func (x xKEM) UnmarshalBinaryPublicKey(data []byte) (kem.PublicKey, error) {
l := x.PublicKeySize()
if len(data) < l {
return nil, ErrInvalidKEMPublicKey
if len(data) != l {
return nil, kem.ErrPubKeySize
}
pk := &xKEMPubKey{x, make([]byte, l)}
copy(pk.pub, data[:l])
Expand Down