summaryrefslogtreecommitdiffstats
path: root/crypto/lskcipher.c
diff options
context:
space:
mode:
Diffstat (limited to 'crypto/lskcipher.c')
-rw-r--r--crypto/lskcipher.c41
1 files changed, 32 insertions, 9 deletions
diff --git a/crypto/lskcipher.c b/crypto/lskcipher.c
index 9edc897309..0f1bd7dcde 100644
--- a/crypto/lskcipher.c
+++ b/crypto/lskcipher.c
@@ -88,8 +88,9 @@ EXPORT_SYMBOL_GPL(crypto_lskcipher_setkey);
static int crypto_lskcipher_crypt_unaligned(
struct crypto_lskcipher *tfm, const u8 *src, u8 *dst, unsigned len,
u8 *iv, int (*crypt)(struct crypto_lskcipher *tfm, const u8 *src,
- u8 *dst, unsigned len, u8 *iv, bool final))
+ u8 *dst, unsigned len, u8 *iv, u32 flags))
{
+ unsigned statesize = crypto_lskcipher_statesize(tfm);
unsigned ivsize = crypto_lskcipher_ivsize(tfm);
unsigned bs = crypto_lskcipher_blocksize(tfm);
unsigned cs = crypto_lskcipher_chunksize(tfm);
@@ -104,7 +105,7 @@ static int crypto_lskcipher_crypt_unaligned(
if (!tiv)
return -ENOMEM;
- memcpy(tiv, iv, ivsize);
+ memcpy(tiv, iv, ivsize + statesize);
p = kmalloc(PAGE_SIZE, GFP_ATOMIC);
err = -ENOMEM;
@@ -119,7 +120,7 @@ static int crypto_lskcipher_crypt_unaligned(
chunk &= ~(cs - 1);
memcpy(p, src, chunk);
- err = crypt(tfm, p, p, chunk, tiv, true);
+ err = crypt(tfm, p, p, chunk, tiv, CRYPTO_LSKCIPHER_FLAG_FINAL);
if (err)
goto out;
@@ -132,7 +133,7 @@ static int crypto_lskcipher_crypt_unaligned(
err = len ? -EINVAL : 0;
out:
- memcpy(iv, tiv, ivsize);
+ memcpy(iv, tiv, ivsize + statesize);
kfree_sensitive(p);
kfree_sensitive(tiv);
return err;
@@ -143,7 +144,7 @@ static int crypto_lskcipher_crypt(struct crypto_lskcipher *tfm, const u8 *src,
int (*crypt)(struct crypto_lskcipher *tfm,
const u8 *src, u8 *dst,
unsigned len, u8 *iv,
- bool final))
+ u32 flags))
{
unsigned long alignmask = crypto_lskcipher_alignmask(tfm);
struct lskcipher_alg *alg = crypto_lskcipher_alg(tfm);
@@ -156,7 +157,7 @@ static int crypto_lskcipher_crypt(struct crypto_lskcipher *tfm, const u8 *src,
goto out;
}
- ret = crypt(tfm, src, dst, len, iv, true);
+ ret = crypt(tfm, src, dst, len, iv, CRYPTO_LSKCIPHER_FLAG_FINAL);
out:
return crypto_lskcipher_errstat(alg, ret);
@@ -197,23 +198,43 @@ EXPORT_SYMBOL_GPL(crypto_lskcipher_decrypt);
static int crypto_lskcipher_crypt_sg(struct skcipher_request *req,
int (*crypt)(struct crypto_lskcipher *tfm,
const u8 *src, u8 *dst,
- unsigned len, u8 *iv,
- bool final))
+ unsigned len, u8 *ivs,
+ u32 flags))
{
struct crypto_skcipher *skcipher = crypto_skcipher_reqtfm(req);
struct crypto_lskcipher **ctx = crypto_skcipher_ctx(skcipher);
+ u8 *ivs = skcipher_request_ctx(req);
struct crypto_lskcipher *tfm = *ctx;
struct skcipher_walk walk;
+ unsigned ivsize;
+ u32 flags;
int err;
+ ivsize = crypto_lskcipher_ivsize(tfm);
+ ivs = PTR_ALIGN(ivs, crypto_skcipher_alignmask(skcipher) + 1);
+ memcpy(ivs, req->iv, ivsize);
+
+ flags = req->base.flags & CRYPTO_TFM_REQ_MAY_SLEEP;
+
+ if (req->base.flags & CRYPTO_SKCIPHER_REQ_CONT)
+ flags |= CRYPTO_LSKCIPHER_FLAG_CONT;
+
+ if (!(req->base.flags & CRYPTO_SKCIPHER_REQ_NOTFINAL))
+ flags |= CRYPTO_LSKCIPHER_FLAG_FINAL;
+
err = skcipher_walk_virt(&walk, req, false);
while (walk.nbytes) {
err = crypt(tfm, walk.src.virt.addr, walk.dst.virt.addr,
- walk.nbytes, walk.iv, walk.nbytes == walk.total);
+ walk.nbytes, ivs,
+ flags & ~(walk.nbytes == walk.total ?
+ 0 : CRYPTO_LSKCIPHER_FLAG_FINAL));
err = skcipher_walk_done(&walk, err);
+ flags |= CRYPTO_LSKCIPHER_FLAG_CONT;
}
+ memcpy(req->iv, ivs, ivsize);
+
return err;
}
@@ -276,6 +297,7 @@ static void __maybe_unused crypto_lskcipher_show(
seq_printf(m, "max keysize : %u\n", skcipher->co.max_keysize);
seq_printf(m, "ivsize : %u\n", skcipher->co.ivsize);
seq_printf(m, "chunksize : %u\n", skcipher->co.chunksize);
+ seq_printf(m, "statesize : %u\n", skcipher->co.statesize);
}
static int __maybe_unused crypto_lskcipher_report(
@@ -618,6 +640,7 @@ struct lskcipher_instance *lskcipher_alloc_instance_simple(
inst->alg.co.min_keysize = cipher_alg->co.min_keysize;
inst->alg.co.max_keysize = cipher_alg->co.max_keysize;
inst->alg.co.ivsize = cipher_alg->co.base.cra_blocksize;
+ inst->alg.co.statesize = cipher_alg->co.statesize;
/* Use struct crypto_lskcipher * by default, can be overridden */
inst->alg.co.base.cra_ctxsize = sizeof(struct crypto_lskcipher *);