summaryrefslogtreecommitdiffstats
path: root/src/aclk/mqtt_websockets/mqtt_wss_client.c
diff options
context:
space:
mode:
Diffstat (limited to 'src/aclk/mqtt_websockets/mqtt_wss_client.c')
-rw-r--r--src/aclk/mqtt_websockets/mqtt_wss_client.c148
1 files changed, 68 insertions, 80 deletions
diff --git a/src/aclk/mqtt_websockets/mqtt_wss_client.c b/src/aclk/mqtt_websockets/mqtt_wss_client.c
index f5b4025d7..2d231ef44 100644
--- a/src/aclk/mqtt_websockets/mqtt_wss_client.c
+++ b/src/aclk/mqtt_websockets/mqtt_wss_client.c
@@ -1,5 +1,4 @@
// SPDX-License-Identifier: GPL-3.0-only
-// Copyright (C) 2020 Timotej Šiškovič
#ifndef _GNU_SOURCE
#define _GNU_SOURCE
@@ -19,9 +18,6 @@
#include <sys/socket.h>
#include <netinet/in.h>
-#include <arpa/inet.h>
-#include <netinet/tcp.h> //TCP_NODELAY
-#include <netdb.h>
#include <openssl/err.h>
#include <openssl/ssl.h>
@@ -107,8 +103,6 @@ struct mqtt_wss_client_struct {
int mqtt_keepalive;
- pthread_mutex_t pub_lock;
-
// signifies that we didn't write all MQTT wanted
// us to write during last cycle (e.g. due to buffer
// size) and thus we should arm POLLOUT
@@ -121,7 +115,7 @@ struct mqtt_wss_client_struct {
void (*msg_callback)(const char *, const void *, size_t, int);
void (*puback_callback)(uint16_t packet_id);
- pthread_mutex_t stat_lock;
+ SPINLOCK stat_lock;
struct mqtt_wss_stats stats;
#ifdef MQTT_WSS_DEBUG
@@ -173,14 +167,13 @@ mqtt_wss_client mqtt_wss_new(const char *log_prefix,
SSL_library_init();
SSL_load_error_strings();
- mqtt_wss_client client = mw_calloc(1, sizeof(struct mqtt_wss_client_struct));
+ mqtt_wss_client client = callocz(1, sizeof(struct mqtt_wss_client_struct));
if (!client) {
mws_error(log, "OOM alocating mqtt_wss_client");
goto fail;
}
- pthread_mutex_init(&client->pub_lock, NULL);
- pthread_mutex_init(&client->stat_lock, NULL);
+ spinlock_init(&client->stat_lock);
client->msg_callback = msg_callback;
client->puback_callback = puback_callback;
@@ -229,7 +222,7 @@ fail_3:
fail_2:
ws_client_destroy(client->ws_client);
fail_1:
- mw_free(client);
+ freez(client);
fail:
mqtt_wss_log_ctx_destroy(log);
return NULL;
@@ -253,12 +246,15 @@ void mqtt_wss_destroy(mqtt_wss_client client)
// as it "borrows" this pointer and might use it
if (client->target_host == client->host)
client->target_host = NULL;
+
if (client->target_host)
- mw_free(client->target_host);
+ freez(client->target_host);
+
if (client->host)
- mw_free(client->host);
- mw_free(client->proxy_passwd);
- mw_free(client->proxy_uname);
+ freez(client->host);
+
+ freez(client->proxy_passwd);
+ freez(client->proxy_uname);
if (client->ssl)
SSL_free(client->ssl);
@@ -269,11 +265,8 @@ void mqtt_wss_destroy(mqtt_wss_client client)
if (client->sockfd > 0)
close(client->sockfd);
- pthread_mutex_destroy(&client->pub_lock);
- pthread_mutex_destroy(&client->stat_lock);
-
mqtt_wss_log_ctx_destroy(client->log);
- mw_free(client);
+ freez(client);
}
static int cert_verify_callback(int preverify_ok, X509_STORE_CTX *ctx)
@@ -298,7 +291,7 @@ static int cert_verify_callback(int preverify_ok, X509_STORE_CTX *ctx)
mws_error(client->log, "verify error:num=%d:%s:depth=%d:%s", err,
X509_verify_cert_error_string(err), depth, err_str);
- mw_free(err_str);
+ freez(err_str);
}
if (!preverify_ok && err == X509_V_ERR_DEPTH_ZERO_SELF_SIGNED_CERT &&
@@ -362,14 +355,14 @@ static int http_parse_reply(mqtt_wss_client client, rbuf_t buf)
}
if (http_code != 200) {
- ptr = mw_malloc(idx + 1);
+ ptr = mallocz(idx + 1);
if (!ptr)
return 6;
rbuf_pop(buf, ptr, idx);
ptr[idx] = 0;
mws_error(client->log, "http_proxy returned error code %d \"%s\"", http_code, ptr);
- mw_free(ptr);
+ freez(ptr);
return 7;
}/* else
rbuf_bump_tail(buf, idx);*/
@@ -450,7 +443,7 @@ static int http_proxy_connect(mqtt_wss_client client)
if (client->proxy_uname) {
size_t creds_plain_len = strlen(client->proxy_uname) + strlen(client->proxy_passwd) + 2;
- char *creds_plain = mw_malloc(creds_plain_len);
+ char *creds_plain = mallocz(creds_plain_len);
if (!creds_plain) {
mws_error(client->log, "OOM creds_plain");
rc = 6;
@@ -460,9 +453,9 @@ static int http_proxy_connect(mqtt_wss_client client)
// OpenSSL encoder puts newline every 64 output bytes
// we remove those but during encoding we need that space in the buffer
creds_base64_len += (1+(creds_base64_len/64)) * strlen("\n");
- char *creds_base64 = mw_malloc(creds_base64_len + 1);
+ char *creds_base64 = mallocz(creds_base64_len + 1);
if (!creds_base64) {
- mw_free(creds_plain);
+ freez(creds_plain);
mws_error(client->log, "OOM creds_base64");
rc = 6;
goto cleanup;
@@ -475,12 +468,12 @@ static int http_proxy_connect(mqtt_wss_client client)
int b64_len;
base64_encode_helper((unsigned char*)creds_base64, &b64_len, (unsigned char*)creds_plain, strlen(creds_plain));
- mw_free(creds_plain);
+ freez(creds_plain);
r_buf_ptr = rbuf_get_linear_insert_range(r_buf, &r_buf_linear_insert_capacity);
snprintf(r_buf_ptr, r_buf_linear_insert_capacity,"Proxy-Authorization: Basic %s" HTTP_ENDLINE, creds_base64);
write(client->sockfd, r_buf_ptr, strlen(r_buf_ptr));
- mw_free(creds_base64);
+ freez(creds_base64);
}
write(client->sockfd, HTTP_ENDLINE, strlen(HTTP_ENDLINE));
@@ -523,15 +516,14 @@ cleanup:
return rc;
}
-int mqtt_wss_connect(mqtt_wss_client client, char *host, int port, struct mqtt_connect_params *mqtt_params, int ssl_flags, struct mqtt_wss_proxy *proxy)
+int mqtt_wss_connect(
+ mqtt_wss_client client,
+ char *host,
+ int port,
+ struct mqtt_connect_params *mqtt_params,
+ int ssl_flags,
+ struct mqtt_wss_proxy *proxy)
{
- struct sockaddr_in addr;
- memset(&addr, 0, sizeof(addr));
- addr.sin_family = AF_INET;
-
- struct hostent *he;
- struct in_addr **addr_list;
-
if (!mqtt_params) {
mws_error(client->log, "mqtt_params can't be null!");
return -1;
@@ -545,23 +537,35 @@ int mqtt_wss_connect(mqtt_wss_client client, char *host, int port, struct mqtt_c
if (client->target_host == client->host)
client->target_host = NULL;
+
if (client->target_host)
- mw_free(client->target_host);
+ freez(client->target_host);
+
if (client->host)
- mw_free(client->host);
+ freez(client->host);
+
+ if (client->proxy_uname) {
+ freez(client->proxy_uname);
+ client->proxy_uname = NULL;
+ }
+
+ if (client->proxy_passwd) {
+ freez(client->proxy_passwd);
+ client->proxy_passwd = NULL;
+ }
if (proxy && proxy->type != MQTT_WSS_DIRECT) {
- client->host = mw_strdup(proxy->host);
+ client->host = strdupz(proxy->host);
client->port = proxy->port;
- client->target_host = mw_strdup(host);
+ client->target_host = strdupz(host);
client->target_port = port;
client->proxy_type = proxy->type;
if (proxy->username)
- client->proxy_uname = mw_strdup(proxy->username);
+ client->proxy_uname = strdupz(proxy->username);
if (proxy->password)
- client->proxy_passwd = mw_strdup(proxy->password);
+ client->proxy_passwd = strdupz(proxy->password);
} else {
- client->host = mw_strdup(host);
+ client->host = strdupz(host);
client->port = port;
client->target_host = client->host;
client->target_port = port;
@@ -569,30 +573,19 @@ int mqtt_wss_connect(mqtt_wss_client client, char *host, int port, struct mqtt_c
client->ssl_flags = ssl_flags;
- //TODO gethostbyname -> getaddinfo
- // hstrerror -> gai_strerror
- if ((he = gethostbyname(client->host)) == NULL) {
- mws_error(client->log, "gethostbyname() error \"%s\"", hstrerror(h_errno));
- return -1;
- }
-
- addr_list = (struct in_addr **)he->h_addr_list;
- if(!addr_list[0]) {
- mws_error(client->log, "No IP addr resolved");
- return -1;
- }
- mws_debug(client->log, "Resolved IP: %s", inet_ntoa(*addr_list[0]));
- addr.sin_addr = *addr_list[0];
- addr.sin_port = htons(client->port);
-
if (client->sockfd > 0)
close(client->sockfd);
- client->sockfd = socket(AF_INET, SOCK_STREAM | DEFAULT_SOCKET_FLAGS, 0);
- if (client->sockfd < 0) {
- mws_error(client->log, "Couldn't create socket()");
- return -1;
+
+ char port_str[16];
+ snprintf(port_str, sizeof(port_str) -1, "%d", client->port);
+ int fd = connect_to_this_ip46(IPPROTO_TCP, SOCK_STREAM, client->host, 0, port_str, NULL);
+ if (fd < 0) {
+ mws_error(client->log, "Could not connect to remote endpoint \"%s\", port %d.\n", client->host, port);
+ return -3;
}
+ client->sockfd = fd;
+
#ifndef SOCK_CLOEXEC
int flags = fcntl(client->sockfd, F_GETFD);
if (flags != -1)
@@ -600,19 +593,10 @@ int mqtt_wss_connect(mqtt_wss_client client, char *host, int port, struct mqtt_c
#endif
int flag = 1;
- int result = setsockopt(client->sockfd,
- IPPROTO_TCP,
- TCP_NODELAY,
- &flag,
- sizeof(int));
+ int result = setsockopt(client->sockfd, IPPROTO_TCP, TCP_NODELAY, &flag, sizeof(int));
if (result < 0)
mws_error(client->log, "Could not dissable NAGLE");
- if (connect(client->sockfd, (struct sockaddr *)&addr, sizeof(addr)) < 0) {
- mws_error(client->log, "Could not connect to remote endpoint \"%s\", port %d.\n", client->host, client->port);
- return -3;
- }
-
client->poll_fds[POLLFD_SOCKET].fd = client->sockfd;
if (fcntl(client->sockfd, F_SETFL, fcntl(client->sockfd, F_GETFL, 0) | O_NONBLOCK) == -1) {
@@ -640,6 +624,7 @@ int mqtt_wss_connect(mqtt_wss_client client, char *host, int port, struct mqtt_c
// free SSL structs from possible previous connections
if (client->ssl)
SSL_free(client->ssl);
+
if (client->ssl_ctx)
SSL_CTX_free(client->ssl_ctx);
@@ -675,6 +660,7 @@ int mqtt_wss_connect(mqtt_wss_client client, char *host, int port, struct mqtt_c
mws_error(client->log, "SSL could not connect");
return -5;
}
+
if (result == -1) {
int ec = SSL_get_error(client->ssl, result);
if (ec != SSL_ERROR_WANT_READ && ec != SSL_ERROR_WANT_WRITE) {
@@ -693,14 +679,16 @@ int mqtt_wss_connect(mqtt_wss_client client, char *host, int port, struct mqtt_c
auth.username_free = NULL;
auth.password = (char*)mqtt_params->password;
auth.password_free = NULL;
+
struct mqtt_lwt_properties lwt;
lwt.will_topic = (char*)mqtt_params->will_topic;
lwt.will_topic_free = NULL;
lwt.will_message = (void*)mqtt_params->will_msg;
lwt.will_message_free = NULL; // TODO expose no copy version to API
lwt.will_message_size = mqtt_params->will_msg_len;
- lwt.will_qos = (mqtt_params->will_flags & MQTT_WSS_PUB_QOSMASK);
- lwt.will_retain = mqtt_params->will_flags & MQTT_WSS_PUB_RETAIN;
+ lwt.will_qos = (int) (mqtt_params->will_flags & MQTT_WSS_PUB_QOSMASK);
+ lwt.will_retain = (int) mqtt_params->will_flags & MQTT_WSS_PUB_RETAIN;
+
int ret = mqtt_ng_connect(client->mqtt, &auth, mqtt_params->will_msg ? &lwt : NULL, 1, client->mqtt_keepalive);
if (ret) {
mws_error(client->log, "Error generating MQTT connect");
@@ -955,9 +943,9 @@ int mqtt_wss_service(mqtt_wss_client client, int timeout_ms)
#ifdef DEBUG_ULTRA_VERBOSE
mws_debug(client->log, "SSL_Read: Read %d.", ret);
#endif
- pthread_mutex_lock(&client->stat_lock);
+ spinlock_lock(&client->stat_lock);
client->stats.bytes_rx += ret;
- pthread_mutex_unlock(&client->stat_lock);
+ spinlock_unlock(&client->stat_lock);
rbuf_bump_head(client->ws_client->buf_read, ret);
} else {
int errnobkp = errno;
@@ -1023,9 +1011,9 @@ int mqtt_wss_service(mqtt_wss_client client, int timeout_ms)
#ifdef DEBUG_ULTRA_VERBOSE
mws_debug(client->log, "SSL_Write: Written %d of avail %d.", ret, size);
#endif
- pthread_mutex_lock(&client->stat_lock);
+ spinlock_lock(&client->stat_lock);
client->stats.bytes_tx += ret;
- pthread_mutex_unlock(&client->stat_lock);
+ spinlock_unlock(&client->stat_lock);
rbuf_bump_tail(client->ws_client->buf_write, ret);
} else {
int errnobkp = errno;
@@ -1115,10 +1103,10 @@ int mqtt_wss_subscribe(mqtt_wss_client client, char *topic, int max_qos_level)
struct mqtt_wss_stats mqtt_wss_get_stats(mqtt_wss_client client)
{
struct mqtt_wss_stats current;
- pthread_mutex_lock(&client->stat_lock);
+ spinlock_lock(&client->stat_lock);
current = client->stats;
memset(&client->stats, 0, sizeof(client->stats));
- pthread_mutex_unlock(&client->stat_lock);
+ spinlock_unlock(&client->stat_lock);
mqtt_ng_get_stats(client->mqtt, &current.mqtt);
return current;
}