diff options
Diffstat (limited to 'src/lua/lua_tcp.c')
-rw-r--r-- | src/lua/lua_tcp.c | 2566 |
1 files changed, 2566 insertions, 0 deletions
diff --git a/src/lua/lua_tcp.c b/src/lua/lua_tcp.c new file mode 100644 index 0000000..45faa79 --- /dev/null +++ b/src/lua/lua_tcp.c @@ -0,0 +1,2566 @@ +/*- + * Copyright 2016 Vsevolod Stakhov + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "lua_common.h" +#include "lua_thread_pool.h" +#include "libserver/ssl_util.h" +#include "utlist.h" +#include "unix-std.h" +#include <math.h> + +static const gchar *M = "rspamd lua tcp"; + +/*** + * @module rspamd_tcp + * Rspamd TCP module represents generic TCP asynchronous client available from LUA code. + * This module hides all complexity: DNS resolving, sessions management, zero-copy + * text transfers and so on under the hood. It can work in partial or complete modes: + * + * - partial mode is used when you need to call a continuation routine each time data is available for read + * - complete mode calls for continuation merely when all data is read from socket (e.g. when a server sends reply and closes a connection) + * @example +local logger = require "rspamd_logger" +local tcp = require "rspamd_tcp" + +rspamd_config.SYM = function(task) + + local function cb(err, data) + logger.infox('err: %1, data: %2', err, tostring(data)) + end + + tcp.request({ + task = task, + host = "google.com", + port = 80, + data = {"GET / HTTP/1.0\r\n", "Host: google.com\r\n", "\r\n"}, + callback = cb}) +end + +-- New TCP syntax test +rspamd_config:register_symbol({ + name = 'TCP_TEST', + type = "normal", + callback = function(task) + local logger = require "rspamd_logger" + local function rcpt_done_cb(err, data, conn) + logger.errx(task, 'RCPT: got reply: %s, error: %s', data, err) + conn:close() + end + local function rcpt_cb(err, conn) + logger.errx(task, 'written rcpt, error: %s', err) + conn:add_read(rcpt_done_cb, '\r\n') + end + local function from_done_cb(err, data, conn) + logger.errx(task, 'FROM: got reply: %s, error: %s', data, err) + conn:add_write(rcpt_cb, 'RCPT TO: <test@yandex.ru>\r\n') + end + local function from_cb(err, conn) + logger.errx(task, 'written from, error: %s', err) + conn:add_read(from_done_cb, '\r\n') + end + local function hello_done_cb(err, data, conn) + logger.errx(task, 'HELO: got reply: %s, error: %s', data, err) + conn:add_write(from_cb, 'MAIL FROM: <>\r\n') + end + local function hello_cb(err, conn) + logger.errx(task, 'written hello, error: %s', err) + conn:add_read(hello_done_cb, '\r\n') + end + local function init_cb(err, data, conn) + logger.errx(task, 'got reply: %s, error: %s', data, err) + conn:add_write(hello_cb, 'HELO example.com\r\n') + end + tcp.request{ + task = task, + callback = init_cb, + stop_pattern = '\r\n', + host = 'mx.yandex.ru', + port = 25 + } + end, + priority = 10, +}) + */ + +LUA_FUNCTION_DEF(tcp, request); + +/*** + * @function rspamd_tcp.connect_sync() + * + * Creates pseudo-synchronous TCP connection. + * Each method of the connection requiring IO, becomes a yielding point, + * i.e. current thread Lua thread is get suspended and resumes as soon as IO is done + * + * This class represents low-level API, using of "lua_tcp_sync" module is recommended. + * + * @example + +local rspamd_tcp = require "rspamd_tcp" +local logger = require "rspamd_logger" + +local function http_simple_tcp_symbol(task) + + local err + local is_ok, connection = rspamd_tcp.connect_sync { + task = task, + host = '127.0.0.1', + timeout = 20, + port = 18080, + ssl = false, -- If SSL connection is needed + ssl_verify = true, -- set to false if verify is not needed + } + + is_ok, err = connection:write('GET /request_sync HTTP/1.1\r\nConnection: keep-alive\r\n\r\n') + + logger.errx(task, 'write %1, %2', is_ok, err) + if not is_ok then + logger.errx(task, 'write error: %1', err) + end + + local data + is_ok, data = connection:read_once(); + + logger.errx(task, 'read_once: is_ok: %1, data: %2', is_ok, data) + + is_ok, err = connection:write("POST /request2 HTTP/1.1\r\n\r\n") + logger.errx(task, 'write[2] %1, %2', is_ok, err) + + is_ok, data = connection:read_once(); + logger.errx(task, 'read_once[2]: is_ok %1, data: %2', is_ok, data) + + connection:close() +end + +rspamd_config:register_symbol({ + name = 'SIMPLE_TCP_TEST', + score = 1.0, + callback = http_simple_tcp_symbol, + no_squeeze = true +}) + * + */ +LUA_FUNCTION_DEF(tcp, connect_sync); + +/*** + * @method tcp:close() + * + * Closes TCP connection + */ +LUA_FUNCTION_DEF(tcp, close); + +/*** + * @method tcp:add_read(callback, [pattern]) + * + * Adds new read event to the tcp connection + * @param {function} callback to be called when data is read + * @param {string} pattern optional stop pattern + */ +LUA_FUNCTION_DEF(tcp, add_read); + +/*** + * @method tcp:add_write(callback, data) + * + * Adds new write event to the tcp connection + * @param {function} optional callback to be called when data is completely written + * @param {table/string/text} data to send to a remote server + */ +LUA_FUNCTION_DEF(tcp, add_write); + +/*** + * @method tcp:shift_callback() + * + * Shifts the current callback and go to the next one (if any) + */ +LUA_FUNCTION_DEF(tcp, shift_callback); + +/*** + * @method tcp:starttls([no_verify]) + * + * Starts tls connection + * @param {boolean} no_verify used to skip ssl verification + */ +LUA_FUNCTION_DEF(tcp, starttls); + +static const struct luaL_reg tcp_libf[] = { + LUA_INTERFACE_DEF(tcp, request), + {"new", lua_tcp_request}, + {"connect", lua_tcp_request}, + {"connect_sync", lua_tcp_connect_sync}, + {NULL, NULL}}; + +static const struct luaL_reg tcp_libm[] = { + LUA_INTERFACE_DEF(tcp, close), + LUA_INTERFACE_DEF(tcp, add_read), + LUA_INTERFACE_DEF(tcp, add_write), + LUA_INTERFACE_DEF(tcp, shift_callback), + LUA_INTERFACE_DEF(tcp, starttls), + {"__tostring", rspamd_lua_class_tostring}, + {NULL, NULL}}; + +/*** + * @method tcp:close() + * + * Closes TCP connection + */ +LUA_FUNCTION_DEF(tcp_sync, close); + +/*** + * @method read_once() + * + * Performs one read operation. If syscall returned with EAGAIN/EINT, + * restarts the operation, so it always returns either data or error. + */ +LUA_FUNCTION_DEF(tcp_sync, read_once); + +/*** + * @method eof() + * + * True if last IO operation ended with EOF, i.e. endpoint closed connection + */ +LUA_FUNCTION_DEF(tcp_sync, eof); + +/*** + * @method shutdown() + * + * Half-shutdown TCP connection + */ +LUA_FUNCTION_DEF(tcp_sync, shutdown); + +/*** + * @method write() + * + * Writes data into the stream. If syscall returned with EAGAIN/EINT + * restarts the operation. If performs write() until all the passed + * data is written completely. + */ +LUA_FUNCTION_DEF(tcp_sync, write); + +LUA_FUNCTION_DEF(tcp_sync, gc); + +static void lua_tcp_sync_session_dtor(gpointer ud); + +static const struct luaL_reg tcp_sync_libm[] = { + LUA_INTERFACE_DEF(tcp_sync, close), + LUA_INTERFACE_DEF(tcp_sync, read_once), + LUA_INTERFACE_DEF(tcp_sync, write), + LUA_INTERFACE_DEF(tcp_sync, eof), + LUA_INTERFACE_DEF(tcp_sync, shutdown), + {"__gc", lua_tcp_sync_gc}, + {"__tostring", rspamd_lua_class_tostring}, + {NULL, NULL}}; + +struct lua_tcp_read_handler { + gchar *stop_pattern; + guint plen; + gint cbref; +}; + +struct lua_tcp_write_handler { + struct iovec *iov; + guint iovlen; + gint cbref; + gsize pos; + gsize total_bytes; +}; + +enum lua_tcp_handler_type { + LUA_WANT_WRITE = 0, + LUA_WANT_READ, + LUA_WANT_CONNECT +}; + +struct lua_tcp_handler { + union { + struct lua_tcp_read_handler r; + struct lua_tcp_write_handler w; + } h; + enum lua_tcp_handler_type type; +}; + +struct lua_tcp_dtor { + rspamd_mempool_destruct_t dtor; + void *data; + struct lua_tcp_dtor *next; +}; + +#define LUA_TCP_FLAG_PARTIAL (1u << 0u) +#define LUA_TCP_FLAG_SHUTDOWN (1u << 2u) +#define LUA_TCP_FLAG_CONNECTED (1u << 3u) +#define LUA_TCP_FLAG_FINISHED (1u << 4u) +#define LUA_TCP_FLAG_SYNC (1u << 5u) +#define LUA_TCP_FLAG_RESOLVED (1u << 6u) +#define LUA_TCP_FLAG_SSL (1u << 7u) +#define LUA_TCP_FLAG_SSL_NOVERIFY (1u << 8u) + +#undef TCP_DEBUG_REFS +#ifdef TCP_DEBUG_REFS +#define TCP_RETAIN(x) \ + do { \ + msg_info("retain ref %p, refcount: %d", (x), (x)->ref.refcount); \ + REF_RETAIN(x); \ + } while (0) + +#define TCP_RELEASE(x) \ + do { \ + msg_info("release ref %p, refcount: %d", (x), (x)->ref.refcount); \ + REF_RELEASE(x); \ + } while (0) +#else +#define TCP_RETAIN(x) REF_RETAIN(x) +#define TCP_RELEASE(x) REF_RELEASE(x) +#endif + +struct lua_tcp_cbdata { + struct rspamd_async_session *session; + struct rspamd_async_event *async_ev; + struct ev_loop *event_loop; + rspamd_inet_addr_t *addr; + GByteArray *in; + GQueue *handlers; + gint fd; + gint connect_cb; + guint port; + guint flags; + gchar tag[7]; + struct rspamd_io_ev ev; + struct lua_tcp_dtor *dtors; + ref_entry_t ref; + struct rspamd_task *task; + struct rspamd_symcache_dynamic_item *item; + struct thread_entry *thread; + struct rspamd_config *cfg; + struct rspamd_ssl_connection *ssl_conn; + gchar *hostname; + struct upstream *up; + gboolean eof; +}; + +#define IS_SYNC(c) (((c)->flags & LUA_TCP_FLAG_SYNC) != 0) + +#define msg_debug_tcp(...) rspamd_conditional_debug_fast(NULL, cbd->addr, \ + rspamd_lua_tcp_log_id, "lua_tcp", cbd->tag, \ + G_STRFUNC, \ + __VA_ARGS__) + +INIT_LOG_MODULE(lua_tcp) + +static void lua_tcp_handler(int fd, short what, gpointer ud); +static void lua_tcp_plan_handler_event(struct lua_tcp_cbdata *cbd, + gboolean can_read, gboolean can_write); +static void lua_tcp_unregister_event(struct lua_tcp_cbdata *cbd); + +static void +lua_tcp_void_finalyser(gpointer arg) +{ +} + +static const gdouble default_tcp_timeout = 5.0; + +static struct rspamd_dns_resolver * +lua_tcp_global_resolver(struct ev_loop *ev_base, + struct rspamd_config *cfg) +{ + static struct rspamd_dns_resolver *global_resolver; + + if (cfg && cfg->dns_resolver) { + return cfg->dns_resolver; + } + + if (global_resolver == NULL) { + global_resolver = rspamd_dns_resolver_init(NULL, ev_base, cfg); + } + + return global_resolver; +} + +static gboolean +lua_tcp_shift_handler(struct lua_tcp_cbdata *cbd) +{ + struct lua_tcp_handler *hdl; + + hdl = g_queue_pop_head(cbd->handlers); + + if (hdl == NULL) { + /* We are done */ + return FALSE; + } + + if (hdl->type == LUA_WANT_READ) { + msg_debug_tcp("switch from read handler %d", hdl->h.r.cbref); + if (hdl->h.r.cbref && hdl->h.r.cbref != -1) { + luaL_unref(cbd->cfg->lua_state, LUA_REGISTRYINDEX, hdl->h.r.cbref); + } + + if (hdl->h.r.stop_pattern) { + g_free(hdl->h.r.stop_pattern); + } + } + else if (hdl->type == LUA_WANT_WRITE) { + msg_debug_tcp("switch from write handler %d", hdl->h.r.cbref); + if (hdl->h.w.cbref && hdl->h.w.cbref != -1) { + luaL_unref(cbd->cfg->lua_state, LUA_REGISTRYINDEX, hdl->h.w.cbref); + } + + if (hdl->h.w.iov) { + g_free(hdl->h.w.iov); + } + } + else { + msg_debug_tcp("removing connect handler"); + /* LUA_WANT_CONNECT: it doesn't allocate anything, nothing to do here */ + } + + g_free(hdl); + + return TRUE; +} + +static void +lua_tcp_fin(gpointer arg) +{ + struct lua_tcp_cbdata *cbd = (struct lua_tcp_cbdata *) arg; + struct lua_tcp_dtor *dtor, *dttmp; + + if (IS_SYNC(cbd) && cbd->task) { + /* + pointer is now becoming invalid, we should remove registered destructor, + all the necessary steps are done here + */ + rspamd_mempool_replace_destructor(cbd->task->task_pool, + lua_tcp_sync_session_dtor, cbd, NULL); + } + + msg_debug_tcp("finishing TCP %s connection", IS_SYNC(cbd) ? "sync" : "async"); + + if (cbd->connect_cb != -1) { + luaL_unref(cbd->cfg->lua_state, LUA_REGISTRYINDEX, cbd->connect_cb); + } + + if (cbd->ssl_conn) { + /* TODO: postpone close in case ssl is used ! */ + rspamd_ssl_connection_free(cbd->ssl_conn); + } + + if (cbd->fd != -1) { + rspamd_ev_watcher_stop(cbd->event_loop, &cbd->ev); + close(cbd->fd); + cbd->fd = -1; + } + + if (cbd->addr) { + rspamd_inet_address_free(cbd->addr); + } + + if (cbd->up) { + rspamd_upstream_unref(cbd->up); + } + + while (lua_tcp_shift_handler(cbd)) {} + g_queue_free(cbd->handlers); + + LL_FOREACH_SAFE(cbd->dtors, dtor, dttmp) + { + dtor->dtor(dtor->data); + g_free(dtor); + } + + g_byte_array_unref(cbd->in); + g_free(cbd->hostname); + g_free(cbd); +} + +static struct lua_tcp_cbdata * +lua_check_tcp(lua_State *L, gint pos) +{ + void *ud = rspamd_lua_check_udata(L, pos, "rspamd{tcp}"); + luaL_argcheck(L, ud != NULL, pos, "'tcp' expected"); + return ud ? *((struct lua_tcp_cbdata **) ud) : NULL; +} + +static void +lua_tcp_maybe_free(struct lua_tcp_cbdata *cbd) +{ + if (IS_SYNC(cbd)) { + /* + * in this mode, we don't remove object, we only remove the event + * Object is owned by lua and will be destroyed on __gc() + */ + + if (cbd->item) { + rspamd_symcache_item_async_dec_check(cbd->task, cbd->item, M); + cbd->item = NULL; + } + + if (cbd->async_ev) { + rspamd_session_remove_event(cbd->session, lua_tcp_void_finalyser, cbd); + } + + cbd->async_ev = NULL; + } + else { + if (cbd->item) { + rspamd_symcache_item_async_dec_check(cbd->task, cbd->item, M); + cbd->item = NULL; + } + + if (cbd->async_ev) { + rspamd_session_remove_event(cbd->session, lua_tcp_fin, cbd); + } + else { + lua_tcp_fin(cbd); + } + } +} + +#ifdef __GNUC__ +static void +lua_tcp_push_error(struct lua_tcp_cbdata *cbd, gboolean is_fatal, + const char *err, ...) __attribute__((format(printf, 3, 4))); +#endif + +static void lua_tcp_resume_thread_error_argp(struct lua_tcp_cbdata *cbd, const gchar *error, va_list argp); + +static void +lua_tcp_push_error(struct lua_tcp_cbdata *cbd, gboolean is_fatal, + const char *err, ...) +{ + va_list ap, ap_copy; + struct lua_tcp_cbdata **pcbd; + struct lua_tcp_handler *hdl; + gint cbref, top; + struct lua_callback_state cbs; + lua_State *L; + gboolean callback_called = FALSE; + + if (is_fatal && cbd->up) { + rspamd_upstream_fail(cbd->up, false, err); + } + + if (cbd->thread) { + va_start(ap, err); + lua_tcp_resume_thread_error_argp(cbd, err, ap); + va_end(ap); + + return; + } + + lua_thread_pool_prepare_callback(cbd->cfg->lua_thread_pool, &cbs); + L = cbs.L; + + va_start(ap, err); + + for (;;) { + hdl = g_queue_peek_head(cbd->handlers); + + if (hdl == NULL) { + break; + } + + if (hdl->type == LUA_WANT_READ) { + cbref = hdl->h.r.cbref; + } + else { + cbref = hdl->h.w.cbref; + } + + if (cbref != -1) { + top = lua_gettop(L); + lua_rawgeti(L, LUA_REGISTRYINDEX, cbref); + + /* Error message */ + va_copy(ap_copy, ap); + lua_pushvfstring(L, err, ap_copy); + va_end(ap_copy); + + /* Body */ + lua_pushnil(L); + /* Connection */ + pcbd = lua_newuserdata(L, sizeof(*pcbd)); + *pcbd = cbd; + rspamd_lua_setclass(L, "rspamd{tcp}", -1); + TCP_RETAIN(cbd); + + if (cbd->item) { + rspamd_symcache_set_cur_item(cbd->task, cbd->item); + } + + if (lua_pcall(L, 3, 0, 0) != 0) { + msg_info("callback call failed: %s", lua_tostring(L, -1)); + } + + lua_settop(L, top); + + TCP_RELEASE(cbd); + + if ((cbd->flags & (LUA_TCP_FLAG_FINISHED | LUA_TCP_FLAG_CONNECTED)) == + (LUA_TCP_FLAG_FINISHED | LUA_TCP_FLAG_CONNECTED)) { + /* A callback has called `close` method, so we need to release a refcount */ + TCP_RELEASE(cbd); + } + + callback_called = TRUE; + } + + if (!is_fatal) { + if (callback_called) { + /* Stop on the first callback found */ + break; + } + else { + /* Shift to another callback to inform about non fatal error */ + msg_debug_tcp("non fatal error find matching callback"); + lua_tcp_shift_handler(cbd); + continue; + } + } + else { + msg_debug_tcp("fatal error rollback all handlers"); + lua_tcp_shift_handler(cbd); + } + } + + va_end(ap); + + lua_thread_pool_restore_callback(&cbs); +} + +static void lua_tcp_resume_thread(struct lua_tcp_cbdata *cbd, const guint8 *str, gsize len); + +static void +lua_tcp_push_data(struct lua_tcp_cbdata *cbd, const guint8 *str, gsize len) +{ + struct rspamd_lua_text *t; + struct lua_tcp_cbdata **pcbd; + struct lua_tcp_handler *hdl; + gint cbref, arg_cnt, top; + struct lua_callback_state cbs; + lua_State *L; + + if (cbd->thread) { + lua_tcp_resume_thread(cbd, str, len); + return; + } + + lua_thread_pool_prepare_callback(cbd->cfg->lua_thread_pool, &cbs); + L = cbs.L; + + hdl = g_queue_peek_head(cbd->handlers); + + g_assert(hdl != NULL); + + if (hdl->type == LUA_WANT_READ) { + cbref = hdl->h.r.cbref; + } + else { + cbref = hdl->h.w.cbref; + } + + if (cbref != -1) { + top = lua_gettop(L); + lua_rawgeti(L, LUA_REGISTRYINDEX, cbref); + /* Error */ + lua_pushnil(L); + /* Body */ + + if (hdl->type == LUA_WANT_READ) { + t = lua_newuserdata(L, sizeof(*t)); + rspamd_lua_setclass(L, "rspamd{text}", -1); + t->start = (const gchar *) str; + t->len = len; + t->flags = 0; + arg_cnt = 3; + } + else { + arg_cnt = 2; + } + /* Connection */ + pcbd = lua_newuserdata(L, sizeof(*pcbd)); + *pcbd = cbd; + rspamd_lua_setclass(L, "rspamd{tcp}", -1); + + TCP_RETAIN(cbd); + + if (cbd->item) { + rspamd_symcache_set_cur_item(cbd->task, cbd->item); + } + + if (lua_pcall(L, arg_cnt, 0, 0) != 0) { + msg_info("callback call failed: %s", lua_tostring(L, -1)); + } + + lua_settop(L, top); + TCP_RELEASE(cbd); + + if ((cbd->flags & (LUA_TCP_FLAG_FINISHED | LUA_TCP_FLAG_CONNECTED)) == + (LUA_TCP_FLAG_FINISHED | LUA_TCP_FLAG_CONNECTED)) { + /* A callback has called `close` method, so we need to release a refcount */ + TCP_RELEASE(cbd); + } + } + + lua_thread_pool_restore_callback(&cbs); +} + +static void +lua_tcp_resume_thread_error_argp(struct lua_tcp_cbdata *cbd, const gchar *error, va_list argp) +{ + struct thread_entry *thread = cbd->thread; + lua_State *L = thread->lua_state; + + lua_pushboolean(L, FALSE); + lua_pushvfstring(L, error, argp); + + lua_tcp_shift_handler(cbd); + // lua_tcp_unregister_event (cbd); + lua_thread_pool_set_running_entry(cbd->cfg->lua_thread_pool, cbd->thread); + lua_thread_resume(thread, 2); + TCP_RELEASE(cbd); +} + +static void +lua_tcp_resume_thread(struct lua_tcp_cbdata *cbd, const guint8 *str, gsize len) +{ + /* + * typical call returns: + * + * read: + * error: + * (nil, error message) + * got data: + * (true, data) + * write/connect: + * error: + * (nil, error message) + * wrote + * (true) + */ + + lua_State *L = cbd->thread->lua_state; + struct lua_tcp_handler *hdl; + + hdl = g_queue_peek_head(cbd->handlers); + + lua_pushboolean(L, TRUE); + if (hdl->type == LUA_WANT_READ) { + lua_pushlstring(L, str, len); + } + else { + lua_pushnil(L); + } + + lua_tcp_shift_handler(cbd); + lua_thread_pool_set_running_entry(cbd->cfg->lua_thread_pool, + cbd->thread); + + if (cbd->item) { + rspamd_symcache_set_cur_item(cbd->task, cbd->item); + } + + lua_thread_resume(cbd->thread, 2); + + TCP_RELEASE(cbd); +} + +static void +lua_tcp_plan_read(struct lua_tcp_cbdata *cbd) +{ + rspamd_ev_watcher_reschedule(cbd->event_loop, &cbd->ev, EV_READ); +} + +static void +lua_tcp_connect_helper(struct lua_tcp_cbdata *cbd) +{ + /* This is used for sync mode only */ + lua_State *L = cbd->thread->lua_state; + + struct lua_tcp_cbdata **pcbd; + + lua_pushboolean(L, TRUE); + + lua_thread_pool_set_running_entry(cbd->cfg->lua_thread_pool, cbd->thread); + pcbd = lua_newuserdata(L, sizeof(*pcbd)); + *pcbd = cbd; + rspamd_lua_setclass(L, "rspamd{tcp_sync}", -1); + msg_debug_tcp("tcp connected"); + + lua_tcp_shift_handler(cbd); + + // lua_tcp_unregister_event (cbd); + lua_thread_resume(cbd->thread, 2); + TCP_RELEASE(cbd); +} + +static void +lua_tcp_write_helper(struct lua_tcp_cbdata *cbd) +{ + struct iovec *start; + guint niov, i; + gint flags = 0; + bool allocated_iov = false; + gsize remain; + gssize r; + struct iovec *cur_iov; + struct lua_tcp_handler *hdl; + struct lua_tcp_write_handler *wh; + struct msghdr msg; + + hdl = g_queue_peek_head(cbd->handlers); + + g_assert(hdl != NULL && hdl->type == LUA_WANT_WRITE); + wh = &hdl->h.w; + + if (wh->pos == wh->total_bytes) { + goto call_finish_handler; + } + + start = &wh->iov[0]; + niov = wh->iovlen; + remain = wh->pos; + /* We know that niov is small enough for that */ + + if (niov < 1024) { + cur_iov = g_alloca(niov * sizeof(struct iovec)); + } + else { + cur_iov = g_malloc0(niov * sizeof(struct iovec)); + allocated_iov = true; + } + + memcpy(cur_iov, wh->iov, niov * sizeof(struct iovec)); + + for (i = 0; i < wh->iovlen && remain > 0; i++) { + /* Find out the first iov required */ + start = &cur_iov[i]; + if (start->iov_len <= remain) { + remain -= start->iov_len; + start = &cur_iov[i + 1]; + niov--; + } + else { + start->iov_base = (void *) ((char *) start->iov_base + remain); + start->iov_len -= remain; + remain = 0; + } + } + + memset(&msg, 0, sizeof(msg)); + msg.msg_iov = start; + msg.msg_iovlen = MIN(IOV_MAX, niov); + g_assert(niov > 0); +#ifdef MSG_NOSIGNAL + flags = MSG_NOSIGNAL; +#endif + + msg_debug_tcp("want write %d io vectors of %d", (int) msg.msg_iovlen, + (int) niov); + + if (cbd->ssl_conn) { + r = rspamd_ssl_writev(cbd->ssl_conn, msg.msg_iov, msg.msg_iovlen); + } + else { + r = sendmsg(cbd->fd, &msg, flags); + } + + if (allocated_iov) { + g_free(cur_iov); + } + + if (r == -1) { + if (!(cbd->ssl_conn)) { + if (errno == EAGAIN || errno == EWOULDBLOCK || errno == EINTR) { + msg_debug_tcp("got temporary failure, retry write"); + lua_tcp_plan_handler_event(cbd, TRUE, TRUE); + return; + } + else { + lua_tcp_push_error(cbd, TRUE, + "IO write error while trying to write %d bytes: %s", + (gint) remain, strerror(errno)); + + msg_debug_tcp("write error, terminate connection"); + TCP_RELEASE(cbd); + } + } + + return; + } + else { + wh->pos += r; + } + + msg_debug_tcp("written %z bytes: %z/%z", r, + wh->pos, wh->total_bytes); + + if (wh->pos >= wh->total_bytes) { + goto call_finish_handler; + } + else { + /* Want to write more */ + if (r > 0) { + /* XXX: special case: we know that we want to write more data + * than it is available in iov function. + * + * Hence, we need to check if we can write more at some point... + */ + lua_tcp_write_helper(cbd); + } + } + + return; + +call_finish_handler: + + msg_debug_tcp("finishing TCP write, calling TCP handler"); + + if ((cbd->flags & LUA_TCP_FLAG_SHUTDOWN)) { + /* Half close the connection */ + shutdown(cbd->fd, SHUT_WR); + cbd->flags &= ~LUA_TCP_FLAG_SHUTDOWN; + } + + if (cbd->up) { + rspamd_upstream_ok(cbd->up); + } + + lua_tcp_push_data(cbd, NULL, 0); + if (!IS_SYNC(cbd)) { + lua_tcp_shift_handler(cbd); + lua_tcp_plan_handler_event(cbd, TRUE, TRUE); + } +} + +static gboolean +lua_tcp_process_read_handler(struct lua_tcp_cbdata *cbd, + struct lua_tcp_read_handler *rh, gboolean eof) +{ + guint slen; + goffset pos; + + if (rh->stop_pattern) { + slen = rh->plen; + + if (cbd->in->len >= slen) { + if ((pos = rspamd_substring_search(cbd->in->data, cbd->in->len, + rh->stop_pattern, slen)) != -1) { + msg_debug_tcp("found TCP stop pattern"); + lua_tcp_push_data(cbd, cbd->in->data, pos); + + if (!IS_SYNC(cbd)) { + lua_tcp_shift_handler(cbd); + } + if (pos + slen < cbd->in->len) { + /* We have a leftover */ + memmove(cbd->in->data, cbd->in->data + pos + slen, + cbd->in->len - (pos + slen)); + cbd->in->len = cbd->in->len - (pos + slen); + } + else { + cbd->in->len = 0; + } + + return TRUE; + } + else { + /* Plan new read */ + msg_debug_tcp("NOT found TCP stop pattern"); + + if (!cbd->eof) { + lua_tcp_plan_read(cbd); + } + else { + /* Got session finished but no stop pattern */ + lua_tcp_push_error(cbd, TRUE, + "IO read error: connection terminated"); + } + } + } + } + else { + msg_debug_tcp("read TCP partial data %d bytes", cbd->in->len); + slen = cbd->in->len; + + /* we have eaten all the data, handler should not know that there is something */ + cbd->in->len = 0; + lua_tcp_push_data(cbd, cbd->in->data, slen); + if (!IS_SYNC(cbd)) { + lua_tcp_shift_handler(cbd); + } + + return TRUE; + } + + return FALSE; +} + +static void +lua_tcp_process_read(struct lua_tcp_cbdata *cbd, + guchar *in, gssize r) +{ + struct lua_tcp_handler *hdl; + struct lua_tcp_read_handler *rh; + + hdl = g_queue_peek_head(cbd->handlers); + + g_assert(hdl != NULL && hdl->type == LUA_WANT_READ); + rh = &hdl->h.r; + + if (r > 0) { + if (cbd->flags & LUA_TCP_FLAG_PARTIAL) { + lua_tcp_push_data(cbd, in, r); + /* Plan next event */ + lua_tcp_plan_read(cbd); + } + else { + g_byte_array_append(cbd->in, in, r); + + if (!lua_tcp_process_read_handler(cbd, rh, FALSE)) { + /* Plan more read */ + lua_tcp_plan_read(cbd); + } + else { + /* Go towards the next handler */ + if (!IS_SYNC(cbd)) { + lua_tcp_plan_handler_event(cbd, TRUE, TRUE); + } + } + } + } + else if (r == 0) { + /* EOF */ + cbd->eof = TRUE; + if (cbd->in->len > 0) { + /* We have some data to process */ + lua_tcp_process_read_handler(cbd, rh, TRUE); + } + else { + lua_tcp_push_error(cbd, TRUE, "IO read error: connection terminated"); + + if ((cbd->flags & LUA_TCP_FLAG_FINISHED)) { + /* A callback has called `close` method, so we need to release a refcount */ + TCP_RELEASE(cbd); + } + } + + lua_tcp_plan_handler_event(cbd, FALSE, FALSE); + } + else { + /* An error occurred */ + if (errno == EAGAIN || errno == EINTR) { + /* Restart call */ + lua_tcp_plan_read(cbd); + + return; + } + + /* Fatal error */ + cbd->eof = TRUE; + if (cbd->in->len > 0) { + /* We have some data to process */ + lua_tcp_process_read_handler(cbd, rh, TRUE); + } + else { + lua_tcp_push_error(cbd, TRUE, + "IO read error while trying to read data: %s", + strerror(errno)); + + if ((cbd->flags & LUA_TCP_FLAG_FINISHED)) { + /* A callback has called `close` method, so we need to release a refcount */ + TCP_RELEASE(cbd); + } + } + + lua_tcp_plan_handler_event(cbd, FALSE, FALSE); + } +} + +static void +lua_tcp_handler(int fd, short what, gpointer ud) +{ + struct lua_tcp_cbdata *cbd = ud; + guchar inbuf[8192]; + gssize r; + gint so_error = 0; + socklen_t so_len = sizeof(so_error); + struct lua_callback_state cbs; + lua_State *L; + enum lua_tcp_handler_type event_type; + TCP_RETAIN(cbd); + + msg_debug_tcp("processed TCP event: %d", what); + + struct lua_tcp_handler *rh = g_queue_peek_head(cbd->handlers); + event_type = rh->type; + + rspamd_ev_watcher_stop(cbd->event_loop, &cbd->ev); + + if (what == EV_READ) { + if (cbd->ssl_conn) { + r = rspamd_ssl_read(cbd->ssl_conn, inbuf, sizeof(inbuf)); + } + else { + r = read(cbd->fd, inbuf, sizeof(inbuf)); + } + + lua_tcp_process_read(cbd, inbuf, r); + } + else if (what == EV_WRITE) { + + if (!(cbd->flags & LUA_TCP_FLAG_CONNECTED)) { + if (getsockopt(fd, SOL_SOCKET, SO_ERROR, &so_error, &so_len) == -1) { + lua_tcp_push_error(cbd, TRUE, "Cannot get socket error: %s", + strerror(errno)); + TCP_RELEASE(cbd); + goto out; + } + else if (so_error != 0) { + lua_tcp_push_error(cbd, TRUE, "Socket error detected: %s", + strerror(so_error)); + TCP_RELEASE(cbd); + goto out; + } + else { + cbd->flags |= LUA_TCP_FLAG_CONNECTED; + + if (cbd->connect_cb != -1) { + struct lua_tcp_cbdata **pcbd; + gint top; + + lua_thread_pool_prepare_callback(cbd->cfg->lua_thread_pool, &cbs); + L = cbs.L; + + top = lua_gettop(L); + lua_rawgeti(L, LUA_REGISTRYINDEX, cbd->connect_cb); + pcbd = lua_newuserdata(L, sizeof(*pcbd)); + *pcbd = cbd; + TCP_RETAIN(cbd); + rspamd_lua_setclass(L, "rspamd{tcp}", -1); + + if (cbd->item) { + rspamd_symcache_set_cur_item(cbd->task, cbd->item); + } + + if (lua_pcall(L, 1, 0, 0) != 0) { + msg_info("callback call failed: %s", lua_tostring(L, -1)); + } + + lua_settop(L, top); + TCP_RELEASE(cbd); + lua_thread_pool_restore_callback(&cbs); + + if ((cbd->flags & (LUA_TCP_FLAG_FINISHED | LUA_TCP_FLAG_CONNECTED)) == + (LUA_TCP_FLAG_FINISHED | LUA_TCP_FLAG_CONNECTED)) { + /* A callback has called `close` method, so we need to release a refcount */ + TCP_RELEASE(cbd); + } + } + } + } + + if (event_type == LUA_WANT_WRITE) { + lua_tcp_write_helper(cbd); + } + else if (event_type == LUA_WANT_CONNECT) { + lua_tcp_connect_helper(cbd); + } + else { + g_assert_not_reached(); + } + + if ((cbd->flags & (LUA_TCP_FLAG_FINISHED | LUA_TCP_FLAG_CONNECTED)) == + (LUA_TCP_FLAG_FINISHED | LUA_TCP_FLAG_CONNECTED)) { + /* A callback has called `close` method, so we need to release a refcount */ + TCP_RELEASE(cbd); + } + } + else { + lua_tcp_push_error(cbd, TRUE, "IO timeout"); + TCP_RELEASE(cbd); + } + +out: + TCP_RELEASE(cbd); +} + +static void +lua_tcp_plan_handler_event(struct lua_tcp_cbdata *cbd, gboolean can_read, + gboolean can_write) +{ + struct lua_tcp_handler *hdl; + + hdl = g_queue_peek_head(cbd->handlers); + + if (hdl == NULL) { + if (!(cbd->flags & LUA_TCP_FLAG_FINISHED)) { + /* We are finished with a connection */ + msg_debug_tcp("no handlers left, finish session"); + cbd->flags |= LUA_TCP_FLAG_FINISHED; + TCP_RELEASE(cbd); + } + } + else { + if (hdl->type == LUA_WANT_READ) { + + /* We need to check if we have some leftover in the buffer */ + if (cbd->in->len > 0) { + msg_debug_tcp("process read buffer leftover"); + if (lua_tcp_process_read_handler(cbd, &hdl->h.r, FALSE)) { + if (!IS_SYNC(cbd)) { + /* We can go to the next handler */ + lua_tcp_plan_handler_event(cbd, can_read, can_write); + } + } + } + else { + if (can_read) { + /* We need to plan a new event */ + msg_debug_tcp("plan new read"); + rspamd_ev_watcher_reschedule(cbd->event_loop, &cbd->ev, + EV_READ); + } + else { + /* Cannot read more */ + msg_debug_tcp("cannot read more"); + lua_tcp_push_error(cbd, FALSE, "EOF, cannot read more data"); + if (!IS_SYNC(cbd)) { + lua_tcp_shift_handler(cbd); + lua_tcp_plan_handler_event(cbd, can_read, can_write); + } + } + } + } + else if (hdl->type == LUA_WANT_WRITE) { + /* + * We need to plan write event if there is something in the + * write request + */ + + if (hdl->h.w.pos < hdl->h.w.total_bytes) { + msg_debug_tcp("plan new write"); + if (can_write) { + rspamd_ev_watcher_reschedule(cbd->event_loop, &cbd->ev, + EV_WRITE); + } + else { + /* Cannot write more */ + lua_tcp_push_error(cbd, FALSE, "EOF, cannot write more data"); + if (!IS_SYNC(cbd)) { + lua_tcp_shift_handler(cbd); + lua_tcp_plan_handler_event(cbd, can_read, can_write); + } + } + } + else { + /* We shouldn't have empty write handlers */ + g_assert_not_reached(); + } + } + else { /* LUA_WANT_CONNECT */ + msg_debug_tcp("plan new connect"); + rspamd_ev_watcher_reschedule(cbd->event_loop, &cbd->ev, + EV_WRITE); + } + } +} + +static gboolean +lua_tcp_register_event(struct lua_tcp_cbdata *cbd) +{ + if (cbd->session) { + event_finalizer_t fin = IS_SYNC(cbd) ? lua_tcp_void_finalyser : lua_tcp_fin; + + if (cbd->item) { + cbd->async_ev = rspamd_session_add_event_full(cbd->session, fin, cbd, M, + rspamd_symcache_dyn_item_name(cbd->task, cbd->item)); + } + else { + cbd->async_ev = rspamd_session_add_event(cbd->session, fin, cbd, M); + } + + if (!cbd->async_ev) { + return FALSE; + } + } + + return TRUE; +} + +static void +lua_tcp_register_watcher(struct lua_tcp_cbdata *cbd) +{ + if (cbd->item && cbd->task) { + rspamd_symcache_item_async_inc(cbd->task, cbd->item, M); + } +} + +static void +lua_tcp_ssl_on_error(gpointer ud, GError *err) +{ + struct lua_tcp_cbdata *cbd = (struct lua_tcp_cbdata *) ud; + + if (err) { + lua_tcp_push_error(cbd, TRUE, "ssl error: %s", err->message); + } + else { + lua_tcp_push_error(cbd, TRUE, "ssl error: unknown error"); + } + + TCP_RELEASE(cbd); +} + +static gboolean +lua_tcp_make_connection(struct lua_tcp_cbdata *cbd) +{ + int fd; + + rspamd_inet_address_set_port(cbd->addr, cbd->port); + fd = rspamd_inet_address_connect(cbd->addr, SOCK_STREAM, TRUE); + + if (fd == -1) { + if (cbd->session) { + rspamd_mempool_t *pool = rspamd_session_mempool(cbd->session); + msg_info_pool("cannot connect to %s (%s): %s", + rspamd_inet_address_to_string(cbd->addr), + cbd->hostname, + strerror(errno)); + } + else { + msg_info("cannot connect to %s (%s): %s", + rspamd_inet_address_to_string(cbd->addr), + cbd->hostname, + strerror(errno)); + } + + return FALSE; + } + + cbd->fd = fd; + +#if 0 + if (!(cbd->flags & LUA_TCP_FLAG_RESOLVED)) { + /* We come here without resolving, so we need to add a watcher */ + lua_tcp_register_watcher (cbd); + } + else { + cbd->flags |= LUA_TCP_FLAG_RESOLVED; + } +#endif + + if (cbd->flags & LUA_TCP_FLAG_SSL) { + gpointer ssl_ctx; + gboolean verify_peer; + + if (cbd->flags & LUA_TCP_FLAG_SSL_NOVERIFY) { + ssl_ctx = cbd->cfg->libs_ctx->ssl_ctx_noverify; + verify_peer = FALSE; + } + else { + ssl_ctx = cbd->cfg->libs_ctx->ssl_ctx; + verify_peer = TRUE; + } + + cbd->ssl_conn = rspamd_ssl_connection_new(ssl_ctx, + cbd->event_loop, + verify_peer, + cbd->tag); + + if (!rspamd_ssl_connect_fd(cbd->ssl_conn, fd, cbd->hostname, &cbd->ev, + cbd->ev.timeout, lua_tcp_handler, lua_tcp_ssl_on_error, cbd)) { + lua_tcp_push_error(cbd, TRUE, "ssl connection failed: %s", + strerror(errno)); + + return FALSE; + } + else { + lua_tcp_register_event(cbd); + } + } + else { + rspamd_ev_watcher_init(&cbd->ev, cbd->fd, EV_WRITE, + lua_tcp_handler, cbd); + lua_tcp_register_event(cbd); + lua_tcp_plan_handler_event(cbd, TRUE, TRUE); + } + + + return TRUE; +} + +static void +lua_tcp_dns_handler(struct rdns_reply *reply, gpointer ud) +{ + struct lua_tcp_cbdata *cbd = (struct lua_tcp_cbdata *) ud; + const struct rdns_request_name *rn; + + if (reply->code != RDNS_RC_NOERROR) { + rn = rdns_request_get_name(reply->request, NULL); + lua_tcp_push_error(cbd, TRUE, "unable to resolve host: %s", + rn->name); + TCP_RELEASE(cbd); + } + else { + /* + * We set this flag as it means that we have already registered the watcher + * when started DNS query + */ + struct rdns_reply_entry *entry; + + DL_FOREACH(reply->entries, entry) + { + if (entry->type == RDNS_REQUEST_A) { + cbd->addr = rspamd_inet_address_new(AF_INET, + &entry->content.a.addr); + break; + } + else if (entry->type == RDNS_REQUEST_AAAA) { + cbd->addr = rspamd_inet_address_new(AF_INET6, + &entry->content.aaa.addr); + break; + } + } + + if (cbd->addr == NULL) { + rn = rdns_request_get_name(reply->request, NULL); + lua_tcp_push_error(cbd, TRUE, "unable to resolve host: %s; no records with this name", + rn->name); + TCP_RELEASE(cbd); + return; + } + + cbd->flags |= LUA_TCP_FLAG_RESOLVED; + rspamd_inet_address_set_port(cbd->addr, cbd->port); + + if (!lua_tcp_make_connection(cbd)) { + lua_tcp_push_error(cbd, TRUE, "unable to make connection to the host %s", + rspamd_inet_address_to_string(cbd->addr)); + TCP_RELEASE(cbd); + } + } +} + +static gboolean +lua_tcp_arg_toiovec(lua_State *L, gint pos, struct lua_tcp_cbdata *cbd, + struct iovec *vec) +{ + struct rspamd_lua_text *t; + gsize len; + const gchar *str; + struct lua_tcp_dtor *dtor; + + if (lua_type(L, pos) == LUA_TUSERDATA) { + t = lua_check_text(L, pos); + + if (t) { + vec->iov_base = (void *) t->start; + vec->iov_len = t->len; + + if (t->flags & RSPAMD_TEXT_FLAG_OWN) { + /* Steal ownership */ + t->flags = 0; + dtor = g_malloc0(sizeof(*dtor)); + dtor->dtor = g_free; + dtor->data = (void *) t->start; + LL_PREPEND(cbd->dtors, dtor); + } + } + else { + msg_err("bad userdata argument at position %d", pos); + return FALSE; + } + } + else if (lua_type(L, pos) == LUA_TSTRING) { + str = luaL_checklstring(L, pos, &len); + vec->iov_base = g_malloc(len); + dtor = g_malloc0(sizeof(*dtor)); + dtor->dtor = g_free; + dtor->data = vec->iov_base; + LL_PREPEND(cbd->dtors, dtor); + memcpy(vec->iov_base, str, len); + vec->iov_len = len; + } + else { + msg_err("bad argument at position %d", pos); + return FALSE; + } + + return TRUE; +} + +/*** + * @function rspamd_tcp.request({params}) + * This function creates and sends TCP request to the specified host and port, + * resolves hostname (if needed) and invokes continuation callback upon data received + * from the remote peer. This function accepts table of arguments with the following + * attributes + * + * - `task`: rspamd task objects (implies `pool`, `session`, `ev_base` and `resolver` arguments) + * - `ev_base`: event base (if no task specified) + * - `resolver`: DNS resolver (no task) + * - `session`: events session (no task) + * - `host`: IP or name of the peer (required) + * - `port`: remote port to use + * - `data`: a table of strings or `rspamd_text` objects that contains data pieces + * - `callback`: continuation function (required) + * - `on_connect`: callback called on connection success + * - `timeout`: floating point value that specifies timeout for IO operations in **seconds** + * - `partial`: boolean flag that specifies that callback should be called on any data portion received + * - `stop_pattern`: stop reading on finding a certain pattern (e.g. \r\n.\r\n for smtp) + * - `shutdown`: half-close socket after writing (boolean: default false) + * - `read`: read response after sending request (boolean: default true) + * - `upstream`: optional upstream object that would be used to get an address + * @return {boolean} true if request has been sent + */ +static gint +lua_tcp_request(lua_State *L) +{ + LUA_TRACE_POINT; + const gchar *host; + gchar *stop_pattern = NULL; + guint port; + gint cbref, tp, conn_cbref = -1; + gsize plen = 0; + struct ev_loop *event_loop = NULL; + struct lua_tcp_cbdata *cbd; + struct rspamd_dns_resolver *resolver = NULL; + struct rspamd_async_session *session = NULL; + struct rspamd_task *task = NULL; + struct rspamd_config *cfg = NULL; + struct iovec *iov = NULL; + struct upstream *up = NULL; + guint niov = 0, total_out; + guint64 h; + gdouble timeout = default_tcp_timeout; + gboolean partial = FALSE, do_shutdown = FALSE, do_read = TRUE, + ssl = FALSE, ssl_noverify = FALSE; + + if (lua_type(L, 1) == LUA_TTABLE) { + lua_pushstring(L, "host"); + lua_gettable(L, -2); + host = luaL_checkstring(L, -1); + lua_pop(L, 1); + + lua_pushstring(L, "port"); + lua_gettable(L, -2); + if (lua_type(L, -1) == LUA_TNUMBER) { + port = lua_tointeger(L, -1); + } + else { + /* We assume that it is a unix socket */ + port = 0; + } + + lua_pop(L, 1); + + lua_pushstring(L, "callback"); + lua_gettable(L, -2); + if (host == NULL || lua_type(L, -1) != LUA_TFUNCTION) { + lua_pop(L, 1); + msg_err("tcp request has bad params"); + lua_pushboolean(L, FALSE); + return 1; + } + cbref = luaL_ref(L, LUA_REGISTRYINDEX); + + cbd = g_malloc0(sizeof(*cbd)); + + lua_pushstring(L, "task"); + lua_gettable(L, -2); + if (lua_type(L, -1) == LUA_TUSERDATA) { + task = lua_check_task(L, -1); + event_loop = task->event_loop; + resolver = task->resolver; + session = task->s; + cfg = task->cfg; + } + lua_pop(L, 1); + + if (task == NULL) { + lua_pushstring(L, "ev_base"); + lua_gettable(L, -2); + if (rspamd_lua_check_udata_maybe(L, -1, "rspamd{ev_base}")) { + event_loop = *(struct ev_loop **) lua_touserdata(L, -1); + } + else { + g_free(cbd); + + return luaL_error(L, "event loop is required"); + } + lua_pop(L, 1); + + lua_pushstring(L, "session"); + lua_gettable(L, -2); + if (rspamd_lua_check_udata_maybe(L, -1, "rspamd{session}")) { + session = *(struct rspamd_async_session **) lua_touserdata(L, -1); + } + else { + session = NULL; + } + lua_pop(L, 1); + + lua_pushstring(L, "config"); + lua_gettable(L, -2); + if (rspamd_lua_check_udata_maybe(L, -1, "rspamd{config}")) { + cfg = *(struct rspamd_config **) lua_touserdata(L, -1); + } + else { + cfg = NULL; + } + lua_pop(L, 1); + + lua_pushstring(L, "resolver"); + lua_gettable(L, -2); + if (rspamd_lua_check_udata_maybe(L, -1, "rspamd{resolver}")) { + resolver = *(struct rspamd_dns_resolver **) lua_touserdata(L, -1); + } + else { + resolver = lua_tcp_global_resolver(event_loop, cfg); + } + lua_pop(L, 1); + } + + lua_pushstring(L, "timeout"); + lua_gettable(L, -2); + if (lua_type(L, -1) == LUA_TNUMBER) { + timeout = lua_tonumber(L, -1); + } + lua_pop(L, 1); + + lua_pushstring(L, "stop_pattern"); + lua_gettable(L, -2); + if (lua_type(L, -1) == LUA_TSTRING) { + const gchar *p; + + p = lua_tolstring(L, -1, &plen); + + if (p && plen > 0) { + stop_pattern = g_malloc(plen); + memcpy(stop_pattern, p, plen); + } + } + lua_pop(L, 1); + + lua_pushstring(L, "partial"); + lua_gettable(L, -2); + if (lua_type(L, -1) == LUA_TBOOLEAN) { + partial = lua_toboolean(L, -1); + } + lua_pop(L, 1); + + lua_pushstring(L, "shutdown"); + lua_gettable(L, -2); + if (lua_type(L, -1) == LUA_TBOOLEAN) { + do_shutdown = lua_toboolean(L, -1); + } + lua_pop(L, 1); + + lua_pushstring(L, "read"); + lua_gettable(L, -2); + if (lua_type(L, -1) == LUA_TBOOLEAN) { + do_read = lua_toboolean(L, -1); + } + lua_pop(L, 1); + + lua_pushstring(L, "ssl"); + lua_gettable(L, -2); + if (lua_type(L, -1) == LUA_TBOOLEAN) { + ssl = lua_toboolean(L, -1); + } + lua_pop(L, 1); + + lua_pushstring(L, "ssl_noverify"); + lua_gettable(L, -2); + if (lua_type(L, -1) == LUA_TBOOLEAN) { + ssl_noverify = lua_toboolean(L, -1); + lua_pop(L, 1); + } + else { + lua_pop(L, 1); /* Previous nil... */ + /* Similar to lua http, meh... */ + lua_pushstring(L, "no_ssl_verify"); + lua_gettable(L, -2); + + if (lua_type(L, -1) == LUA_TBOOLEAN) { + ssl_noverify = lua_toboolean(L, -1); + } + + lua_pop(L, 1); + } + + lua_pushstring(L, "on_connect"); + lua_gettable(L, -2); + + if (lua_type(L, -1) == LUA_TFUNCTION) { + conn_cbref = luaL_ref(L, LUA_REGISTRYINDEX); + } + else { + lua_pop(L, 1); + } + + lua_pushstring(L, "upstream"); + lua_gettable(L, 1); + + if (lua_type(L, -1) == LUA_TUSERDATA) { + struct rspamd_lua_upstream *lup = lua_check_upstream(L, -1); + + if (lup) { + /* Preserve pointer in case if lup is destructed */ + up = lup->up; + } + } + + lua_pop(L, 1); + + lua_pushstring(L, "data"); + lua_gettable(L, -2); + total_out = 0; + + tp = lua_type(L, -1); + if (tp == LUA_TSTRING || tp == LUA_TUSERDATA) { + iov = g_malloc(sizeof(*iov)); + niov = 1; + + if (!lua_tcp_arg_toiovec(L, -1, cbd, iov)) { + lua_pop(L, 1); + msg_err("tcp request has bad data argument"); + lua_pushboolean(L, FALSE); + g_free(iov); + g_free(cbd); + + return 1; + } + + total_out = iov[0].iov_len; + } + else if (tp == LUA_TTABLE) { + /* Count parts */ + lua_pushnil(L); + while (lua_next(L, -2) != 0) { + niov++; + lua_pop(L, 1); + } + + iov = g_malloc(sizeof(*iov) * niov); + lua_pushnil(L); + niov = 0; + + while (lua_next(L, -2) != 0) { + if (!lua_tcp_arg_toiovec(L, -1, cbd, &iov[niov])) { + lua_pop(L, 2); + msg_err("tcp request has bad data argument at pos %d", niov); + lua_pushboolean(L, FALSE); + g_free(iov); + g_free(cbd); + + return 1; + } + + total_out += iov[niov].iov_len; + niov++; + + lua_pop(L, 1); + } + } + + lua_pop(L, 1); + } + else { + return luaL_error(L, "tcp request has bad params"); + } + + if (resolver == NULL && cfg == NULL && task == NULL) { + g_free(cbd); + g_free(iov); + + return luaL_error(L, "tcp request has bad params: one of " + "{resolver,task,config} should be set"); + } + + cbd->task = task; + + if (task) { + cbd->item = rspamd_symcache_get_cur_item(task); + } + + cbd->cfg = cfg; + h = rspamd_random_uint64_fast(); + rspamd_snprintf(cbd->tag, sizeof(cbd->tag), "%uxL", h); + cbd->handlers = g_queue_new(); + cbd->hostname = g_strdup(host); + + if (total_out > 0) { + struct lua_tcp_handler *wh; + + wh = g_malloc0(sizeof(*wh)); + wh->type = LUA_WANT_WRITE; + wh->h.w.iov = iov; + wh->h.w.iovlen = niov; + wh->h.w.total_bytes = total_out; + wh->h.w.pos = 0; + /* Cannot set write handler here */ + wh->h.w.cbref = -1; + + if (cbref != -1 && !do_read) { + /* We have write only callback */ + wh->h.w.cbref = cbref; + } + else { + /* We have simple client callback */ + wh->h.w.cbref = -1; + } + + g_queue_push_tail(cbd->handlers, wh); + } + + cbd->event_loop = event_loop; + cbd->fd = -1; + cbd->port = port; + cbd->ev.timeout = timeout; + + if (ssl) { + cbd->flags |= LUA_TCP_FLAG_SSL; + + if (ssl_noverify) { + cbd->flags |= LUA_TCP_FLAG_SSL_NOVERIFY; + } + } + + if (do_read) { + cbd->in = g_byte_array_sized_new(8192); + } + else { + /* Save some space... */ + cbd->in = g_byte_array_new(); + } + + if (partial) { + cbd->flags |= LUA_TCP_FLAG_PARTIAL; + } + + if (do_shutdown) { + cbd->flags |= LUA_TCP_FLAG_SHUTDOWN; + } + + if (do_read) { + struct lua_tcp_handler *rh; + + rh = g_malloc0(sizeof(*rh)); + rh->type = LUA_WANT_READ; + rh->h.r.cbref = cbref; + rh->h.r.stop_pattern = stop_pattern; + rh->h.r.plen = plen; + g_queue_push_tail(cbd->handlers, rh); + } + + cbd->connect_cb = conn_cbref; + REF_INIT_RETAIN(cbd, lua_tcp_maybe_free); + + if (up) { + cbd->up = rspamd_upstream_ref(up); + } + + if (session) { + cbd->session = session; + + if (rspamd_session_blocked(session)) { + lua_tcp_push_error(cbd, TRUE, "async session is the blocked state"); + TCP_RELEASE(cbd); + cbd->item = NULL; /* To avoid decrease with no watcher */ + lua_pushboolean(L, FALSE); + + return 1; + } + } + + if (cbd->up) { + /* Use upstream to get addr */ + cbd->addr = rspamd_inet_address_copy(rspamd_upstream_addr_next(cbd->up), NULL); + + /* Host is numeric IP, no need to resolve */ + lua_tcp_register_watcher(cbd); + + if (!lua_tcp_make_connection(cbd)) { + lua_tcp_push_error(cbd, TRUE, "cannot connect to the host: %s", host); + lua_pushboolean(L, FALSE); + + rspamd_upstream_fail(cbd->up, true, "failed to connect"); + + /* No reset of the item as watcher has been registered */ + TCP_RELEASE(cbd); + + return 1; + } + } + else if (rspamd_parse_inet_address(&cbd->addr, + host, strlen(host), RSPAMD_INET_ADDRESS_PARSE_DEFAULT)) { + rspamd_inet_address_set_port(cbd->addr, port); + /* Host is numeric IP, no need to resolve */ + lua_tcp_register_watcher(cbd); + + if (!lua_tcp_make_connection(cbd)) { + lua_tcp_push_error(cbd, TRUE, "cannot connect to the host: %s", host); + lua_pushboolean(L, FALSE); + + /* No reset of the item as watcher has been registered */ + TCP_RELEASE(cbd); + + return 1; + } + } + else { + if (task == NULL) { + if (!rspamd_dns_resolver_request(resolver, session, NULL, lua_tcp_dns_handler, cbd, + RDNS_REQUEST_A, host)) { + lua_tcp_push_error(cbd, TRUE, "cannot resolve host: %s", host); + lua_pushboolean(L, FALSE); + cbd->item = NULL; /* To avoid decrease with no watcher */ + TCP_RELEASE(cbd); + + return 1; + } + else { + lua_tcp_register_watcher(cbd); + } + } + else { + if (!rspamd_dns_resolver_request_task(task, lua_tcp_dns_handler, cbd, + RDNS_REQUEST_A, host)) { + lua_tcp_push_error(cbd, TRUE, "cannot resolve host: %s", host); + lua_pushboolean(L, FALSE); + cbd->item = NULL; /* To avoid decrease with no watcher */ + + TCP_RELEASE(cbd); + + return 1; + } + else { + lua_tcp_register_watcher(cbd); + } + } + } + + lua_pushboolean(L, TRUE); + return 1; +} + +/*** + * @function rspamd_tcp.connect_sync({params}) + * Creates new pseudo-synchronous connection to the specific address:port + * + * - `task`: rspamd task objects (implies `pool`, `session`, `ev_base` and `resolver` arguments) + * - `ev_base`: event base (if no task specified) + * - `resolver`: DNS resolver (no task) + * - `session`: events session (no task) + * - `config`: config (no task) + * - `host`: IP or name of the peer (required) + * - `port`: remote port to use + * - `timeout`: floating point value that specifies timeout for IO operations in **seconds** + * @return {boolean} true if request has been sent + */ +static gint +lua_tcp_connect_sync(lua_State *L) +{ + LUA_TRACE_POINT; + GError *err = NULL; + + gint64 port = -1; + gdouble timeout = default_tcp_timeout; + const gchar *host = NULL; + gint ret; + guint64 h; + + struct rspamd_task *task = NULL; + struct rspamd_async_session *session = NULL; + struct rspamd_dns_resolver *resolver = NULL; + struct rspamd_config *cfg = NULL; + struct ev_loop *ev_base = NULL; + struct lua_tcp_cbdata *cbd; + + + int arguments_validated = rspamd_lua_parse_table_arguments(L, 1, &err, + RSPAMD_LUA_PARSE_ARGUMENTS_DEFAULT, + "task=U{task};session=U{session};resolver=U{resolver};ev_base=U{ev_base};" + "*host=S;*port=I;timeout=D;config=U{config}", + &task, &session, &resolver, &ev_base, + &host, &port, &timeout, &cfg); + + if (!arguments_validated) { + if (err) { + ret = luaL_error(L, "invalid arguments: %s", err->message); + g_error_free(err); + + return ret; + } + + return luaL_error(L, "invalid arguments"); + } + + if (0 > port || port > 65535) { + return luaL_error(L, "invalid port given (correct values: 1..65535)"); + } + + if (task == NULL && (cfg == NULL || ev_base == NULL || session == NULL)) { + return luaL_error(L, "invalid arguments: either task or config+ev_base+session should be set"); + } + + if (isnan(timeout)) { + /* rspamd_lua_parse_table_arguments() sets missing N field to zero */ + timeout = default_tcp_timeout; + } + + cbd = g_new0(struct lua_tcp_cbdata, 1); + + if (task) { + static const gchar hexdigests[16] = "0123456789abcdef"; + + cfg = task->cfg; + ev_base = task->event_loop; + session = task->s; + /* Make a readable tag */ + memcpy(cbd->tag, task->task_pool->tag.uid, sizeof(cbd->tag) - 2); + cbd->tag[sizeof(cbd->tag) - 2] = hexdigests[GPOINTER_TO_INT(cbd) & 0xf]; + cbd->tag[sizeof(cbd->tag) - 1] = 0; + } + else { + h = rspamd_random_uint64_fast(); + rspamd_snprintf(cbd->tag, sizeof(cbd->tag), "%uxL", h); + } + + if (resolver == NULL) { + if (task) { + resolver = task->resolver; + } + else { + resolver = lua_tcp_global_resolver(ev_base, cfg); + } + } + + cbd->task = task; + cbd->cfg = cfg; + cbd->thread = lua_thread_pool_get_running_entry(cfg->lua_thread_pool); + + + cbd->handlers = g_queue_new(); + + cbd->event_loop = ev_base; + cbd->flags |= LUA_TCP_FLAG_SYNC; + cbd->fd = -1; + cbd->port = (guint16) port; + + cbd->in = g_byte_array_new(); + + cbd->connect_cb = -1; + + REF_INIT_RETAIN(cbd, lua_tcp_maybe_free); + + if (task) { + rspamd_mempool_add_destructor(task->task_pool, lua_tcp_sync_session_dtor, cbd); + } + + struct lua_tcp_handler *wh; + + wh = g_malloc0(sizeof(*wh)); + wh->type = LUA_WANT_CONNECT; + + g_queue_push_tail(cbd->handlers, wh); + + if (session) { + cbd->session = session; + + if (rspamd_session_blocked(session)) { + TCP_RELEASE(cbd); + lua_pushboolean(L, FALSE); + lua_pushliteral(L, "Session is being destroyed, requests are not allowed"); + + return 2; + } + } + + if (rspamd_parse_inet_address(&cbd->addr, + host, strlen(host), RSPAMD_INET_ADDRESS_PARSE_DEFAULT)) { + rspamd_inet_address_set_port(cbd->addr, (guint16) port); + /* Host is numeric IP, no need to resolve */ + if (!lua_tcp_make_connection(cbd)) { + lua_pushboolean(L, FALSE); + lua_pushliteral(L, "Failed to initiate connection"); + + TCP_RELEASE(cbd); + + return 2; + } + } + else { + if (task == NULL) { + if (!rspamd_dns_resolver_request(resolver, session, NULL, lua_tcp_dns_handler, cbd, + RDNS_REQUEST_A, host)) { + lua_pushboolean(L, FALSE); + lua_pushliteral(L, "Failed to initiate dns request"); + + TCP_RELEASE(cbd); + + return 2; + } + else { + lua_tcp_register_watcher(cbd); + } + } + else { + cbd->item = rspamd_symcache_get_cur_item(task); + + if (!rspamd_dns_resolver_request_task(task, lua_tcp_dns_handler, cbd, + RDNS_REQUEST_A, host)) { + cbd->item = NULL; /* We have not registered watcher */ + lua_pushboolean(L, FALSE); + lua_pushliteral(L, "Failed to initiate dns request"); + TCP_RELEASE(cbd); + + return 2; + } + else { + lua_tcp_register_watcher(cbd); + } + } + } + + return lua_thread_yield(cbd->thread, 0); +} + +static gint +lua_tcp_close(lua_State *L) +{ + LUA_TRACE_POINT; + struct lua_tcp_cbdata *cbd = lua_check_tcp(L, 1); + + if (cbd == NULL) { + return luaL_error(L, "invalid arguments"); + } + + cbd->flags |= LUA_TCP_FLAG_FINISHED; + + if (cbd->ssl_conn) { + /* TODO: postpone close in case ssl is used ! */ + rspamd_ssl_connection_free(cbd->ssl_conn); + cbd->ssl_conn = NULL; + } + + if (cbd->fd != -1) { + rspamd_ev_watcher_stop(cbd->event_loop, &cbd->ev); + close(cbd->fd); + cbd->fd = -1; + } + + if (cbd->addr) { + rspamd_inet_address_free(cbd->addr); + cbd->addr = NULL; + } + + if (cbd->up) { + rspamd_upstream_unref(cbd->up); + cbd->up = NULL; + } + /* Do not release refcount as it will be handled elsewhere */ + /* TCP_RELEASE (cbd); */ + + return 0; +} + +static gint +lua_tcp_add_read(lua_State *L) +{ + LUA_TRACE_POINT; + struct lua_tcp_cbdata *cbd = lua_check_tcp(L, 1); + struct lua_tcp_handler *rh; + gchar *stop_pattern = NULL; + const gchar *p; + gsize plen = 0; + gint cbref = -1; + + if (cbd == NULL) { + return luaL_error(L, "invalid arguments"); + } + + if (lua_type(L, 2) == LUA_TFUNCTION) { + lua_pushvalue(L, 2); + cbref = luaL_ref(L, LUA_REGISTRYINDEX); + } + + if (lua_type(L, 3) == LUA_TSTRING) { + p = lua_tolstring(L, 3, &plen); + + if (p && plen > 0) { + stop_pattern = g_malloc(plen); + memcpy(stop_pattern, p, plen); + } + } + + rh = g_malloc0(sizeof(*rh)); + rh->type = LUA_WANT_READ; + rh->h.r.cbref = cbref; + rh->h.r.stop_pattern = stop_pattern; + rh->h.r.plen = plen; + msg_debug_tcp("added read event, cbref: %d", cbref); + + g_queue_push_tail(cbd->handlers, rh); + + return 0; +} + +static gint +lua_tcp_add_write(lua_State *L) +{ + LUA_TRACE_POINT; + struct lua_tcp_cbdata *cbd = lua_check_tcp(L, 1); + struct lua_tcp_handler *wh; + gint cbref = -1, tp; + struct iovec *iov = NULL; + guint niov = 0, total_out = 0; + + if (cbd == NULL) { + return luaL_error(L, "invalid arguments"); + } + + if (lua_type(L, 2) == LUA_TFUNCTION) { + lua_pushvalue(L, 2); + cbref = luaL_ref(L, LUA_REGISTRYINDEX); + } + + tp = lua_type(L, 3); + if (tp == LUA_TSTRING || tp == LUA_TUSERDATA) { + iov = g_malloc(sizeof(*iov)); + niov = 1; + + if (!lua_tcp_arg_toiovec(L, 3, cbd, iov)) { + msg_err("tcp request has bad data argument"); + lua_pushboolean(L, FALSE); + g_free(iov); + + return 1; + } + + total_out = iov[0].iov_len; + } + else if (tp == LUA_TTABLE) { + /* Count parts */ + lua_pushvalue(L, 3); + + lua_pushnil(L); + while (lua_next(L, -2) != 0) { + niov++; + lua_pop(L, 1); + } + + iov = g_malloc(sizeof(*iov) * niov); + lua_pushnil(L); + niov = 0; + + while (lua_next(L, -2) != 0) { + if (!lua_tcp_arg_toiovec(L, -1, cbd, &iov[niov])) { + lua_pop(L, 2); + msg_err("tcp request has bad data argument at pos %d", niov); + lua_pushboolean(L, FALSE); + g_free(iov); + g_free(cbd); + + return 1; + } + + total_out += iov[niov].iov_len; + niov++; + + lua_pop(L, 1); + } + + lua_pop(L, 1); + } + + wh = g_malloc0(sizeof(*wh)); + wh->type = LUA_WANT_WRITE; + wh->h.w.iov = iov; + wh->h.w.iovlen = niov; + wh->h.w.total_bytes = total_out; + wh->h.w.pos = 0; + /* Cannot set write handler here */ + wh->h.w.cbref = cbref; + msg_debug_tcp("added write event, cbref: %d", cbref); + + g_queue_push_tail(cbd->handlers, wh); + lua_pushboolean(L, TRUE); + + return 1; +} + +static gint +lua_tcp_shift_callback(lua_State *L) +{ + LUA_TRACE_POINT; + struct lua_tcp_cbdata *cbd = lua_check_tcp(L, 1); + + if (cbd == NULL) { + return luaL_error(L, "invalid arguments"); + } + + lua_tcp_shift_handler(cbd); + lua_tcp_plan_handler_event(cbd, TRUE, TRUE); + + return 0; +} + +static struct lua_tcp_cbdata * +lua_check_sync_tcp(lua_State *L, gint pos) +{ + void *ud = rspamd_lua_check_udata(L, pos, "rspamd{tcp_sync}"); + luaL_argcheck(L, ud != NULL, pos, "'tcp' expected"); + return ud ? *((struct lua_tcp_cbdata **) ud) : NULL; +} + +static int +lua_tcp_sync_close(lua_State *L) +{ + LUA_TRACE_POINT; + struct lua_tcp_cbdata *cbd = lua_check_sync_tcp(L, 1); + + if (cbd == NULL) { + return luaL_error(L, "invalid arguments [self is not rspamd{tcp_sync}]"); + } + cbd->flags |= LUA_TCP_FLAG_FINISHED; + + if (cbd->fd != -1) { + rspamd_ev_watcher_stop(cbd->event_loop, &cbd->ev); + close(cbd->fd); + cbd->fd = -1; + } + + return 0; +} + +static void +lua_tcp_sync_session_dtor(gpointer ud) +{ + struct lua_tcp_cbdata *cbd = ud; + cbd->flags |= LUA_TCP_FLAG_FINISHED; + + if (cbd->fd != -1) { + msg_debug("closing sync TCP connection"); + rspamd_ev_watcher_stop(cbd->event_loop, &cbd->ev); + close(cbd->fd); + cbd->fd = -1; + } + + /* Task is gone, we should not try use it anymore */ + cbd->task = NULL; + + /* All events are removed when task is done, we should not refer them */ + cbd->async_ev = NULL; +} + +static int +lua_tcp_sync_read_once(lua_State *L) +{ + LUA_TRACE_POINT; + struct lua_tcp_cbdata *cbd = lua_check_sync_tcp(L, 1); + struct lua_tcp_handler *rh; + + if (cbd == NULL) { + return luaL_error(L, "invalid arguments [self is not rspamd{tcp_sync}]"); + } + + struct thread_entry *thread = lua_thread_pool_get_running_entry(cbd->cfg->lua_thread_pool); + + rh = g_malloc0(sizeof(*rh)); + rh->type = LUA_WANT_READ; + rh->h.r.cbref = -1; + + msg_debug_tcp("added read sync event, thread: %p", thread); + + g_queue_push_tail(cbd->handlers, rh); + lua_tcp_plan_handler_event(cbd, TRUE, TRUE); + + TCP_RETAIN(cbd); + + return lua_thread_yield(thread, 0); +} + +static int +lua_tcp_sync_write(lua_State *L) +{ + LUA_TRACE_POINT; + struct lua_tcp_cbdata *cbd = lua_check_sync_tcp(L, 1); + struct lua_tcp_handler *wh; + gint tp; + struct iovec *iov = NULL; + guint niov = 0; + gsize total_out = 0; + + if (cbd == NULL) { + return luaL_error(L, "invalid arguments [self is not rspamd{tcp_sync}]"); + } + + struct thread_entry *thread = lua_thread_pool_get_running_entry(cbd->cfg->lua_thread_pool); + + tp = lua_type(L, 2); + if (tp == LUA_TSTRING || tp == LUA_TUSERDATA) { + iov = g_malloc(sizeof(*iov)); + niov = 1; + + if (!lua_tcp_arg_toiovec(L, 2, cbd, iov)) { + msg_err("tcp request has bad data argument"); + g_free(iov); + g_free(cbd); + + return luaL_error(L, "invalid arguments second parameter (data) is expected to be either string or rspamd{text}"); + } + + total_out = iov[0].iov_len; + } + else if (tp == LUA_TTABLE) { + /* Count parts */ + lua_pushvalue(L, 3); + + lua_pushnil(L); + while (lua_next(L, -2) != 0) { + niov++; + lua_pop(L, 1); + } + + iov = g_malloc(sizeof(*iov) * niov); + lua_pushnil(L); + niov = 0; + + while (lua_next(L, -2) != 0) { + if (!lua_tcp_arg_toiovec(L, -1, cbd, &iov[niov])) { + msg_err("tcp request has bad data argument at pos %d", niov); + g_free(iov); + g_free(cbd); + + return luaL_error(L, "invalid arguments second parameter (data) is expected to be either string or rspamd{text}"); + } + + total_out += iov[niov].iov_len; + niov++; + + lua_pop(L, 1); + } + + lua_pop(L, 1); + } + + wh = g_malloc0(sizeof(*wh)); + wh->type = LUA_WANT_WRITE; + wh->h.w.iov = iov; + wh->h.w.iovlen = niov; + wh->h.w.total_bytes = total_out; + wh->h.w.pos = 0; + wh->h.w.cbref = -1; + msg_debug_tcp("added sync write event, thread: %p", thread); + + g_queue_push_tail(cbd->handlers, wh); + lua_tcp_plan_handler_event(cbd, TRUE, TRUE); + + TCP_RETAIN(cbd); + + return lua_thread_yield(thread, 0); +} + +static gint +lua_tcp_sync_eof(lua_State *L) +{ + LUA_TRACE_POINT; + struct lua_tcp_cbdata *cbd = lua_check_sync_tcp(L, 1); + if (cbd == NULL) { + return luaL_error(L, "invalid arguments [self is not rspamd{tcp_sync}]"); + } + + lua_pushboolean(L, cbd->eof); + + return 1; +} + +static gint +lua_tcp_sync_shutdown(lua_State *L) +{ + LUA_TRACE_POINT; + struct lua_tcp_cbdata *cbd = lua_check_sync_tcp(L, 1); + if (cbd == NULL) { + return luaL_error(L, "invalid arguments [self is not rspamd{tcp_sync}]"); + } + + shutdown(cbd->fd, SHUT_WR); + + return 0; +} + +static gint +lua_tcp_starttls(lua_State *L) +{ + LUA_TRACE_POINT; + struct lua_tcp_cbdata *cbd = lua_check_tcp(L, 1); + gpointer ssl_ctx; + gboolean verify_peer; + + if (cbd == NULL || cbd->ssl_conn != NULL) { + return luaL_error(L, "invalid arguments"); + } + + if (cbd->flags & LUA_TCP_FLAG_SSL_NOVERIFY) { + ssl_ctx = cbd->cfg->libs_ctx->ssl_ctx_noverify; + verify_peer = FALSE; + } + else { + ssl_ctx = cbd->cfg->libs_ctx->ssl_ctx; + verify_peer = TRUE; + } + + cbd->ssl_conn = rspamd_ssl_connection_new(ssl_ctx, + cbd->event_loop, + verify_peer, + cbd->tag); + + if (!rspamd_ssl_connect_fd(cbd->ssl_conn, cbd->fd, cbd->hostname, &cbd->ev, + cbd->ev.timeout, lua_tcp_handler, lua_tcp_ssl_on_error, cbd)) { + lua_tcp_push_error(cbd, TRUE, "ssl connection failed: %s", + strerror(errno)); + } + + return 0; +} + +static gint +lua_tcp_sync_gc(lua_State *L) +{ + struct lua_tcp_cbdata *cbd = lua_check_sync_tcp(L, 1); + if (!cbd) { + return luaL_error(L, "invalid arguments [self is not rspamd{tcp_sync}]"); + } + + lua_tcp_maybe_free(cbd); + lua_tcp_fin(cbd); + + return 0; +} + +static gint +lua_load_tcp(lua_State *L) +{ + lua_newtable(L); + luaL_register(L, NULL, tcp_libf); + + return 1; +} + +void luaopen_tcp(lua_State *L) +{ + rspamd_lua_add_preload(L, "rspamd_tcp", lua_load_tcp); + rspamd_lua_new_class(L, "rspamd{tcp}", tcp_libm); + rspamd_lua_new_class(L, "rspamd{tcp_sync}", tcp_sync_libm); + lua_pop(L, 1); +} |