diff options
Diffstat (limited to 'signerverifier/rsa.go')
-rw-r--r-- | signerverifier/rsa.go | 161 |
1 files changed, 161 insertions, 0 deletions
diff --git a/signerverifier/rsa.go b/signerverifier/rsa.go new file mode 100644 index 0000000..b039659 --- /dev/null +++ b/signerverifier/rsa.go @@ -0,0 +1,161 @@ +package signerverifier + +import ( + "context" + "crypto" + "crypto/rand" + "crypto/rsa" + "crypto/sha256" + "crypto/x509" + "fmt" + "os" + "strings" +) + +const ( + RSAKeyType = "rsa" + RSAKeyScheme = "rsassa-pss-sha256" + RSAPrivateKeyPEM = "RSA PRIVATE KEY" +) + +// RSAPSSSignerVerifier is a dsse.SignerVerifier compliant interface to sign and +// verify signatures using RSA keys following the RSA-PSS scheme. +type RSAPSSSignerVerifier struct { + keyID string + private *rsa.PrivateKey + public *rsa.PublicKey +} + +// NewRSAPSSSignerVerifierFromSSLibKey creates an RSAPSSSignerVerifier from an +// SSLibKey. +func NewRSAPSSSignerVerifierFromSSLibKey(key *SSLibKey) (*RSAPSSSignerVerifier, error) { + if len(key.KeyVal.Public) == 0 { + return nil, ErrInvalidKey + } + + _, publicParsedKey, err := decodeAndParsePEM([]byte(key.KeyVal.Public)) + if err != nil { + return nil, fmt.Errorf("unable to create RSA-PSS signerverifier: %w", err) + } + + if len(key.KeyVal.Private) > 0 { + _, privateParsedKey, err := decodeAndParsePEM([]byte(key.KeyVal.Private)) + if err != nil { + return nil, fmt.Errorf("unable to create RSA-PSS signerverifier: %w", err) + } + + return &RSAPSSSignerVerifier{ + keyID: key.KeyID, + public: publicParsedKey.(*rsa.PublicKey), + private: privateParsedKey.(*rsa.PrivateKey), + }, nil + } + + return &RSAPSSSignerVerifier{ + keyID: key.KeyID, + public: publicParsedKey.(*rsa.PublicKey), + private: nil, + }, nil +} + +// Sign creates a signature for `data`. +func (sv *RSAPSSSignerVerifier) Sign(ctx context.Context, data []byte) ([]byte, error) { + if sv.private == nil { + return nil, ErrNotPrivateKey + } + + hashedData := hashBeforeSigning(data, sha256.New()) + + return rsa.SignPSS(rand.Reader, sv.private, crypto.SHA256, hashedData, &rsa.PSSOptions{SaltLength: sha256.Size, Hash: crypto.SHA256}) +} + +// Verify verifies the `sig` value passed in against `data`. +func (sv *RSAPSSSignerVerifier) Verify(ctx context.Context, data []byte, sig []byte) error { + hashedData := hashBeforeSigning(data, sha256.New()) + + if err := rsa.VerifyPSS(sv.public, crypto.SHA256, hashedData, sig, &rsa.PSSOptions{SaltLength: sha256.Size, Hash: crypto.SHA256}); err != nil { + return ErrSignatureVerificationFailed + } + + return nil +} + +// KeyID returns the identifier of the key used to create the +// RSAPSSSignerVerifier instance. +func (sv *RSAPSSSignerVerifier) KeyID() (string, error) { + return sv.keyID, nil +} + +// Public returns the public portion of the key used to create the +// RSAPSSSignerVerifier instance. +func (sv *RSAPSSSignerVerifier) Public() crypto.PublicKey { + return sv.public +} + +// LoadRSAPSSKeyFromFile returns an SSLibKey instance for an RSA key stored in a +// file. +func LoadRSAPSSKeyFromFile(path string) (*SSLibKey, error) { + contents, err := os.ReadFile(path) + if err != nil { + return nil, fmt.Errorf("unable to load RSA key from file: %w", err) + } + + return LoadRSAPSSKeyFromBytes(contents) +} + +// LoadRSAPSSKeyFromBytes is a function that takes a byte array as input. This byte array should represent a PEM encoded RSA key, as PEM encoding is required. +// The function returns an SSLibKey instance, which is a struct that holds the key data. + +func LoadRSAPSSKeyFromBytes(contents []byte) (*SSLibKey, error) { + pemData, keyObj, err := decodeAndParsePEM(contents) + if err != nil { + return nil, fmt.Errorf("unable to load RSA key from file: %w", err) + } + + key := &SSLibKey{ + KeyType: RSAKeyType, + Scheme: RSAKeyScheme, + KeyIDHashAlgorithms: KeyIDHashAlgorithms, + KeyVal: KeyVal{}, + } + + pubKeyBytes, err := marshalAndGeneratePEM(keyObj) + if err != nil { + return nil, fmt.Errorf("unable to load RSA key from file: %w", err) + } + key.KeyVal.Public = strings.TrimSpace(string(pubKeyBytes)) + + if _, ok := keyObj.(*rsa.PrivateKey); ok { + key.KeyVal.Private = strings.TrimSpace(string(generatePEMBlock(pemData.Bytes, RSAPrivateKeyPEM))) + } + + if len(key.KeyID) == 0 { + keyID, err := calculateKeyID(key) + if err != nil { + return nil, fmt.Errorf("unable to load RSA key from file: %w", err) + } + key.KeyID = keyID + } + + return key, nil +} + +func marshalAndGeneratePEM(key interface{}) ([]byte, error) { + var pubKeyBytes []byte + var err error + + switch k := key.(type) { + case *rsa.PublicKey: + pubKeyBytes, err = x509.MarshalPKIXPublicKey(k) + case *rsa.PrivateKey: + pubKeyBytes, err = x509.MarshalPKIXPublicKey(k.Public()) + default: + return nil, fmt.Errorf("unexpected key type: %T", k) + } + + if err != nil { + return nil, err + } + + return generatePEMBlock(pubKeyBytes, PublicKeyPEM), nil +} |