summaryrefslogtreecommitdiffstats
path: root/dsse/sign_test.go
diff options
context:
space:
mode:
Diffstat (limited to 'dsse/sign_test.go')
-rw-r--r--dsse/sign_test.go380
1 files changed, 380 insertions, 0 deletions
diff --git a/dsse/sign_test.go b/dsse/sign_test.go
new file mode 100644
index 0000000..c4abb73
--- /dev/null
+++ b/dsse/sign_test.go
@@ -0,0 +1,380 @@
+package dsse
+
+import (
+ "context"
+ "crypto"
+ "crypto/ecdsa"
+ "crypto/elliptic"
+ "crypto/sha256"
+ "encoding/base64"
+ "errors"
+ "fmt"
+ "math/big"
+ "testing"
+
+ "github.com/codahale/rfc6979"
+ "github.com/stretchr/testify/assert"
+)
+
+var errLength = errors.New("invalid length")
+
+func TestPAE(t *testing.T) {
+ t.Run("Empty", func(t *testing.T) {
+ var want = []byte("DSSEv1 0 0 ")
+
+ got := PAE("", []byte{})
+ assert.Equal(t, want, got, "Wrong encoding")
+ })
+ t.Run("Hello world", func(t *testing.T) {
+ var want = []byte("DSSEv1 29 http://example.com/HelloWorld 11 hello world")
+
+ got := PAE("http://example.com/HelloWorld", []byte("hello world"))
+ assert.Equal(t, want, got, "Wrong encoding")
+ })
+ t.Run("Unicode-only", func(t *testing.T) {
+ var want = []byte("DSSEv1 29 http://example.com/HelloWorld 3 ಠ")
+
+ got := PAE("http://example.com/HelloWorld", []byte("ಠ"))
+ assert.Equal(t, want, got, "Wrong encoding")
+ })
+}
+
+type nilSignerVerifier int
+
+func (n nilSignerVerifier) Sign(ctx context.Context, data []byte) ([]byte, error) {
+ return data, nil
+}
+
+func (n nilSignerVerifier) Verify(ctx context.Context, data, sig []byte) error {
+ if len(data) != len(sig) {
+ return errLength
+ }
+
+ for i := range data {
+ if data[i] != sig[i] {
+ return errVerify
+ }
+ }
+
+ return nil
+}
+
+func (n nilSignerVerifier) KeyID() (string, error) {
+ return "nil", nil
+}
+
+func (n nilSignerVerifier) Public() crypto.PublicKey {
+ return "nil-public"
+}
+
+type nullSignerVerifier int
+
+func (n nullSignerVerifier) Sign(ctx context.Context, data []byte) ([]byte, error) {
+ return data, nil
+}
+
+func (n nullSignerVerifier) Verify(ctx context.Context, data, sig []byte) error {
+ if len(data) != len(sig) {
+ return errLength
+ }
+
+ for i := range data {
+ if data[i] != sig[i] {
+ return errVerify
+ }
+ }
+
+ return nil
+}
+
+func (n nullSignerVerifier) KeyID() (string, error) {
+ return "null", nil
+}
+
+func (n nullSignerVerifier) Public() crypto.PublicKey {
+ return "null-public"
+}
+
+type errsigner int
+
+func (n errsigner) Sign(ctx context.Context, data []byte) ([]byte, error) {
+ return nil, fmt.Errorf("signing error")
+}
+
+func (n errsigner) Verify(ctx context.Context, data, sig []byte) error {
+ return errVerify
+}
+
+func (n errsigner) KeyID() (string, error) {
+ return "err", nil
+}
+
+func (n errsigner) Public() crypto.PublicKey {
+ return "err-public"
+}
+
+type errSignerVerifier int
+
+var errVerify = fmt.Errorf("accepted signatures do not match threshold, Found: 0, Expected 1")
+var errThreshold = fmt.Errorf("invalid threshold")
+
+func (n errSignerVerifier) Sign(ctx context.Context, data []byte) ([]byte, error) {
+ return data, nil
+}
+
+func (n errSignerVerifier) Verify(ctx context.Context, data, sig []byte) error {
+ return errVerify
+}
+
+func (n errSignerVerifier) KeyID() (string, error) {
+ return "err", nil
+}
+
+func (n errSignerVerifier) Public() crypto.PublicKey {
+ return "err-public"
+}
+
+type badverifier int
+
+func (n badverifier) Sign(ctx context.Context, data []byte) ([]byte, error) {
+ return append(data, byte(0)), nil
+}
+
+func (n badverifier) Verify(ctx context.Context, data, sig []byte) error {
+
+ if len(data) != len(sig) {
+ return errLength
+ }
+
+ for i := range data {
+ if data[i] != sig[i] {
+ return errVerify
+ }
+ }
+
+ return nil
+}
+
+func (n badverifier) KeyID() (string, error) {
+ return "bad", nil
+}
+
+func (n badverifier) Public() crypto.PublicKey {
+ return "bad-public"
+}
+
+func TestNoSigners(t *testing.T) {
+ t.Run("nil slice", func(t *testing.T) {
+ signer, err := NewEnvelopeSigner(nil)
+ assert.Nil(t, signer, "unexpected signer")
+ assert.NotNil(t, err, "error expected")
+ assert.Equal(t, ErrNoSigners, err, "wrong error")
+ })
+
+ t.Run("empty slice", func(t *testing.T) {
+ signer, err := NewEnvelopeSigner([]Signer{}...)
+ assert.Nil(t, signer, "unexpected signer")
+ assert.NotNil(t, err, "error expected")
+ assert.Equal(t, ErrNoSigners, err, "wrong error")
+ })
+}
+
+func TestNilSign(t *testing.T) {
+ var keyID = "nil"
+ var payloadType = "http://example.com/HelloWorld"
+ var payload = []byte("hello world")
+
+ pae := PAE(payloadType, payload)
+ want := Envelope{
+ Payload: base64.StdEncoding.EncodeToString([]byte(payload)),
+ PayloadType: payloadType,
+ Signatures: []Signature{
+ {
+ KeyID: keyID,
+ Sig: base64.StdEncoding.EncodeToString(pae),
+ },
+ },
+ }
+
+ var ns nilSignerVerifier
+ signer, err := NewEnvelopeSigner(ns)
+ assert.Nil(t, err, "unexpected error")
+
+ got, err := signer.SignPayload(context.TODO(), payloadType, []byte(payload))
+ assert.Nil(t, err, "sign failed")
+ assert.Equal(t, &want, got, "bad signature")
+}
+
+func TestSignError(t *testing.T) {
+ var es errsigner
+ signer, err := NewEnvelopeSigner(es)
+ assert.Nil(t, err, "unexpected error")
+
+ got, err := signer.SignPayload(context.TODO(), "t", []byte("d"))
+ assert.Nil(t, got, "expected nil")
+ assert.NotNil(t, err, "error expected")
+ assert.Equal(t, "signing error", err.Error(), "wrong error")
+}
+
+func newEcdsaKey() *ecdsa.PrivateKey {
+ var x big.Int
+ var y big.Int
+ var d big.Int
+
+ _, ok := x.SetString("46950820868899156662930047687818585632848591499744589407958293238635476079160", 10)
+ if !ok {
+ return nil
+ }
+ _, ok = y.SetString("5640078356564379163099075877009565129882514886557779369047442380624545832820", 10)
+ if !ok {
+ return nil
+ }
+ _, ok = d.SetString("97358161215184420915383655311931858321456579547487070936769975997791359926199", 10)
+ if !ok {
+ return nil
+ }
+
+ var private = ecdsa.PrivateKey{
+ PublicKey: ecdsa.PublicKey{
+ Curve: elliptic.P256(),
+ X: &x,
+ Y: &y,
+ },
+ D: &d,
+ }
+
+ return &private
+}
+
+type ecdsaSignerVerifier struct {
+ keyID string
+ key *ecdsa.PrivateKey
+ rLen int
+ verified bool
+}
+
+func (es *ecdsaSignerVerifier) Sign(ctx context.Context, data []byte) ([]byte, error) {
+ // Data is complete message, hash it and sign the digest
+ digest := sha256.Sum256(data)
+ r, s, err := rfc6979.SignECDSA(es.key, digest[:], sha256.New)
+ if err != nil {
+ return nil, err
+ }
+
+ rb := r.Bytes()
+ sb := s.Bytes()
+ es.rLen = len(rb)
+ rawSig := append(rb, sb...)
+
+ return rawSig, nil
+}
+
+func (es *ecdsaSignerVerifier) Verify(ctx context.Context, data, sig []byte) error {
+ var r big.Int
+ var s big.Int
+ digest := sha256.Sum256(data)
+ // Signature here is the raw bytes of r and s concatenated
+ rb := sig[:es.rLen]
+ sb := sig[es.rLen:]
+ r.SetBytes(rb)
+ s.SetBytes(sb)
+
+ ok := ecdsa.Verify(&es.key.PublicKey, digest[:], &r, &s)
+ es.verified = ok
+
+ if ok {
+ return nil
+ }
+ return errVerify
+}
+
+func (es *ecdsaSignerVerifier) KeyID() (string, error) {
+ return es.keyID, nil
+}
+
+func (es *ecdsaSignerVerifier) Public() crypto.PublicKey {
+ return es.key.Public()
+}
+
+// Test against the example in the protocol specification:
+// https://github.com/secure-systems-lab/dsse/blob/master/protocol.md
+func TestEcdsaSign(t *testing.T) {
+ var keyID = "test key 123"
+ var payloadType = "http://example.com/HelloWorld"
+ var payload = "hello world"
+ var ecdsa = &ecdsaSignerVerifier{
+ keyID: keyID,
+ key: newEcdsaKey(),
+ }
+ var want = Envelope{
+ Payload: "aGVsbG8gd29ybGQ=",
+ PayloadType: payloadType,
+ Signatures: []Signature{
+ {
+ KeyID: keyID,
+ Sig: "A3JqsQGtVsJ2O2xqrI5IcnXip5GToJ3F+FnZ+O88SjtR6rDAajabZKciJTfUiHqJPcIAriEGAHTVeCUjW2JIZA==",
+ },
+ },
+ }
+
+ signer, err := NewEnvelopeSigner(ecdsa)
+ assert.Nil(t, err, "unexpected error")
+
+ env, err := signer.SignPayload(context.TODO(), payloadType, []byte(payload))
+ assert.Nil(t, err, "unexpected error")
+ assert.Equal(t, &want, env, "Wrong envelope generated")
+
+ // Now verify
+ verifier, err := NewEnvelopeVerifier(ecdsa)
+ assert.Nil(t, err, "unexpected error")
+ acceptedKeys, err := verifier.Verify(context.TODO(), env)
+ assert.Nil(t, err, "unexpected error")
+ assert.True(t, ecdsa.verified, "verify was not called")
+ assert.Len(t, acceptedKeys, 1, "unexpected keys")
+ assert.Equal(t, acceptedKeys[0].KeyID, keyID, "unexpected keyid")
+}
+
+func TestDecodeB64Payload(t *testing.T) {
+ var want = make([]byte, 256)
+ for i := range want {
+ want[i] = byte(i)
+ }
+ var b64Url = "AAECAwQFBgcICQoLDA0ODxAREhMUFRYXGBkaGxwdHh8gISIjJCUmJygpKissLS4vMDEyMzQ1Njc4OTo7PD0+P0BBQkNERUZHSElKS0xNTk9QUVJTVFVWV1hZWltcXV5fYGFiY2RlZmdoaWprbG1ub3BxcnN0dXZ3eHl6e3x9fn+AgYKDhIWGh4iJiouMjY6PkJGSk5SVlpeYmZqbnJ2en6ChoqOkpaanqKmqq6ytrq+wsbKztLW2t7i5uru8vb6/wMHCw8TFxsfIycrLzM3Oz9DR0tPU1dbX2Nna29zd3t/g4eLj5OXm5+jp6uvs7e7v8PHy8/T19vf4+fr7/P3+/w=="
+ var b64Std = "AAECAwQFBgcICQoLDA0ODxAREhMUFRYXGBkaGxwdHh8gISIjJCUmJygpKissLS4vMDEyMzQ1Njc4OTo7PD0-P0BBQkNERUZHSElKS0xNTk9QUVJTVFVWV1hZWltcXV5fYGFiY2RlZmdoaWprbG1ub3BxcnN0dXZ3eHl6e3x9fn-AgYKDhIWGh4iJiouMjY6PkJGSk5SVlpeYmZqbnJ2en6ChoqOkpaanqKmqq6ytrq-wsbKztLW2t7i5uru8vb6_wMHCw8TFxsfIycrLzM3Oz9DR0tPU1dbX2Nna29zd3t_g4eLj5OXm5-jp6uvs7e7v8PHy8_T19vf4-fr7_P3-_w=="
+ var b64UrlErr = "AAECAwQFBgcICQoLDA0ODxAREhMUFRYXGBkaGxwdHh8gISIjJCUmJygpKissLS4vMDEyMzQ1Njc4OTo7PD0+P0BBQkNERUZHSElKS0xNTk9QUVJTVFVWV1hZWltcXV5fYGFiY2RlZmdoaWprbG1ub3BxcnN0dXZ3eHl6e3x9fn+AgYKDhIWGh4iJiouMjY6PkJGSk5SVlpeYmZqbnJ2en6ChoqOkpaanqKmqq6ytrq+wsbKztLW2t7i5uru8vb6/wMHCw8TFxsfIycrLzM3Oz9DR0tPU1dbX2Nna29zd3t/g4eLj5OXm5+jp6uvs7e7v8PHy8/T19vf4+fr7/P3+/w"
+ var b64StdErr = "AAECAwQFBgcICQoLDA0ODxAREhMUFRYXGBkaGxwdHh8gISIjJCUmJygpKissLS4vMDEyMzQ1Njc4OTo7PD0-P0BBQkNERUZHSElKS0xNTk9QUVJTVFVWV1hZWltcXV5fYGFiY2RlZmdoaWprbG1ub3BxcnN0dXZ3eHl6e3x9fn-AgYKDhIWGh4iJiouMjY6PkJGSk5SVlpeYmZqbnJ2en6ChoqOkpaanqKmqq6ytrq-wsbKztLW2t7i5uru8vb6_wMHCw8TFxsfIycrLzM3Oz9DR0tPU1dbX2Nna29zd3t_g4eLj5OXm5-jp6uvs7e7v8PHy8_T19vf4-fr7_P3-_w"
+
+ t.Run("Standard encoding", func(t *testing.T) {
+ env := &Envelope{
+ Payload: b64Std,
+ }
+ got, err := env.DecodeB64Payload()
+ assert.Nil(t, err, "unexpected error")
+ assert.Equal(t, want, got, "wrong data")
+ })
+ t.Run("URL encoding", func(t *testing.T) {
+ env := &Envelope{
+ Payload: b64Url,
+ }
+ got, err := env.DecodeB64Payload()
+ assert.Nil(t, err, "unexpected error")
+ assert.Equal(t, want, got, "wrong data")
+ })
+
+ t.Run("Standard encoding - error", func(t *testing.T) {
+ env := &Envelope{
+ Payload: b64StdErr,
+ }
+ got, err := env.DecodeB64Payload()
+ assert.NotNil(t, err, "expected error")
+ assert.Nil(t, got, "wrong data")
+ })
+ t.Run("URL encoding - error", func(t *testing.T) {
+ env := &Envelope{
+ Payload: b64UrlErr,
+ }
+ got, err := env.DecodeB64Payload()
+ assert.NotNil(t, err, "expected error")
+ assert.Nil(t, got, "wrong data")
+ })
+}