summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2024-04-16 18:15:16 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2024-04-16 18:15:16 +0000
commit7ff64900ddd056f849635cef0f384be76c46c318 (patch)
treea95f201f843c1eceae41457bca3297b9ddf09c60
parentInitial commit. (diff)
downloadgolang-github-containers-libtrust-7ff64900ddd056f849635cef0f384be76c46c318.tar.xz
golang-github-containers-libtrust-7ff64900ddd056f849635cef0f384be76c46c318.zip
Adding upstream version 0.0~git20230121.c1716e8.upstream/0.0_git20230121.c1716e8upstream
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
-rw-r--r--CODE-OF-CONDUCT.md3
-rw-r--r--CONTRIBUTING.md13
-rw-r--r--LICENSE191
-rw-r--r--MAINTAINERS3
-rw-r--r--README.md22
-rw-r--r--SECURITY.md3
-rw-r--r--certificates.go175
-rw-r--r--certificates_test.go111
-rw-r--r--doc.go9
-rw-r--r--ec_key.go422
-rw-r--r--ec_key_no_openssl.go23
-rw-r--r--ec_key_openssl.go24
-rw-r--r--ec_key_test.go157
-rw-r--r--filter.go50
-rw-r--r--filter_test.go81
-rw-r--r--hash.go56
-rw-r--r--jsonsign.go657
-rw-r--r--jsonsign_test.go380
-rw-r--r--key.go253
-rw-r--r--key_files.go255
-rw-r--r--key_files_test.go220
-rw-r--r--key_manager.go175
-rw-r--r--key_test.go80
-rw-r--r--rsa_key.go427
-rw-r--r--rsa_key_test.go157
-rw-r--r--testutil/certificates.go94
-rw-r--r--tlsdemo/README.md50
-rw-r--r--tlsdemo/client.go89
-rw-r--r--tlsdemo/gencert.go62
-rw-r--r--tlsdemo/genkeys.go61
-rw-r--r--tlsdemo/server.go80
-rw-r--r--trustgraph/graph.go50
-rw-r--r--trustgraph/memory_graph.go133
-rw-r--r--trustgraph/memory_graph_test.go174
-rw-r--r--trustgraph/statement.go227
-rw-r--r--trustgraph/statement_test.go417
-rw-r--r--util.go363
-rw-r--r--util_test.go45
38 files changed, 5792 insertions, 0 deletions
diff --git a/CODE-OF-CONDUCT.md b/CODE-OF-CONDUCT.md
new file mode 100644
index 0000000..a7d8acb
--- /dev/null
+++ b/CODE-OF-CONDUCT.md
@@ -0,0 +1,3 @@
+## The libtrust Project Community Code of Conduct
+
+The libtrust project follows the [Containers Community Code of Conduct](https://github.com/containers/common/blob/main/CODE-OF-CONDUCT.md).
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
new file mode 100644
index 0000000..05be0f8
--- /dev/null
+++ b/CONTRIBUTING.md
@@ -0,0 +1,13 @@
+# Contributing to libtrust
+
+Want to hack on libtrust? Awesome! Here are instructions to get you
+started.
+
+libtrust is a part of the [Docker](https://www.docker.com) project, and follows
+the same rules and principles. If you're already familiar with the way
+Docker does things, you'll feel right at home.
+
+Otherwise, go read
+[Docker's contributions guidelines](https://github.com/docker/docker/blob/master/CONTRIBUTING.md).
+
+Happy hacking!
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000..2744858
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,191 @@
+
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ Copyright 2014 Docker, Inc.
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
diff --git a/MAINTAINERS b/MAINTAINERS
new file mode 100644
index 0000000..9768175
--- /dev/null
+++ b/MAINTAINERS
@@ -0,0 +1,3 @@
+Solomon Hykes <solomon@docker.com>
+Josh Hawn <josh@docker.com> (github: jlhawn)
+Derek McGowan <derek@docker.com> (github: dmcgowan)
diff --git a/README.md b/README.md
new file mode 100644
index 0000000..dcffb31
--- /dev/null
+++ b/README.md
@@ -0,0 +1,22 @@
+# libtrust
+
+> **WARNING** this library is no longer actively developed, and will be integrated
+> in the [docker/distribution][https://www.github.com/docker/distribution]
+> repository in future.
+
+Libtrust is library for managing authentication and authorization using public key cryptography.
+
+Authentication is handled using the identity attached to the public key.
+Libtrust provides multiple methods to prove possession of the private key associated with an identity.
+ - TLS x509 certificates
+ - Signature verification
+ - Key Challenge
+
+Authorization and access control is managed through a distributed trust graph.
+Trust servers are used as the authorities of the trust graph and allow caching portions of the graph for faster access.
+
+## Copyright and license
+
+Code and documentation copyright 2014 Docker, inc. Code released under the Apache 2.0 license.
+Docs released under Creative commons.
+
diff --git a/SECURITY.md b/SECURITY.md
new file mode 100644
index 0000000..966f4f0
--- /dev/null
+++ b/SECURITY.md
@@ -0,0 +1,3 @@
+## Security and Disclosure Information Policy for the libtrust Project
+
+The libtrust Project follows the [Security and Disclosure Information Policy](https://github.com/containers/common/blob/main/SECURITY.md) for the Containers Projects.
diff --git a/certificates.go b/certificates.go
new file mode 100644
index 0000000..3dcca33
--- /dev/null
+++ b/certificates.go
@@ -0,0 +1,175 @@
+package libtrust
+
+import (
+ "crypto/rand"
+ "crypto/x509"
+ "crypto/x509/pkix"
+ "encoding/pem"
+ "fmt"
+ "io/ioutil"
+ "math/big"
+ "net"
+ "time"
+)
+
+type certTemplateInfo struct {
+ commonName string
+ domains []string
+ ipAddresses []net.IP
+ isCA bool
+ clientAuth bool
+ serverAuth bool
+}
+
+func generateCertTemplate(info *certTemplateInfo) *x509.Certificate {
+ // Generate a certificate template which is valid from the past week to
+ // 10 years from now. The usage of the certificate depends on the
+ // specified fields in the given certTempInfo object.
+ var (
+ keyUsage x509.KeyUsage
+ extKeyUsage []x509.ExtKeyUsage
+ )
+
+ if info.isCA {
+ keyUsage = x509.KeyUsageCertSign
+ }
+
+ if info.clientAuth {
+ extKeyUsage = append(extKeyUsage, x509.ExtKeyUsageClientAuth)
+ }
+
+ if info.serverAuth {
+ extKeyUsage = append(extKeyUsage, x509.ExtKeyUsageServerAuth)
+ }
+
+ return &x509.Certificate{
+ SerialNumber: big.NewInt(0),
+ Subject: pkix.Name{
+ CommonName: info.commonName,
+ },
+ NotBefore: time.Now().Add(-time.Hour * 24 * 7),
+ NotAfter: time.Now().Add(time.Hour * 24 * 365 * 10),
+ DNSNames: info.domains,
+ IPAddresses: info.ipAddresses,
+ IsCA: info.isCA,
+ KeyUsage: keyUsage,
+ ExtKeyUsage: extKeyUsage,
+ BasicConstraintsValid: info.isCA,
+ }
+}
+
+func generateCert(pub PublicKey, priv PrivateKey, subInfo, issInfo *certTemplateInfo) (cert *x509.Certificate, err error) {
+ pubCertTemplate := generateCertTemplate(subInfo)
+ privCertTemplate := generateCertTemplate(issInfo)
+
+ certDER, err := x509.CreateCertificate(
+ rand.Reader, pubCertTemplate, privCertTemplate,
+ pub.CryptoPublicKey(), priv.CryptoPrivateKey(),
+ )
+ if err != nil {
+ return nil, fmt.Errorf("failed to create certificate: %s", err)
+ }
+
+ cert, err = x509.ParseCertificate(certDER)
+ if err != nil {
+ return nil, fmt.Errorf("failed to parse certificate: %s", err)
+ }
+
+ return
+}
+
+// GenerateSelfSignedServerCert creates a self-signed certificate for the
+// given key which is to be used for TLS servers with the given domains and
+// IP addresses.
+func GenerateSelfSignedServerCert(key PrivateKey, domains []string, ipAddresses []net.IP) (*x509.Certificate, error) {
+ info := &certTemplateInfo{
+ commonName: key.KeyID(),
+ domains: domains,
+ ipAddresses: ipAddresses,
+ serverAuth: true,
+ }
+
+ return generateCert(key.PublicKey(), key, info, info)
+}
+
+// GenerateSelfSignedClientCert creates a self-signed certificate for the
+// given key which is to be used for TLS clients.
+func GenerateSelfSignedClientCert(key PrivateKey) (*x509.Certificate, error) {
+ info := &certTemplateInfo{
+ commonName: key.KeyID(),
+ clientAuth: true,
+ }
+
+ return generateCert(key.PublicKey(), key, info, info)
+}
+
+// GenerateCACert creates a certificate which can be used as a trusted
+// certificate authority.
+func GenerateCACert(signer PrivateKey, trustedKey PublicKey) (*x509.Certificate, error) {
+ subjectInfo := &certTemplateInfo{
+ commonName: trustedKey.KeyID(),
+ isCA: true,
+ }
+ issuerInfo := &certTemplateInfo{
+ commonName: signer.KeyID(),
+ }
+
+ return generateCert(trustedKey, signer, subjectInfo, issuerInfo)
+}
+
+// GenerateCACertPool creates a certificate authority pool to be used for a
+// TLS configuration. Any self-signed certificates issued by the specified
+// trusted keys will be verified during a TLS handshake
+func GenerateCACertPool(signer PrivateKey, trustedKeys []PublicKey) (*x509.CertPool, error) {
+ certPool := x509.NewCertPool()
+
+ for _, trustedKey := range trustedKeys {
+ cert, err := GenerateCACert(signer, trustedKey)
+ if err != nil {
+ return nil, fmt.Errorf("failed to generate CA certificate: %s", err)
+ }
+
+ certPool.AddCert(cert)
+ }
+
+ return certPool, nil
+}
+
+// LoadCertificateBundle loads certificates from the given file. The file should be pem encoded
+// containing one or more certificates. The expected pem type is "CERTIFICATE".
+func LoadCertificateBundle(filename string) ([]*x509.Certificate, error) {
+ b, err := ioutil.ReadFile(filename)
+ if err != nil {
+ return nil, err
+ }
+ certificates := []*x509.Certificate{}
+ var block *pem.Block
+ block, b = pem.Decode(b)
+ for ; block != nil; block, b = pem.Decode(b) {
+ if block.Type == "CERTIFICATE" {
+ cert, err := x509.ParseCertificate(block.Bytes)
+ if err != nil {
+ return nil, err
+ }
+ certificates = append(certificates, cert)
+ } else {
+ return nil, fmt.Errorf("invalid pem block type: %s", block.Type)
+ }
+ }
+
+ return certificates, nil
+}
+
+// LoadCertificatePool loads a CA pool from the given file. The file should be pem encoded
+// containing one or more certificates. The expected pem type is "CERTIFICATE".
+func LoadCertificatePool(filename string) (*x509.CertPool, error) {
+ certs, err := LoadCertificateBundle(filename)
+ if err != nil {
+ return nil, err
+ }
+ pool := x509.NewCertPool()
+ for _, cert := range certs {
+ pool.AddCert(cert)
+ }
+ return pool, nil
+}
diff --git a/certificates_test.go b/certificates_test.go
new file mode 100644
index 0000000..c111f35
--- /dev/null
+++ b/certificates_test.go
@@ -0,0 +1,111 @@
+package libtrust
+
+import (
+ "encoding/pem"
+ "io/ioutil"
+ "net"
+ "os"
+ "path"
+ "testing"
+)
+
+func TestGenerateCertificates(t *testing.T) {
+ key, err := GenerateECP256PrivateKey()
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ _, err = GenerateSelfSignedServerCert(key, []string{"localhost"}, []net.IP{net.ParseIP("127.0.0.1")})
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ _, err = GenerateSelfSignedClientCert(key)
+ if err != nil {
+ t.Fatal(err)
+ }
+}
+
+func TestGenerateCACertPool(t *testing.T) {
+ key, err := GenerateECP256PrivateKey()
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ caKey1, err := GenerateECP256PrivateKey()
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ caKey2, err := GenerateECP256PrivateKey()
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ _, err = GenerateCACertPool(key, []PublicKey{caKey1.PublicKey(), caKey2.PublicKey()})
+ if err != nil {
+ t.Fatal(err)
+ }
+}
+
+func TestLoadCertificates(t *testing.T) {
+ key, err := GenerateECP256PrivateKey()
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ caKey1, err := GenerateECP256PrivateKey()
+ if err != nil {
+ t.Fatal(err)
+ }
+ caKey2, err := GenerateECP256PrivateKey()
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ cert1, err := GenerateCACert(caKey1, key)
+ if err != nil {
+ t.Fatal(err)
+ }
+ cert2, err := GenerateCACert(caKey2, key)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ d, err := ioutil.TempDir("/tmp", "cert-test")
+ if err != nil {
+ t.Fatal(err)
+ }
+ caFile := path.Join(d, "ca.pem")
+ f, err := os.OpenFile(caFile, os.O_CREATE|os.O_WRONLY, 0644)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ err = pem.Encode(f, &pem.Block{Type: "CERTIFICATE", Bytes: cert1.Raw})
+ if err != nil {
+ t.Fatal(err)
+ }
+ err = pem.Encode(f, &pem.Block{Type: "CERTIFICATE", Bytes: cert2.Raw})
+ if err != nil {
+ t.Fatal(err)
+ }
+ f.Close()
+
+ certs, err := LoadCertificateBundle(caFile)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if len(certs) != 2 {
+ t.Fatalf("Wrong number of certs received, expected: %d, received %d", 2, len(certs))
+ }
+
+ pool, err := LoadCertificatePool(caFile)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ if len(pool.Subjects()) != 2 {
+ t.Fatalf("Invalid certificate pool")
+ }
+}
diff --git a/doc.go b/doc.go
new file mode 100644
index 0000000..ec5d215
--- /dev/null
+++ b/doc.go
@@ -0,0 +1,9 @@
+/*
+Package libtrust provides an interface for managing authentication and
+authorization using public key cryptography. Authentication is handled
+using the identity attached to the public key and verified through TLS
+x509 certificates, a key challenge, or signature. Authorization and
+access control is managed through a trust graph distributed between
+both remote trust servers and locally cached and managed data.
+*/
+package libtrust
diff --git a/ec_key.go b/ec_key.go
new file mode 100644
index 0000000..0ee1b91
--- /dev/null
+++ b/ec_key.go
@@ -0,0 +1,422 @@
+package libtrust
+
+import (
+ "crypto"
+ "crypto/ecdsa"
+ "crypto/elliptic"
+ "crypto/rand"
+ "crypto/x509"
+ "encoding/json"
+ "encoding/pem"
+ "errors"
+ "fmt"
+ "io"
+ "math/big"
+)
+
+/*
+ * EC DSA PUBLIC KEY
+ */
+
+// ecPublicKey implements a libtrust.PublicKey using elliptic curve digital
+// signature algorithms.
+type ecPublicKey struct {
+ *ecdsa.PublicKey
+ curveName string
+ signatureAlgorithm *signatureAlgorithm
+ extended map[string]interface{}
+}
+
+func fromECPublicKey(cryptoPublicKey *ecdsa.PublicKey) (*ecPublicKey, error) {
+ curve := cryptoPublicKey.Curve
+
+ switch {
+ case curve == elliptic.P256():
+ return &ecPublicKey{cryptoPublicKey, "P-256", es256, map[string]interface{}{}}, nil
+ case curve == elliptic.P384():
+ return &ecPublicKey{cryptoPublicKey, "P-384", es384, map[string]interface{}{}}, nil
+ case curve == elliptic.P521():
+ return &ecPublicKey{cryptoPublicKey, "P-521", es512, map[string]interface{}{}}, nil
+ default:
+ return nil, errors.New("unsupported elliptic curve")
+ }
+}
+
+// KeyType returns the key type for elliptic curve keys, i.e., "EC".
+func (k *ecPublicKey) KeyType() string {
+ return "EC"
+}
+
+// CurveName returns the elliptic curve identifier.
+// Possible values are "P-256", "P-384", and "P-521".
+func (k *ecPublicKey) CurveName() string {
+ return k.curveName
+}
+
+// KeyID returns a distinct identifier which is unique to this Public Key.
+func (k *ecPublicKey) KeyID() string {
+ return keyIDFromCryptoKey(k)
+}
+
+func (k *ecPublicKey) String() string {
+ return fmt.Sprintf("EC Public Key <%s>", k.KeyID())
+}
+
+// Verify verifyies the signature of the data in the io.Reader using this
+// PublicKey. The alg parameter should identify the digital signature
+// algorithm which was used to produce the signature and should be supported
+// by this public key. Returns a nil error if the signature is valid.
+func (k *ecPublicKey) Verify(data io.Reader, alg string, signature []byte) error {
+ // For EC keys there is only one supported signature algorithm depending
+ // on the curve parameters.
+ if k.signatureAlgorithm.HeaderParam() != alg {
+ return fmt.Errorf("unable to verify signature: EC Public Key with curve %q does not support signature algorithm %q", k.curveName, alg)
+ }
+
+ // signature is the concatenation of (r, s), base64Url encoded.
+ sigLength := len(signature)
+ expectedOctetLength := 2 * ((k.Params().BitSize + 7) >> 3)
+ if sigLength != expectedOctetLength {
+ return fmt.Errorf("signature length is %d octets long, should be %d", sigLength, expectedOctetLength)
+ }
+
+ rBytes, sBytes := signature[:sigLength/2], signature[sigLength/2:]
+ r := new(big.Int).SetBytes(rBytes)
+ s := new(big.Int).SetBytes(sBytes)
+
+ hasher := k.signatureAlgorithm.HashID().New()
+ _, err := io.Copy(hasher, data)
+ if err != nil {
+ return fmt.Errorf("error reading data to sign: %s", err)
+ }
+ hash := hasher.Sum(nil)
+
+ if !ecdsa.Verify(k.PublicKey, hash, r, s) {
+ return errors.New("invalid signature")
+ }
+
+ return nil
+}
+
+// CryptoPublicKey returns the internal object which can be used as a
+// crypto.PublicKey for use with other standard library operations. The type
+// is either *rsa.PublicKey or *ecdsa.PublicKey
+func (k *ecPublicKey) CryptoPublicKey() crypto.PublicKey {
+ return k.PublicKey
+}
+
+func (k *ecPublicKey) toMap() map[string]interface{} {
+ jwk := make(map[string]interface{})
+ for k, v := range k.extended {
+ jwk[k] = v
+ }
+ jwk["kty"] = k.KeyType()
+ jwk["kid"] = k.KeyID()
+ jwk["crv"] = k.CurveName()
+
+ xBytes := k.X.Bytes()
+ yBytes := k.Y.Bytes()
+ octetLength := (k.Params().BitSize + 7) >> 3
+ // MUST include leading zeros in the output so that x, y are each
+ // *octetLength* bytes long.
+ xBuf := make([]byte, octetLength-len(xBytes), octetLength)
+ yBuf := make([]byte, octetLength-len(yBytes), octetLength)
+ xBuf = append(xBuf, xBytes...)
+ yBuf = append(yBuf, yBytes...)
+
+ jwk["x"] = joseBase64UrlEncode(xBuf)
+ jwk["y"] = joseBase64UrlEncode(yBuf)
+
+ return jwk
+}
+
+// MarshalJSON serializes this Public Key using the JWK JSON serialization format for
+// elliptic curve keys.
+func (k *ecPublicKey) MarshalJSON() (data []byte, err error) {
+ return json.Marshal(k.toMap())
+}
+
+// PEMBlock serializes this Public Key to DER-encoded PKIX format.
+func (k *ecPublicKey) PEMBlock() (*pem.Block, error) {
+ derBytes, err := x509.MarshalPKIXPublicKey(k.PublicKey)
+ if err != nil {
+ return nil, fmt.Errorf("unable to serialize EC PublicKey to DER-encoded PKIX format: %s", err)
+ }
+ k.extended["kid"] = k.KeyID() // For display purposes.
+ return createPemBlock("PUBLIC KEY", derBytes, k.extended)
+}
+
+func (k *ecPublicKey) AddExtendedField(field string, value interface{}) {
+ k.extended[field] = value
+}
+
+func (k *ecPublicKey) GetExtendedField(field string) interface{} {
+ v, ok := k.extended[field]
+ if !ok {
+ return nil
+ }
+ return v
+}
+
+func ecPublicKeyFromMap(jwk map[string]interface{}) (*ecPublicKey, error) {
+ // JWK key type (kty) has already been determined to be "EC".
+ // Need to extract 'crv', 'x', 'y', and 'kid' and check for
+ // consistency.
+
+ // Get the curve identifier value.
+ crv, err := stringFromMap(jwk, "crv")
+ if err != nil {
+ return nil, fmt.Errorf("JWK EC Public Key curve identifier: %s", err)
+ }
+
+ var (
+ curve elliptic.Curve
+ sigAlg *signatureAlgorithm
+ )
+
+ switch {
+ case crv == "P-256":
+ curve = elliptic.P256()
+ sigAlg = es256
+ case crv == "P-384":
+ curve = elliptic.P384()
+ sigAlg = es384
+ case crv == "P-521":
+ curve = elliptic.P521()
+ sigAlg = es512
+ default:
+ return nil, fmt.Errorf("JWK EC Public Key curve identifier not supported: %q\n", crv)
+ }
+
+ // Get the X and Y coordinates for the public key point.
+ xB64Url, err := stringFromMap(jwk, "x")
+ if err != nil {
+ return nil, fmt.Errorf("JWK EC Public Key x-coordinate: %s", err)
+ }
+ x, err := parseECCoordinate(xB64Url, curve)
+ if err != nil {
+ return nil, fmt.Errorf("JWK EC Public Key x-coordinate: %s", err)
+ }
+
+ yB64Url, err := stringFromMap(jwk, "y")
+ if err != nil {
+ return nil, fmt.Errorf("JWK EC Public Key y-coordinate: %s", err)
+ }
+ y, err := parseECCoordinate(yB64Url, curve)
+ if err != nil {
+ return nil, fmt.Errorf("JWK EC Public Key y-coordinate: %s", err)
+ }
+
+ key := &ecPublicKey{
+ PublicKey: &ecdsa.PublicKey{Curve: curve, X: x, Y: y},
+ curveName: crv, signatureAlgorithm: sigAlg,
+ }
+
+ // Key ID is optional too, but if it exists, it should match the key.
+ _, ok := jwk["kid"]
+ if ok {
+ kid, err := stringFromMap(jwk, "kid")
+ if err != nil {
+ return nil, fmt.Errorf("JWK EC Public Key ID: %s", err)
+ }
+ if kid != key.KeyID() {
+ return nil, fmt.Errorf("JWK EC Public Key ID does not match: %s", kid)
+ }
+ }
+
+ key.extended = jwk
+
+ return key, nil
+}
+
+/*
+ * EC DSA PRIVATE KEY
+ */
+
+// ecPrivateKey implements a JWK Private Key using elliptic curve digital signature
+// algorithms.
+type ecPrivateKey struct {
+ ecPublicKey
+ *ecdsa.PrivateKey
+}
+
+func fromECPrivateKey(cryptoPrivateKey *ecdsa.PrivateKey) (*ecPrivateKey, error) {
+ publicKey, err := fromECPublicKey(&cryptoPrivateKey.PublicKey)
+ if err != nil {
+ return nil, err
+ }
+
+ return &ecPrivateKey{*publicKey, cryptoPrivateKey}, nil
+}
+
+// PublicKey returns the Public Key data associated with this Private Key.
+func (k *ecPrivateKey) PublicKey() PublicKey {
+ return &k.ecPublicKey
+}
+
+func (k *ecPrivateKey) String() string {
+ return fmt.Sprintf("EC Private Key <%s>", k.KeyID())
+}
+
+// Sign signs the data read from the io.Reader using a signature algorithm supported
+// by the elliptic curve private key. If the specified hashing algorithm is
+// supported by this key, that hash function is used to generate the signature
+// otherwise the the default hashing algorithm for this key is used. Returns
+// the signature and the name of the JWK signature algorithm used, e.g.,
+// "ES256", "ES384", "ES512".
+func (k *ecPrivateKey) Sign(data io.Reader, hashID crypto.Hash) (signature []byte, alg string, err error) {
+ // Generate a signature of the data using the internal alg.
+ // The given hashId is only a suggestion, and since EC keys only support
+ // on signature/hash algorithm given the curve name, we disregard it for
+ // the elliptic curve JWK signature implementation.
+ r, s, err := k.sign(data, hashID)
+ if err != nil {
+ return nil, "", fmt.Errorf("error producing signature: %s", err)
+ }
+
+ rBytes, sBytes := r.Bytes(), s.Bytes()
+ octetLength := (k.ecPublicKey.Params().BitSize + 7) >> 3
+ // MUST include leading zeros in the output
+ rBuf := make([]byte, octetLength-len(rBytes), octetLength)
+ sBuf := make([]byte, octetLength-len(sBytes), octetLength)
+
+ rBuf = append(rBuf, rBytes...)
+ sBuf = append(sBuf, sBytes...)
+
+ signature = append(rBuf, sBuf...)
+ alg = k.signatureAlgorithm.HeaderParam()
+
+ return
+}
+
+// CryptoPrivateKey returns the internal object which can be used as a
+// crypto.PublicKey for use with other standard library operations. The type
+// is either *rsa.PublicKey or *ecdsa.PublicKey
+func (k *ecPrivateKey) CryptoPrivateKey() crypto.PrivateKey {
+ return k.PrivateKey
+}
+
+func (k *ecPrivateKey) toMap() map[string]interface{} {
+ jwk := k.ecPublicKey.toMap()
+
+ dBytes := k.D.Bytes()
+ // The length of this octet string MUST be ceiling(log-base-2(n)/8)
+ // octets (where n is the order of the curve). This is because the private
+ // key d must be in the interval [1, n-1] so the bitlength of d should be
+ // no larger than the bitlength of n-1. The easiest way to find the octet
+ // length is to take bitlength(n-1), add 7 to force a carry, and shift this
+ // bit sequence right by 3, which is essentially dividing by 8 and adding
+ // 1 if there is any remainder. Thus, the private key value d should be
+ // output to (bitlength(n-1)+7)>>3 octets.
+ n := k.ecPublicKey.Params().N
+ octetLength := (new(big.Int).Sub(n, big.NewInt(1)).BitLen() + 7) >> 3
+ // Create a buffer with the necessary zero-padding.
+ dBuf := make([]byte, octetLength-len(dBytes), octetLength)
+ dBuf = append(dBuf, dBytes...)
+
+ jwk["d"] = joseBase64UrlEncode(dBuf)
+
+ return jwk
+}
+
+// MarshalJSON serializes this Private Key using the JWK JSON serialization format for
+// elliptic curve keys.
+func (k *ecPrivateKey) MarshalJSON() (data []byte, err error) {
+ return json.Marshal(k.toMap())
+}
+
+// PEMBlock serializes this Private Key to DER-encoded PKIX format.
+func (k *ecPrivateKey) PEMBlock() (*pem.Block, error) {
+ derBytes, err := x509.MarshalECPrivateKey(k.PrivateKey)
+ if err != nil {
+ return nil, fmt.Errorf("unable to serialize EC PrivateKey to DER-encoded PKIX format: %s", err)
+ }
+ k.extended["keyID"] = k.KeyID() // For display purposes.
+ return createPemBlock("EC PRIVATE KEY", derBytes, k.extended)
+}
+
+func ecPrivateKeyFromMap(jwk map[string]interface{}) (*ecPrivateKey, error) {
+ dB64Url, err := stringFromMap(jwk, "d")
+ if err != nil {
+ return nil, fmt.Errorf("JWK EC Private Key: %s", err)
+ }
+
+ // JWK key type (kty) has already been determined to be "EC".
+ // Need to extract the public key information, then extract the private
+ // key value 'd'.
+ publicKey, err := ecPublicKeyFromMap(jwk)
+ if err != nil {
+ return nil, err
+ }
+
+ d, err := parseECPrivateParam(dB64Url, publicKey.Curve)
+ if err != nil {
+ return nil, fmt.Errorf("JWK EC Private Key d-param: %s", err)
+ }
+
+ key := &ecPrivateKey{
+ ecPublicKey: *publicKey,
+ PrivateKey: &ecdsa.PrivateKey{
+ PublicKey: *publicKey.PublicKey,
+ D: d,
+ },
+ }
+
+ return key, nil
+}
+
+/*
+ * Key Generation Functions.
+ */
+
+func generateECPrivateKey(curve elliptic.Curve) (k *ecPrivateKey, err error) {
+ k = new(ecPrivateKey)
+ k.PrivateKey, err = ecdsa.GenerateKey(curve, rand.Reader)
+ if err != nil {
+ return nil, err
+ }
+
+ k.ecPublicKey.PublicKey = &k.PrivateKey.PublicKey
+ k.extended = make(map[string]interface{})
+
+ return
+}
+
+// GenerateECP256PrivateKey generates a key pair using elliptic curve P-256.
+func GenerateECP256PrivateKey() (PrivateKey, error) {
+ k, err := generateECPrivateKey(elliptic.P256())
+ if err != nil {
+ return nil, fmt.Errorf("error generating EC P-256 key: %s", err)
+ }
+
+ k.curveName = "P-256"
+ k.signatureAlgorithm = es256
+
+ return k, nil
+}
+
+// GenerateECP384PrivateKey generates a key pair using elliptic curve P-384.
+func GenerateECP384PrivateKey() (PrivateKey, error) {
+ k, err := generateECPrivateKey(elliptic.P384())
+ if err != nil {
+ return nil, fmt.Errorf("error generating EC P-384 key: %s", err)
+ }
+
+ k.curveName = "P-384"
+ k.signatureAlgorithm = es384
+
+ return k, nil
+}
+
+// GenerateECP521PrivateKey generates aß key pair using elliptic curve P-521.
+func GenerateECP521PrivateKey() (PrivateKey, error) {
+ k, err := generateECPrivateKey(elliptic.P521())
+ if err != nil {
+ return nil, fmt.Errorf("error generating EC P-521 key: %s", err)
+ }
+
+ k.curveName = "P-521"
+ k.signatureAlgorithm = es512
+
+ return k, nil
+}
diff --git a/ec_key_no_openssl.go b/ec_key_no_openssl.go
new file mode 100644
index 0000000..d6cdaca
--- /dev/null
+++ b/ec_key_no_openssl.go
@@ -0,0 +1,23 @@
+// +build !libtrust_openssl
+
+package libtrust
+
+import (
+ "crypto"
+ "crypto/ecdsa"
+ "crypto/rand"
+ "fmt"
+ "io"
+ "math/big"
+)
+
+func (k *ecPrivateKey) sign(data io.Reader, hashID crypto.Hash) (r, s *big.Int, err error) {
+ hasher := k.signatureAlgorithm.HashID().New()
+ _, err = io.Copy(hasher, data)
+ if err != nil {
+ return nil, nil, fmt.Errorf("error reading data to sign: %s", err)
+ }
+ hash := hasher.Sum(nil)
+
+ return ecdsa.Sign(rand.Reader, k.PrivateKey, hash)
+}
diff --git a/ec_key_openssl.go b/ec_key_openssl.go
new file mode 100644
index 0000000..4137511
--- /dev/null
+++ b/ec_key_openssl.go
@@ -0,0 +1,24 @@
+// +build libtrust_openssl
+
+package libtrust
+
+import (
+ "bytes"
+ "crypto"
+ "crypto/ecdsa"
+ "crypto/rand"
+ "fmt"
+ "io"
+ "math/big"
+)
+
+func (k *ecPrivateKey) sign(data io.Reader, hashID crypto.Hash) (r, s *big.Int, err error) {
+ hId := k.signatureAlgorithm.HashID()
+ buf := new(bytes.Buffer)
+ _, err = buf.ReadFrom(data)
+ if err != nil {
+ return nil, nil, fmt.Errorf("error reading data: %s", err)
+ }
+
+ return ecdsa.HashSign(rand.Reader, k.PrivateKey, buf.Bytes(), hId)
+}
diff --git a/ec_key_test.go b/ec_key_test.go
new file mode 100644
index 0000000..26ac381
--- /dev/null
+++ b/ec_key_test.go
@@ -0,0 +1,157 @@
+package libtrust
+
+import (
+ "bytes"
+ "encoding/json"
+ "testing"
+)
+
+func generateECTestKeys(t *testing.T) []PrivateKey {
+ p256Key, err := GenerateECP256PrivateKey()
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ p384Key, err := GenerateECP384PrivateKey()
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ p521Key, err := GenerateECP521PrivateKey()
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ return []PrivateKey{p256Key, p384Key, p521Key}
+}
+
+func TestECKeys(t *testing.T) {
+ ecKeys := generateECTestKeys(t)
+
+ for _, ecKey := range ecKeys {
+ if ecKey.KeyType() != "EC" {
+ t.Fatalf("key type must be %q, instead got %q", "EC", ecKey.KeyType())
+ }
+ }
+}
+
+func TestECSignVerify(t *testing.T) {
+ ecKeys := generateECTestKeys(t)
+
+ message := "Hello, World!"
+ data := bytes.NewReader([]byte(message))
+
+ sigAlgs := []*signatureAlgorithm{es256, es384, es512}
+
+ for i, ecKey := range ecKeys {
+ sigAlg := sigAlgs[i]
+
+ t.Logf("%s signature of %q with kid: %s\n", sigAlg.HeaderParam(), message, ecKey.KeyID())
+
+ data.Seek(0, 0) // Reset the byte reader
+
+ // Sign
+ sig, alg, err := ecKey.Sign(data, sigAlg.HashID())
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ data.Seek(0, 0) // Reset the byte reader
+
+ // Verify
+ err = ecKey.Verify(data, alg, sig)
+ if err != nil {
+ t.Fatal(err)
+ }
+ }
+}
+
+func TestMarshalUnmarshalECKeys(t *testing.T) {
+ ecKeys := generateECTestKeys(t)
+ data := bytes.NewReader([]byte("This is a test. I repeat: this is only a test."))
+ sigAlgs := []*signatureAlgorithm{es256, es384, es512}
+
+ for i, ecKey := range ecKeys {
+ sigAlg := sigAlgs[i]
+ privateJWKJSON, err := json.MarshalIndent(ecKey, "", " ")
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ publicJWKJSON, err := json.MarshalIndent(ecKey.PublicKey(), "", " ")
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ t.Logf("JWK Private Key: %s", string(privateJWKJSON))
+ t.Logf("JWK Public Key: %s", string(publicJWKJSON))
+
+ privKey2, err := UnmarshalPrivateKeyJWK(privateJWKJSON)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ pubKey2, err := UnmarshalPublicKeyJWK(publicJWKJSON)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ // Ensure we can sign/verify a message with the unmarshalled keys.
+ data.Seek(0, 0) // Reset the byte reader
+ signature, alg, err := privKey2.Sign(data, sigAlg.HashID())
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ data.Seek(0, 0) // Reset the byte reader
+ err = pubKey2.Verify(data, alg, signature)
+ if err != nil {
+ t.Fatal(err)
+ }
+ }
+}
+
+func TestFromCryptoECKeys(t *testing.T) {
+ ecKeys := generateECTestKeys(t)
+
+ for _, ecKey := range ecKeys {
+ cryptoPrivateKey := ecKey.CryptoPrivateKey()
+ cryptoPublicKey := ecKey.CryptoPublicKey()
+
+ pubKey, err := FromCryptoPublicKey(cryptoPublicKey)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ if pubKey.KeyID() != ecKey.KeyID() {
+ t.Fatal("public key key ID mismatch")
+ }
+
+ privKey, err := FromCryptoPrivateKey(cryptoPrivateKey)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ if privKey.KeyID() != ecKey.KeyID() {
+ t.Fatal("public key key ID mismatch")
+ }
+ }
+}
+
+func TestExtendedFields(t *testing.T) {
+ key, err := GenerateECP256PrivateKey()
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ key.AddExtendedField("test", "foobar")
+ val := key.GetExtendedField("test")
+
+ gotVal, ok := val.(string)
+ if !ok {
+ t.Fatalf("value is not a string")
+ } else if gotVal != val {
+ t.Fatalf("value %q is not equal to %q", gotVal, val)
+ }
+
+}
diff --git a/filter.go b/filter.go
new file mode 100644
index 0000000..5b2b4fc
--- /dev/null
+++ b/filter.go
@@ -0,0 +1,50 @@
+package libtrust
+
+import (
+ "path/filepath"
+)
+
+// FilterByHosts filters the list of PublicKeys to only those which contain a
+// 'hosts' pattern which matches the given host. If *includeEmpty* is true,
+// then keys which do not specify any hosts are also returned.
+func FilterByHosts(keys []PublicKey, host string, includeEmpty bool) ([]PublicKey, error) {
+ filtered := make([]PublicKey, 0, len(keys))
+
+ for _, pubKey := range keys {
+ var hosts []string
+ switch v := pubKey.GetExtendedField("hosts").(type) {
+ case []string:
+ hosts = v
+ case []interface{}:
+ for _, value := range v {
+ h, ok := value.(string)
+ if !ok {
+ continue
+ }
+ hosts = append(hosts, h)
+ }
+ }
+
+ if len(hosts) == 0 {
+ if includeEmpty {
+ filtered = append(filtered, pubKey)
+ }
+ continue
+ }
+
+ // Check if any hosts match pattern
+ for _, hostPattern := range hosts {
+ match, err := filepath.Match(hostPattern, host)
+ if err != nil {
+ return nil, err
+ }
+
+ if match {
+ filtered = append(filtered, pubKey)
+ continue
+ }
+ }
+ }
+
+ return filtered, nil
+}
diff --git a/filter_test.go b/filter_test.go
new file mode 100644
index 0000000..997e554
--- /dev/null
+++ b/filter_test.go
@@ -0,0 +1,81 @@
+package libtrust
+
+import (
+ "testing"
+)
+
+func compareKeySlices(t *testing.T, sliceA, sliceB []PublicKey) {
+ if len(sliceA) != len(sliceB) {
+ t.Fatalf("slice size %d, expected %d", len(sliceA), len(sliceB))
+ }
+
+ for i, itemA := range sliceA {
+ itemB := sliceB[i]
+ if itemA != itemB {
+ t.Fatalf("slice index %d not equal: %#v != %#v", i, itemA, itemB)
+ }
+ }
+}
+
+func TestFilter(t *testing.T) {
+ keys := make([]PublicKey, 0, 8)
+
+ // Create 8 keys and add host entries.
+ for i := 0; i < cap(keys); i++ {
+ key, err := GenerateECP256PrivateKey()
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ // we use both []interface{} and []string here because jwt uses
+ // []interface{} format, while PEM uses []string
+ switch {
+ case i == 0:
+ // Don't add entries for this key, key 0.
+ break
+ case i%2 == 0:
+ // Should catch keys 2, 4, and 6.
+ key.AddExtendedField("hosts", []interface{}{"*.even.example.com"})
+ case i == 7:
+ // Should catch only the last key, and make it match any hostname.
+ key.AddExtendedField("hosts", []string{"*"})
+ default:
+ // should catch keys 1, 3, 5.
+ key.AddExtendedField("hosts", []string{"*.example.com"})
+ }
+
+ keys = append(keys, key)
+ }
+
+ // Should match 2 keys, the empty one, and the one that matches all hosts.
+ matchedKeys, err := FilterByHosts(keys, "foo.bar.com", true)
+ if err != nil {
+ t.Fatal(err)
+ }
+ expectedMatch := []PublicKey{keys[0], keys[7]}
+ compareKeySlices(t, expectedMatch, matchedKeys)
+
+ // Should match 1 key, the one that matches any host.
+ matchedKeys, err = FilterByHosts(keys, "foo.bar.com", false)
+ if err != nil {
+ t.Fatal(err)
+ }
+ expectedMatch = []PublicKey{keys[7]}
+ compareKeySlices(t, expectedMatch, matchedKeys)
+
+ // Should match keys that end in "example.com", and the key that matches anything.
+ matchedKeys, err = FilterByHosts(keys, "foo.example.com", false)
+ if err != nil {
+ t.Fatal(err)
+ }
+ expectedMatch = []PublicKey{keys[1], keys[3], keys[5], keys[7]}
+ compareKeySlices(t, expectedMatch, matchedKeys)
+
+ // Should match all of the keys except the empty key.
+ matchedKeys, err = FilterByHosts(keys, "foo.even.example.com", false)
+ if err != nil {
+ t.Fatal(err)
+ }
+ expectedMatch = keys[1:]
+ compareKeySlices(t, expectedMatch, matchedKeys)
+}
diff --git a/hash.go b/hash.go
new file mode 100644
index 0000000..a2df787
--- /dev/null
+++ b/hash.go
@@ -0,0 +1,56 @@
+package libtrust
+
+import (
+ "crypto"
+ _ "crypto/sha256" // Registrer SHA224 and SHA256
+ _ "crypto/sha512" // Registrer SHA384 and SHA512
+ "fmt"
+)
+
+type signatureAlgorithm struct {
+ algHeaderParam string
+ hashID crypto.Hash
+}
+
+func (h *signatureAlgorithm) HeaderParam() string {
+ return h.algHeaderParam
+}
+
+func (h *signatureAlgorithm) HashID() crypto.Hash {
+ return h.hashID
+}
+
+var (
+ rs256 = &signatureAlgorithm{"RS256", crypto.SHA256}
+ rs384 = &signatureAlgorithm{"RS384", crypto.SHA384}
+ rs512 = &signatureAlgorithm{"RS512", crypto.SHA512}
+ es256 = &signatureAlgorithm{"ES256", crypto.SHA256}
+ es384 = &signatureAlgorithm{"ES384", crypto.SHA384}
+ es512 = &signatureAlgorithm{"ES512", crypto.SHA512}
+)
+
+func rsaSignatureAlgorithmByName(alg string) (*signatureAlgorithm, error) {
+ switch {
+ case alg == "RS256":
+ return rs256, nil
+ case alg == "RS384":
+ return rs384, nil
+ case alg == "RS512":
+ return rs512, nil
+ default:
+ return nil, fmt.Errorf("RSA Digital Signature Algorithm %q not supported", alg)
+ }
+}
+
+func rsaPKCS1v15SignatureAlgorithmForHashID(hashID crypto.Hash) *signatureAlgorithm {
+ switch {
+ case hashID == crypto.SHA512:
+ return rs512
+ case hashID == crypto.SHA384:
+ return rs384
+ case hashID == crypto.SHA256:
+ fallthrough
+ default:
+ return rs256
+ }
+}
diff --git a/jsonsign.go b/jsonsign.go
new file mode 100644
index 0000000..cb2ca9a
--- /dev/null
+++ b/jsonsign.go
@@ -0,0 +1,657 @@
+package libtrust
+
+import (
+ "bytes"
+ "crypto"
+ "crypto/x509"
+ "encoding/base64"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "sort"
+ "time"
+ "unicode"
+)
+
+var (
+ // ErrInvalidSignContent is used when the content to be signed is invalid.
+ ErrInvalidSignContent = errors.New("invalid sign content")
+
+ // ErrInvalidJSONContent is used when invalid json is encountered.
+ ErrInvalidJSONContent = errors.New("invalid json content")
+
+ // ErrMissingSignatureKey is used when the specified signature key
+ // does not exist in the JSON content.
+ ErrMissingSignatureKey = errors.New("missing signature key")
+)
+
+type jsHeader struct {
+ JWK PublicKey `json:"jwk,omitempty"`
+ Algorithm string `json:"alg"`
+ Chain []string `json:"x5c,omitempty"`
+}
+
+type jsSignature struct {
+ Header jsHeader `json:"header"`
+ Signature string `json:"signature"`
+ Protected string `json:"protected,omitempty"`
+}
+
+type jsSignaturesSorted []jsSignature
+
+func (jsbkid jsSignaturesSorted) Swap(i, j int) { jsbkid[i], jsbkid[j] = jsbkid[j], jsbkid[i] }
+func (jsbkid jsSignaturesSorted) Len() int { return len(jsbkid) }
+
+func (jsbkid jsSignaturesSorted) Less(i, j int) bool {
+ ki, kj := jsbkid[i].Header.JWK.KeyID(), jsbkid[j].Header.JWK.KeyID()
+ si, sj := jsbkid[i].Signature, jsbkid[j].Signature
+
+ if ki == kj {
+ return si < sj
+ }
+
+ return ki < kj
+}
+
+type signKey struct {
+ PrivateKey
+ Chain []*x509.Certificate
+}
+
+// JSONSignature represents a signature of a json object.
+type JSONSignature struct {
+ payload string
+ signatures []jsSignature
+ indent string
+ formatLength int
+ formatTail []byte
+}
+
+func newJSONSignature() *JSONSignature {
+ return &JSONSignature{
+ signatures: make([]jsSignature, 0, 1),
+ }
+}
+
+// Payload returns the encoded payload of the signature. This
+// payload should not be signed directly
+func (js *JSONSignature) Payload() ([]byte, error) {
+ return joseBase64UrlDecode(js.payload)
+}
+
+func (js *JSONSignature) protectedHeader() (string, error) {
+ protected := map[string]interface{}{
+ "formatLength": js.formatLength,
+ "formatTail": joseBase64UrlEncode(js.formatTail),
+ "time": time.Now().UTC().Format(time.RFC3339),
+ }
+ protectedBytes, err := json.Marshal(protected)
+ if err != nil {
+ return "", err
+ }
+
+ return joseBase64UrlEncode(protectedBytes), nil
+}
+
+func (js *JSONSignature) signBytes(protectedHeader string) ([]byte, error) {
+ buf := make([]byte, len(js.payload)+len(protectedHeader)+1)
+ copy(buf, protectedHeader)
+ buf[len(protectedHeader)] = '.'
+ copy(buf[len(protectedHeader)+1:], js.payload)
+ return buf, nil
+}
+
+// Sign adds a signature using the given private key.
+func (js *JSONSignature) Sign(key PrivateKey) error {
+ protected, err := js.protectedHeader()
+ if err != nil {
+ return err
+ }
+ signBytes, err := js.signBytes(protected)
+ if err != nil {
+ return err
+ }
+ sigBytes, algorithm, err := key.Sign(bytes.NewReader(signBytes), crypto.SHA256)
+ if err != nil {
+ return err
+ }
+
+ js.signatures = append(js.signatures, jsSignature{
+ Header: jsHeader{
+ JWK: key.PublicKey(),
+ Algorithm: algorithm,
+ },
+ Signature: joseBase64UrlEncode(sigBytes),
+ Protected: protected,
+ })
+
+ return nil
+}
+
+// SignWithChain adds a signature using the given private key
+// and setting the x509 chain. The public key of the first element
+// in the chain must be the public key corresponding with the sign key.
+func (js *JSONSignature) SignWithChain(key PrivateKey, chain []*x509.Certificate) error {
+ // Ensure key.Chain[0] is public key for key
+ //key.Chain.PublicKey
+ //key.PublicKey().CryptoPublicKey()
+
+ // Verify chain
+ protected, err := js.protectedHeader()
+ if err != nil {
+ return err
+ }
+ signBytes, err := js.signBytes(protected)
+ if err != nil {
+ return err
+ }
+ sigBytes, algorithm, err := key.Sign(bytes.NewReader(signBytes), crypto.SHA256)
+ if err != nil {
+ return err
+ }
+
+ header := jsHeader{
+ Chain: make([]string, len(chain)),
+ Algorithm: algorithm,
+ }
+
+ for i, cert := range chain {
+ header.Chain[i] = base64.StdEncoding.EncodeToString(cert.Raw)
+ }
+
+ js.signatures = append(js.signatures, jsSignature{
+ Header: header,
+ Signature: joseBase64UrlEncode(sigBytes),
+ Protected: protected,
+ })
+
+ return nil
+}
+
+// Verify verifies all the signatures and returns the list of
+// public keys used to sign. Any x509 chains are not checked.
+func (js *JSONSignature) Verify() ([]PublicKey, error) {
+ keys := make([]PublicKey, len(js.signatures))
+ for i, signature := range js.signatures {
+ signBytes, err := js.signBytes(signature.Protected)
+ if err != nil {
+ return nil, err
+ }
+ var publicKey PublicKey
+ if len(signature.Header.Chain) > 0 {
+ certBytes, err := base64.StdEncoding.DecodeString(signature.Header.Chain[0])
+ if err != nil {
+ return nil, err
+ }
+ cert, err := x509.ParseCertificate(certBytes)
+ if err != nil {
+ return nil, err
+ }
+ publicKey, err = FromCryptoPublicKey(cert.PublicKey)
+ if err != nil {
+ return nil, err
+ }
+ } else if signature.Header.JWK != nil {
+ publicKey = signature.Header.JWK
+ } else {
+ return nil, errors.New("missing public key")
+ }
+
+ sigBytes, err := joseBase64UrlDecode(signature.Signature)
+ if err != nil {
+ return nil, err
+ }
+
+ err = publicKey.Verify(bytes.NewReader(signBytes), signature.Header.Algorithm, sigBytes)
+ if err != nil {
+ return nil, err
+ }
+
+ keys[i] = publicKey
+ }
+ return keys, nil
+}
+
+// VerifyChains verifies all the signatures and the chains associated
+// with each signature and returns the list of verified chains.
+// Signatures without an x509 chain are not checked.
+func (js *JSONSignature) VerifyChains(ca *x509.CertPool) ([][]*x509.Certificate, error) {
+ chains := make([][]*x509.Certificate, 0, len(js.signatures))
+ for _, signature := range js.signatures {
+ signBytes, err := js.signBytes(signature.Protected)
+ if err != nil {
+ return nil, err
+ }
+ var publicKey PublicKey
+ if len(signature.Header.Chain) > 0 {
+ certBytes, err := base64.StdEncoding.DecodeString(signature.Header.Chain[0])
+ if err != nil {
+ return nil, err
+ }
+ cert, err := x509.ParseCertificate(certBytes)
+ if err != nil {
+ return nil, err
+ }
+ publicKey, err = FromCryptoPublicKey(cert.PublicKey)
+ if err != nil {
+ return nil, err
+ }
+ intermediates := x509.NewCertPool()
+ if len(signature.Header.Chain) > 1 {
+ intermediateChain := signature.Header.Chain[1:]
+ for i := range intermediateChain {
+ certBytes, err := base64.StdEncoding.DecodeString(intermediateChain[i])
+ if err != nil {
+ return nil, err
+ }
+ intermediate, err := x509.ParseCertificate(certBytes)
+ if err != nil {
+ return nil, err
+ }
+ intermediates.AddCert(intermediate)
+ }
+ }
+
+ verifyOptions := x509.VerifyOptions{
+ Intermediates: intermediates,
+ Roots: ca,
+ }
+
+ verifiedChains, err := cert.Verify(verifyOptions)
+ if err != nil {
+ return nil, err
+ }
+ chains = append(chains, verifiedChains...)
+
+ sigBytes, err := joseBase64UrlDecode(signature.Signature)
+ if err != nil {
+ return nil, err
+ }
+
+ err = publicKey.Verify(bytes.NewReader(signBytes), signature.Header.Algorithm, sigBytes)
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ }
+ return chains, nil
+}
+
+// JWS returns JSON serialized JWS according to
+// http://tools.ietf.org/html/draft-ietf-jose-json-web-signature-31#section-7.2
+func (js *JSONSignature) JWS() ([]byte, error) {
+ if len(js.signatures) == 0 {
+ return nil, errors.New("missing signature")
+ }
+
+ sort.Sort(jsSignaturesSorted(js.signatures))
+
+ jsonMap := map[string]interface{}{
+ "payload": js.payload,
+ "signatures": js.signatures,
+ }
+
+ return json.MarshalIndent(jsonMap, "", " ")
+}
+
+func notSpace(r rune) bool {
+ return !unicode.IsSpace(r)
+}
+
+func detectJSONIndent(jsonContent []byte) (indent string) {
+ if len(jsonContent) > 2 && jsonContent[0] == '{' && jsonContent[1] == '\n' {
+ quoteIndex := bytes.IndexRune(jsonContent[1:], '"')
+ if quoteIndex > 0 {
+ indent = string(jsonContent[2 : quoteIndex+1])
+ }
+ }
+ return
+}
+
+type jsParsedHeader struct {
+ JWK json.RawMessage `json:"jwk"`
+ Algorithm string `json:"alg"`
+ Chain []string `json:"x5c"`
+}
+
+type jsParsedSignature struct {
+ Header jsParsedHeader `json:"header"`
+ Signature string `json:"signature"`
+ Protected string `json:"protected"`
+}
+
+// ParseJWS parses a JWS serialized JSON object into a Json Signature.
+func ParseJWS(content []byte) (*JSONSignature, error) {
+ type jsParsed struct {
+ Payload string `json:"payload"`
+ Signatures []jsParsedSignature `json:"signatures"`
+ }
+ parsed := &jsParsed{}
+ err := json.Unmarshal(content, parsed)
+ if err != nil {
+ return nil, err
+ }
+ if len(parsed.Signatures) == 0 {
+ return nil, errors.New("missing signatures")
+ }
+ payload, err := joseBase64UrlDecode(parsed.Payload)
+ if err != nil {
+ return nil, err
+ }
+
+ js, err := NewJSONSignature(payload)
+ if err != nil {
+ return nil, err
+ }
+ js.signatures = make([]jsSignature, len(parsed.Signatures))
+ for i, signature := range parsed.Signatures {
+ header := jsHeader{
+ Algorithm: signature.Header.Algorithm,
+ }
+ if signature.Header.Chain != nil {
+ header.Chain = signature.Header.Chain
+ }
+ if signature.Header.JWK != nil {
+ publicKey, err := UnmarshalPublicKeyJWK([]byte(signature.Header.JWK))
+ if err != nil {
+ return nil, err
+ }
+ header.JWK = publicKey
+ }
+ js.signatures[i] = jsSignature{
+ Header: header,
+ Signature: signature.Signature,
+ Protected: signature.Protected,
+ }
+ }
+
+ return js, nil
+}
+
+// NewJSONSignature returns a new unsigned JWS from a json byte array.
+// JSONSignature will need to be signed before serializing or storing.
+// Optionally, one or more signatures can be provided as byte buffers,
+// containing serialized JWS signatures, to assemble a fully signed JWS
+// package. It is the callers responsibility to ensure uniqueness of the
+// provided signatures.
+func NewJSONSignature(content []byte, signatures ...[]byte) (*JSONSignature, error) {
+ var dataMap map[string]interface{}
+ err := json.Unmarshal(content, &dataMap)
+ if err != nil {
+ return nil, err
+ }
+
+ js := newJSONSignature()
+ js.indent = detectJSONIndent(content)
+
+ js.payload = joseBase64UrlEncode(content)
+
+ // Find trailing } and whitespace, put in protected header
+ closeIndex := bytes.LastIndexFunc(content, notSpace)
+ if content[closeIndex] != '}' {
+ return nil, ErrInvalidJSONContent
+ }
+ lastRuneIndex := bytes.LastIndexFunc(content[:closeIndex], notSpace)
+ if content[lastRuneIndex] == ',' {
+ return nil, ErrInvalidJSONContent
+ }
+ js.formatLength = lastRuneIndex + 1
+ js.formatTail = content[js.formatLength:]
+
+ if len(signatures) > 0 {
+ for _, signature := range signatures {
+ var parsedJSig jsParsedSignature
+
+ if err := json.Unmarshal(signature, &parsedJSig); err != nil {
+ return nil, err
+ }
+
+ // TODO(stevvooe): A lot of the code below is repeated in
+ // ParseJWS. It will require more refactoring to fix that.
+ jsig := jsSignature{
+ Header: jsHeader{
+ Algorithm: parsedJSig.Header.Algorithm,
+ },
+ Signature: parsedJSig.Signature,
+ Protected: parsedJSig.Protected,
+ }
+
+ if parsedJSig.Header.Chain != nil {
+ jsig.Header.Chain = parsedJSig.Header.Chain
+ }
+
+ if parsedJSig.Header.JWK != nil {
+ publicKey, err := UnmarshalPublicKeyJWK([]byte(parsedJSig.Header.JWK))
+ if err != nil {
+ return nil, err
+ }
+ jsig.Header.JWK = publicKey
+ }
+
+ js.signatures = append(js.signatures, jsig)
+ }
+ }
+
+ return js, nil
+}
+
+// NewJSONSignatureFromMap returns a new unsigned JSONSignature from a map or
+// struct. JWS will need to be signed before serializing or storing.
+func NewJSONSignatureFromMap(content interface{}) (*JSONSignature, error) {
+ switch content.(type) {
+ case map[string]interface{}:
+ case struct{}:
+ default:
+ return nil, errors.New("invalid data type")
+ }
+
+ js := newJSONSignature()
+ js.indent = " "
+
+ payload, err := json.MarshalIndent(content, "", js.indent)
+ if err != nil {
+ return nil, err
+ }
+ js.payload = joseBase64UrlEncode(payload)
+
+ // Remove '\n}' from formatted section, put in protected header
+ js.formatLength = len(payload) - 2
+ js.formatTail = payload[js.formatLength:]
+
+ return js, nil
+}
+
+func readIntFromMap(key string, m map[string]interface{}) (int, bool) {
+ value, ok := m[key]
+ if !ok {
+ return 0, false
+ }
+ switch v := value.(type) {
+ case int:
+ return v, true
+ case float64:
+ return int(v), true
+ default:
+ return 0, false
+ }
+}
+
+func readStringFromMap(key string, m map[string]interface{}) (v string, ok bool) {
+ value, ok := m[key]
+ if !ok {
+ return "", false
+ }
+ v, ok = value.(string)
+ return
+}
+
+// ParsePrettySignature parses a formatted signature into a
+// JSON signature. If the signatures are missing the format information
+// an error is thrown. The formatted signature must be created by
+// the same method as format signature.
+func ParsePrettySignature(content []byte, signatureKey string) (*JSONSignature, error) {
+ var contentMap map[string]json.RawMessage
+ err := json.Unmarshal(content, &contentMap)
+ if err != nil {
+ return nil, fmt.Errorf("error unmarshalling content: %s", err)
+ }
+ sigMessage, ok := contentMap[signatureKey]
+ if !ok {
+ return nil, ErrMissingSignatureKey
+ }
+
+ var signatureBlocks []jsParsedSignature
+ err = json.Unmarshal([]byte(sigMessage), &signatureBlocks)
+ if err != nil {
+ return nil, fmt.Errorf("error unmarshalling signatures: %s", err)
+ }
+
+ js := newJSONSignature()
+ js.signatures = make([]jsSignature, len(signatureBlocks))
+
+ for i, signatureBlock := range signatureBlocks {
+ protectedBytes, err := joseBase64UrlDecode(signatureBlock.Protected)
+ if err != nil {
+ return nil, fmt.Errorf("base64 decode error: %s", err)
+ }
+ var protectedHeader map[string]interface{}
+ err = json.Unmarshal(protectedBytes, &protectedHeader)
+ if err != nil {
+ return nil, fmt.Errorf("error unmarshalling protected header: %s", err)
+ }
+
+ formatLength, ok := readIntFromMap("formatLength", protectedHeader)
+ if !ok {
+ return nil, errors.New("missing formatted length")
+ }
+ encodedTail, ok := readStringFromMap("formatTail", protectedHeader)
+ if !ok {
+ return nil, errors.New("missing formatted tail")
+ }
+ formatTail, err := joseBase64UrlDecode(encodedTail)
+ if err != nil {
+ return nil, fmt.Errorf("base64 decode error on tail: %s", err)
+ }
+ if js.formatLength == 0 {
+ js.formatLength = formatLength
+ } else if js.formatLength != formatLength {
+ return nil, errors.New("conflicting format length")
+ }
+ if len(js.formatTail) == 0 {
+ js.formatTail = formatTail
+ } else if bytes.Compare(js.formatTail, formatTail) != 0 {
+ return nil, errors.New("conflicting format tail")
+ }
+
+ header := jsHeader{
+ Algorithm: signatureBlock.Header.Algorithm,
+ Chain: signatureBlock.Header.Chain,
+ }
+ if signatureBlock.Header.JWK != nil {
+ publicKey, err := UnmarshalPublicKeyJWK([]byte(signatureBlock.Header.JWK))
+ if err != nil {
+ return nil, fmt.Errorf("error unmarshalling public key: %s", err)
+ }
+ header.JWK = publicKey
+ }
+ js.signatures[i] = jsSignature{
+ Header: header,
+ Signature: signatureBlock.Signature,
+ Protected: signatureBlock.Protected,
+ }
+ }
+ if js.formatLength > len(content) {
+ return nil, errors.New("invalid format length")
+ }
+ formatted := make([]byte, js.formatLength+len(js.formatTail))
+ copy(formatted, content[:js.formatLength])
+ copy(formatted[js.formatLength:], js.formatTail)
+ js.indent = detectJSONIndent(formatted)
+ js.payload = joseBase64UrlEncode(formatted)
+
+ return js, nil
+}
+
+// PrettySignature formats a json signature into an easy to read
+// single json serialized object.
+func (js *JSONSignature) PrettySignature(signatureKey string) ([]byte, error) {
+ if len(js.signatures) == 0 {
+ return nil, errors.New("no signatures")
+ }
+ payload, err := joseBase64UrlDecode(js.payload)
+ if err != nil {
+ return nil, err
+ }
+ payload = payload[:js.formatLength]
+
+ sort.Sort(jsSignaturesSorted(js.signatures))
+
+ var marshalled []byte
+ var marshallErr error
+ if js.indent != "" {
+ marshalled, marshallErr = json.MarshalIndent(js.signatures, js.indent, js.indent)
+ } else {
+ marshalled, marshallErr = json.Marshal(js.signatures)
+ }
+ if marshallErr != nil {
+ return nil, marshallErr
+ }
+
+ buf := bytes.NewBuffer(make([]byte, 0, len(payload)+len(marshalled)+34))
+ buf.Write(payload)
+ buf.WriteByte(',')
+ if js.indent != "" {
+ buf.WriteByte('\n')
+ buf.WriteString(js.indent)
+ buf.WriteByte('"')
+ buf.WriteString(signatureKey)
+ buf.WriteString("\": ")
+ buf.Write(marshalled)
+ buf.WriteByte('\n')
+ } else {
+ buf.WriteByte('"')
+ buf.WriteString(signatureKey)
+ buf.WriteString("\":")
+ buf.Write(marshalled)
+ }
+ buf.WriteByte('}')
+
+ return buf.Bytes(), nil
+}
+
+// Signatures provides the signatures on this JWS as opaque blobs, sorted by
+// keyID. These blobs can be stored and reassembled with payloads. Internally,
+// they are simply marshaled json web signatures but implementations should
+// not rely on this.
+func (js *JSONSignature) Signatures() ([][]byte, error) {
+ sort.Sort(jsSignaturesSorted(js.signatures))
+
+ var sb [][]byte
+ for _, jsig := range js.signatures {
+ p, err := json.Marshal(jsig)
+ if err != nil {
+ return nil, err
+ }
+
+ sb = append(sb, p)
+ }
+
+ return sb, nil
+}
+
+// Merge combines the signatures from one or more other signatures into the
+// method receiver. If the payloads differ for any argument, an error will be
+// returned and the receiver will not be modified.
+func (js *JSONSignature) Merge(others ...*JSONSignature) error {
+ merged := js.signatures
+ for _, other := range others {
+ if js.payload != other.payload {
+ return fmt.Errorf("payloads differ from merge target")
+ }
+ merged = append(merged, other.signatures...)
+ }
+
+ js.signatures = merged
+ return nil
+}
diff --git a/jsonsign_test.go b/jsonsign_test.go
new file mode 100644
index 0000000..43e26ff
--- /dev/null
+++ b/jsonsign_test.go
@@ -0,0 +1,380 @@
+package libtrust
+
+import (
+ "bytes"
+ "crypto/rand"
+ "crypto/x509"
+ "encoding/json"
+ "fmt"
+ "io"
+ "testing"
+
+ "github.com/containers/libtrust/testutil"
+)
+
+func createTestJSON(sigKey string, indent string) (map[string]interface{}, []byte) {
+ testMap := map[string]interface{}{
+ "name": "dmcgowan/mycontainer",
+ "config": map[string]interface{}{
+ "ports": []int{9101, 9102},
+ "run": "/bin/echo \"Hello\"",
+ },
+ "layers": []string{
+ "2893c080-27f5-11e4-8c21-0800200c9a66",
+ "c54bc25b-fbb2-497b-a899-a8bc1b5b9d55",
+ "4d5d7e03-f908-49f3-a7f6-9ba28dfe0fb4",
+ "0b6da891-7f7f-4abf-9c97-7887549e696c",
+ "1d960389-ae4f-4011-85fd-18d0f96a67ad",
+ },
+ }
+ formattedSection := `{"config":{"ports":[9101,9102],"run":"/bin/echo \"Hello\""},"layers":["2893c080-27f5-11e4-8c21-0800200c9a66","c54bc25b-fbb2-497b-a899-a8bc1b5b9d55","4d5d7e03-f908-49f3-a7f6-9ba28dfe0fb4","0b6da891-7f7f-4abf-9c97-7887549e696c","1d960389-ae4f-4011-85fd-18d0f96a67ad"],"name":"dmcgowan/mycontainer","%s":[{"header":{`
+ formattedSection = fmt.Sprintf(formattedSection, sigKey)
+ if indent != "" {
+ buf := bytes.NewBuffer(nil)
+ json.Indent(buf, []byte(formattedSection), "", indent)
+ return testMap, buf.Bytes()
+ }
+ return testMap, []byte(formattedSection)
+
+}
+
+func TestSignJSON(t *testing.T) {
+ key, err := GenerateECP256PrivateKey()
+ if err != nil {
+ t.Fatalf("Error generating EC key: %s", err)
+ }
+
+ testMap, _ := createTestJSON("buildSignatures", " ")
+ indented, err := json.MarshalIndent(testMap, "", " ")
+ if err != nil {
+ t.Fatalf("Marshall error: %s", err)
+ }
+
+ js, err := NewJSONSignature(indented)
+ if err != nil {
+ t.Fatalf("Error creating JSON signature: %s", err)
+ }
+ err = js.Sign(key)
+ if err != nil {
+ t.Fatalf("Error signing content: %s", err)
+ }
+
+ keys, err := js.Verify()
+ if err != nil {
+ t.Fatalf("Error verifying signature: %s", err)
+ }
+ if len(keys) != 1 {
+ t.Fatalf("Error wrong number of keys returned")
+ }
+ if keys[0].KeyID() != key.KeyID() {
+ t.Fatalf("Unexpected public key returned")
+ }
+
+}
+
+func TestSignMap(t *testing.T) {
+ key, err := GenerateECP256PrivateKey()
+ if err != nil {
+ t.Fatalf("Error generating EC key: %s", err)
+ }
+
+ testMap, _ := createTestJSON("buildSignatures", " ")
+ js, err := NewJSONSignatureFromMap(testMap)
+ if err != nil {
+ t.Fatalf("Error creating JSON signature: %s", err)
+ }
+ err = js.Sign(key)
+ if err != nil {
+ t.Fatalf("Error signing JSON signature: %s", err)
+ }
+
+ keys, err := js.Verify()
+ if err != nil {
+ t.Fatalf("Error verifying signature: %s", err)
+ }
+ if len(keys) != 1 {
+ t.Fatalf("Error wrong number of keys returned")
+ }
+ if keys[0].KeyID() != key.KeyID() {
+ t.Fatalf("Unexpected public key returned")
+ }
+}
+
+func TestFormattedJson(t *testing.T) {
+ key, err := GenerateECP256PrivateKey()
+ if err != nil {
+ t.Fatalf("Error generating EC key: %s", err)
+ }
+
+ testMap, firstSection := createTestJSON("buildSignatures", " ")
+ indented, err := json.MarshalIndent(testMap, "", " ")
+ if err != nil {
+ t.Fatalf("Marshall error: %s", err)
+ }
+
+ js, err := NewJSONSignature(indented)
+ if err != nil {
+ t.Fatalf("Error creating JSON signature: %s", err)
+ }
+ err = js.Sign(key)
+ if err != nil {
+ t.Fatalf("Error signing content: %s", err)
+ }
+
+ b, err := js.PrettySignature("buildSignatures")
+ if err != nil {
+ t.Fatalf("Error signing map: %s", err)
+ }
+
+ if bytes.Compare(b[:len(firstSection)], firstSection) != 0 {
+ t.Fatalf("Wrong signed value\nExpected:\n%s\nActual:\n%s", firstSection, b[:len(firstSection)])
+ }
+
+ parsed, err := ParsePrettySignature(b, "buildSignatures")
+ if err != nil {
+ t.Fatalf("Error parsing formatted signature: %s", err)
+ }
+
+ keys, err := parsed.Verify()
+ if err != nil {
+ t.Fatalf("Error verifying signature: %s", err)
+ }
+ if len(keys) != 1 {
+ t.Fatalf("Error wrong number of keys returned")
+ }
+ if keys[0].KeyID() != key.KeyID() {
+ t.Fatalf("Unexpected public key returned")
+ }
+
+ var unmarshalled map[string]interface{}
+ err = json.Unmarshal(b, &unmarshalled)
+ if err != nil {
+ t.Fatalf("Could not unmarshall after parse: %s", err)
+ }
+
+}
+
+func TestFormattedFlatJson(t *testing.T) {
+ key, err := GenerateECP256PrivateKey()
+ if err != nil {
+ t.Fatalf("Error generating EC key: %s", err)
+ }
+
+ testMap, firstSection := createTestJSON("buildSignatures", "")
+ unindented, err := json.Marshal(testMap)
+ if err != nil {
+ t.Fatalf("Marshall error: %s", err)
+ }
+
+ js, err := NewJSONSignature(unindented)
+ if err != nil {
+ t.Fatalf("Error creating JSON signature: %s", err)
+ }
+ err = js.Sign(key)
+ if err != nil {
+ t.Fatalf("Error signing JSON signature: %s", err)
+ }
+
+ b, err := js.PrettySignature("buildSignatures")
+ if err != nil {
+ t.Fatalf("Error signing map: %s", err)
+ }
+
+ if bytes.Compare(b[:len(firstSection)], firstSection) != 0 {
+ t.Fatalf("Wrong signed value\nExpected:\n%s\nActual:\n%s", firstSection, b[:len(firstSection)])
+ }
+
+ parsed, err := ParsePrettySignature(b, "buildSignatures")
+ if err != nil {
+ t.Fatalf("Error parsing formatted signature: %s", err)
+ }
+
+ keys, err := parsed.Verify()
+ if err != nil {
+ t.Fatalf("Error verifying signature: %s", err)
+ }
+ if len(keys) != 1 {
+ t.Fatalf("Error wrong number of keys returned")
+ }
+ if keys[0].KeyID() != key.KeyID() {
+ t.Fatalf("Unexpected public key returned")
+ }
+}
+
+func generateTrustChain(t *testing.T, key PrivateKey, ca *x509.Certificate) (PrivateKey, []*x509.Certificate) {
+ parent := ca
+ parentKey := key
+ chain := make([]*x509.Certificate, 6)
+ for i := 5; i > 0; i-- {
+ intermediatekey, err := GenerateECP256PrivateKey()
+ if err != nil {
+ t.Fatalf("Error generate key: %s", err)
+ }
+ chain[i], err = testutil.GenerateIntermediate(intermediatekey.CryptoPublicKey(), parentKey.CryptoPrivateKey(), parent)
+ if err != nil {
+ t.Fatalf("Error generating intermdiate certificate: %s", err)
+ }
+ parent = chain[i]
+ parentKey = intermediatekey
+ }
+ trustKey, err := GenerateECP256PrivateKey()
+ if err != nil {
+ t.Fatalf("Error generate key: %s", err)
+ }
+ chain[0], err = testutil.GenerateTrustCert(trustKey.CryptoPublicKey(), parentKey.CryptoPrivateKey(), parent)
+ if err != nil {
+ t.Fatalf("Error generate trust cert: %s", err)
+ }
+
+ return trustKey, chain
+}
+
+func TestChainVerify(t *testing.T) {
+ caKey, err := GenerateECP256PrivateKey()
+ if err != nil {
+ t.Fatalf("Error generating key: %s", err)
+ }
+ ca, err := testutil.GenerateTrustCA(caKey.CryptoPublicKey(), caKey.CryptoPrivateKey())
+ if err != nil {
+ t.Fatalf("Error generating ca: %s", err)
+ }
+ trustKey, chain := generateTrustChain(t, caKey, ca)
+
+ testMap, _ := createTestJSON("verifySignatures", " ")
+ js, err := NewJSONSignatureFromMap(testMap)
+ if err != nil {
+ t.Fatalf("Error creating JSONSignature from map: %s", err)
+ }
+
+ err = js.SignWithChain(trustKey, chain)
+ if err != nil {
+ t.Fatalf("Error signing with chain: %s", err)
+ }
+
+ pool := x509.NewCertPool()
+ pool.AddCert(ca)
+ chains, err := js.VerifyChains(pool)
+ if err != nil {
+ t.Fatalf("Error verifying content: %s", err)
+ }
+ if len(chains) != 1 {
+ t.Fatalf("Unexpected chains length: %d", len(chains))
+ }
+ if len(chains[0]) != 7 {
+ t.Fatalf("Unexpected chain length: %d", len(chains[0]))
+ }
+}
+
+func TestInvalidChain(t *testing.T) {
+ caKey, err := GenerateECP256PrivateKey()
+ if err != nil {
+ t.Fatalf("Error generating key: %s", err)
+ }
+ ca, err := testutil.GenerateTrustCA(caKey.CryptoPublicKey(), caKey.CryptoPrivateKey())
+ if err != nil {
+ t.Fatalf("Error generating ca: %s", err)
+ }
+ trustKey, chain := generateTrustChain(t, caKey, ca)
+
+ testMap, _ := createTestJSON("verifySignatures", " ")
+ js, err := NewJSONSignatureFromMap(testMap)
+ if err != nil {
+ t.Fatalf("Error creating JSONSignature from map: %s", err)
+ }
+
+ err = js.SignWithChain(trustKey, chain[:5])
+ if err != nil {
+ t.Fatalf("Error signing with chain: %s", err)
+ }
+
+ pool := x509.NewCertPool()
+ pool.AddCert(ca)
+ chains, err := js.VerifyChains(pool)
+ if err == nil {
+ t.Fatalf("Expected error verifying with bad chain")
+ }
+ if len(chains) != 0 {
+ t.Fatalf("Unexpected chains returned from invalid verify")
+ }
+}
+
+func TestMergeSignatures(t *testing.T) {
+ pk1, err := GenerateECP256PrivateKey()
+ if err != nil {
+ t.Fatalf("unexpected error generating private key 1: %v", err)
+ }
+
+ pk2, err := GenerateECP256PrivateKey()
+ if err != nil {
+ t.Fatalf("unexpected error generating private key 2: %v", err)
+ }
+
+ payload := make([]byte, 1<<10)
+ if _, err = io.ReadFull(rand.Reader, payload); err != nil {
+ t.Fatalf("error generating payload: %v", err)
+ }
+
+ payload, _ = json.Marshal(map[string]interface{}{"data": payload})
+
+ sig1, err := NewJSONSignature(payload)
+ if err != nil {
+ t.Fatalf("unexpected error creating signature 1: %v", err)
+ }
+
+ if err := sig1.Sign(pk1); err != nil {
+ t.Fatalf("unexpected error signing with pk1: %v", err)
+ }
+
+ sig2, err := NewJSONSignature(payload)
+ if err != nil {
+ t.Fatalf("unexpected error creating signature 2: %v", err)
+ }
+
+ if err := sig2.Sign(pk2); err != nil {
+ t.Fatalf("unexpected error signing with pk2: %v", err)
+ }
+
+ // Now, we actually merge into sig1
+ if err := sig1.Merge(sig2); err != nil {
+ t.Fatalf("unexpected error merging: %v", err)
+ }
+
+ // Verify the new signature package
+ pubkeys, err := sig1.Verify()
+ if err != nil {
+ t.Fatalf("unexpected error during verify: %v", err)
+ }
+
+ // Make sure the pubkeys match the two private keys from before
+ privkeys := map[string]PrivateKey{
+ pk1.KeyID(): pk1,
+ pk2.KeyID(): pk2,
+ }
+
+ found := map[string]struct{}{}
+
+ for _, pubkey := range pubkeys {
+ if _, ok := privkeys[pubkey.KeyID()]; !ok {
+ t.Fatalf("unexpected public key found during verification: %v", pubkey)
+ }
+
+ found[pubkey.KeyID()] = struct{}{}
+ }
+
+ // Make sure we've found all the private keys from verification
+ for keyid, _ := range privkeys {
+ if _, ok := found[keyid]; !ok {
+ t.Fatalf("public key %v not found during verification", keyid)
+ }
+ }
+
+ // Create another signature, with a different payload, and ensure we get an error.
+ sig3, err := NewJSONSignature([]byte("{}"))
+ if err != nil {
+ t.Fatalf("unexpected error making signature for sig3: %v", err)
+ }
+
+ if err := sig1.Merge(sig3); err == nil {
+ t.Fatalf("error expected during invalid merge with different payload")
+ }
+}
diff --git a/key.go b/key.go
new file mode 100644
index 0000000..73642db
--- /dev/null
+++ b/key.go
@@ -0,0 +1,253 @@
+package libtrust
+
+import (
+ "crypto"
+ "crypto/ecdsa"
+ "crypto/rsa"
+ "crypto/x509"
+ "encoding/json"
+ "encoding/pem"
+ "errors"
+ "fmt"
+ "io"
+)
+
+// PublicKey is a generic interface for a Public Key.
+type PublicKey interface {
+ // KeyType returns the key type for this key. For elliptic curve keys,
+ // this value should be "EC". For RSA keys, this value should be "RSA".
+ KeyType() string
+ // KeyID returns a distinct identifier which is unique to this Public Key.
+ // The format generated by this library is a base32 encoding of a 240 bit
+ // hash of the public key data divided into 12 groups like so:
+ // ABCD:EFGH:IJKL:MNOP:QRST:UVWX:YZ23:4567:ABCD:EFGH:IJKL:MNOP
+ KeyID() string
+ // Verify verifyies the signature of the data in the io.Reader using this
+ // Public Key. The alg parameter should identify the digital signature
+ // algorithm which was used to produce the signature and should be
+ // supported by this public key. Returns a nil error if the signature
+ // is valid.
+ Verify(data io.Reader, alg string, signature []byte) error
+ // CryptoPublicKey returns the internal object which can be used as a
+ // crypto.PublicKey for use with other standard library operations. The type
+ // is either *rsa.PublicKey or *ecdsa.PublicKey
+ CryptoPublicKey() crypto.PublicKey
+ // These public keys can be serialized to the standard JSON encoding for
+ // JSON Web Keys. See section 6 of the IETF draft RFC for JOSE JSON Web
+ // Algorithms.
+ MarshalJSON() ([]byte, error)
+ // These keys can also be serialized to the standard PEM encoding.
+ PEMBlock() (*pem.Block, error)
+ // The string representation of a key is its key type and ID.
+ String() string
+ AddExtendedField(string, interface{})
+ GetExtendedField(string) interface{}
+}
+
+// PrivateKey is a generic interface for a Private Key.
+type PrivateKey interface {
+ // A PrivateKey contains all fields and methods of a PublicKey of the
+ // same type. The MarshalJSON method also outputs the private key as a
+ // JSON Web Key, and the PEMBlock method outputs the private key as a
+ // PEM block.
+ PublicKey
+ // PublicKey returns the PublicKey associated with this PrivateKey.
+ PublicKey() PublicKey
+ // Sign signs the data read from the io.Reader using a signature algorithm
+ // supported by the private key. If the specified hashing algorithm is
+ // supported by this key, that hash function is used to generate the
+ // signature otherwise the the default hashing algorithm for this key is
+ // used. Returns the signature and identifier of the algorithm used.
+ Sign(data io.Reader, hashID crypto.Hash) (signature []byte, alg string, err error)
+ // CryptoPrivateKey returns the internal object which can be used as a
+ // crypto.PublicKey for use with other standard library operations. The
+ // type is either *rsa.PublicKey or *ecdsa.PublicKey
+ CryptoPrivateKey() crypto.PrivateKey
+}
+
+// FromCryptoPublicKey returns a libtrust PublicKey representation of the given
+// *ecdsa.PublicKey or *rsa.PublicKey. Returns a non-nil error when the given
+// key is of an unsupported type.
+func FromCryptoPublicKey(cryptoPublicKey crypto.PublicKey) (PublicKey, error) {
+ switch cryptoPublicKey := cryptoPublicKey.(type) {
+ case *ecdsa.PublicKey:
+ return fromECPublicKey(cryptoPublicKey)
+ case *rsa.PublicKey:
+ return fromRSAPublicKey(cryptoPublicKey), nil
+ default:
+ return nil, fmt.Errorf("public key type %T is not supported", cryptoPublicKey)
+ }
+}
+
+// FromCryptoPrivateKey returns a libtrust PrivateKey representation of the given
+// *ecdsa.PrivateKey or *rsa.PrivateKey. Returns a non-nil error when the given
+// key is of an unsupported type.
+func FromCryptoPrivateKey(cryptoPrivateKey crypto.PrivateKey) (PrivateKey, error) {
+ switch cryptoPrivateKey := cryptoPrivateKey.(type) {
+ case *ecdsa.PrivateKey:
+ return fromECPrivateKey(cryptoPrivateKey)
+ case *rsa.PrivateKey:
+ return fromRSAPrivateKey(cryptoPrivateKey), nil
+ default:
+ return nil, fmt.Errorf("private key type %T is not supported", cryptoPrivateKey)
+ }
+}
+
+// UnmarshalPublicKeyPEM parses the PEM encoded data and returns a libtrust
+// PublicKey or an error if there is a problem with the encoding.
+func UnmarshalPublicKeyPEM(data []byte) (PublicKey, error) {
+ pemBlock, _ := pem.Decode(data)
+ if pemBlock == nil {
+ return nil, errors.New("unable to find PEM encoded data")
+ } else if pemBlock.Type != "PUBLIC KEY" {
+ return nil, fmt.Errorf("unable to get PublicKey from PEM type: %s", pemBlock.Type)
+ }
+
+ return pubKeyFromPEMBlock(pemBlock)
+}
+
+// UnmarshalPublicKeyPEMBundle parses the PEM encoded data as a bundle of
+// PEM blocks appended one after the other and returns a slice of PublicKey
+// objects that it finds.
+func UnmarshalPublicKeyPEMBundle(data []byte) ([]PublicKey, error) {
+ pubKeys := []PublicKey{}
+
+ for {
+ var pemBlock *pem.Block
+ pemBlock, data = pem.Decode(data)
+ if pemBlock == nil {
+ break
+ } else if pemBlock.Type != "PUBLIC KEY" {
+ return nil, fmt.Errorf("unable to get PublicKey from PEM type: %s", pemBlock.Type)
+ }
+
+ pubKey, err := pubKeyFromPEMBlock(pemBlock)
+ if err != nil {
+ return nil, err
+ }
+
+ pubKeys = append(pubKeys, pubKey)
+ }
+
+ return pubKeys, nil
+}
+
+// UnmarshalPrivateKeyPEM parses the PEM encoded data and returns a libtrust
+// PrivateKey or an error if there is a problem with the encoding.
+func UnmarshalPrivateKeyPEM(data []byte) (PrivateKey, error) {
+ pemBlock, _ := pem.Decode(data)
+ if pemBlock == nil {
+ return nil, errors.New("unable to find PEM encoded data")
+ }
+
+ var key PrivateKey
+
+ switch {
+ case pemBlock.Type == "RSA PRIVATE KEY":
+ rsaPrivateKey, err := x509.ParsePKCS1PrivateKey(pemBlock.Bytes)
+ if err != nil {
+ return nil, fmt.Errorf("unable to decode RSA Private Key PEM data: %s", err)
+ }
+ key = fromRSAPrivateKey(rsaPrivateKey)
+ case pemBlock.Type == "EC PRIVATE KEY":
+ ecPrivateKey, err := x509.ParseECPrivateKey(pemBlock.Bytes)
+ if err != nil {
+ return nil, fmt.Errorf("unable to decode EC Private Key PEM data: %s", err)
+ }
+ key, err = fromECPrivateKey(ecPrivateKey)
+ if err != nil {
+ return nil, err
+ }
+ default:
+ return nil, fmt.Errorf("unable to get PrivateKey from PEM type: %s", pemBlock.Type)
+ }
+
+ addPEMHeadersToKey(pemBlock, key.PublicKey())
+
+ return key, nil
+}
+
+// UnmarshalPublicKeyJWK unmarshals the given JSON Web Key into a generic
+// Public Key to be used with libtrust.
+func UnmarshalPublicKeyJWK(data []byte) (PublicKey, error) {
+ jwk := make(map[string]interface{})
+
+ err := json.Unmarshal(data, &jwk)
+ if err != nil {
+ return nil, fmt.Errorf(
+ "decoding JWK Public Key JSON data: %s\n", err,
+ )
+ }
+
+ // Get the Key Type value.
+ kty, err := stringFromMap(jwk, "kty")
+ if err != nil {
+ return nil, fmt.Errorf("JWK Public Key type: %s", err)
+ }
+
+ switch {
+ case kty == "EC":
+ // Call out to unmarshal EC public key.
+ return ecPublicKeyFromMap(jwk)
+ case kty == "RSA":
+ // Call out to unmarshal RSA public key.
+ return rsaPublicKeyFromMap(jwk)
+ default:
+ return nil, fmt.Errorf(
+ "JWK Public Key type not supported: %q\n", kty,
+ )
+ }
+}
+
+// UnmarshalPublicKeyJWKSet parses the JSON encoded data as a JSON Web Key Set
+// and returns a slice of Public Key objects.
+func UnmarshalPublicKeyJWKSet(data []byte) ([]PublicKey, error) {
+ rawKeys, err := loadJSONKeySetRaw(data)
+ if err != nil {
+ return nil, err
+ }
+
+ pubKeys := make([]PublicKey, 0, len(rawKeys))
+
+ for _, rawKey := range rawKeys {
+ pubKey, err := UnmarshalPublicKeyJWK(rawKey)
+ if err != nil {
+ return nil, err
+ }
+ pubKeys = append(pubKeys, pubKey)
+ }
+
+ return pubKeys, nil
+}
+
+// UnmarshalPrivateKeyJWK unmarshals the given JSON Web Key into a generic
+// Private Key to be used with libtrust.
+func UnmarshalPrivateKeyJWK(data []byte) (PrivateKey, error) {
+ jwk := make(map[string]interface{})
+
+ err := json.Unmarshal(data, &jwk)
+ if err != nil {
+ return nil, fmt.Errorf(
+ "decoding JWK Private Key JSON data: %s\n", err,
+ )
+ }
+
+ // Get the Key Type value.
+ kty, err := stringFromMap(jwk, "kty")
+ if err != nil {
+ return nil, fmt.Errorf("JWK Private Key type: %s", err)
+ }
+
+ switch {
+ case kty == "EC":
+ // Call out to unmarshal EC private key.
+ return ecPrivateKeyFromMap(jwk)
+ case kty == "RSA":
+ // Call out to unmarshal RSA private key.
+ return rsaPrivateKeyFromMap(jwk)
+ default:
+ return nil, fmt.Errorf(
+ "JWK Private Key type not supported: %q\n", kty,
+ )
+ }
+}
diff --git a/key_files.go b/key_files.go
new file mode 100644
index 0000000..c526de5
--- /dev/null
+++ b/key_files.go
@@ -0,0 +1,255 @@
+package libtrust
+
+import (
+ "encoding/json"
+ "encoding/pem"
+ "errors"
+ "fmt"
+ "io/ioutil"
+ "os"
+ "strings"
+)
+
+var (
+ // ErrKeyFileDoesNotExist indicates that the private key file does not exist.
+ ErrKeyFileDoesNotExist = errors.New("key file does not exist")
+)
+
+func readKeyFileBytes(filename string) ([]byte, error) {
+ data, err := ioutil.ReadFile(filename)
+ if err != nil {
+ if os.IsNotExist(err) {
+ err = ErrKeyFileDoesNotExist
+ } else {
+ err = fmt.Errorf("unable to read key file %s: %s", filename, err)
+ }
+
+ return nil, err
+ }
+
+ return data, nil
+}
+
+/*
+ Loading and Saving of Public and Private Keys in either PEM or JWK format.
+*/
+
+// LoadKeyFile opens the given filename and attempts to read a Private Key
+// encoded in either PEM or JWK format (if .json or .jwk file extension).
+func LoadKeyFile(filename string) (PrivateKey, error) {
+ contents, err := readKeyFileBytes(filename)
+ if err != nil {
+ return nil, err
+ }
+
+ var key PrivateKey
+
+ if strings.HasSuffix(filename, ".json") || strings.HasSuffix(filename, ".jwk") {
+ key, err = UnmarshalPrivateKeyJWK(contents)
+ if err != nil {
+ return nil, fmt.Errorf("unable to decode private key JWK: %s", err)
+ }
+ } else {
+ key, err = UnmarshalPrivateKeyPEM(contents)
+ if err != nil {
+ return nil, fmt.Errorf("unable to decode private key PEM: %s", err)
+ }
+ }
+
+ return key, nil
+}
+
+// LoadPublicKeyFile opens the given filename and attempts to read a Public Key
+// encoded in either PEM or JWK format (if .json or .jwk file extension).
+func LoadPublicKeyFile(filename string) (PublicKey, error) {
+ contents, err := readKeyFileBytes(filename)
+ if err != nil {
+ return nil, err
+ }
+
+ var key PublicKey
+
+ if strings.HasSuffix(filename, ".json") || strings.HasSuffix(filename, ".jwk") {
+ key, err = UnmarshalPublicKeyJWK(contents)
+ if err != nil {
+ return nil, fmt.Errorf("unable to decode public key JWK: %s", err)
+ }
+ } else {
+ key, err = UnmarshalPublicKeyPEM(contents)
+ if err != nil {
+ return nil, fmt.Errorf("unable to decode public key PEM: %s", err)
+ }
+ }
+
+ return key, nil
+}
+
+// SaveKey saves the given key to a file using the provided filename.
+// This process will overwrite any existing file at the provided location.
+func SaveKey(filename string, key PrivateKey) error {
+ var encodedKey []byte
+ var err error
+
+ if strings.HasSuffix(filename, ".json") || strings.HasSuffix(filename, ".jwk") {
+ // Encode in JSON Web Key format.
+ encodedKey, err = json.MarshalIndent(key, "", " ")
+ if err != nil {
+ return fmt.Errorf("unable to encode private key JWK: %s", err)
+ }
+ } else {
+ // Encode in PEM format.
+ pemBlock, err := key.PEMBlock()
+ if err != nil {
+ return fmt.Errorf("unable to encode private key PEM: %s", err)
+ }
+ encodedKey = pem.EncodeToMemory(pemBlock)
+ }
+
+ err = ioutil.WriteFile(filename, encodedKey, os.FileMode(0600))
+ if err != nil {
+ return fmt.Errorf("unable to write private key file %s: %s", filename, err)
+ }
+
+ return nil
+}
+
+// SavePublicKey saves the given public key to the file.
+func SavePublicKey(filename string, key PublicKey) error {
+ var encodedKey []byte
+ var err error
+
+ if strings.HasSuffix(filename, ".json") || strings.HasSuffix(filename, ".jwk") {
+ // Encode in JSON Web Key format.
+ encodedKey, err = json.MarshalIndent(key, "", " ")
+ if err != nil {
+ return fmt.Errorf("unable to encode public key JWK: %s", err)
+ }
+ } else {
+ // Encode in PEM format.
+ pemBlock, err := key.PEMBlock()
+ if err != nil {
+ return fmt.Errorf("unable to encode public key PEM: %s", err)
+ }
+ encodedKey = pem.EncodeToMemory(pemBlock)
+ }
+
+ err = ioutil.WriteFile(filename, encodedKey, os.FileMode(0644))
+ if err != nil {
+ return fmt.Errorf("unable to write public key file %s: %s", filename, err)
+ }
+
+ return nil
+}
+
+// Public Key Set files
+
+type jwkSet struct {
+ Keys []json.RawMessage `json:"keys"`
+}
+
+// LoadKeySetFile loads a key set
+func LoadKeySetFile(filename string) ([]PublicKey, error) {
+ if strings.HasSuffix(filename, ".json") || strings.HasSuffix(filename, ".jwk") {
+ return loadJSONKeySetFile(filename)
+ }
+
+ // Must be a PEM format file
+ return loadPEMKeySetFile(filename)
+}
+
+func loadJSONKeySetRaw(data []byte) ([]json.RawMessage, error) {
+ if len(data) == 0 {
+ // This is okay, just return an empty slice.
+ return []json.RawMessage{}, nil
+ }
+
+ keySet := jwkSet{}
+
+ err := json.Unmarshal(data, &keySet)
+ if err != nil {
+ return nil, fmt.Errorf("unable to decode JSON Web Key Set: %s", err)
+ }
+
+ return keySet.Keys, nil
+}
+
+func loadJSONKeySetFile(filename string) ([]PublicKey, error) {
+ contents, err := readKeyFileBytes(filename)
+ if err != nil && err != ErrKeyFileDoesNotExist {
+ return nil, err
+ }
+
+ return UnmarshalPublicKeyJWKSet(contents)
+}
+
+func loadPEMKeySetFile(filename string) ([]PublicKey, error) {
+ data, err := readKeyFileBytes(filename)
+ if err != nil && err != ErrKeyFileDoesNotExist {
+ return nil, err
+ }
+
+ return UnmarshalPublicKeyPEMBundle(data)
+}
+
+// AddKeySetFile adds a key to a key set
+func AddKeySetFile(filename string, key PublicKey) error {
+ if strings.HasSuffix(filename, ".json") || strings.HasSuffix(filename, ".jwk") {
+ return addKeySetJSONFile(filename, key)
+ }
+
+ // Must be a PEM format file
+ return addKeySetPEMFile(filename, key)
+}
+
+func addKeySetJSONFile(filename string, key PublicKey) error {
+ encodedKey, err := json.Marshal(key)
+ if err != nil {
+ return fmt.Errorf("unable to encode trusted client key: %s", err)
+ }
+
+ contents, err := readKeyFileBytes(filename)
+ if err != nil && err != ErrKeyFileDoesNotExist {
+ return err
+ }
+
+ rawEntries, err := loadJSONKeySetRaw(contents)
+ if err != nil {
+ return err
+ }
+
+ rawEntries = append(rawEntries, json.RawMessage(encodedKey))
+ entriesWrapper := jwkSet{Keys: rawEntries}
+
+ encodedEntries, err := json.MarshalIndent(entriesWrapper, "", " ")
+ if err != nil {
+ return fmt.Errorf("unable to encode trusted client keys: %s", err)
+ }
+
+ err = ioutil.WriteFile(filename, encodedEntries, os.FileMode(0644))
+ if err != nil {
+ return fmt.Errorf("unable to write trusted client keys file %s: %s", filename, err)
+ }
+
+ return nil
+}
+
+func addKeySetPEMFile(filename string, key PublicKey) error {
+ // Encode to PEM, open file for appending, write PEM.
+ file, err := os.OpenFile(filename, os.O_CREATE|os.O_APPEND|os.O_RDWR, os.FileMode(0644))
+ if err != nil {
+ return fmt.Errorf("unable to open trusted client keys file %s: %s", filename, err)
+ }
+ defer file.Close()
+
+ pemBlock, err := key.PEMBlock()
+ if err != nil {
+ return fmt.Errorf("unable to encoded trusted key: %s", err)
+ }
+
+ _, err = file.Write(pem.EncodeToMemory(pemBlock))
+ if err != nil {
+ return fmt.Errorf("unable to write trusted keys file: %s", err)
+ }
+
+ return nil
+}
diff --git a/key_files_test.go b/key_files_test.go
new file mode 100644
index 0000000..57e691f
--- /dev/null
+++ b/key_files_test.go
@@ -0,0 +1,220 @@
+package libtrust
+
+import (
+ "errors"
+ "io/ioutil"
+ "os"
+ "testing"
+)
+
+func makeTempFile(t *testing.T, prefix string) (filename string) {
+ file, err := ioutil.TempFile("", prefix)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ filename = file.Name()
+ file.Close()
+
+ return
+}
+
+func TestKeyFiles(t *testing.T) {
+ key, err := GenerateECP256PrivateKey()
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ testKeyFiles(t, key)
+
+ key, err = GenerateRSA2048PrivateKey()
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ testKeyFiles(t, key)
+}
+
+func testKeyFiles(t *testing.T, key PrivateKey) {
+ var err error
+
+ privateKeyFilename := makeTempFile(t, "private_key")
+ privateKeyFilenamePEM := privateKeyFilename + ".pem"
+ privateKeyFilenameJWK := privateKeyFilename + ".jwk"
+
+ publicKeyFilename := makeTempFile(t, "public_key")
+ publicKeyFilenamePEM := publicKeyFilename + ".pem"
+ publicKeyFilenameJWK := publicKeyFilename + ".jwk"
+
+ if err = SaveKey(privateKeyFilenamePEM, key); err != nil {
+ t.Fatal(err)
+ }
+
+ if err = SaveKey(privateKeyFilenameJWK, key); err != nil {
+ t.Fatal(err)
+ }
+
+ if err = SavePublicKey(publicKeyFilenamePEM, key.PublicKey()); err != nil {
+ t.Fatal(err)
+ }
+
+ if err = SavePublicKey(publicKeyFilenameJWK, key.PublicKey()); err != nil {
+ t.Fatal(err)
+ }
+
+ loadedPEMKey, err := LoadKeyFile(privateKeyFilenamePEM)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ loadedJWKKey, err := LoadKeyFile(privateKeyFilenameJWK)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ loadedPEMPublicKey, err := LoadPublicKeyFile(publicKeyFilenamePEM)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ loadedJWKPublicKey, err := LoadPublicKeyFile(publicKeyFilenameJWK)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ if key.KeyID() != loadedPEMKey.KeyID() {
+ t.Fatal(errors.New("key IDs do not match"))
+ }
+
+ if key.KeyID() != loadedJWKKey.KeyID() {
+ t.Fatal(errors.New("key IDs do not match"))
+ }
+
+ if key.KeyID() != loadedPEMPublicKey.KeyID() {
+ t.Fatal(errors.New("key IDs do not match"))
+ }
+
+ if key.KeyID() != loadedJWKPublicKey.KeyID() {
+ t.Fatal(errors.New("key IDs do not match"))
+ }
+
+ os.Remove(privateKeyFilename)
+ os.Remove(privateKeyFilenamePEM)
+ os.Remove(privateKeyFilenameJWK)
+ os.Remove(publicKeyFilename)
+ os.Remove(publicKeyFilenamePEM)
+ os.Remove(publicKeyFilenameJWK)
+}
+
+func TestTrustedHostKeysFile(t *testing.T) {
+ trustedHostKeysFilename := makeTempFile(t, "trusted_host_keys")
+ trustedHostKeysFilenamePEM := trustedHostKeysFilename + ".pem"
+ trustedHostKeysFilenameJWK := trustedHostKeysFilename + ".json"
+
+ testTrustedHostKeysFile(t, trustedHostKeysFilenamePEM)
+ testTrustedHostKeysFile(t, trustedHostKeysFilenameJWK)
+
+ os.Remove(trustedHostKeysFilename)
+ os.Remove(trustedHostKeysFilenamePEM)
+ os.Remove(trustedHostKeysFilenameJWK)
+}
+
+func testTrustedHostKeysFile(t *testing.T, trustedHostKeysFilename string) {
+ hostAddress1 := "docker.example.com:2376"
+ hostKey1, err := GenerateECP256PrivateKey()
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ hostKey1.AddExtendedField("hosts", []string{hostAddress1})
+ err = AddKeySetFile(trustedHostKeysFilename, hostKey1.PublicKey())
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ trustedHostKeysMapping, err := LoadKeySetFile(trustedHostKeysFilename)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ for addr, hostKey := range trustedHostKeysMapping {
+ t.Logf("Host Address: %d\n", addr)
+ t.Logf("Host Key: %s\n\n", hostKey)
+ }
+
+ hostAddress2 := "192.168.59.103:2376"
+ hostKey2, err := GenerateRSA2048PrivateKey()
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ hostKey2.AddExtendedField("hosts", hostAddress2)
+ err = AddKeySetFile(trustedHostKeysFilename, hostKey2.PublicKey())
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ trustedHostKeysMapping, err = LoadKeySetFile(trustedHostKeysFilename)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ for addr, hostKey := range trustedHostKeysMapping {
+ t.Logf("Host Address: %d\n", addr)
+ t.Logf("Host Key: %s\n\n", hostKey)
+ }
+
+}
+
+func TestTrustedClientKeysFile(t *testing.T) {
+ trustedClientKeysFilename := makeTempFile(t, "trusted_client_keys")
+ trustedClientKeysFilenamePEM := trustedClientKeysFilename + ".pem"
+ trustedClientKeysFilenameJWK := trustedClientKeysFilename + ".json"
+
+ testTrustedClientKeysFile(t, trustedClientKeysFilenamePEM)
+ testTrustedClientKeysFile(t, trustedClientKeysFilenameJWK)
+
+ os.Remove(trustedClientKeysFilename)
+ os.Remove(trustedClientKeysFilenamePEM)
+ os.Remove(trustedClientKeysFilenameJWK)
+}
+
+func testTrustedClientKeysFile(t *testing.T, trustedClientKeysFilename string) {
+ clientKey1, err := GenerateECP256PrivateKey()
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ err = AddKeySetFile(trustedClientKeysFilename, clientKey1.PublicKey())
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ trustedClientKeys, err := LoadKeySetFile(trustedClientKeysFilename)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ for _, clientKey := range trustedClientKeys {
+ t.Logf("Client Key: %s\n", clientKey)
+ }
+
+ clientKey2, err := GenerateRSA2048PrivateKey()
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ err = AddKeySetFile(trustedClientKeysFilename, clientKey2.PublicKey())
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ trustedClientKeys, err = LoadKeySetFile(trustedClientKeysFilename)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ for _, clientKey := range trustedClientKeys {
+ t.Logf("Client Key: %s\n", clientKey)
+ }
+}
diff --git a/key_manager.go b/key_manager.go
new file mode 100644
index 0000000..9a98ae3
--- /dev/null
+++ b/key_manager.go
@@ -0,0 +1,175 @@
+package libtrust
+
+import (
+ "crypto/tls"
+ "crypto/x509"
+ "fmt"
+ "io/ioutil"
+ "net"
+ "os"
+ "path"
+ "sync"
+)
+
+// ClientKeyManager manages client keys on the filesystem
+type ClientKeyManager struct {
+ key PrivateKey
+ clientFile string
+ clientDir string
+
+ clientLock sync.RWMutex
+ clients []PublicKey
+
+ configLock sync.Mutex
+ configs []*tls.Config
+}
+
+// NewClientKeyManager loads a new manager from a set of key files
+// and managed by the given private key.
+func NewClientKeyManager(trustKey PrivateKey, clientFile, clientDir string) (*ClientKeyManager, error) {
+ m := &ClientKeyManager{
+ key: trustKey,
+ clientFile: clientFile,
+ clientDir: clientDir,
+ }
+ if err := m.loadKeys(); err != nil {
+ return nil, err
+ }
+ // TODO Start watching file and directory
+
+ return m, nil
+}
+
+func (c *ClientKeyManager) loadKeys() (err error) {
+ // Load authorized keys file
+ var clients []PublicKey
+ if c.clientFile != "" {
+ clients, err = LoadKeySetFile(c.clientFile)
+ if err != nil {
+ return fmt.Errorf("unable to load authorized keys: %s", err)
+ }
+ }
+
+ // Add clients from authorized keys directory
+ files, err := ioutil.ReadDir(c.clientDir)
+ if err != nil && !os.IsNotExist(err) {
+ return fmt.Errorf("unable to open authorized keys directory: %s", err)
+ }
+ for _, f := range files {
+ if !f.IsDir() {
+ publicKey, err := LoadPublicKeyFile(path.Join(c.clientDir, f.Name()))
+ if err != nil {
+ return fmt.Errorf("unable to load authorized key file: %s", err)
+ }
+ clients = append(clients, publicKey)
+ }
+ }
+
+ c.clientLock.Lock()
+ c.clients = clients
+ c.clientLock.Unlock()
+
+ return nil
+}
+
+// RegisterTLSConfig registers a tls configuration to manager
+// such that any changes to the keys may be reflected in
+// the tls client CA pool
+func (c *ClientKeyManager) RegisterTLSConfig(tlsConfig *tls.Config) error {
+ c.clientLock.RLock()
+ certPool, err := GenerateCACertPool(c.key, c.clients)
+ if err != nil {
+ return fmt.Errorf("CA pool generation error: %s", err)
+ }
+ c.clientLock.RUnlock()
+
+ tlsConfig.ClientCAs = certPool
+
+ c.configLock.Lock()
+ c.configs = append(c.configs, tlsConfig)
+ c.configLock.Unlock()
+
+ return nil
+}
+
+// NewIdentityAuthTLSConfig creates a tls.Config for the server to use for
+// libtrust identity authentication for the domain specified
+func NewIdentityAuthTLSConfig(trustKey PrivateKey, clients *ClientKeyManager, addr string, domain string) (*tls.Config, error) {
+ tlsConfig := newTLSConfig()
+
+ tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
+ if err := clients.RegisterTLSConfig(tlsConfig); err != nil {
+ return nil, err
+ }
+
+ // Generate cert
+ ips, domains, err := parseAddr(addr)
+ if err != nil {
+ return nil, err
+ }
+ // add domain that it expects clients to use
+ domains = append(domains, domain)
+ x509Cert, err := GenerateSelfSignedServerCert(trustKey, domains, ips)
+ if err != nil {
+ return nil, fmt.Errorf("certificate generation error: %s", err)
+ }
+ tlsConfig.Certificates = []tls.Certificate{{
+ Certificate: [][]byte{x509Cert.Raw},
+ PrivateKey: trustKey.CryptoPrivateKey(),
+ Leaf: x509Cert,
+ }}
+
+ return tlsConfig, nil
+}
+
+// NewCertAuthTLSConfig creates a tls.Config for the server to use for
+// certificate authentication
+func NewCertAuthTLSConfig(caPath, certPath, keyPath string) (*tls.Config, error) {
+ tlsConfig := newTLSConfig()
+
+ cert, err := tls.LoadX509KeyPair(certPath, keyPath)
+ if err != nil {
+ return nil, fmt.Errorf("Couldn't load X509 key pair (%s, %s): %s. Key encrypted?", certPath, keyPath, err)
+ }
+ tlsConfig.Certificates = []tls.Certificate{cert}
+
+ // Verify client certificates against a CA?
+ if caPath != "" {
+ certPool := x509.NewCertPool()
+ file, err := ioutil.ReadFile(caPath)
+ if err != nil {
+ return nil, fmt.Errorf("Couldn't read CA certificate: %s", err)
+ }
+ certPool.AppendCertsFromPEM(file)
+
+ tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
+ tlsConfig.ClientCAs = certPool
+ }
+
+ return tlsConfig, nil
+}
+
+func newTLSConfig() *tls.Config {
+ return &tls.Config{
+ NextProtos: []string{"http/1.1"},
+ // Avoid fallback on insecure SSL protocols
+ MinVersion: tls.VersionTLS10,
+ }
+}
+
+// parseAddr parses an address into an array of IPs and domains
+func parseAddr(addr string) ([]net.IP, []string, error) {
+ host, _, err := net.SplitHostPort(addr)
+ if err != nil {
+ return nil, nil, err
+ }
+ var domains []string
+ var ips []net.IP
+ ip := net.ParseIP(host)
+ if ip != nil {
+ ips = []net.IP{ip}
+ } else {
+ domains = []string{host}
+ }
+ return ips, domains, nil
+}
diff --git a/key_test.go b/key_test.go
new file mode 100644
index 0000000..f6c59cc
--- /dev/null
+++ b/key_test.go
@@ -0,0 +1,80 @@
+package libtrust
+
+import (
+ "testing"
+)
+
+type generateFunc func() (PrivateKey, error)
+
+func runGenerateBench(b *testing.B, f generateFunc, name string) {
+ for i := 0; i < b.N; i++ {
+ _, err := f()
+ if err != nil {
+ b.Fatalf("Error generating %s: %s", name, err)
+ }
+ }
+}
+
+func runFingerprintBench(b *testing.B, f generateFunc, name string) {
+ b.StopTimer()
+ // Don't count this relatively slow generation call.
+ key, err := f()
+ if err != nil {
+ b.Fatalf("Error generating %s: %s", name, err)
+ }
+ b.StartTimer()
+
+ for i := 0; i < b.N; i++ {
+ if key.KeyID() == "" {
+ b.Fatalf("Error generating key ID for %s", name)
+ }
+ }
+}
+
+func BenchmarkECP256Generate(b *testing.B) {
+ runGenerateBench(b, GenerateECP256PrivateKey, "P256")
+}
+
+func BenchmarkECP384Generate(b *testing.B) {
+ runGenerateBench(b, GenerateECP384PrivateKey, "P384")
+}
+
+func BenchmarkECP521Generate(b *testing.B) {
+ runGenerateBench(b, GenerateECP521PrivateKey, "P521")
+}
+
+func BenchmarkRSA2048Generate(b *testing.B) {
+ runGenerateBench(b, GenerateRSA2048PrivateKey, "RSA2048")
+}
+
+func BenchmarkRSA3072Generate(b *testing.B) {
+ runGenerateBench(b, GenerateRSA3072PrivateKey, "RSA3072")
+}
+
+func BenchmarkRSA4096Generate(b *testing.B) {
+ runGenerateBench(b, GenerateRSA4096PrivateKey, "RSA4096")
+}
+
+func BenchmarkECP256Fingerprint(b *testing.B) {
+ runFingerprintBench(b, GenerateECP256PrivateKey, "P256")
+}
+
+func BenchmarkECP384Fingerprint(b *testing.B) {
+ runFingerprintBench(b, GenerateECP384PrivateKey, "P384")
+}
+
+func BenchmarkECP521Fingerprint(b *testing.B) {
+ runFingerprintBench(b, GenerateECP521PrivateKey, "P521")
+}
+
+func BenchmarkRSA2048Fingerprint(b *testing.B) {
+ runFingerprintBench(b, GenerateRSA2048PrivateKey, "RSA2048")
+}
+
+func BenchmarkRSA3072Fingerprint(b *testing.B) {
+ runFingerprintBench(b, GenerateRSA3072PrivateKey, "RSA3072")
+}
+
+func BenchmarkRSA4096Fingerprint(b *testing.B) {
+ runFingerprintBench(b, GenerateRSA4096PrivateKey, "RSA4096")
+}
diff --git a/rsa_key.go b/rsa_key.go
new file mode 100644
index 0000000..dac4cac
--- /dev/null
+++ b/rsa_key.go
@@ -0,0 +1,427 @@
+package libtrust
+
+import (
+ "crypto"
+ "crypto/rand"
+ "crypto/rsa"
+ "crypto/x509"
+ "encoding/json"
+ "encoding/pem"
+ "errors"
+ "fmt"
+ "io"
+ "math/big"
+)
+
+/*
+ * RSA DSA PUBLIC KEY
+ */
+
+// rsaPublicKey implements a JWK Public Key using RSA digital signature algorithms.
+type rsaPublicKey struct {
+ *rsa.PublicKey
+ extended map[string]interface{}
+}
+
+func fromRSAPublicKey(cryptoPublicKey *rsa.PublicKey) *rsaPublicKey {
+ return &rsaPublicKey{cryptoPublicKey, map[string]interface{}{}}
+}
+
+// KeyType returns the JWK key type for RSA keys, i.e., "RSA".
+func (k *rsaPublicKey) KeyType() string {
+ return "RSA"
+}
+
+// KeyID returns a distinct identifier which is unique to this Public Key.
+func (k *rsaPublicKey) KeyID() string {
+ return keyIDFromCryptoKey(k)
+}
+
+func (k *rsaPublicKey) String() string {
+ return fmt.Sprintf("RSA Public Key <%s>", k.KeyID())
+}
+
+// Verify verifyies the signature of the data in the io.Reader using this Public Key.
+// The alg parameter should be the name of the JWA digital signature algorithm
+// which was used to produce the signature and should be supported by this
+// public key. Returns a nil error if the signature is valid.
+func (k *rsaPublicKey) Verify(data io.Reader, alg string, signature []byte) error {
+ // Verify the signature of the given date, return non-nil error if valid.
+ sigAlg, err := rsaSignatureAlgorithmByName(alg)
+ if err != nil {
+ return fmt.Errorf("unable to verify Signature: %s", err)
+ }
+
+ hasher := sigAlg.HashID().New()
+ _, err = io.Copy(hasher, data)
+ if err != nil {
+ return fmt.Errorf("error reading data to sign: %s", err)
+ }
+ hash := hasher.Sum(nil)
+
+ err = rsa.VerifyPKCS1v15(k.PublicKey, sigAlg.HashID(), hash, signature)
+ if err != nil {
+ return fmt.Errorf("invalid %s signature: %s", sigAlg.HeaderParam(), err)
+ }
+
+ return nil
+}
+
+// CryptoPublicKey returns the internal object which can be used as a
+// crypto.PublicKey for use with other standard library operations. The type
+// is either *rsa.PublicKey or *ecdsa.PublicKey
+func (k *rsaPublicKey) CryptoPublicKey() crypto.PublicKey {
+ return k.PublicKey
+}
+
+func (k *rsaPublicKey) toMap() map[string]interface{} {
+ jwk := make(map[string]interface{})
+ for k, v := range k.extended {
+ jwk[k] = v
+ }
+ jwk["kty"] = k.KeyType()
+ jwk["kid"] = k.KeyID()
+ jwk["n"] = joseBase64UrlEncode(k.N.Bytes())
+ jwk["e"] = joseBase64UrlEncode(serializeRSAPublicExponentParam(k.E))
+
+ return jwk
+}
+
+// MarshalJSON serializes this Public Key using the JWK JSON serialization format for
+// RSA keys.
+func (k *rsaPublicKey) MarshalJSON() (data []byte, err error) {
+ return json.Marshal(k.toMap())
+}
+
+// PEMBlock serializes this Public Key to DER-encoded PKIX format.
+func (k *rsaPublicKey) PEMBlock() (*pem.Block, error) {
+ derBytes, err := x509.MarshalPKIXPublicKey(k.PublicKey)
+ if err != nil {
+ return nil, fmt.Errorf("unable to serialize RSA PublicKey to DER-encoded PKIX format: %s", err)
+ }
+ k.extended["kid"] = k.KeyID() // For display purposes.
+ return createPemBlock("PUBLIC KEY", derBytes, k.extended)
+}
+
+func (k *rsaPublicKey) AddExtendedField(field string, value interface{}) {
+ k.extended[field] = value
+}
+
+func (k *rsaPublicKey) GetExtendedField(field string) interface{} {
+ v, ok := k.extended[field]
+ if !ok {
+ return nil
+ }
+ return v
+}
+
+func rsaPublicKeyFromMap(jwk map[string]interface{}) (*rsaPublicKey, error) {
+ // JWK key type (kty) has already been determined to be "RSA".
+ // Need to extract 'n', 'e', and 'kid' and check for
+ // consistency.
+
+ // Get the modulus parameter N.
+ nB64Url, err := stringFromMap(jwk, "n")
+ if err != nil {
+ return nil, fmt.Errorf("JWK RSA Public Key modulus: %s", err)
+ }
+
+ n, err := parseRSAModulusParam(nB64Url)
+ if err != nil {
+ return nil, fmt.Errorf("JWK RSA Public Key modulus: %s", err)
+ }
+
+ // Get the public exponent E.
+ eB64Url, err := stringFromMap(jwk, "e")
+ if err != nil {
+ return nil, fmt.Errorf("JWK RSA Public Key exponent: %s", err)
+ }
+
+ e, err := parseRSAPublicExponentParam(eB64Url)
+ if err != nil {
+ return nil, fmt.Errorf("JWK RSA Public Key exponent: %s", err)
+ }
+
+ key := &rsaPublicKey{
+ PublicKey: &rsa.PublicKey{N: n, E: e},
+ }
+
+ // Key ID is optional, but if it exists, it should match the key.
+ _, ok := jwk["kid"]
+ if ok {
+ kid, err := stringFromMap(jwk, "kid")
+ if err != nil {
+ return nil, fmt.Errorf("JWK RSA Public Key ID: %s", err)
+ }
+ if kid != key.KeyID() {
+ return nil, fmt.Errorf("JWK RSA Public Key ID does not match: %s", kid)
+ }
+ }
+
+ if _, ok := jwk["d"]; ok {
+ return nil, fmt.Errorf("JWK RSA Public Key cannot contain private exponent")
+ }
+
+ key.extended = jwk
+
+ return key, nil
+}
+
+/*
+ * RSA DSA PRIVATE KEY
+ */
+
+// rsaPrivateKey implements a JWK Private Key using RSA digital signature algorithms.
+type rsaPrivateKey struct {
+ rsaPublicKey
+ *rsa.PrivateKey
+}
+
+func fromRSAPrivateKey(cryptoPrivateKey *rsa.PrivateKey) *rsaPrivateKey {
+ return &rsaPrivateKey{
+ *fromRSAPublicKey(&cryptoPrivateKey.PublicKey),
+ cryptoPrivateKey,
+ }
+}
+
+// PublicKey returns the Public Key data associated with this Private Key.
+func (k *rsaPrivateKey) PublicKey() PublicKey {
+ return &k.rsaPublicKey
+}
+
+func (k *rsaPrivateKey) String() string {
+ return fmt.Sprintf("RSA Private Key <%s>", k.KeyID())
+}
+
+// Sign signs the data read from the io.Reader using a signature algorithm supported
+// by the RSA private key. If the specified hashing algorithm is supported by
+// this key, that hash function is used to generate the signature otherwise the
+// the default hashing algorithm for this key is used. Returns the signature
+// and the name of the JWK signature algorithm used, e.g., "RS256", "RS384",
+// "RS512".
+func (k *rsaPrivateKey) Sign(data io.Reader, hashID crypto.Hash) (signature []byte, alg string, err error) {
+ // Generate a signature of the data using the internal alg.
+ sigAlg := rsaPKCS1v15SignatureAlgorithmForHashID(hashID)
+ hasher := sigAlg.HashID().New()
+
+ _, err = io.Copy(hasher, data)
+ if err != nil {
+ return nil, "", fmt.Errorf("error reading data to sign: %s", err)
+ }
+ hash := hasher.Sum(nil)
+
+ signature, err = rsa.SignPKCS1v15(rand.Reader, k.PrivateKey, sigAlg.HashID(), hash)
+ if err != nil {
+ return nil, "", fmt.Errorf("error producing signature: %s", err)
+ }
+
+ alg = sigAlg.HeaderParam()
+
+ return
+}
+
+// CryptoPrivateKey returns the internal object which can be used as a
+// crypto.PublicKey for use with other standard library operations. The type
+// is either *rsa.PublicKey or *ecdsa.PublicKey
+func (k *rsaPrivateKey) CryptoPrivateKey() crypto.PrivateKey {
+ return k.PrivateKey
+}
+
+func (k *rsaPrivateKey) toMap() map[string]interface{} {
+ k.Precompute() // Make sure the precomputed values are stored.
+ jwk := k.rsaPublicKey.toMap()
+
+ jwk["d"] = joseBase64UrlEncode(k.D.Bytes())
+ jwk["p"] = joseBase64UrlEncode(k.Primes[0].Bytes())
+ jwk["q"] = joseBase64UrlEncode(k.Primes[1].Bytes())
+ jwk["dp"] = joseBase64UrlEncode(k.Precomputed.Dp.Bytes())
+ jwk["dq"] = joseBase64UrlEncode(k.Precomputed.Dq.Bytes())
+ jwk["qi"] = joseBase64UrlEncode(k.Precomputed.Qinv.Bytes())
+
+ otherPrimes := k.Primes[2:]
+
+ if len(otherPrimes) > 0 {
+ otherPrimesInfo := make([]interface{}, len(otherPrimes))
+ for i, r := range otherPrimes {
+ otherPrimeInfo := make(map[string]string, 3)
+ otherPrimeInfo["r"] = joseBase64UrlEncode(r.Bytes())
+ crtVal := k.Precomputed.CRTValues[i]
+ otherPrimeInfo["d"] = joseBase64UrlEncode(crtVal.Exp.Bytes())
+ otherPrimeInfo["t"] = joseBase64UrlEncode(crtVal.Coeff.Bytes())
+ otherPrimesInfo[i] = otherPrimeInfo
+ }
+ jwk["oth"] = otherPrimesInfo
+ }
+
+ return jwk
+}
+
+// MarshalJSON serializes this Private Key using the JWK JSON serialization format for
+// RSA keys.
+func (k *rsaPrivateKey) MarshalJSON() (data []byte, err error) {
+ return json.Marshal(k.toMap())
+}
+
+// PEMBlock serializes this Private Key to DER-encoded PKIX format.
+func (k *rsaPrivateKey) PEMBlock() (*pem.Block, error) {
+ derBytes := x509.MarshalPKCS1PrivateKey(k.PrivateKey)
+ k.extended["keyID"] = k.KeyID() // For display purposes.
+ return createPemBlock("RSA PRIVATE KEY", derBytes, k.extended)
+}
+
+func rsaPrivateKeyFromMap(jwk map[string]interface{}) (*rsaPrivateKey, error) {
+ // The JWA spec for RSA Private Keys (draft rfc section 5.3.2) states that
+ // only the private key exponent 'd' is REQUIRED, the others are just for
+ // signature/decryption optimizations and SHOULD be included when the JWK
+ // is produced. We MAY choose to accept a JWK which only includes 'd', but
+ // we're going to go ahead and not choose to accept it without the extra
+ // fields. Only the 'oth' field will be optional (for multi-prime keys).
+ privateExponent, err := parseRSAPrivateKeyParamFromMap(jwk, "d")
+ if err != nil {
+ return nil, fmt.Errorf("JWK RSA Private Key exponent: %s", err)
+ }
+ firstPrimeFactor, err := parseRSAPrivateKeyParamFromMap(jwk, "p")
+ if err != nil {
+ return nil, fmt.Errorf("JWK RSA Private Key prime factor: %s", err)
+ }
+ secondPrimeFactor, err := parseRSAPrivateKeyParamFromMap(jwk, "q")
+ if err != nil {
+ return nil, fmt.Errorf("JWK RSA Private Key prime factor: %s", err)
+ }
+ firstFactorCRT, err := parseRSAPrivateKeyParamFromMap(jwk, "dp")
+ if err != nil {
+ return nil, fmt.Errorf("JWK RSA Private Key CRT exponent: %s", err)
+ }
+ secondFactorCRT, err := parseRSAPrivateKeyParamFromMap(jwk, "dq")
+ if err != nil {
+ return nil, fmt.Errorf("JWK RSA Private Key CRT exponent: %s", err)
+ }
+ crtCoeff, err := parseRSAPrivateKeyParamFromMap(jwk, "qi")
+ if err != nil {
+ return nil, fmt.Errorf("JWK RSA Private Key CRT coefficient: %s", err)
+ }
+
+ var oth interface{}
+ if _, ok := jwk["oth"]; ok {
+ oth = jwk["oth"]
+ delete(jwk, "oth")
+ }
+
+ // JWK key type (kty) has already been determined to be "RSA".
+ // Need to extract the public key information, then extract the private
+ // key values.
+ publicKey, err := rsaPublicKeyFromMap(jwk)
+ if err != nil {
+ return nil, err
+ }
+
+ privateKey := &rsa.PrivateKey{
+ PublicKey: *publicKey.PublicKey,
+ D: privateExponent,
+ Primes: []*big.Int{firstPrimeFactor, secondPrimeFactor},
+ Precomputed: rsa.PrecomputedValues{
+ Dp: firstFactorCRT,
+ Dq: secondFactorCRT,
+ Qinv: crtCoeff,
+ },
+ }
+
+ if oth != nil {
+ // Should be an array of more JSON objects.
+ otherPrimesInfo, ok := oth.([]interface{})
+ if !ok {
+ return nil, errors.New("JWK RSA Private Key: Invalid other primes info: must be an array")
+ }
+ numOtherPrimeFactors := len(otherPrimesInfo)
+ if numOtherPrimeFactors == 0 {
+ return nil, errors.New("JWK RSA Privake Key: Invalid other primes info: must be absent or non-empty")
+ }
+ otherPrimeFactors := make([]*big.Int, numOtherPrimeFactors)
+ productOfPrimes := new(big.Int).Mul(firstPrimeFactor, secondPrimeFactor)
+ crtValues := make([]rsa.CRTValue, numOtherPrimeFactors)
+
+ for i, val := range otherPrimesInfo {
+ otherPrimeinfo, ok := val.(map[string]interface{})
+ if !ok {
+ return nil, errors.New("JWK RSA Private Key: Invalid other prime info: must be a JSON object")
+ }
+
+ otherPrimeFactor, err := parseRSAPrivateKeyParamFromMap(otherPrimeinfo, "r")
+ if err != nil {
+ return nil, fmt.Errorf("JWK RSA Private Key prime factor: %s", err)
+ }
+ otherFactorCRT, err := parseRSAPrivateKeyParamFromMap(otherPrimeinfo, "d")
+ if err != nil {
+ return nil, fmt.Errorf("JWK RSA Private Key CRT exponent: %s", err)
+ }
+ otherCrtCoeff, err := parseRSAPrivateKeyParamFromMap(otherPrimeinfo, "t")
+ if err != nil {
+ return nil, fmt.Errorf("JWK RSA Private Key CRT coefficient: %s", err)
+ }
+
+ crtValue := crtValues[i]
+ crtValue.Exp = otherFactorCRT
+ crtValue.Coeff = otherCrtCoeff
+ crtValue.R = productOfPrimes
+ otherPrimeFactors[i] = otherPrimeFactor
+ productOfPrimes = new(big.Int).Mul(productOfPrimes, otherPrimeFactor)
+ }
+
+ privateKey.Primes = append(privateKey.Primes, otherPrimeFactors...)
+ privateKey.Precomputed.CRTValues = crtValues
+ }
+
+ key := &rsaPrivateKey{
+ rsaPublicKey: *publicKey,
+ PrivateKey: privateKey,
+ }
+
+ return key, nil
+}
+
+/*
+ * Key Generation Functions.
+ */
+
+func generateRSAPrivateKey(bits int) (k *rsaPrivateKey, err error) {
+ k = new(rsaPrivateKey)
+ k.PrivateKey, err = rsa.GenerateKey(rand.Reader, bits)
+ if err != nil {
+ return nil, err
+ }
+
+ k.rsaPublicKey.PublicKey = &k.PrivateKey.PublicKey
+ k.extended = make(map[string]interface{})
+
+ return
+}
+
+// GenerateRSA2048PrivateKey generates a key pair using 2048-bit RSA.
+func GenerateRSA2048PrivateKey() (PrivateKey, error) {
+ k, err := generateRSAPrivateKey(2048)
+ if err != nil {
+ return nil, fmt.Errorf("error generating RSA 2048-bit key: %s", err)
+ }
+
+ return k, nil
+}
+
+// GenerateRSA3072PrivateKey generates a key pair using 3072-bit RSA.
+func GenerateRSA3072PrivateKey() (PrivateKey, error) {
+ k, err := generateRSAPrivateKey(3072)
+ if err != nil {
+ return nil, fmt.Errorf("error generating RSA 3072-bit key: %s", err)
+ }
+
+ return k, nil
+}
+
+// GenerateRSA4096PrivateKey generates a key pair using 4096-bit RSA.
+func GenerateRSA4096PrivateKey() (PrivateKey, error) {
+ k, err := generateRSAPrivateKey(4096)
+ if err != nil {
+ return nil, fmt.Errorf("error generating RSA 4096-bit key: %s", err)
+ }
+
+ return k, nil
+}
diff --git a/rsa_key_test.go b/rsa_key_test.go
new file mode 100644
index 0000000..5ec7707
--- /dev/null
+++ b/rsa_key_test.go
@@ -0,0 +1,157 @@
+package libtrust
+
+import (
+ "bytes"
+ "encoding/json"
+ "log"
+ "testing"
+)
+
+var rsaKeys []PrivateKey
+
+func init() {
+ var err error
+ rsaKeys, err = generateRSATestKeys()
+ if err != nil {
+ log.Fatal(err)
+ }
+}
+
+func generateRSATestKeys() (keys []PrivateKey, err error) {
+ log.Println("Generating RSA 2048-bit Test Key")
+ rsa2048Key, err := GenerateRSA2048PrivateKey()
+ if err != nil {
+ return
+ }
+
+ log.Println("Generating RSA 3072-bit Test Key")
+ rsa3072Key, err := GenerateRSA3072PrivateKey()
+ if err != nil {
+ return
+ }
+
+ log.Println("Generating RSA 4096-bit Test Key")
+ rsa4096Key, err := GenerateRSA4096PrivateKey()
+ if err != nil {
+ return
+ }
+
+ log.Println("Done generating RSA Test Keys!")
+ keys = []PrivateKey{rsa2048Key, rsa3072Key, rsa4096Key}
+
+ return
+}
+
+func TestRSAKeys(t *testing.T) {
+ for _, rsaKey := range rsaKeys {
+ if rsaKey.KeyType() != "RSA" {
+ t.Fatalf("key type must be %q, instead got %q", "RSA", rsaKey.KeyType())
+ }
+ }
+}
+
+func TestRSASignVerify(t *testing.T) {
+ message := "Hello, World!"
+ data := bytes.NewReader([]byte(message))
+
+ sigAlgs := []*signatureAlgorithm{rs256, rs384, rs512}
+
+ for i, rsaKey := range rsaKeys {
+ sigAlg := sigAlgs[i]
+
+ t.Logf("%s signature of %q with kid: %s\n", sigAlg.HeaderParam(), message, rsaKey.KeyID())
+
+ data.Seek(0, 0) // Reset the byte reader
+
+ // Sign
+ sig, alg, err := rsaKey.Sign(data, sigAlg.HashID())
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ data.Seek(0, 0) // Reset the byte reader
+
+ // Verify
+ err = rsaKey.Verify(data, alg, sig)
+ if err != nil {
+ t.Fatal(err)
+ }
+ }
+}
+
+func TestMarshalUnmarshalRSAKeys(t *testing.T) {
+ data := bytes.NewReader([]byte("This is a test. I repeat: this is only a test."))
+ sigAlgs := []*signatureAlgorithm{rs256, rs384, rs512}
+
+ for i, rsaKey := range rsaKeys {
+ sigAlg := sigAlgs[i]
+ privateJWKJSON, err := json.MarshalIndent(rsaKey, "", " ")
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ publicJWKJSON, err := json.MarshalIndent(rsaKey.PublicKey(), "", " ")
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ t.Logf("JWK Private Key: %s", string(privateJWKJSON))
+ t.Logf("JWK Public Key: %s", string(publicJWKJSON))
+
+ privKey2, err := UnmarshalPrivateKeyJWK(privateJWKJSON)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ pubKey2, err := UnmarshalPublicKeyJWK(publicJWKJSON)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ // Ensure we can sign/verify a message with the unmarshalled keys.
+ data.Seek(0, 0) // Reset the byte reader
+ signature, alg, err := privKey2.Sign(data, sigAlg.HashID())
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ data.Seek(0, 0) // Reset the byte reader
+ err = pubKey2.Verify(data, alg, signature)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ // It's a good idea to validate the Private Key to make sure our
+ // (un)marshal process didn't corrupt the extra parameters.
+ k := privKey2.(*rsaPrivateKey)
+ err = k.PrivateKey.Validate()
+ if err != nil {
+ t.Fatal(err)
+ }
+ }
+}
+
+func TestFromCryptoRSAKeys(t *testing.T) {
+ for _, rsaKey := range rsaKeys {
+ cryptoPrivateKey := rsaKey.CryptoPrivateKey()
+ cryptoPublicKey := rsaKey.CryptoPublicKey()
+
+ pubKey, err := FromCryptoPublicKey(cryptoPublicKey)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ if pubKey.KeyID() != rsaKey.KeyID() {
+ t.Fatal("public key key ID mismatch")
+ }
+
+ privKey, err := FromCryptoPrivateKey(cryptoPrivateKey)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ if privKey.KeyID() != rsaKey.KeyID() {
+ t.Fatal("public key key ID mismatch")
+ }
+ }
+}
diff --git a/testutil/certificates.go b/testutil/certificates.go
new file mode 100644
index 0000000..89debf6
--- /dev/null
+++ b/testutil/certificates.go
@@ -0,0 +1,94 @@
+package testutil
+
+import (
+ "crypto"
+ "crypto/rand"
+ "crypto/x509"
+ "crypto/x509/pkix"
+ "math/big"
+ "time"
+)
+
+// GenerateTrustCA generates a new certificate authority for testing.
+func GenerateTrustCA(pub crypto.PublicKey, priv crypto.PrivateKey) (*x509.Certificate, error) {
+ cert := &x509.Certificate{
+ SerialNumber: big.NewInt(0),
+ Subject: pkix.Name{
+ CommonName: "CA Root",
+ },
+ NotBefore: time.Now().Add(-time.Second),
+ NotAfter: time.Now().Add(time.Hour),
+ IsCA: true,
+ KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign,
+ BasicConstraintsValid: true,
+ }
+
+ certDER, err := x509.CreateCertificate(rand.Reader, cert, cert, pub, priv)
+ if err != nil {
+ return nil, err
+ }
+
+ cert, err = x509.ParseCertificate(certDER)
+ if err != nil {
+ return nil, err
+ }
+
+ return cert, nil
+}
+
+// GenerateIntermediate generates an intermediate certificate for testing using
+// the parent certificate (likely a CA) and the provided keys.
+func GenerateIntermediate(key crypto.PublicKey, parentKey crypto.PrivateKey, parent *x509.Certificate) (*x509.Certificate, error) {
+ cert := &x509.Certificate{
+ SerialNumber: big.NewInt(0),
+ Subject: pkix.Name{
+ CommonName: "Intermediate",
+ },
+ NotBefore: time.Now().Add(-time.Second),
+ NotAfter: time.Now().Add(time.Hour),
+ IsCA: true,
+ KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign,
+ BasicConstraintsValid: true,
+ }
+
+ certDER, err := x509.CreateCertificate(rand.Reader, cert, parent, key, parentKey)
+ if err != nil {
+ return nil, err
+ }
+
+ cert, err = x509.ParseCertificate(certDER)
+ if err != nil {
+ return nil, err
+ }
+
+ return cert, nil
+}
+
+// GenerateTrustCert generates a new trust certificate for testing. Unlike the
+// intermediate certificates, this certificate should be used for signature
+// only, not creating certificates.
+func GenerateTrustCert(key crypto.PublicKey, parentKey crypto.PrivateKey, parent *x509.Certificate) (*x509.Certificate, error) {
+ cert := &x509.Certificate{
+ SerialNumber: big.NewInt(0),
+ Subject: pkix.Name{
+ CommonName: "Trust Cert",
+ },
+ NotBefore: time.Now().Add(-time.Second),
+ NotAfter: time.Now().Add(time.Hour),
+ IsCA: true,
+ KeyUsage: x509.KeyUsageDigitalSignature,
+ BasicConstraintsValid: true,
+ }
+
+ certDER, err := x509.CreateCertificate(rand.Reader, cert, parent, key, parentKey)
+ if err != nil {
+ return nil, err
+ }
+
+ cert, err = x509.ParseCertificate(certDER)
+ if err != nil {
+ return nil, err
+ }
+
+ return cert, nil
+}
diff --git a/tlsdemo/README.md b/tlsdemo/README.md
new file mode 100644
index 0000000..24124db
--- /dev/null
+++ b/tlsdemo/README.md
@@ -0,0 +1,50 @@
+## Libtrust TLS Config Demo
+
+This program generates key pairs and trust files for a TLS client and server.
+
+To generate the keys, run:
+
+```
+$ go run genkeys.go
+```
+
+The generated files are:
+
+```
+$ ls -l client_data/ server_data/
+client_data/:
+total 24
+-rw------- 1 jlhawn staff 281 Aug 8 16:21 private_key.json
+-rw-r--r-- 1 jlhawn staff 225 Aug 8 16:21 public_key.json
+-rw-r--r-- 1 jlhawn staff 275 Aug 8 16:21 trusted_hosts.json
+
+server_data/:
+total 24
+-rw-r--r-- 1 jlhawn staff 348 Aug 8 16:21 trusted_clients.json
+-rw------- 1 jlhawn staff 281 Aug 8 16:21 private_key.json
+-rw-r--r-- 1 jlhawn staff 225 Aug 8 16:21 public_key.json
+```
+
+The private key and public key for the client and server are stored in `private_key.json` and `public_key.json`, respectively, and in their respective directories. They are represented as JSON Web Keys: JSON objects which represent either an ECDSA or RSA private key. The host keys trusted by the client are stored in `trusted_hosts.json` and contain a mapping of an internet address, `<HOSTNAME_OR_IP>:<PORT>`, to a JSON Web Key which is a JSON object representing either an ECDSA or RSA public key of the trusted server. The client keys trusted by the server are stored in `trusted_clients.json` and contain an array of JSON objects which contain a comment field which can be used describe the key and a JSON Web Key which is a JSON object representing either an ECDSA or RSA public key of the trusted client.
+
+To start the server, run:
+
+```
+$ go run server.go
+```
+
+This starts an HTTPS server which listens on `localhost:8888`. The server configures itself with a certificate which is valid for both `localhost` and `127.0.0.1` and uses the key from `server_data/private_key.json`. It accepts connections from clients which present a certificate for a key that it is configured to trust from the `trusted_clients.json` file and returns a simple 'hello' message.
+
+To make a request using the client, run:
+
+```
+$ go run client.go
+```
+
+This command creates an HTTPS client which makes a GET request to `https://localhost:8888`. The client configures itself with a certificate using the key from `client_data/private_key.json`. It only connects to a server which presents a certificate signed by the key specified for the `localhost:8888` address from `client_data/trusted_hosts.json` and made to be used for the `localhost` hostname. If the connection succeeds, it prints the response from the server.
+
+The file `gencert.go` can be used to generate PEM encoded version of the client key and certificate. If you save them to `key.pem` and `cert.pem` respectively, you can use them with `curl` to test out the server (if it is still running).
+
+```
+curl --cert cert.pem --key key.pem -k https://localhost:8888
+```
diff --git a/tlsdemo/client.go b/tlsdemo/client.go
new file mode 100644
index 0000000..8d35600
--- /dev/null
+++ b/tlsdemo/client.go
@@ -0,0 +1,89 @@
+package main
+
+import (
+ "crypto/tls"
+ "fmt"
+ "io/ioutil"
+ "log"
+ "net"
+ "net/http"
+
+ "github.com/containers/libtrust"
+)
+
+var (
+ serverAddress = "localhost:8888"
+ privateKeyFilename = "client_data/private_key.pem"
+ trustedHostsFilename = "client_data/trusted_hosts.pem"
+)
+
+func main() {
+ // Load Client Key.
+ clientKey, err := libtrust.LoadKeyFile(privateKeyFilename)
+ if err != nil {
+ log.Fatal(err)
+ }
+
+ // Generate Client Certificate.
+ selfSignedClientCert, err := libtrust.GenerateSelfSignedClientCert(clientKey)
+ if err != nil {
+ log.Fatal(err)
+ }
+
+ // Load trusted host keys.
+ hostKeys, err := libtrust.LoadKeySetFile(trustedHostsFilename)
+ if err != nil {
+ log.Fatal(err)
+ }
+
+ // Ensure the host we want to connect to is trusted!
+ host, _, err := net.SplitHostPort(serverAddress)
+ if err != nil {
+ log.Fatal(err)
+ }
+ serverKeys, err := libtrust.FilterByHosts(hostKeys, host, false)
+ if err != nil {
+ log.Fatalf("%q is not a known and trusted host", host)
+ }
+
+ // Generate a CA pool with the trusted host's key.
+ caPool, err := libtrust.GenerateCACertPool(clientKey, serverKeys)
+ if err != nil {
+ log.Fatal(err)
+ }
+
+ // Create HTTP Client.
+ client := &http.Client{
+ Transport: &http.Transport{
+ TLSClientConfig: &tls.Config{
+ Certificates: []tls.Certificate{
+ tls.Certificate{
+ Certificate: [][]byte{selfSignedClientCert.Raw},
+ PrivateKey: clientKey.CryptoPrivateKey(),
+ Leaf: selfSignedClientCert,
+ },
+ },
+ RootCAs: caPool,
+ },
+ },
+ }
+
+ var makeRequest = func(url string) {
+ resp, err := client.Get(url)
+ if err != nil {
+ log.Fatal(err)
+ }
+ defer resp.Body.Close()
+
+ body, err := ioutil.ReadAll(resp.Body)
+ if err != nil {
+ log.Fatal(err)
+ }
+
+ log.Println(resp.Status)
+ log.Println(string(body))
+ }
+
+ // Make the request to the trusted server!
+ makeRequest(fmt.Sprintf("https://%s", serverAddress))
+}
diff --git a/tlsdemo/gencert.go b/tlsdemo/gencert.go
new file mode 100644
index 0000000..3638f68
--- /dev/null
+++ b/tlsdemo/gencert.go
@@ -0,0 +1,62 @@
+package main
+
+import (
+ "encoding/pem"
+ "fmt"
+ "log"
+ "net"
+
+ "github.com/containers/libtrust"
+)
+
+var (
+ serverAddress = "localhost:8888"
+ clientPrivateKeyFilename = "client_data/private_key.pem"
+ trustedHostsFilename = "client_data/trusted_hosts.pem"
+)
+
+func main() {
+ key, err := libtrust.LoadKeyFile(clientPrivateKeyFilename)
+ if err != nil {
+ log.Fatal(err)
+ }
+
+ keyPEMBlock, err := key.PEMBlock()
+ if err != nil {
+ log.Fatal(err)
+ }
+
+ encodedPrivKey := pem.EncodeToMemory(keyPEMBlock)
+ fmt.Printf("Client Key:\n\n%s\n", string(encodedPrivKey))
+
+ cert, err := libtrust.GenerateSelfSignedClientCert(key)
+ if err != nil {
+ log.Fatal(err)
+ }
+
+ encodedCert := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: cert.Raw})
+ fmt.Printf("Client Cert:\n\n%s\n", string(encodedCert))
+
+ trustedServerKeys, err := libtrust.LoadKeySetFile(trustedHostsFilename)
+ if err != nil {
+ log.Fatal(err)
+ }
+
+ hostname, _, err := net.SplitHostPort(serverAddress)
+ if err != nil {
+ log.Fatal(err)
+ }
+
+ trustedServerKeys, err = libtrust.FilterByHosts(trustedServerKeys, hostname, false)
+ if err != nil {
+ log.Fatal(err)
+ }
+
+ caCert, err := libtrust.GenerateCACert(key, trustedServerKeys[0])
+ if err != nil {
+ log.Fatal(err)
+ }
+
+ encodedCert = pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: caCert.Raw})
+ fmt.Printf("CA Cert:\n\n%s\n", string(encodedCert))
+}
diff --git a/tlsdemo/genkeys.go b/tlsdemo/genkeys.go
new file mode 100644
index 0000000..4c7a7aa
--- /dev/null
+++ b/tlsdemo/genkeys.go
@@ -0,0 +1,61 @@
+package main
+
+import (
+ "log"
+
+ "github.com/containers/libtrust"
+)
+
+func main() {
+ // Generate client key.
+ clientKey, err := libtrust.GenerateECP256PrivateKey()
+ if err != nil {
+ log.Fatal(err)
+ }
+
+ // Add a comment for the client key.
+ clientKey.AddExtendedField("comment", "TLS Demo Client")
+
+ // Save the client key, public and private versions.
+ err = libtrust.SaveKey("client_data/private_key.pem", clientKey)
+ if err != nil {
+ log.Fatal(err)
+ }
+
+ err = libtrust.SavePublicKey("client_data/public_key.pem", clientKey.PublicKey())
+ if err != nil {
+ log.Fatal(err)
+ }
+
+ // Generate server key.
+ serverKey, err := libtrust.GenerateECP256PrivateKey()
+ if err != nil {
+ log.Fatal(err)
+ }
+
+ // Set the list of addresses to use for the server.
+ serverKey.AddExtendedField("hosts", []string{"localhost", "docker.example.com"})
+
+ // Save the server key, public and private versions.
+ err = libtrust.SaveKey("server_data/private_key.pem", serverKey)
+ if err != nil {
+ log.Fatal(err)
+ }
+
+ err = libtrust.SavePublicKey("server_data/public_key.pem", serverKey.PublicKey())
+ if err != nil {
+ log.Fatal(err)
+ }
+
+ // Generate Authorized Keys file for server.
+ err = libtrust.AddKeySetFile("server_data/trusted_clients.pem", clientKey.PublicKey())
+ if err != nil {
+ log.Fatal(err)
+ }
+
+ // Generate Known Host Keys file for client.
+ err = libtrust.AddKeySetFile("client_data/trusted_hosts.pem", serverKey.PublicKey())
+ if err != nil {
+ log.Fatal(err)
+ }
+}
diff --git a/tlsdemo/server.go b/tlsdemo/server.go
new file mode 100644
index 0000000..d4b8ae4
--- /dev/null
+++ b/tlsdemo/server.go
@@ -0,0 +1,80 @@
+package main
+
+import (
+ "crypto/tls"
+ "fmt"
+ "html"
+ "log"
+ "net"
+ "net/http"
+
+ "github.com/containers/libtrust"
+)
+
+var (
+ serverAddress = "localhost:8888"
+ privateKeyFilename = "server_data/private_key.pem"
+ authorizedClientsFilename = "server_data/trusted_clients.pem"
+)
+
+func requestHandler(w http.ResponseWriter, r *http.Request) {
+ clientCert := r.TLS.PeerCertificates[0]
+ keyID := clientCert.Subject.CommonName
+ log.Printf("Request from keyID: %s\n", keyID)
+ fmt.Fprintf(w, "Hello, client! I'm a server! And you are %T: %s.\n", clientCert.PublicKey, html.EscapeString(keyID))
+}
+
+func main() {
+ // Load server key.
+ serverKey, err := libtrust.LoadKeyFile(privateKeyFilename)
+ if err != nil {
+ log.Fatal(err)
+ }
+
+ // Generate server certificate.
+ selfSignedServerCert, err := libtrust.GenerateSelfSignedServerCert(
+ serverKey, []string{"localhost"}, []net.IP{net.ParseIP("127.0.0.1")},
+ )
+ if err != nil {
+ log.Fatal(err)
+ }
+
+ // Load authorized client keys.
+ authorizedClients, err := libtrust.LoadKeySetFile(authorizedClientsFilename)
+ if err != nil {
+ log.Fatal(err)
+ }
+
+ // Create CA pool using trusted client keys.
+ caPool, err := libtrust.GenerateCACertPool(serverKey, authorizedClients)
+ if err != nil {
+ log.Fatal(err)
+ }
+
+ // Create TLS config, requiring client certificates.
+ tlsConfig := &tls.Config{
+ Certificates: []tls.Certificate{
+ tls.Certificate{
+ Certificate: [][]byte{selfSignedServerCert.Raw},
+ PrivateKey: serverKey.CryptoPrivateKey(),
+ Leaf: selfSignedServerCert,
+ },
+ },
+ ClientAuth: tls.RequireAndVerifyClientCert,
+ ClientCAs: caPool,
+ }
+
+ // Create HTTP server with simple request handler.
+ server := &http.Server{
+ Addr: serverAddress,
+ Handler: http.HandlerFunc(requestHandler),
+ }
+
+ // Listen and server HTTPS using the libtrust TLS config.
+ listener, err := net.Listen("tcp", server.Addr)
+ if err != nil {
+ log.Fatal(err)
+ }
+ tlsListener := tls.NewListener(listener, tlsConfig)
+ server.Serve(tlsListener)
+}
diff --git a/trustgraph/graph.go b/trustgraph/graph.go
new file mode 100644
index 0000000..b49be77
--- /dev/null
+++ b/trustgraph/graph.go
@@ -0,0 +1,50 @@
+package trustgraph
+
+import "github.com/containers/libtrust"
+
+// TrustGraph represents a graph of authorization mapping
+// public keys to nodes and grants between nodes.
+type TrustGraph interface {
+ // Verifies that the given public key is allowed to perform
+ // the given action on the given node according to the trust
+ // graph.
+ Verify(libtrust.PublicKey, string, uint16) (bool, error)
+
+ // GetGrants returns an array of all grant chains which are used to
+ // allow the requested permission.
+ GetGrants(libtrust.PublicKey, string, uint16) ([][]*Grant, error)
+}
+
+// Grant represents a transfer of permission from one part of the
+// trust graph to another. This is the only way to delegate
+// permission between two different sub trees in the graph.
+type Grant struct {
+ // Subject is the namespace being granted
+ Subject string
+
+ // Permissions is a bit map of permissions
+ Permission uint16
+
+ // Grantee represents the node being granted
+ // a permission scope. The grantee can be
+ // either a namespace item or a key id where namespace
+ // items will always start with a '/'.
+ Grantee string
+
+ // statement represents the statement used to create
+ // this object.
+ statement *Statement
+}
+
+// Permissions
+// Read node 0x01 (can read node, no sub nodes)
+// Write node 0x02 (can write to node object, cannot create subnodes)
+// Read subtree 0x04 (delegates read to each sub node)
+// Write subtree 0x08 (delegates write to each sub node, included create on the subject)
+//
+// Permission shortcuts
+// ReadItem = 0x01
+// WriteItem = 0x03
+// ReadAccess = 0x07
+// WriteAccess = 0x0F
+// Delegate = 0x0F
diff --git a/trustgraph/memory_graph.go b/trustgraph/memory_graph.go
new file mode 100644
index 0000000..9ba8af7
--- /dev/null
+++ b/trustgraph/memory_graph.go
@@ -0,0 +1,133 @@
+package trustgraph
+
+import (
+ "strings"
+
+ "github.com/containers/libtrust"
+)
+
+type grantNode struct {
+ grants []*Grant
+ children map[string]*grantNode
+}
+
+type memoryGraph struct {
+ roots map[string]*grantNode
+}
+
+func newGrantNode() *grantNode {
+ return &grantNode{
+ grants: []*Grant{},
+ children: map[string]*grantNode{},
+ }
+}
+
+// NewMemoryGraph returns a new in memory trust graph created from
+// a static list of grants. This graph is immutable after creation
+// and any alterations should create a new instance.
+func NewMemoryGraph(grants []*Grant) TrustGraph {
+ roots := map[string]*grantNode{}
+ for _, grant := range grants {
+ parts := strings.Split(grant.Grantee, "/")
+ nodes := roots
+ var node *grantNode
+ var nodeOk bool
+ for _, part := range parts {
+ node, nodeOk = nodes[part]
+ if !nodeOk {
+ node = newGrantNode()
+ nodes[part] = node
+ }
+ if part != "" {
+ node.grants = append(node.grants, grant)
+ }
+ nodes = node.children
+ }
+ }
+ return &memoryGraph{roots}
+}
+
+func (g *memoryGraph) getGrants(name string) []*Grant {
+ nameParts := strings.Split(name, "/")
+ nodes := g.roots
+ var node *grantNode
+ var nodeOk bool
+ for _, part := range nameParts {
+ node, nodeOk = nodes[part]
+ if !nodeOk {
+ return nil
+ }
+ nodes = node.children
+ }
+ return node.grants
+}
+
+func isSubName(name, sub string) bool {
+ if strings.HasPrefix(name, sub) {
+ if len(name) == len(sub) || name[len(sub)] == '/' {
+ return true
+ }
+ }
+ return false
+}
+
+type walkFunc func(*Grant, []*Grant) bool
+
+func foundWalkFunc(*Grant, []*Grant) bool {
+ return true
+}
+
+func (g *memoryGraph) walkGrants(start, target string, permission uint16, f walkFunc, chain []*Grant, visited map[*Grant]bool, collect bool) bool {
+ if visited == nil {
+ visited = map[*Grant]bool{}
+ }
+ grants := g.getGrants(start)
+ subGrants := make([]*Grant, 0, len(grants))
+ for _, grant := range grants {
+ if visited[grant] {
+ continue
+ }
+ visited[grant] = true
+ if grant.Permission&permission == permission {
+ if isSubName(target, grant.Subject) {
+ if f(grant, chain) {
+ return true
+ }
+ } else {
+ subGrants = append(subGrants, grant)
+ }
+ }
+ }
+ for _, grant := range subGrants {
+ var chainCopy []*Grant
+ if collect {
+ chainCopy = make([]*Grant, len(chain)+1)
+ copy(chainCopy, chain)
+ chainCopy[len(chainCopy)-1] = grant
+ } else {
+ chainCopy = nil
+ }
+
+ if g.walkGrants(grant.Subject, target, permission, f, chainCopy, visited, collect) {
+ return true
+ }
+ }
+ return false
+}
+
+func (g *memoryGraph) Verify(key libtrust.PublicKey, node string, permission uint16) (bool, error) {
+ return g.walkGrants(key.KeyID(), node, permission, foundWalkFunc, nil, nil, false), nil
+}
+
+func (g *memoryGraph) GetGrants(key libtrust.PublicKey, node string, permission uint16) ([][]*Grant, error) {
+ grants := [][]*Grant{}
+ collect := func(grant *Grant, chain []*Grant) bool {
+ grantChain := make([]*Grant, len(chain)+1)
+ copy(grantChain, chain)
+ grantChain[len(grantChain)-1] = grant
+ grants = append(grants, grantChain)
+ return false
+ }
+ g.walkGrants(key.KeyID(), node, permission, collect, nil, nil, true)
+ return grants, nil
+}
diff --git a/trustgraph/memory_graph_test.go b/trustgraph/memory_graph_test.go
new file mode 100644
index 0000000..fcbb6a0
--- /dev/null
+++ b/trustgraph/memory_graph_test.go
@@ -0,0 +1,174 @@
+package trustgraph
+
+import (
+ "fmt"
+ "testing"
+
+ "github.com/containers/libtrust"
+)
+
+func createTestKeysAndGrants(count int) ([]*Grant, []libtrust.PrivateKey) {
+ grants := make([]*Grant, count)
+ keys := make([]libtrust.PrivateKey, count)
+ for i := 0; i < count; i++ {
+ pk, err := libtrust.GenerateECP256PrivateKey()
+ if err != nil {
+ panic(err)
+ }
+ grant := &Grant{
+ Subject: fmt.Sprintf("/user-%d", i+1),
+ Permission: 0x0f,
+ Grantee: pk.KeyID(),
+ }
+ keys[i] = pk
+ grants[i] = grant
+ }
+ return grants, keys
+}
+
+func testVerified(t *testing.T, g TrustGraph, k libtrust.PublicKey, keyName, target string, permission uint16) {
+ if ok, err := g.Verify(k, target, permission); err != nil {
+ t.Fatalf("Unexpected error during verification: %s", err)
+ } else if !ok {
+ t.Errorf("key failed verification\n\tKey: %s(%s)\n\tNamespace: %s", keyName, k.KeyID(), target)
+ }
+}
+
+func testNotVerified(t *testing.T, g TrustGraph, k libtrust.PublicKey, keyName, target string, permission uint16) {
+ if ok, err := g.Verify(k, target, permission); err != nil {
+ t.Fatalf("Unexpected error during verification: %s", err)
+ } else if ok {
+ t.Errorf("key should have failed verification\n\tKey: %s(%s)\n\tNamespace: %s", keyName, k.KeyID(), target)
+ }
+}
+
+func TestVerify(t *testing.T) {
+ grants, keys := createTestKeysAndGrants(4)
+ extraGrants := make([]*Grant, 3)
+ extraGrants[0] = &Grant{
+ Subject: "/user-3",
+ Permission: 0x0f,
+ Grantee: "/user-2",
+ }
+ extraGrants[1] = &Grant{
+ Subject: "/user-3/sub-project",
+ Permission: 0x0f,
+ Grantee: "/user-4",
+ }
+ extraGrants[2] = &Grant{
+ Subject: "/user-4",
+ Permission: 0x07,
+ Grantee: "/user-1",
+ }
+ grants = append(grants, extraGrants...)
+
+ g := NewMemoryGraph(grants)
+
+ testVerified(t, g, keys[0].PublicKey(), "user-key-1", "/user-1", 0x0f)
+ testVerified(t, g, keys[0].PublicKey(), "user-key-1", "/user-1/some-project/sub-value", 0x0f)
+ testVerified(t, g, keys[0].PublicKey(), "user-key-1", "/user-4", 0x07)
+ testVerified(t, g, keys[1].PublicKey(), "user-key-2", "/user-2/", 0x0f)
+ testVerified(t, g, keys[2].PublicKey(), "user-key-3", "/user-3/sub-value", 0x0f)
+ testVerified(t, g, keys[1].PublicKey(), "user-key-2", "/user-3/sub-value", 0x0f)
+ testVerified(t, g, keys[1].PublicKey(), "user-key-2", "/user-3", 0x0f)
+ testVerified(t, g, keys[1].PublicKey(), "user-key-2", "/user-3/", 0x0f)
+ testVerified(t, g, keys[3].PublicKey(), "user-key-4", "/user-3/sub-project", 0x0f)
+ testVerified(t, g, keys[3].PublicKey(), "user-key-4", "/user-3/sub-project/app", 0x0f)
+ testVerified(t, g, keys[3].PublicKey(), "user-key-4", "/user-4", 0x0f)
+
+ testNotVerified(t, g, keys[0].PublicKey(), "user-key-1", "/user-2", 0x0f)
+ testNotVerified(t, g, keys[0].PublicKey(), "user-key-1", "/user-3/sub-value", 0x0f)
+ testNotVerified(t, g, keys[0].PublicKey(), "user-key-1", "/user-4", 0x0f)
+ testNotVerified(t, g, keys[1].PublicKey(), "user-key-2", "/user-1/", 0x0f)
+ testNotVerified(t, g, keys[2].PublicKey(), "user-key-3", "/user-2", 0x0f)
+ testNotVerified(t, g, keys[1].PublicKey(), "user-key-2", "/user-4", 0x0f)
+ testNotVerified(t, g, keys[3].PublicKey(), "user-key-4", "/user-3", 0x0f)
+}
+
+func TestCircularWalk(t *testing.T) {
+ grants, keys := createTestKeysAndGrants(3)
+ user1Grant := &Grant{
+ Subject: "/user-2",
+ Permission: 0x0f,
+ Grantee: "/user-1",
+ }
+ user2Grant := &Grant{
+ Subject: "/user-1",
+ Permission: 0x0f,
+ Grantee: "/user-2",
+ }
+ grants = append(grants, user1Grant, user2Grant)
+
+ g := NewMemoryGraph(grants)
+
+ testVerified(t, g, keys[0].PublicKey(), "user-key-1", "/user-1", 0x0f)
+ testVerified(t, g, keys[0].PublicKey(), "user-key-1", "/user-2", 0x0f)
+ testVerified(t, g, keys[1].PublicKey(), "user-key-2", "/user-2", 0x0f)
+ testVerified(t, g, keys[1].PublicKey(), "user-key-2", "/user-1", 0x0f)
+ testVerified(t, g, keys[2].PublicKey(), "user-key-3", "/user-3", 0x0f)
+
+ testNotVerified(t, g, keys[0].PublicKey(), "user-key-1", "/user-3", 0x0f)
+ testNotVerified(t, g, keys[1].PublicKey(), "user-key-2", "/user-3", 0x0f)
+}
+
+func assertGrantSame(t *testing.T, actual, expected *Grant) {
+ if actual != expected {
+ t.Fatalf("Unexpected grant retrieved\n\tExpected: %v\n\tActual: %v", expected, actual)
+ }
+}
+
+func TestGetGrants(t *testing.T) {
+ grants, keys := createTestKeysAndGrants(5)
+ extraGrants := make([]*Grant, 4)
+ extraGrants[0] = &Grant{
+ Subject: "/user-3/friend-project",
+ Permission: 0x0f,
+ Grantee: "/user-2/friends",
+ }
+ extraGrants[1] = &Grant{
+ Subject: "/user-3/sub-project",
+ Permission: 0x0f,
+ Grantee: "/user-4",
+ }
+ extraGrants[2] = &Grant{
+ Subject: "/user-2/friends",
+ Permission: 0x0f,
+ Grantee: "/user-5/fun-project",
+ }
+ extraGrants[3] = &Grant{
+ Subject: "/user-5/fun-project",
+ Permission: 0x0f,
+ Grantee: "/user-1",
+ }
+ grants = append(grants, extraGrants...)
+
+ g := NewMemoryGraph(grants)
+
+ grantChains, err := g.GetGrants(keys[3], "/user-3/sub-project/specific-app", 0x0f)
+ if err != nil {
+ t.Fatalf("Error getting grants: %s", err)
+ }
+ if len(grantChains) != 1 {
+ t.Fatalf("Expected number of grant chains returned, expected %d, received %d", 1, len(grantChains))
+ }
+ if len(grantChains[0]) != 2 {
+ t.Fatalf("Unexpected number of grants retrieved\n\tExpected: %d\n\tActual: %d", 2, len(grantChains[0]))
+ }
+ assertGrantSame(t, grantChains[0][0], grants[3])
+ assertGrantSame(t, grantChains[0][1], extraGrants[1])
+
+ grantChains, err = g.GetGrants(keys[0], "/user-3/friend-project/fun-app", 0x0f)
+ if err != nil {
+ t.Fatalf("Error getting grants: %s", err)
+ }
+ if len(grantChains) != 1 {
+ t.Fatalf("Expected number of grant chains returned, expected %d, received %d", 1, len(grantChains))
+ }
+ if len(grantChains[0]) != 4 {
+ t.Fatalf("Unexpected number of grants retrieved\n\tExpected: %d\n\tActual: %d", 2, len(grantChains[0]))
+ }
+ assertGrantSame(t, grantChains[0][0], grants[0])
+ assertGrantSame(t, grantChains[0][1], extraGrants[3])
+ assertGrantSame(t, grantChains[0][2], extraGrants[2])
+ assertGrantSame(t, grantChains[0][3], extraGrants[0])
+}
diff --git a/trustgraph/statement.go b/trustgraph/statement.go
new file mode 100644
index 0000000..c684ec0
--- /dev/null
+++ b/trustgraph/statement.go
@@ -0,0 +1,227 @@
+package trustgraph
+
+import (
+ "crypto/x509"
+ "encoding/json"
+ "io"
+ "io/ioutil"
+ "sort"
+ "strings"
+ "time"
+
+ "github.com/containers/libtrust"
+)
+
+type jsonGrant struct {
+ Subject string `json:"subject"`
+ Permission uint16 `json:"permission"`
+ Grantee string `json:"grantee"`
+}
+
+type jsonRevocation struct {
+ Subject string `json:"subject"`
+ Revocation uint16 `json:"revocation"`
+ Grantee string `json:"grantee"`
+}
+
+type jsonStatement struct {
+ Revocations []*jsonRevocation `json:"revocations"`
+ Grants []*jsonGrant `json:"grants"`
+ Expiration time.Time `json:"expiration"`
+ IssuedAt time.Time `json:"issuedAt"`
+}
+
+func (g *jsonGrant) Grant(statement *Statement) *Grant {
+ return &Grant{
+ Subject: g.Subject,
+ Permission: g.Permission,
+ Grantee: g.Grantee,
+ statement: statement,
+ }
+}
+
+// Statement represents a set of grants made from a verifiable
+// authority. A statement has an expiration associated with it
+// set by the authority.
+type Statement struct {
+ jsonStatement
+
+ signature *libtrust.JSONSignature
+}
+
+// IsExpired returns whether the statement has expired
+func (s *Statement) IsExpired() bool {
+ return s.Expiration.Before(time.Now().Add(-10 * time.Second))
+}
+
+// Bytes returns an indented json representation of the statement
+// in a byte array. This value can be written to a file or stream
+// without alteration.
+func (s *Statement) Bytes() ([]byte, error) {
+ return s.signature.PrettySignature("signatures")
+}
+
+// LoadStatement loads and verifies a statement from an input stream.
+func LoadStatement(r io.Reader, authority *x509.CertPool) (*Statement, error) {
+ b, err := ioutil.ReadAll(r)
+ if err != nil {
+ return nil, err
+ }
+ js, err := libtrust.ParsePrettySignature(b, "signatures")
+ if err != nil {
+ return nil, err
+ }
+ payload, err := js.Payload()
+ if err != nil {
+ return nil, err
+ }
+ var statement Statement
+ err = json.Unmarshal(payload, &statement.jsonStatement)
+ if err != nil {
+ return nil, err
+ }
+
+ if authority == nil {
+ _, err = js.Verify()
+ if err != nil {
+ return nil, err
+ }
+ } else {
+ _, err = js.VerifyChains(authority)
+ if err != nil {
+ return nil, err
+ }
+ }
+ statement.signature = js
+
+ return &statement, nil
+}
+
+// CreateStatements creates and signs a statement from a stream of grants
+// and revocations in a JSON array.
+func CreateStatement(grants, revocations io.Reader, expiration time.Duration, key libtrust.PrivateKey, chain []*x509.Certificate) (*Statement, error) {
+ var statement Statement
+ err := json.NewDecoder(grants).Decode(&statement.jsonStatement.Grants)
+ if err != nil {
+ return nil, err
+ }
+ err = json.NewDecoder(revocations).Decode(&statement.jsonStatement.Revocations)
+ if err != nil {
+ return nil, err
+ }
+ statement.jsonStatement.Expiration = time.Now().UTC().Add(expiration)
+ statement.jsonStatement.IssuedAt = time.Now().UTC()
+
+ b, err := json.MarshalIndent(&statement.jsonStatement, "", " ")
+ if err != nil {
+ return nil, err
+ }
+
+ statement.signature, err = libtrust.NewJSONSignature(b)
+ if err != nil {
+ return nil, err
+ }
+ err = statement.signature.SignWithChain(key, chain)
+ if err != nil {
+ return nil, err
+ }
+
+ return &statement, nil
+}
+
+type statementList []*Statement
+
+func (s statementList) Len() int {
+ return len(s)
+}
+
+func (s statementList) Less(i, j int) bool {
+ return s[i].IssuedAt.Before(s[j].IssuedAt)
+}
+
+func (s statementList) Swap(i, j int) {
+ s[i], s[j] = s[j], s[i]
+}
+
+// CollapseStatements returns a single list of the valid statements as well as the
+// time when the next grant will expire.
+func CollapseStatements(statements []*Statement, useExpired bool) ([]*Grant, time.Time, error) {
+ sorted := make(statementList, 0, len(statements))
+ for _, statement := range statements {
+ if useExpired || !statement.IsExpired() {
+ sorted = append(sorted, statement)
+ }
+ }
+ sort.Sort(sorted)
+
+ var minExpired time.Time
+ var grantCount int
+ roots := map[string]*grantNode{}
+ for i, statement := range sorted {
+ if statement.Expiration.Before(minExpired) || i == 0 {
+ minExpired = statement.Expiration
+ }
+ for _, grant := range statement.Grants {
+ parts := strings.Split(grant.Grantee, "/")
+ nodes := roots
+ g := grant.Grant(statement)
+ grantCount = grantCount + 1
+
+ for _, part := range parts {
+ node, nodeOk := nodes[part]
+ if !nodeOk {
+ node = newGrantNode()
+ nodes[part] = node
+ }
+ node.grants = append(node.grants, g)
+ nodes = node.children
+ }
+ }
+
+ for _, revocation := range statement.Revocations {
+ parts := strings.Split(revocation.Grantee, "/")
+ nodes := roots
+
+ var node *grantNode
+ var nodeOk bool
+ for _, part := range parts {
+ node, nodeOk = nodes[part]
+ if !nodeOk {
+ break
+ }
+ nodes = node.children
+ }
+ if node != nil {
+ for _, grant := range node.grants {
+ if isSubName(grant.Subject, revocation.Subject) {
+ grant.Permission = grant.Permission &^ revocation.Revocation
+ }
+ }
+ }
+ }
+ }
+
+ retGrants := make([]*Grant, 0, grantCount)
+ for _, rootNodes := range roots {
+ retGrants = append(retGrants, rootNodes.grants...)
+ }
+
+ return retGrants, minExpired, nil
+}
+
+// FilterStatements filters the statements to statements including the given grants.
+func FilterStatements(grants []*Grant) ([]*Statement, error) {
+ statements := map[*Statement]bool{}
+ for _, grant := range grants {
+ if grant.statement != nil {
+ statements[grant.statement] = true
+ }
+ }
+ retStatements := make([]*Statement, len(statements))
+ var i int
+ for statement := range statements {
+ retStatements[i] = statement
+ i++
+ }
+ return retStatements, nil
+}
diff --git a/trustgraph/statement_test.go b/trustgraph/statement_test.go
new file mode 100644
index 0000000..1b8f7c8
--- /dev/null
+++ b/trustgraph/statement_test.go
@@ -0,0 +1,417 @@
+package trustgraph
+
+import (
+ "bytes"
+ "crypto/x509"
+ "encoding/json"
+ "testing"
+ "time"
+
+ "github.com/containers/libtrust"
+ "github.com/containers/libtrust/testutil"
+)
+
+const testStatementExpiration = time.Hour * 5
+
+func generateStatement(grants []*Grant, key libtrust.PrivateKey, chain []*x509.Certificate) (*Statement, error) {
+ var statement Statement
+
+ statement.Grants = make([]*jsonGrant, len(grants))
+ for i, grant := range grants {
+ statement.Grants[i] = &jsonGrant{
+ Subject: grant.Subject,
+ Permission: grant.Permission,
+ Grantee: grant.Grantee,
+ }
+ }
+ statement.IssuedAt = time.Now()
+ statement.Expiration = time.Now().Add(testStatementExpiration)
+ statement.Revocations = make([]*jsonRevocation, 0)
+
+ marshalled, err := json.MarshalIndent(statement.jsonStatement, "", " ")
+ if err != nil {
+ return nil, err
+ }
+
+ sig, err := libtrust.NewJSONSignature(marshalled)
+ if err != nil {
+ return nil, err
+ }
+ err = sig.SignWithChain(key, chain)
+ if err != nil {
+ return nil, err
+ }
+ statement.signature = sig
+
+ return &statement, nil
+}
+
+func generateTrustChain(t *testing.T, chainLen int) (libtrust.PrivateKey, *x509.CertPool, []*x509.Certificate) {
+ caKey, err := libtrust.GenerateECP256PrivateKey()
+ if err != nil {
+ t.Fatalf("Error generating key: %s", err)
+ }
+ ca, err := testutil.GenerateTrustCA(caKey.CryptoPublicKey(), caKey.CryptoPrivateKey())
+ if err != nil {
+ t.Fatalf("Error generating ca: %s", err)
+ }
+
+ parent := ca
+ parentKey := caKey
+ chain := make([]*x509.Certificate, chainLen)
+ for i := chainLen - 1; i > 0; i-- {
+ intermediatekey, err := libtrust.GenerateECP256PrivateKey()
+ if err != nil {
+ t.Fatalf("Error generate key: %s", err)
+ }
+ chain[i], err = testutil.GenerateIntermediate(intermediatekey.CryptoPublicKey(), parentKey.CryptoPrivateKey(), parent)
+ if err != nil {
+ t.Fatalf("Error generating intermdiate certificate: %s", err)
+ }
+ parent = chain[i]
+ parentKey = intermediatekey
+ }
+ trustKey, err := libtrust.GenerateECP256PrivateKey()
+ if err != nil {
+ t.Fatalf("Error generate key: %s", err)
+ }
+ chain[0], err = testutil.GenerateTrustCert(trustKey.CryptoPublicKey(), parentKey.CryptoPrivateKey(), parent)
+ if err != nil {
+ t.Fatalf("Error generate trust cert: %s", err)
+ }
+
+ caPool := x509.NewCertPool()
+ caPool.AddCert(ca)
+
+ return trustKey, caPool, chain
+}
+
+func TestLoadStatement(t *testing.T) {
+ grantCount := 4
+ grants, _ := createTestKeysAndGrants(grantCount)
+
+ trustKey, caPool, chain := generateTrustChain(t, 6)
+
+ statement, err := generateStatement(grants, trustKey, chain)
+ if err != nil {
+ t.Fatalf("Error generating statement: %s", err)
+ }
+
+ statementBytes, err := statement.Bytes()
+ if err != nil {
+ t.Fatalf("Error getting statement bytes: %s", err)
+ }
+
+ s2, err := LoadStatement(bytes.NewReader(statementBytes), caPool)
+ if err != nil {
+ t.Fatalf("Error loading statement: %s", err)
+ }
+ if len(s2.Grants) != grantCount {
+ t.Fatalf("Unexpected grant length\n\tExpected: %d\n\tActual: %d", grantCount, len(s2.Grants))
+ }
+
+ pool := x509.NewCertPool()
+ _, err = LoadStatement(bytes.NewReader(statementBytes), pool)
+ if err == nil {
+ t.Fatalf("No error thrown verifying without an authority")
+ } else if _, ok := err.(x509.UnknownAuthorityError); !ok {
+ t.Fatalf("Unexpected error verifying without authority: %s", err)
+ }
+
+ s2, err = LoadStatement(bytes.NewReader(statementBytes), nil)
+ if err != nil {
+ t.Fatalf("Error loading statement: %s", err)
+ }
+ if len(s2.Grants) != grantCount {
+ t.Fatalf("Unexpected grant length\n\tExpected: %d\n\tActual: %d", grantCount, len(s2.Grants))
+ }
+
+ badData := make([]byte, len(statementBytes))
+ copy(badData, statementBytes)
+ badData[0] = '['
+ _, err = LoadStatement(bytes.NewReader(badData), nil)
+ if err == nil {
+ t.Fatalf("No error thrown parsing bad json")
+ }
+
+ alteredData := make([]byte, len(statementBytes))
+ copy(alteredData, statementBytes)
+ alteredData[30] = '0'
+ _, err = LoadStatement(bytes.NewReader(alteredData), nil)
+ if err == nil {
+ t.Fatalf("No error thrown from bad data")
+ }
+}
+
+func TestCollapseGrants(t *testing.T) {
+ grantCount := 8
+ grants, keys := createTestKeysAndGrants(grantCount)
+ linkGrants := make([]*Grant, 4)
+ linkGrants[0] = &Grant{
+ Subject: "/user-3",
+ Permission: 0x0f,
+ Grantee: "/user-2",
+ }
+ linkGrants[1] = &Grant{
+ Subject: "/user-3/sub-project",
+ Permission: 0x0f,
+ Grantee: "/user-4",
+ }
+ linkGrants[2] = &Grant{
+ Subject: "/user-6",
+ Permission: 0x0f,
+ Grantee: "/user-7",
+ }
+ linkGrants[3] = &Grant{
+ Subject: "/user-6/sub-project/specific-app",
+ Permission: 0x0f,
+ Grantee: "/user-5",
+ }
+ trustKey, pool, chain := generateTrustChain(t, 3)
+
+ statements := make([]*Statement, 3)
+ var err error
+ statements[0], err = generateStatement(grants[0:4], trustKey, chain)
+ if err != nil {
+ t.Fatalf("Error generating statement: %s", err)
+ }
+ statements[1], err = generateStatement(grants[4:], trustKey, chain)
+ if err != nil {
+ t.Fatalf("Error generating statement: %s", err)
+ }
+ statements[2], err = generateStatement(linkGrants, trustKey, chain)
+ if err != nil {
+ t.Fatalf("Error generating statement: %s", err)
+ }
+
+ statementsCopy := make([]*Statement, len(statements))
+ for i, statement := range statements {
+ b, err := statement.Bytes()
+ if err != nil {
+ t.Fatalf("Error getting statement bytes: %s", err)
+ }
+ verifiedStatement, err := LoadStatement(bytes.NewReader(b), pool)
+ if err != nil {
+ t.Fatalf("Error loading statement: %s", err)
+ }
+ // Force sort by reversing order
+ statementsCopy[len(statementsCopy)-i-1] = verifiedStatement
+ }
+ statements = statementsCopy
+
+ collapsedGrants, expiration, err := CollapseStatements(statements, false)
+ if len(collapsedGrants) != 12 {
+ t.Fatalf("Unexpected number of grants\n\tExpected: %d\n\tActual: %d", 12, len(collapsedGrants))
+ }
+ if expiration.After(time.Now().Add(time.Hour*5)) || expiration.Before(time.Now()) {
+ t.Fatalf("Unexpected expiration time: %s", expiration.String())
+ }
+ g := NewMemoryGraph(collapsedGrants)
+
+ testVerified(t, g, keys[0].PublicKey(), "user-key-1", "/user-1", 0x0f)
+ testVerified(t, g, keys[1].PublicKey(), "user-key-2", "/user-2", 0x0f)
+ testVerified(t, g, keys[2].PublicKey(), "user-key-3", "/user-3", 0x0f)
+ testVerified(t, g, keys[3].PublicKey(), "user-key-4", "/user-4", 0x0f)
+ testVerified(t, g, keys[4].PublicKey(), "user-key-5", "/user-5", 0x0f)
+ testVerified(t, g, keys[5].PublicKey(), "user-key-6", "/user-6", 0x0f)
+ testVerified(t, g, keys[6].PublicKey(), "user-key-7", "/user-7", 0x0f)
+ testVerified(t, g, keys[7].PublicKey(), "user-key-8", "/user-8", 0x0f)
+ testVerified(t, g, keys[1].PublicKey(), "user-key-2", "/user-3", 0x0f)
+ testVerified(t, g, keys[1].PublicKey(), "user-key-2", "/user-3/sub-project/specific-app", 0x0f)
+ testVerified(t, g, keys[3].PublicKey(), "user-key-4", "/user-3/sub-project", 0x0f)
+ testVerified(t, g, keys[6].PublicKey(), "user-key-7", "/user-6", 0x0f)
+ testVerified(t, g, keys[6].PublicKey(), "user-key-7", "/user-6/sub-project/specific-app", 0x0f)
+ testVerified(t, g, keys[4].PublicKey(), "user-key-5", "/user-6/sub-project/specific-app", 0x0f)
+
+ testNotVerified(t, g, keys[3].PublicKey(), "user-key-4", "/user-3", 0x0f)
+ testNotVerified(t, g, keys[3].PublicKey(), "user-key-4", "/user-6/sub-project", 0x0f)
+ testNotVerified(t, g, keys[4].PublicKey(), "user-key-5", "/user-6/sub-project", 0x0f)
+
+ // Add revocation grant
+ statements = append(statements, &Statement{
+ jsonStatement{
+ IssuedAt: time.Now(),
+ Expiration: time.Now().Add(testStatementExpiration),
+ Grants: []*jsonGrant{},
+ Revocations: []*jsonRevocation{
+ &jsonRevocation{
+ Subject: "/user-1",
+ Revocation: 0x0f,
+ Grantee: keys[0].KeyID(),
+ },
+ &jsonRevocation{
+ Subject: "/user-2",
+ Revocation: 0x08,
+ Grantee: keys[1].KeyID(),
+ },
+ &jsonRevocation{
+ Subject: "/user-6",
+ Revocation: 0x0f,
+ Grantee: "/user-7",
+ },
+ &jsonRevocation{
+ Subject: "/user-9",
+ Revocation: 0x0f,
+ Grantee: "/user-10",
+ },
+ },
+ },
+ nil,
+ })
+
+ collapsedGrants, expiration, err = CollapseStatements(statements, false)
+ if len(collapsedGrants) != 12 {
+ t.Fatalf("Unexpected number of grants\n\tExpected: %d\n\tActual: %d", 12, len(collapsedGrants))
+ }
+ if expiration.After(time.Now().Add(time.Hour*5)) || expiration.Before(time.Now()) {
+ t.Fatalf("Unexpected expiration time: %s", expiration.String())
+ }
+ g = NewMemoryGraph(collapsedGrants)
+
+ testNotVerified(t, g, keys[0].PublicKey(), "user-key-1", "/user-1", 0x0f)
+ testNotVerified(t, g, keys[1].PublicKey(), "user-key-2", "/user-2", 0x0f)
+ testNotVerified(t, g, keys[6].PublicKey(), "user-key-7", "/user-6/sub-project/specific-app", 0x0f)
+
+ testVerified(t, g, keys[1].PublicKey(), "user-key-2", "/user-2", 0x07)
+}
+
+func TestFilterStatements(t *testing.T) {
+ grantCount := 8
+ grants, keys := createTestKeysAndGrants(grantCount)
+ linkGrants := make([]*Grant, 3)
+ linkGrants[0] = &Grant{
+ Subject: "/user-3",
+ Permission: 0x0f,
+ Grantee: "/user-2",
+ }
+ linkGrants[1] = &Grant{
+ Subject: "/user-5",
+ Permission: 0x0f,
+ Grantee: "/user-4",
+ }
+ linkGrants[2] = &Grant{
+ Subject: "/user-7",
+ Permission: 0x0f,
+ Grantee: "/user-6",
+ }
+
+ trustKey, _, chain := generateTrustChain(t, 3)
+
+ statements := make([]*Statement, 5)
+ var err error
+ statements[0], err = generateStatement(grants[0:2], trustKey, chain)
+ if err != nil {
+ t.Fatalf("Error generating statement: %s", err)
+ }
+ statements[1], err = generateStatement(grants[2:4], trustKey, chain)
+ if err != nil {
+ t.Fatalf("Error generating statement: %s", err)
+ }
+ statements[2], err = generateStatement(grants[4:6], trustKey, chain)
+ if err != nil {
+ t.Fatalf("Error generating statement: %s", err)
+ }
+ statements[3], err = generateStatement(grants[6:], trustKey, chain)
+ if err != nil {
+ t.Fatalf("Error generating statement: %s", err)
+ }
+ statements[4], err = generateStatement(linkGrants, trustKey, chain)
+ if err != nil {
+ t.Fatalf("Error generating statement: %s", err)
+ }
+ collapsed, _, err := CollapseStatements(statements, false)
+ if err != nil {
+ t.Fatalf("Error collapsing grants: %s", err)
+ }
+
+ // Filter 1, all 5 statements
+ filter1, err := FilterStatements(collapsed)
+ if err != nil {
+ t.Fatalf("Error filtering statements: %s", err)
+ }
+ if len(filter1) != 5 {
+ t.Fatalf("Wrong number of statements, expected %d, received %d", 5, len(filter1))
+ }
+
+ // Filter 2, one statement
+ filter2, err := FilterStatements([]*Grant{collapsed[0]})
+ if err != nil {
+ t.Fatalf("Error filtering statements: %s", err)
+ }
+ if len(filter2) != 1 {
+ t.Fatalf("Wrong number of statements, expected %d, received %d", 1, len(filter2))
+ }
+
+ // Filter 3, 2 statements, from graph lookup
+ g := NewMemoryGraph(collapsed)
+ lookupGrants, err := g.GetGrants(keys[1], "/user-3", 0x0f)
+ if err != nil {
+ t.Fatalf("Error looking up grants: %s", err)
+ }
+ if len(lookupGrants) != 1 {
+ t.Fatalf("Wrong numberof grant chains returned from lookup, expected %d, received %d", 1, len(lookupGrants))
+ }
+ if len(lookupGrants[0]) != 2 {
+ t.Fatalf("Wrong number of grants looked up, expected %d, received %d", 2, len(lookupGrants))
+ }
+ filter3, err := FilterStatements(lookupGrants[0])
+ if err != nil {
+ t.Fatalf("Error filtering statements: %s", err)
+ }
+ if len(filter3) != 2 {
+ t.Fatalf("Wrong number of statements, expected %d, received %d", 2, len(filter3))
+ }
+
+}
+
+func TestCreateStatement(t *testing.T) {
+ grantJSON := bytes.NewReader([]byte(`[
+ {
+ "subject": "/user-2",
+ "permission": 15,
+ "grantee": "/user-1"
+ },
+ {
+ "subject": "/user-7",
+ "permission": 1,
+ "grantee": "/user-9"
+ },
+ {
+ "subject": "/user-3",
+ "permission": 15,
+ "grantee": "/user-2"
+ }
+]`))
+ revocationJSON := bytes.NewReader([]byte(`[
+ {
+ "subject": "user-8",
+ "revocation": 12,
+ "grantee": "user-9"
+ }
+]`))
+
+ trustKey, pool, chain := generateTrustChain(t, 3)
+
+ statement, err := CreateStatement(grantJSON, revocationJSON, testStatementExpiration, trustKey, chain)
+ if err != nil {
+ t.Fatalf("Error creating statement: %s", err)
+ }
+
+ b, err := statement.Bytes()
+ if err != nil {
+ t.Fatalf("Error retrieving bytes: %s", err)
+ }
+
+ verified, err := LoadStatement(bytes.NewReader(b), pool)
+ if err != nil {
+ t.Fatalf("Error loading statement: %s", err)
+ }
+
+ if len(verified.Grants) != 3 {
+ t.Errorf("Unexpected number of grants, expected %d, received %d", 3, len(verified.Grants))
+ }
+
+ if len(verified.Revocations) != 1 {
+ t.Errorf("Unexpected number of revocations, expected %d, received %d", 1, len(verified.Revocations))
+ }
+}
diff --git a/util.go b/util.go
new file mode 100644
index 0000000..a5a101d
--- /dev/null
+++ b/util.go
@@ -0,0 +1,363 @@
+package libtrust
+
+import (
+ "bytes"
+ "crypto"
+ "crypto/elliptic"
+ "crypto/tls"
+ "crypto/x509"
+ "encoding/base32"
+ "encoding/base64"
+ "encoding/binary"
+ "encoding/pem"
+ "errors"
+ "fmt"
+ "math/big"
+ "net/url"
+ "os"
+ "path/filepath"
+ "strings"
+ "time"
+)
+
+// LoadOrCreateTrustKey will load a PrivateKey from the specified path
+func LoadOrCreateTrustKey(trustKeyPath string) (PrivateKey, error) {
+ if err := os.MkdirAll(filepath.Dir(trustKeyPath), 0700); err != nil {
+ return nil, err
+ }
+
+ trustKey, err := LoadKeyFile(trustKeyPath)
+ if err == ErrKeyFileDoesNotExist {
+ trustKey, err = GenerateECP256PrivateKey()
+ if err != nil {
+ return nil, fmt.Errorf("error generating key: %s", err)
+ }
+
+ if err := SaveKey(trustKeyPath, trustKey); err != nil {
+ return nil, fmt.Errorf("error saving key file: %s", err)
+ }
+
+ dir, file := filepath.Split(trustKeyPath)
+ if err := SavePublicKey(filepath.Join(dir, "public-"+file), trustKey.PublicKey()); err != nil {
+ return nil, fmt.Errorf("error saving public key file: %s", err)
+ }
+ } else if err != nil {
+ return nil, fmt.Errorf("error loading key file: %s", err)
+ }
+ return trustKey, nil
+}
+
+// NewIdentityAuthTLSClientConfig returns a tls.Config configured to use identity
+// based authentication from the specified dockerUrl, the rootConfigPath and
+// the server name to which it is connecting.
+// If trustUnknownHosts is true it will automatically add the host to the
+// known-hosts.json in rootConfigPath.
+func NewIdentityAuthTLSClientConfig(dockerUrl string, trustUnknownHosts bool, rootConfigPath string, serverName string) (*tls.Config, error) {
+ tlsConfig := newTLSConfig()
+
+ trustKeyPath := filepath.Join(rootConfigPath, "key.json")
+ knownHostsPath := filepath.Join(rootConfigPath, "known-hosts.json")
+
+ u, err := url.Parse(dockerUrl)
+ if err != nil {
+ return nil, fmt.Errorf("unable to parse machine url")
+ }
+
+ if u.Scheme == "unix" {
+ return nil, nil
+ }
+
+ addr := u.Host
+ proto := "tcp"
+
+ trustKey, err := LoadOrCreateTrustKey(trustKeyPath)
+ if err != nil {
+ return nil, fmt.Errorf("unable to load trust key: %s", err)
+ }
+
+ knownHosts, err := LoadKeySetFile(knownHostsPath)
+ if err != nil {
+ return nil, fmt.Errorf("could not load trusted hosts file: %s", err)
+ }
+
+ allowedHosts, err := FilterByHosts(knownHosts, addr, false)
+ if err != nil {
+ return nil, fmt.Errorf("error filtering hosts: %s", err)
+ }
+
+ certPool, err := GenerateCACertPool(trustKey, allowedHosts)
+ if err != nil {
+ return nil, fmt.Errorf("Could not create CA pool: %s", err)
+ }
+
+ tlsConfig.ServerName = serverName
+ tlsConfig.RootCAs = certPool
+
+ x509Cert, err := GenerateSelfSignedClientCert(trustKey)
+ if err != nil {
+ return nil, fmt.Errorf("certificate generation error: %s", err)
+ }
+
+ tlsConfig.Certificates = []tls.Certificate{{
+ Certificate: [][]byte{x509Cert.Raw},
+ PrivateKey: trustKey.CryptoPrivateKey(),
+ Leaf: x509Cert,
+ }}
+
+ tlsConfig.InsecureSkipVerify = true
+
+ testConn, err := tls.Dial(proto, addr, tlsConfig)
+ if err != nil {
+ return nil, fmt.Errorf("tls Handshake error: %s", err)
+ }
+
+ opts := x509.VerifyOptions{
+ Roots: tlsConfig.RootCAs,
+ CurrentTime: time.Now(),
+ DNSName: tlsConfig.ServerName,
+ Intermediates: x509.NewCertPool(),
+ }
+
+ certs := testConn.ConnectionState().PeerCertificates
+ for i, cert := range certs {
+ if i == 0 {
+ continue
+ }
+ opts.Intermediates.AddCert(cert)
+ }
+
+ if _, err := certs[0].Verify(opts); err != nil {
+ if _, ok := err.(x509.UnknownAuthorityError); ok {
+ if trustUnknownHosts {
+ pubKey, err := FromCryptoPublicKey(certs[0].PublicKey)
+ if err != nil {
+ return nil, fmt.Errorf("error extracting public key from cert: %s", err)
+ }
+
+ pubKey.AddExtendedField("hosts", []string{addr})
+
+ if err := AddKeySetFile(knownHostsPath, pubKey); err != nil {
+ return nil, fmt.Errorf("error adding machine to known hosts: %s", err)
+ }
+ } else {
+ return nil, fmt.Errorf("unable to connect. unknown host: %s", addr)
+ }
+ }
+ }
+
+ testConn.Close()
+ tlsConfig.InsecureSkipVerify = false
+
+ return tlsConfig, nil
+}
+
+// joseBase64UrlEncode encodes the given data using the standard base64 url
+// encoding format but with all trailing '=' characters omitted in accordance
+// with the jose specification.
+// http://tools.ietf.org/html/draft-ietf-jose-json-web-signature-31#section-2
+func joseBase64UrlEncode(b []byte) string {
+ return strings.TrimRight(base64.URLEncoding.EncodeToString(b), "=")
+}
+
+// joseBase64UrlDecode decodes the given string using the standard base64 url
+// decoder but first adds the appropriate number of trailing '=' characters in
+// accordance with the jose specification.
+// http://tools.ietf.org/html/draft-ietf-jose-json-web-signature-31#section-2
+func joseBase64UrlDecode(s string) ([]byte, error) {
+ s = strings.Replace(s, "\n", "", -1)
+ s = strings.Replace(s, " ", "", -1)
+ switch len(s) % 4 {
+ case 0:
+ case 2:
+ s += "=="
+ case 3:
+ s += "="
+ default:
+ return nil, errors.New("illegal base64url string")
+ }
+ return base64.URLEncoding.DecodeString(s)
+}
+
+func keyIDEncode(b []byte) string {
+ s := strings.TrimRight(base32.StdEncoding.EncodeToString(b), "=")
+ var buf bytes.Buffer
+ var i int
+ for i = 0; i < len(s)/4-1; i++ {
+ start := i * 4
+ end := start + 4
+ buf.WriteString(s[start:end] + ":")
+ }
+ buf.WriteString(s[i*4:])
+ return buf.String()
+}
+
+func keyIDFromCryptoKey(pubKey PublicKey) string {
+ // Generate and return a 'libtrust' fingerprint of the public key.
+ // For an RSA key this should be:
+ // SHA256(DER encoded ASN1)
+ // Then truncated to 240 bits and encoded into 12 base32 groups like so:
+ // ABCD:EFGH:IJKL:MNOP:QRST:UVWX:YZ23:4567:ABCD:EFGH:IJKL:MNOP
+ derBytes, err := x509.MarshalPKIXPublicKey(pubKey.CryptoPublicKey())
+ if err != nil {
+ return ""
+ }
+ hasher := crypto.SHA256.New()
+ hasher.Write(derBytes)
+ return keyIDEncode(hasher.Sum(nil)[:30])
+}
+
+func stringFromMap(m map[string]interface{}, key string) (string, error) {
+ val, ok := m[key]
+ if !ok {
+ return "", fmt.Errorf("%q value not specified", key)
+ }
+
+ str, ok := val.(string)
+ if !ok {
+ return "", fmt.Errorf("%q value must be a string", key)
+ }
+ delete(m, key)
+
+ return str, nil
+}
+
+func parseECCoordinate(cB64Url string, curve elliptic.Curve) (*big.Int, error) {
+ curveByteLen := (curve.Params().BitSize + 7) >> 3
+
+ cBytes, err := joseBase64UrlDecode(cB64Url)
+ if err != nil {
+ return nil, fmt.Errorf("invalid base64 URL encoding: %s", err)
+ }
+ cByteLength := len(cBytes)
+ if cByteLength != curveByteLen {
+ return nil, fmt.Errorf("invalid number of octets: got %d, should be %d", cByteLength, curveByteLen)
+ }
+ return new(big.Int).SetBytes(cBytes), nil
+}
+
+func parseECPrivateParam(dB64Url string, curve elliptic.Curve) (*big.Int, error) {
+ dBytes, err := joseBase64UrlDecode(dB64Url)
+ if err != nil {
+ return nil, fmt.Errorf("invalid base64 URL encoding: %s", err)
+ }
+
+ // The length of this octet string MUST be ceiling(log-base-2(n)/8)
+ // octets (where n is the order of the curve). This is because the private
+ // key d must be in the interval [1, n-1] so the bitlength of d should be
+ // no larger than the bitlength of n-1. The easiest way to find the octet
+ // length is to take bitlength(n-1), add 7 to force a carry, and shift this
+ // bit sequence right by 3, which is essentially dividing by 8 and adding
+ // 1 if there is any remainder. Thus, the private key value d should be
+ // output to (bitlength(n-1)+7)>>3 octets.
+ n := curve.Params().N
+ octetLength := (new(big.Int).Sub(n, big.NewInt(1)).BitLen() + 7) >> 3
+ dByteLength := len(dBytes)
+
+ if dByteLength != octetLength {
+ return nil, fmt.Errorf("invalid number of octets: got %d, should be %d", dByteLength, octetLength)
+ }
+
+ return new(big.Int).SetBytes(dBytes), nil
+}
+
+func parseRSAModulusParam(nB64Url string) (*big.Int, error) {
+ nBytes, err := joseBase64UrlDecode(nB64Url)
+ if err != nil {
+ return nil, fmt.Errorf("invalid base64 URL encoding: %s", err)
+ }
+
+ return new(big.Int).SetBytes(nBytes), nil
+}
+
+func serializeRSAPublicExponentParam(e int) []byte {
+ // We MUST use the minimum number of octets to represent E.
+ // E is supposed to be 65537 for performance and security reasons
+ // and is what golang's rsa package generates, but it might be
+ // different if imported from some other generator.
+ buf := make([]byte, 4)
+ binary.BigEndian.PutUint32(buf, uint32(e))
+ var i int
+ for i = 0; i < 8; i++ {
+ if buf[i] != 0 {
+ break
+ }
+ }
+ return buf[i:]
+}
+
+func parseRSAPublicExponentParam(eB64Url string) (int, error) {
+ eBytes, err := joseBase64UrlDecode(eB64Url)
+ if err != nil {
+ return 0, fmt.Errorf("invalid base64 URL encoding: %s", err)
+ }
+ // Only the minimum number of bytes were used to represent E, but
+ // binary.BigEndian.Uint32 expects at least 4 bytes, so we need
+ // to add zero padding if necassary.
+ byteLen := len(eBytes)
+ buf := make([]byte, 4-byteLen, 4)
+ eBytes = append(buf, eBytes...)
+
+ return int(binary.BigEndian.Uint32(eBytes)), nil
+}
+
+func parseRSAPrivateKeyParamFromMap(m map[string]interface{}, key string) (*big.Int, error) {
+ b64Url, err := stringFromMap(m, key)
+ if err != nil {
+ return nil, err
+ }
+
+ paramBytes, err := joseBase64UrlDecode(b64Url)
+ if err != nil {
+ return nil, fmt.Errorf("invaled base64 URL encoding: %s", err)
+ }
+
+ return new(big.Int).SetBytes(paramBytes), nil
+}
+
+func createPemBlock(name string, derBytes []byte, headers map[string]interface{}) (*pem.Block, error) {
+ pemBlock := &pem.Block{Type: name, Bytes: derBytes, Headers: map[string]string{}}
+ for k, v := range headers {
+ switch val := v.(type) {
+ case string:
+ pemBlock.Headers[k] = val
+ case []string:
+ if k == "hosts" {
+ pemBlock.Headers[k] = strings.Join(val, ",")
+ } else {
+ // Return error, non-encodable type
+ }
+ default:
+ // Return error, non-encodable type
+ }
+ }
+
+ return pemBlock, nil
+}
+
+func pubKeyFromPEMBlock(pemBlock *pem.Block) (PublicKey, error) {
+ cryptoPublicKey, err := x509.ParsePKIXPublicKey(pemBlock.Bytes)
+ if err != nil {
+ return nil, fmt.Errorf("unable to decode Public Key PEM data: %s", err)
+ }
+
+ pubKey, err := FromCryptoPublicKey(cryptoPublicKey)
+ if err != nil {
+ return nil, err
+ }
+
+ addPEMHeadersToKey(pemBlock, pubKey)
+
+ return pubKey, nil
+}
+
+func addPEMHeadersToKey(pemBlock *pem.Block, pubKey PublicKey) {
+ for key, value := range pemBlock.Headers {
+ var safeVal interface{}
+ if key == "hosts" {
+ safeVal = strings.Split(value, ",")
+ } else {
+ safeVal = value
+ }
+ pubKey.AddExtendedField(key, safeVal)
+ }
+}
diff --git a/util_test.go b/util_test.go
new file mode 100644
index 0000000..83b7cfb
--- /dev/null
+++ b/util_test.go
@@ -0,0 +1,45 @@
+package libtrust
+
+import (
+ "encoding/pem"
+ "reflect"
+ "testing"
+)
+
+func TestAddPEMHeadersToKey(t *testing.T) {
+ pk := &rsaPublicKey{nil, map[string]interface{}{}}
+ blk := &pem.Block{Headers: map[string]string{"hosts": "localhost,127.0.0.1"}}
+ addPEMHeadersToKey(blk, pk)
+
+ val := pk.GetExtendedField("hosts")
+ hosts, ok := val.([]string)
+ if !ok {
+ t.Fatalf("hosts type(%v), expected []string", reflect.TypeOf(val))
+ }
+ expected := []string{"localhost", "127.0.0.1"}
+ if !reflect.DeepEqual(hosts, expected) {
+ t.Errorf("hosts(%v), expected %v", hosts, expected)
+ }
+}
+
+func TestBase64URL(t *testing.T) {
+ clean := "eyJhbGciOiJQQkVTMi1IUzI1NitBMTI4S1ciLCJwMnMiOiIyV0NUY0paMVJ2ZF9DSnVKcmlwUTF3IiwicDJjIjo0MDk2LCJlbmMiOiJBMTI4Q0JDLUhTMjU2IiwiY3R5IjoiandrK2pzb24ifQ"
+
+ tests := []string{
+ clean, // clean roundtrip
+ "eyJhbGciOiJQQkVTMi1IUzI1NitBMTI4S1ciLCJwMnMiOiIyV0NUY0paMVJ2\nZF9DSnVKcmlwUTF3IiwicDJjIjo0MDk2LCJlbmMiOiJBMTI4Q0JDLUhTMjU2\nIiwiY3R5IjoiandrK2pzb24ifQ", // with newlines
+ "eyJhbGciOiJQQkVTMi1IUzI1NitBMTI4S1ciLCJwMnMiOiIyV0NUY0paMVJ2 \n ZF9DSnVKcmlwUTF3IiwicDJjIjo0MDk2LCJlbmMiOiJBMTI4Q0JDLUhTMjU2 \n IiwiY3R5IjoiandrK2pzb24ifQ", // with newlines and spaces
+ }
+
+ for i, test := range tests {
+ b, err := joseBase64UrlDecode(test)
+ if err != nil {
+ t.Fatalf("on test %d: %s", i, err)
+ }
+ got := joseBase64UrlEncode(b)
+
+ if got != clean {
+ t.Errorf("expected %q, got %q", clean, got)
+ }
+ }
+}