summaryrefslogtreecommitdiffstats
path: root/util.go
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2024-04-16 18:15:16 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2024-04-16 18:15:16 +0000
commit7ff64900ddd056f849635cef0f384be76c46c318 (patch)
treea95f201f843c1eceae41457bca3297b9ddf09c60 /util.go
parentInitial commit. (diff)
downloadgolang-github-containers-libtrust-7ff64900ddd056f849635cef0f384be76c46c318.tar.xz
golang-github-containers-libtrust-7ff64900ddd056f849635cef0f384be76c46c318.zip
Adding upstream version 0.0~git20230121.c1716e8.upstream/0.0_git20230121.c1716e8upstream
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'util.go')
-rw-r--r--util.go363
1 files changed, 363 insertions, 0 deletions
diff --git a/util.go b/util.go
new file mode 100644
index 0000000..a5a101d
--- /dev/null
+++ b/util.go
@@ -0,0 +1,363 @@
+package libtrust
+
+import (
+ "bytes"
+ "crypto"
+ "crypto/elliptic"
+ "crypto/tls"
+ "crypto/x509"
+ "encoding/base32"
+ "encoding/base64"
+ "encoding/binary"
+ "encoding/pem"
+ "errors"
+ "fmt"
+ "math/big"
+ "net/url"
+ "os"
+ "path/filepath"
+ "strings"
+ "time"
+)
+
+// LoadOrCreateTrustKey will load a PrivateKey from the specified path
+func LoadOrCreateTrustKey(trustKeyPath string) (PrivateKey, error) {
+ if err := os.MkdirAll(filepath.Dir(trustKeyPath), 0700); err != nil {
+ return nil, err
+ }
+
+ trustKey, err := LoadKeyFile(trustKeyPath)
+ if err == ErrKeyFileDoesNotExist {
+ trustKey, err = GenerateECP256PrivateKey()
+ if err != nil {
+ return nil, fmt.Errorf("error generating key: %s", err)
+ }
+
+ if err := SaveKey(trustKeyPath, trustKey); err != nil {
+ return nil, fmt.Errorf("error saving key file: %s", err)
+ }
+
+ dir, file := filepath.Split(trustKeyPath)
+ if err := SavePublicKey(filepath.Join(dir, "public-"+file), trustKey.PublicKey()); err != nil {
+ return nil, fmt.Errorf("error saving public key file: %s", err)
+ }
+ } else if err != nil {
+ return nil, fmt.Errorf("error loading key file: %s", err)
+ }
+ return trustKey, nil
+}
+
+// NewIdentityAuthTLSClientConfig returns a tls.Config configured to use identity
+// based authentication from the specified dockerUrl, the rootConfigPath and
+// the server name to which it is connecting.
+// If trustUnknownHosts is true it will automatically add the host to the
+// known-hosts.json in rootConfigPath.
+func NewIdentityAuthTLSClientConfig(dockerUrl string, trustUnknownHosts bool, rootConfigPath string, serverName string) (*tls.Config, error) {
+ tlsConfig := newTLSConfig()
+
+ trustKeyPath := filepath.Join(rootConfigPath, "key.json")
+ knownHostsPath := filepath.Join(rootConfigPath, "known-hosts.json")
+
+ u, err := url.Parse(dockerUrl)
+ if err != nil {
+ return nil, fmt.Errorf("unable to parse machine url")
+ }
+
+ if u.Scheme == "unix" {
+ return nil, nil
+ }
+
+ addr := u.Host
+ proto := "tcp"
+
+ trustKey, err := LoadOrCreateTrustKey(trustKeyPath)
+ if err != nil {
+ return nil, fmt.Errorf("unable to load trust key: %s", err)
+ }
+
+ knownHosts, err := LoadKeySetFile(knownHostsPath)
+ if err != nil {
+ return nil, fmt.Errorf("could not load trusted hosts file: %s", err)
+ }
+
+ allowedHosts, err := FilterByHosts(knownHosts, addr, false)
+ if err != nil {
+ return nil, fmt.Errorf("error filtering hosts: %s", err)
+ }
+
+ certPool, err := GenerateCACertPool(trustKey, allowedHosts)
+ if err != nil {
+ return nil, fmt.Errorf("Could not create CA pool: %s", err)
+ }
+
+ tlsConfig.ServerName = serverName
+ tlsConfig.RootCAs = certPool
+
+ x509Cert, err := GenerateSelfSignedClientCert(trustKey)
+ if err != nil {
+ return nil, fmt.Errorf("certificate generation error: %s", err)
+ }
+
+ tlsConfig.Certificates = []tls.Certificate{{
+ Certificate: [][]byte{x509Cert.Raw},
+ PrivateKey: trustKey.CryptoPrivateKey(),
+ Leaf: x509Cert,
+ }}
+
+ tlsConfig.InsecureSkipVerify = true
+
+ testConn, err := tls.Dial(proto, addr, tlsConfig)
+ if err != nil {
+ return nil, fmt.Errorf("tls Handshake error: %s", err)
+ }
+
+ opts := x509.VerifyOptions{
+ Roots: tlsConfig.RootCAs,
+ CurrentTime: time.Now(),
+ DNSName: tlsConfig.ServerName,
+ Intermediates: x509.NewCertPool(),
+ }
+
+ certs := testConn.ConnectionState().PeerCertificates
+ for i, cert := range certs {
+ if i == 0 {
+ continue
+ }
+ opts.Intermediates.AddCert(cert)
+ }
+
+ if _, err := certs[0].Verify(opts); err != nil {
+ if _, ok := err.(x509.UnknownAuthorityError); ok {
+ if trustUnknownHosts {
+ pubKey, err := FromCryptoPublicKey(certs[0].PublicKey)
+ if err != nil {
+ return nil, fmt.Errorf("error extracting public key from cert: %s", err)
+ }
+
+ pubKey.AddExtendedField("hosts", []string{addr})
+
+ if err := AddKeySetFile(knownHostsPath, pubKey); err != nil {
+ return nil, fmt.Errorf("error adding machine to known hosts: %s", err)
+ }
+ } else {
+ return nil, fmt.Errorf("unable to connect. unknown host: %s", addr)
+ }
+ }
+ }
+
+ testConn.Close()
+ tlsConfig.InsecureSkipVerify = false
+
+ return tlsConfig, nil
+}
+
+// joseBase64UrlEncode encodes the given data using the standard base64 url
+// encoding format but with all trailing '=' characters omitted in accordance
+// with the jose specification.
+// http://tools.ietf.org/html/draft-ietf-jose-json-web-signature-31#section-2
+func joseBase64UrlEncode(b []byte) string {
+ return strings.TrimRight(base64.URLEncoding.EncodeToString(b), "=")
+}
+
+// joseBase64UrlDecode decodes the given string using the standard base64 url
+// decoder but first adds the appropriate number of trailing '=' characters in
+// accordance with the jose specification.
+// http://tools.ietf.org/html/draft-ietf-jose-json-web-signature-31#section-2
+func joseBase64UrlDecode(s string) ([]byte, error) {
+ s = strings.Replace(s, "\n", "", -1)
+ s = strings.Replace(s, " ", "", -1)
+ switch len(s) % 4 {
+ case 0:
+ case 2:
+ s += "=="
+ case 3:
+ s += "="
+ default:
+ return nil, errors.New("illegal base64url string")
+ }
+ return base64.URLEncoding.DecodeString(s)
+}
+
+func keyIDEncode(b []byte) string {
+ s := strings.TrimRight(base32.StdEncoding.EncodeToString(b), "=")
+ var buf bytes.Buffer
+ var i int
+ for i = 0; i < len(s)/4-1; i++ {
+ start := i * 4
+ end := start + 4
+ buf.WriteString(s[start:end] + ":")
+ }
+ buf.WriteString(s[i*4:])
+ return buf.String()
+}
+
+func keyIDFromCryptoKey(pubKey PublicKey) string {
+ // Generate and return a 'libtrust' fingerprint of the public key.
+ // For an RSA key this should be:
+ // SHA256(DER encoded ASN1)
+ // Then truncated to 240 bits and encoded into 12 base32 groups like so:
+ // ABCD:EFGH:IJKL:MNOP:QRST:UVWX:YZ23:4567:ABCD:EFGH:IJKL:MNOP
+ derBytes, err := x509.MarshalPKIXPublicKey(pubKey.CryptoPublicKey())
+ if err != nil {
+ return ""
+ }
+ hasher := crypto.SHA256.New()
+ hasher.Write(derBytes)
+ return keyIDEncode(hasher.Sum(nil)[:30])
+}
+
+func stringFromMap(m map[string]interface{}, key string) (string, error) {
+ val, ok := m[key]
+ if !ok {
+ return "", fmt.Errorf("%q value not specified", key)
+ }
+
+ str, ok := val.(string)
+ if !ok {
+ return "", fmt.Errorf("%q value must be a string", key)
+ }
+ delete(m, key)
+
+ return str, nil
+}
+
+func parseECCoordinate(cB64Url string, curve elliptic.Curve) (*big.Int, error) {
+ curveByteLen := (curve.Params().BitSize + 7) >> 3
+
+ cBytes, err := joseBase64UrlDecode(cB64Url)
+ if err != nil {
+ return nil, fmt.Errorf("invalid base64 URL encoding: %s", err)
+ }
+ cByteLength := len(cBytes)
+ if cByteLength != curveByteLen {
+ return nil, fmt.Errorf("invalid number of octets: got %d, should be %d", cByteLength, curveByteLen)
+ }
+ return new(big.Int).SetBytes(cBytes), nil
+}
+
+func parseECPrivateParam(dB64Url string, curve elliptic.Curve) (*big.Int, error) {
+ dBytes, err := joseBase64UrlDecode(dB64Url)
+ if err != nil {
+ return nil, fmt.Errorf("invalid base64 URL encoding: %s", err)
+ }
+
+ // The length of this octet string MUST be ceiling(log-base-2(n)/8)
+ // octets (where n is the order of the curve). This is because the private
+ // key d must be in the interval [1, n-1] so the bitlength of d should be
+ // no larger than the bitlength of n-1. The easiest way to find the octet
+ // length is to take bitlength(n-1), add 7 to force a carry, and shift this
+ // bit sequence right by 3, which is essentially dividing by 8 and adding
+ // 1 if there is any remainder. Thus, the private key value d should be
+ // output to (bitlength(n-1)+7)>>3 octets.
+ n := curve.Params().N
+ octetLength := (new(big.Int).Sub(n, big.NewInt(1)).BitLen() + 7) >> 3
+ dByteLength := len(dBytes)
+
+ if dByteLength != octetLength {
+ return nil, fmt.Errorf("invalid number of octets: got %d, should be %d", dByteLength, octetLength)
+ }
+
+ return new(big.Int).SetBytes(dBytes), nil
+}
+
+func parseRSAModulusParam(nB64Url string) (*big.Int, error) {
+ nBytes, err := joseBase64UrlDecode(nB64Url)
+ if err != nil {
+ return nil, fmt.Errorf("invalid base64 URL encoding: %s", err)
+ }
+
+ return new(big.Int).SetBytes(nBytes), nil
+}
+
+func serializeRSAPublicExponentParam(e int) []byte {
+ // We MUST use the minimum number of octets to represent E.
+ // E is supposed to be 65537 for performance and security reasons
+ // and is what golang's rsa package generates, but it might be
+ // different if imported from some other generator.
+ buf := make([]byte, 4)
+ binary.BigEndian.PutUint32(buf, uint32(e))
+ var i int
+ for i = 0; i < 8; i++ {
+ if buf[i] != 0 {
+ break
+ }
+ }
+ return buf[i:]
+}
+
+func parseRSAPublicExponentParam(eB64Url string) (int, error) {
+ eBytes, err := joseBase64UrlDecode(eB64Url)
+ if err != nil {
+ return 0, fmt.Errorf("invalid base64 URL encoding: %s", err)
+ }
+ // Only the minimum number of bytes were used to represent E, but
+ // binary.BigEndian.Uint32 expects at least 4 bytes, so we need
+ // to add zero padding if necassary.
+ byteLen := len(eBytes)
+ buf := make([]byte, 4-byteLen, 4)
+ eBytes = append(buf, eBytes...)
+
+ return int(binary.BigEndian.Uint32(eBytes)), nil
+}
+
+func parseRSAPrivateKeyParamFromMap(m map[string]interface{}, key string) (*big.Int, error) {
+ b64Url, err := stringFromMap(m, key)
+ if err != nil {
+ return nil, err
+ }
+
+ paramBytes, err := joseBase64UrlDecode(b64Url)
+ if err != nil {
+ return nil, fmt.Errorf("invaled base64 URL encoding: %s", err)
+ }
+
+ return new(big.Int).SetBytes(paramBytes), nil
+}
+
+func createPemBlock(name string, derBytes []byte, headers map[string]interface{}) (*pem.Block, error) {
+ pemBlock := &pem.Block{Type: name, Bytes: derBytes, Headers: map[string]string{}}
+ for k, v := range headers {
+ switch val := v.(type) {
+ case string:
+ pemBlock.Headers[k] = val
+ case []string:
+ if k == "hosts" {
+ pemBlock.Headers[k] = strings.Join(val, ",")
+ } else {
+ // Return error, non-encodable type
+ }
+ default:
+ // Return error, non-encodable type
+ }
+ }
+
+ return pemBlock, nil
+}
+
+func pubKeyFromPEMBlock(pemBlock *pem.Block) (PublicKey, error) {
+ cryptoPublicKey, err := x509.ParsePKIXPublicKey(pemBlock.Bytes)
+ if err != nil {
+ return nil, fmt.Errorf("unable to decode Public Key PEM data: %s", err)
+ }
+
+ pubKey, err := FromCryptoPublicKey(cryptoPublicKey)
+ if err != nil {
+ return nil, err
+ }
+
+ addPEMHeadersToKey(pemBlock, pubKey)
+
+ return pubKey, nil
+}
+
+func addPEMHeadersToKey(pemBlock *pem.Block, pubKey PublicKey) {
+ for key, value := range pemBlock.Headers {
+ var safeVal interface{}
+ if key == "hosts" {
+ safeVal = strings.Split(value, ",")
+ } else {
+ safeVal = value
+ }
+ pubKey.AddExtendedField(key, safeVal)
+ }
+}