summaryrefslogtreecommitdiffstats
path: root/src/lua/lua_compress.c
diff options
context:
space:
mode:
Diffstat (limited to 'src/lua/lua_compress.c')
-rw-r--r--src/lua/lua_compress.c622
1 files changed, 622 insertions, 0 deletions
diff --git a/src/lua/lua_compress.c b/src/lua/lua_compress.c
new file mode 100644
index 0000000..77c82c5
--- /dev/null
+++ b/src/lua/lua_compress.c
@@ -0,0 +1,622 @@
+/*-
+ * Copyright 2021 Vsevolod Stakhov
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "lua_common.h"
+#include "unix-std.h"
+#include <zlib.h>
+
+#ifdef SYS_ZSTD
+#include "zstd.h"
+#include "zstd_errors.h"
+#else
+#include "contrib/zstd/zstd.h"
+#include "contrib/zstd/error_public.h"
+#endif
+
+/***
+ * @module rspamd_compress
+ * This module contains compression/decompression routines (zstd and zlib currently)
+ */
+
+/***
+ * @function zstd.compress_ctx()
+ * Creates new compression ctx
+ * @return {compress_ctx} new compress ctx
+ */
+LUA_FUNCTION_DEF(zstd, compress_ctx);
+
+/***
+ * @function zstd.compress_ctx()
+ * Creates new compression ctx
+ * @return {compress_ctx} new compress ctx
+ */
+LUA_FUNCTION_DEF(zstd, decompress_ctx);
+
+LUA_FUNCTION_DEF(zstd_compress, stream);
+LUA_FUNCTION_DEF(zstd_compress, dtor);
+
+LUA_FUNCTION_DEF(zstd_decompress, stream);
+LUA_FUNCTION_DEF(zstd_decompress, dtor);
+
+static const struct luaL_reg zstd_compress_lib_f[] = {
+ LUA_INTERFACE_DEF(zstd, compress_ctx),
+ LUA_INTERFACE_DEF(zstd, decompress_ctx),
+ {NULL, NULL}};
+
+static const struct luaL_reg zstd_compress_lib_m[] = {
+ LUA_INTERFACE_DEF(zstd_compress, stream),
+ {"__gc", lua_zstd_compress_dtor},
+ {NULL, NULL}};
+
+static const struct luaL_reg zstd_decompress_lib_m[] = {
+ LUA_INTERFACE_DEF(zstd_decompress, stream),
+ {"__gc", lua_zstd_decompress_dtor},
+ {NULL, NULL}};
+
+static ZSTD_CStream *
+lua_check_zstd_compress_ctx(lua_State *L, gint pos)
+{
+ void *ud = rspamd_lua_check_udata(L, pos, "rspamd{zstd_compress}");
+ luaL_argcheck(L, ud != NULL, pos, "'zstd_compress' expected");
+ return ud ? *(ZSTD_CStream **) ud : NULL;
+}
+
+static ZSTD_DStream *
+lua_check_zstd_decompress_ctx(lua_State *L, gint pos)
+{
+ void *ud = rspamd_lua_check_udata(L, pos, "rspamd{zstd_decompress}");
+ luaL_argcheck(L, ud != NULL, pos, "'zstd_decompress' expected");
+ return ud ? *(ZSTD_DStream **) ud : NULL;
+}
+
+int lua_zstd_push_error(lua_State *L, int err)
+{
+ lua_pushnil(L);
+ lua_pushfstring(L, "zstd error %d (%s)", err, ZSTD_getErrorString(err));
+
+ return 2;
+}
+
+gint lua_compress_zstd_compress(lua_State *L)
+{
+ LUA_TRACE_POINT;
+ struct rspamd_lua_text *t = NULL, *res;
+ gsize sz, r;
+ gint comp_level = 1;
+
+ t = lua_check_text_or_string(L, 1);
+
+ if (t == NULL || t->start == NULL) {
+ return luaL_error(L, "invalid arguments");
+ }
+
+ if (lua_type(L, 2) == LUA_TNUMBER) {
+ comp_level = lua_tointeger(L, 2);
+ }
+
+ sz = ZSTD_compressBound(t->len);
+
+ if (ZSTD_isError(sz)) {
+ msg_err("cannot compress data: %s", ZSTD_getErrorName(sz));
+ lua_pushnil(L);
+
+ return 1;
+ }
+
+ res = lua_newuserdata(L, sizeof(*res));
+ res->start = g_malloc(sz);
+ res->flags = RSPAMD_TEXT_FLAG_OWN;
+ rspamd_lua_setclass(L, "rspamd{text}", -1);
+ r = ZSTD_compress((void *) res->start, sz, t->start, t->len, comp_level);
+
+ if (ZSTD_isError(r)) {
+ msg_err("cannot compress data: %s", ZSTD_getErrorName(r));
+ lua_pop(L, 1); /* Text will be freed here */
+ lua_pushnil(L);
+
+ return 1;
+ }
+
+ res->len = r;
+
+ return 1;
+}
+
+gint lua_compress_zstd_decompress(lua_State *L)
+{
+ LUA_TRACE_POINT;
+ struct rspamd_lua_text *t = NULL, *res;
+ gsize outlen, r;
+ ZSTD_DStream *zstream;
+ ZSTD_inBuffer zin;
+ ZSTD_outBuffer zout;
+ gchar *out;
+
+ t = lua_check_text_or_string(L, 1);
+
+ if (t == NULL || t->start == NULL) {
+ return luaL_error(L, "invalid arguments");
+ }
+
+ zstream = ZSTD_createDStream();
+ ZSTD_initDStream(zstream);
+
+ zin.pos = 0;
+ zin.src = t->start;
+ zin.size = t->len;
+
+ if ((outlen = ZSTD_getDecompressedSize(zin.src, zin.size)) == 0) {
+ outlen = ZSTD_DStreamOutSize();
+ }
+
+ out = g_malloc(outlen);
+
+ zout.dst = out;
+ zout.pos = 0;
+ zout.size = outlen;
+
+ while (zin.pos < zin.size) {
+ r = ZSTD_decompressStream(zstream, &zout, &zin);
+
+ if (ZSTD_isError(r)) {
+ msg_err("cannot decompress data: %s", ZSTD_getErrorName(r));
+ ZSTD_freeDStream(zstream);
+ g_free(out);
+ lua_pushstring(L, ZSTD_getErrorName(r));
+ lua_pushnil(L);
+
+ return 2;
+ }
+
+ if (zin.pos < zin.size && zout.pos == zout.size) {
+ /* We need to extend output buffer */
+ zout.size = zout.size * 2;
+ out = g_realloc(zout.dst, zout.size);
+ zout.dst = out;
+ }
+ }
+
+ ZSTD_freeDStream(zstream);
+ lua_pushnil(L); /* Error */
+ res = lua_newuserdata(L, sizeof(*res));
+ res->start = out;
+ res->flags = RSPAMD_TEXT_FLAG_OWN;
+ rspamd_lua_setclass(L, "rspamd{text}", -1);
+ res->len = zout.pos;
+
+ return 2;
+}
+
+gint lua_compress_zlib_decompress(lua_State *L, bool is_gzip)
+{
+ LUA_TRACE_POINT;
+ struct rspamd_lua_text *t = NULL, *res;
+ gsize sz;
+ z_stream strm;
+ gint rc;
+ guchar *p;
+ gsize remain;
+ gssize size_limit = -1;
+
+ int windowBits = is_gzip ? (MAX_WBITS + 16) : (MAX_WBITS);
+
+ t = lua_check_text_or_string(L, 1);
+
+ if (t == NULL || t->start == NULL) {
+ return luaL_error(L, "invalid arguments");
+ }
+
+ if (lua_type(L, 2) == LUA_TNUMBER) {
+ size_limit = lua_tointeger(L, 2);
+ if (size_limit <= 0) {
+ return luaL_error(L, "invalid arguments (size_limit)");
+ }
+
+ sz = MIN(t->len * 2, size_limit);
+ }
+ else {
+ sz = t->len * 2;
+ }
+
+ memset(&strm, 0, sizeof(strm));
+ /* windowBits +16 to decode gzip, zlib 1.2.0.4+ */
+
+ /* Here are dragons to distinguish between raw deflate and zlib */
+ if (windowBits == MAX_WBITS && t->len > 0) {
+ if ((int) (unsigned char) ((t->start[0] << 4)) != 0x80) {
+ /* Assume raw deflate */
+ windowBits = -windowBits;
+ }
+ }
+
+ rc = inflateInit2(&strm, windowBits);
+
+ if (rc != Z_OK) {
+ return luaL_error(L, "cannot init zlib");
+ }
+
+ strm.avail_in = t->len;
+ strm.next_in = (guchar *) t->start;
+
+ res = lua_newuserdata(L, sizeof(*res));
+ res->start = g_malloc(sz);
+ res->flags = RSPAMD_TEXT_FLAG_OWN;
+ rspamd_lua_setclass(L, "rspamd{text}", -1);
+
+ p = (guchar *) res->start;
+ remain = sz;
+
+ while (strm.avail_in != 0) {
+ strm.avail_out = remain;
+ strm.next_out = p;
+
+ rc = inflate(&strm, Z_NO_FLUSH);
+
+ if (rc != Z_OK && rc != Z_BUF_ERROR) {
+ if (rc == Z_STREAM_END) {
+ break;
+ }
+ else {
+ msg_err("cannot decompress data: %s (last error: %s)",
+ zError(rc), strm.msg);
+ lua_pop(L, 1); /* Text will be freed here */
+ lua_pushnil(L);
+ inflateEnd(&strm);
+
+ return 1;
+ }
+ }
+
+ res->len = strm.total_out;
+
+ if (strm.avail_out == 0 && strm.avail_in != 0) {
+
+ if (size_limit > 0 || res->len >= G_MAXUINT32 / 2) {
+ if (res->len > size_limit || res->len >= G_MAXUINT32 / 2) {
+ lua_pop(L, 1); /* Text will be freed here */
+ lua_pushnil(L);
+ inflateEnd(&strm);
+
+ return 1;
+ }
+ }
+
+ /* Need to allocate more */
+ remain = res->len;
+ res->start = g_realloc((gpointer) res->start, res->len * 2);
+ sz = res->len * 2;
+ p = (guchar *) res->start + remain;
+ remain = sz - remain;
+ }
+ }
+
+ inflateEnd(&strm);
+ res->len = strm.total_out;
+
+ return 1;
+}
+
+gint lua_compress_zlib_compress(lua_State *L)
+{
+ LUA_TRACE_POINT;
+ struct rspamd_lua_text *t = NULL, *res;
+ gsize sz;
+ z_stream strm;
+ gint rc, comp_level = Z_DEFAULT_COMPRESSION;
+ guchar *p;
+ gsize remain;
+
+ t = lua_check_text_or_string(L, 1);
+
+ if (t == NULL || t->start == NULL) {
+ return luaL_error(L, "invalid arguments");
+ }
+
+ if (lua_isnumber(L, 2)) {
+ comp_level = lua_tointeger(L, 2);
+
+ if (comp_level > Z_BEST_COMPRESSION || comp_level < Z_BEST_SPEED) {
+ return luaL_error(L, "invalid arguments: compression level must be between %d and %d",
+ Z_BEST_SPEED, Z_BEST_COMPRESSION);
+ }
+ }
+
+
+ memset(&strm, 0, sizeof(strm));
+ rc = deflateInit2(&strm, comp_level, Z_DEFLATED,
+ MAX_WBITS + 16, MAX_MEM_LEVEL - 1, Z_DEFAULT_STRATEGY);
+
+ if (rc != Z_OK) {
+ return luaL_error(L, "cannot init zlib: %s", zError(rc));
+ }
+
+ sz = deflateBound(&strm, t->len);
+
+ strm.avail_in = t->len;
+ strm.next_in = (guchar *) t->start;
+
+ res = lua_newuserdata(L, sizeof(*res));
+ res->start = g_malloc(sz);
+ res->flags = RSPAMD_TEXT_FLAG_OWN;
+ rspamd_lua_setclass(L, "rspamd{text}", -1);
+
+ p = (guchar *) res->start;
+ remain = sz;
+
+ while (strm.avail_in != 0) {
+ strm.avail_out = remain;
+ strm.next_out = p;
+
+ rc = deflate(&strm, Z_FINISH);
+
+ if (rc != Z_OK && rc != Z_BUF_ERROR) {
+ if (rc == Z_STREAM_END) {
+ break;
+ }
+ else {
+ msg_err("cannot compress data: %s (last error: %s)",
+ zError(rc), strm.msg);
+ lua_pop(L, 1); /* Text will be freed here */
+ lua_pushnil(L);
+ deflateEnd(&strm);
+
+ return 1;
+ }
+ }
+
+ res->len = strm.total_out;
+
+ if (strm.avail_out == 0 && strm.avail_in != 0) {
+ /* Need to allocate more */
+ remain = res->len;
+ res->start = g_realloc((gpointer) res->start, strm.avail_in + sz);
+ sz = strm.avail_in + sz;
+ p = (guchar *) res->start + remain;
+ remain = sz - remain;
+ }
+ }
+
+ deflateEnd(&strm);
+ res->len = strm.total_out;
+
+ return 1;
+}
+
+/* Stream API interface for Zstd: both compression and decompression */
+
+/* Operations allowed by zstd stream methods */
+static const char *const zstd_stream_op[] = {
+ "continue",
+ "flush",
+ "end",
+ NULL};
+
+static gint
+lua_zstd_compress_ctx(lua_State *L)
+{
+ ZSTD_CCtx *ctx, **pctx;
+
+ pctx = lua_newuserdata(L, sizeof(*pctx));
+ ctx = ZSTD_createCCtx();
+
+ if (!ctx) {
+ return luaL_error(L, "context create failed");
+ }
+
+ *pctx = ctx;
+ rspamd_lua_setclass(L, "rspamd{zstd_compress}", -1);
+ return 1;
+}
+
+static gint
+lua_zstd_compress_dtor(lua_State *L)
+{
+ ZSTD_CCtx *ctx = lua_check_zstd_compress_ctx(L, 1);
+
+ if (ctx) {
+ ZSTD_freeCCtx(ctx);
+ }
+
+ return 0;
+}
+
+static gint
+lua_zstd_compress_reset(lua_State *L)
+{
+ ZSTD_CCtx *ctx = lua_check_zstd_compress_ctx(L, 1);
+
+ if (ctx) {
+ ZSTD_CCtx_reset(ctx, ZSTD_reset_session_and_parameters);
+ }
+ else {
+ return luaL_error(L, "invalid arguments");
+ }
+
+ return 0;
+}
+
+static gint
+lua_zstd_compress_stream(lua_State *L)
+{
+ ZSTD_CStream *ctx = lua_check_zstd_compress_ctx(L, 1);
+ struct rspamd_lua_text *t = lua_check_text_or_string(L, 2);
+ int op = luaL_checkoption(L, 3, zstd_stream_op[0], zstd_stream_op);
+ int err = 0;
+ ZSTD_inBuffer inb;
+ ZSTD_outBuffer onb;
+
+ if (ctx && t) {
+ gsize dlen = 0;
+
+ inb.size = t->len;
+ inb.pos = 0;
+ inb.src = (const void *) t->start;
+
+ onb.pos = 0;
+ onb.size = ZSTD_CStreamInSize(); /* Initial guess */
+ onb.dst = NULL;
+
+ for (;;) {
+ if ((onb.dst = g_realloc(onb.dst, onb.size)) == NULL) {
+ return lua_zstd_push_error(L, ZSTD_error_memory_allocation);
+ }
+
+ dlen = onb.size;
+
+ int res = ZSTD_compressStream2(ctx, &onb, &inb, op);
+
+ if (res == 0) {
+ /* All done */
+ break;
+ }
+
+ if ((err = ZSTD_getErrorCode(res))) {
+ break;
+ }
+
+ onb.size *= 2;
+ res += dlen; /* Hint returned by compression routine */
+
+ /* Either double the buffer, or use the hint provided */
+ if (onb.size < res) {
+ onb.size = res;
+ }
+ }
+ }
+ else {
+ return luaL_error(L, "invalid arguments");
+ }
+
+ if (err) {
+ return lua_zstd_push_error(L, err);
+ }
+
+ lua_new_text(L, onb.dst, onb.pos, TRUE);
+
+ return 1;
+}
+
+static gint
+lua_zstd_decompress_dtor(lua_State *L)
+{
+ ZSTD_DStream *ctx = lua_check_zstd_decompress_ctx(L, 1);
+
+ if (ctx) {
+ ZSTD_freeDStream(ctx);
+ }
+
+ return 0;
+}
+
+
+static gint
+lua_zstd_decompress_ctx(lua_State *L)
+{
+ ZSTD_DStream *ctx, **pctx;
+
+ pctx = lua_newuserdata(L, sizeof(*pctx));
+ ctx = ZSTD_createDStream();
+
+ if (!ctx) {
+ return luaL_error(L, "context create failed");
+ }
+
+ *pctx = ctx;
+ rspamd_lua_setclass(L, "rspamd{zstd_decompress}", -1);
+ return 1;
+}
+
+static gint
+lua_zstd_decompress_stream(lua_State *L)
+{
+ ZSTD_DStream *ctx = lua_check_zstd_decompress_ctx(L, 1);
+ struct rspamd_lua_text *t = lua_check_text_or_string(L, 2);
+ int err = 0;
+ ZSTD_inBuffer inb;
+ ZSTD_outBuffer onb;
+
+ if (ctx && t) {
+ gsize dlen = 0;
+
+ if (t->len == 0) {
+ return lua_zstd_push_error(L, ZSTD_error_init_missing);
+ }
+
+ inb.size = t->len;
+ inb.pos = 0;
+ inb.src = (const void *) t->start;
+
+ onb.pos = 0;
+ onb.size = ZSTD_DStreamInSize(); /* Initial guess */
+ onb.dst = NULL;
+
+ for (;;) {
+ if ((onb.dst = g_realloc(onb.dst, onb.size)) == NULL) {
+ return lua_zstd_push_error(L, ZSTD_error_memory_allocation);
+ }
+
+ dlen = onb.size;
+
+ int res = ZSTD_decompressStream(ctx, &onb, &inb);
+
+ if (res == 0) {
+ /* All done */
+ break;
+ }
+
+ if ((err = ZSTD_getErrorCode(res))) {
+ break;
+ }
+
+ onb.size *= 2;
+ res += dlen; /* Hint returned by compression routine */
+
+ /* Either double the buffer, or use the hint provided */
+ if (onb.size < res) {
+ onb.size = res;
+ }
+ }
+ }
+ else {
+ return luaL_error(L, "invalid arguments");
+ }
+
+ if (err) {
+ return lua_zstd_push_error(L, err);
+ }
+
+ lua_new_text(L, onb.dst, onb.pos, TRUE);
+
+ return 1;
+}
+
+static gint
+lua_load_zstd(lua_State *L)
+{
+ lua_newtable(L);
+ luaL_register(L, NULL, zstd_compress_lib_f);
+
+ return 1;
+}
+
+void luaopen_compress(lua_State *L)
+{
+ rspamd_lua_new_class(L, "rspamd{zstd_compress}", zstd_compress_lib_m);
+ rspamd_lua_new_class(L, "rspamd{zstd_decompress}", zstd_decompress_lib_m);
+ lua_pop(L, 2);
+
+ rspamd_lua_add_preload(L, "rspamd_zstd", lua_load_zstd);
+}