diff --git a/platform/linux.c b/platform/linux.c index 68aa1cb..bffeade 100644 --- a/platform/linux.c +++ b/platform/linux.c @@ -1,46 +1,87 @@ -#include -#include -#include -#include -#include +#include +#include +#include +#include +#include +#include +#include +#include +#include #include #include +#include "mqtt_internal.h" #include "platform.h" const size_t max_receive_buffer_size = 4 * 4096; // 16 KB +#define MAX_TASKS 16 + struct _PlatformData { - pthread_t read_thread; + pthread_t tasks[MAX_TASKS]; + int sock; }; -void initialize_platform(MQTTHandle *handle) { +PlatformStatusCode platform_init(MQTTHandle *handle) { handle->platform = calloc(1, sizeof(struct _PlatformData)); -} - -MQTTStatus run_read_task(MQTTHandle *handle, Reader reader) { - if (pthread_create(&handle->platform->read_thread, NULL, (void *(*)(void *))reader, (void *)handle)) { - return MQTT_STATUS_ERROR; + handle->platform->sock = -1; + if (handle->platform) { + return PlatformStatusOk; } - return MQTT_STATUS_OK; + return PlatformStatusError; } -MQTTStatus join_read_task(MQTTHandle *handle) { - if (handle->platform->read_thread) { - pthread_join(handle->platform->read_thread, NULL); - handle->platform->read_thread = 0; +PlatformStatusCode platform_release(MQTTHandle *handle) { + PlatformData *p = handle->platform; + + for (uint8_t free_task = 0; free_task < MAX_TASKS; free_task++) { + if (p->tasks[free_task] != 0) { + DEBUG_LOG("Cannot free platform handle, there are tasks running!"); + return PlatformStatusError; + } } - return MQTT_STATUS_OK; -} -void release_platform(MQTTHandle *handle) { free(handle->platform); + return PlatformStatusOk; } +PlatformStatusCode platform_run_task(MQTTHandle *handle, int *task_handle, PlatformTask callback) { + PlatformData *p = handle->platform; + uint8_t free_task = 0; -bool hostname_to_ip(char *hostname , char *ip) { + for (free_task = 0; free_task < MAX_TASKS; free_task++) { + if (p->tasks[free_task] == 0) { + break; + } + } + if (free_task == MAX_TASKS) { + return PlatformStatusError; + } + + if (pthread_create(&p->tasks[free_task], NULL, (void *(*)(void *))callback, (void *)handle)) { + return PlatformStatusError; + } + + return PlatformStatusOk; +} + +PlatformStatusCode platform_cleanup_task(MQTTHandle *handle, int task_handle) { + PlatformData *p = handle->platform; + + if ((task_handle < 0) || (task_handle >= MAX_TASKS)) { + return PlatformStatusError; + } + + if (p->tasks[task_handle]) { + pthread_join(p->tasks[task_handle], NULL); + p->tasks[task_handle] = 0; + } + return PlatformStatusOk; +} + +PlatformStatusCode platform_resolve_host(char *hostname , char *ip) { struct addrinfo hints, *servinfo; struct sockaddr_in *h; @@ -50,8 +91,8 @@ bool hostname_to_ip(char *hostname , char *ip) { int ret = getaddrinfo(hostname, NULL, &hints, &servinfo); if (ret != 0) { - // fprintf(stderr, "getaddrinfo: %s\n", gai_strerror(rv)); - return false; + DEBUG_LOG("Resolving host failed: %s", gai_strerror(ret)); + return PlatformStatusError; } // loop through all the results and connect to the first we can @@ -61,5 +102,99 @@ bool hostname_to_ip(char *hostname , char *ip) { } freeaddrinfo(servinfo); // all done with this structure - return true; + return PlatformStatusOk; +} + +PlatformStatusCode platform_connect(MQTTHandle *handle) { + PlatformData *p = handle->platform; + + int ret; + struct sockaddr_in servaddr; + memset(&servaddr, 0, sizeof(servaddr)); + + p->sock = socket(AF_INET, SOCK_STREAM, 0); + servaddr.sin_family = AF_INET; + servaddr.sin_port = htons(handle->config->port); + + char ip[40]; + if (platform_resolve_host(handle->config->hostname, ip) != PlatformStatusOk) { + bool free_handle = handle->error_handler(handle, handle->config, MQTT_Error_Host_Not_Found); + if (free_handle) { + mqtt_free(handle); + } + DEBUG_LOG("Resolving hostname failed: %s", strerror(errno)); + close(p->sock); + return PlatformStatusError; + } + ret = inet_pton(AF_INET, ip, &(servaddr.sin_addr)); + if (ret == 0) { + bool free_handle = handle->error_handler(handle, handle->config, MQTT_Error_Host_Not_Found); + if (free_handle) { + mqtt_free(handle); + } + DEBUG_LOG("Converting to servaddr failed: %s", strerror(errno)); + close(p->sock); + return PlatformStatusError; + } + + ret = connect(p->sock, (struct sockaddr *)&servaddr, sizeof(servaddr)); + if (ret != 0) { + bool free_handle = handle->error_handler(handle, handle->config, MQTT_Error_Connection_Refused); + if (free_handle) { + mqtt_free(handle); + } + DEBUG_LOG("Connection failed: %s", strerror(errno)); + close(p->sock); + return PlatformStatusError; + } + + return PlatformStatusOk; +} + +PlatformStatusCode platform_read(MQTTHandle *handle, Buffer *buffer) { + PlatformData *p = handle->platform; + + while (1) { + ssize_t num_bytes = read(p->sock, &buffer->data[buffer->position], buffer_free_space(buffer)); + if (num_bytes == 0) { + /* Socket closed, coordinated shutdown */ + DEBUG_LOG("Socket closed"); + return PlatformStatusError; + } else if (num_bytes < 0) { + if ((errno == EINTR) || (errno == EAGAIN)) { + continue; + } + + /* Set reader task to dead */ + handle->reader_alive = false; + return PlatformStatusError; + } + + buffer->position += num_bytes; + return PlatformStatusOk; + } +} + +PlatformStatusCode platform_write(MQTTHandle *handle, Buffer *buffer) { + PlatformData *p = handle->platform; + + while (!buffer_eof(buffer)) { + ssize_t bytes = write(p->sock, buffer->data + buffer->position, buffer_free_space(buffer)); + if (bytes <= 0) { + return PlatformStatusError; + } + buffer->position += bytes; + } + + return PlatformStatusOk; +} + +PlatformStatusCode platform_disconnect(MQTTHandle *handle) { + PlatformData *p = handle->platform; + if (p->sock >= 0) { + close(p->sock); + p->sock = -1; + } + + return PlatformStatusOk; } diff --git a/platform/platform.h b/platform/platform.h index e70e5e3..e47801e 100644 --- a/platform/platform.h +++ b/platform/platform.h @@ -3,41 +3,95 @@ #include "mqtt_internal.h" -typedef void (*Reader)(MQTTHandle *handle); +typedef void (*PlatformTask)(MQTTHandle *handle); /** maximum receiver buffer size, defined by platform */ extern const size_t max_receive_buffer_size; +typedef enum { + PlatformStatusOk, /**< Everything ok */ + PlatformStatusError, /**< Non-recoverable error */ + PlatformStatusRetry /**< Recoverable error */ +} PlatformStatusCode; -bool hostname_to_ip(char *hostname, char *ip); /** * Initialize platform specific data * * @param handle: The handle to initialize + * @return Platform status code */ -void initialize_platform(MQTTHandle *handle); +PlatformStatusCode platform_init(MQTTHandle *handle); + +/** + * Platform specific function to release resources associated with a MQTTHandle + * + * @param handle: The handle to clean up + * @return Platform status code + */ +PlatformStatusCode platform_release(MQTTHandle *handle); /** * Platform specific function to start a reading thread * * @param handle: The broker connection handle - * @param reader: callback to run in the thread + * @param task_handle: Task handle output + * @param callback: callback to run in the thread + * @return Platform status code */ -MQTTStatus run_read_task(MQTTHandle *handle, Reader reader); +PlatformStatusCode platform_run_task(MQTTHandle *handle, int *task_handle, PlatformTask callback); /** * Platform specific function to clean up the reading thread - * - * @param handle: The broker connection handle + * + * @param handle: State handle + * @param task_handle: Task handle to clean up + * @return Platform status code */ -MQTTStatus join_read_task(MQTTHandle *handle); +PlatformStatusCode platform_cleanup_task(MQTTHandle *handle, int task_handle); /** - * Platform specific function to release resources associated with a MQTTHandle - * - * @param handle: The handle to clean up + * Resolve host + * + * @param hostname: Hostname to resolve + * @param ip_out: resulting IP address if no error occured + * @return Platform status code */ -void release_platform(MQTTHandle *handle); +PlatformStatusCode platform_resolve_host(char *hostname, char *ip_out); + +/** + * Connect to host from configuration + * + * @param handle: The configuration + * @return Platform status code + */ +PlatformStatusCode platform_connect(MQTTHandle *handle); + +/** + * Read from the "socket" in the handle + * + * @param handle: State handle + * @param buffer: Read target + * @return Platform status code + */ +PlatformStatusCode platform_read(MQTTHandle *handle, Buffer *buffer); + +/** + * Write to the "socket" in the handle + * + * @param handle: State handle + * @param buffer: Write source + * @return Platform status code + */ +PlatformStatusCode platform_write(MQTTHandle *handle, Buffer *buffer); + +/** + * Disconnect the "socket" in the handle + * + * @param handle: State handle + * @return Platform status code + */ +PlatformStatusCode platform_disconnect(MQTTHandle *handle); + #endif /* platform_h__included */ diff --git a/src/mqtt.c b/src/mqtt.c index 17a68ce..29807b5 100644 --- a/src/mqtt.c +++ b/src/mqtt.c @@ -14,16 +14,11 @@ #include "protocol.h" #include "debug.h" -static inline void mqtt_free(MQTTHandle *handle) { - release_platform(handle); +void mqtt_free(MQTTHandle *handle) { + platform_release(handle); free(handle); } -static inline void disconnect(MQTTHandle *handle) { - close(handle->sock); - // FIXME: Do we have to do anything else? -} - static inline void parse_packet(MQTTHandle *handle, MQTTPacket *packet) { switch (packet->packet_type) { case PacketTypeConnAck: @@ -35,7 +30,7 @@ static inline void parse_packet(MQTTHandle *handle, MQTTPacket *packet) { case PacketTypeUnsubAck: if (!dispatch_packet(handle, packet)) { DEBUG_LOG("Unexpected packet! (type: %s, packet_id: %d)", get_packet_name(packet), get_packet_id(packet)); - disconnect(handle); + (void)platform_disconnect(handle); } break; @@ -55,70 +50,61 @@ static inline void parse_packet(MQTTHandle *handle, MQTTPacket *packet) { case PacketTypePingReq: case PacketTypeDisconnect: DEBUG_LOG("Server packet on client connection? What's up with the broker?"); - disconnect(handle); + (void)platform_disconnect(handle); break; } } static void _reader(MQTTHandle *handle) { - ssize_t num_bytes; - char *read_buffer = malloc(max_receive_buffer_size); - uint8_t offset = 0; + Buffer *buffer = buffer_allocate(max_receive_buffer_size); handle->reader_alive = true; while (1) { - num_bytes = read(handle->sock, &read_buffer[offset], max_receive_buffer_size - offset); - if (num_bytes == 0) { - /* Socket closed, coordinated shutdown */ - DEBUG_LOG("Socket closed"); - handle->reader_alive = false; - return; - } else if (num_bytes < 0) { - if ((errno == EINTR) || (errno == EAGAIN)) { - continue; - } - - /* Set reader task to dead */ + PlatformStatusCode ret = platform_read(handle, buffer); + if (ret == PlatformStatusError) { handle->reader_alive = false; return; } while (1) { - Buffer *buffer = buffer_from_data_no_copy(read_buffer, num_bytes); + buffer->len = buffer->position; + buffer->position = 0; + MQTTPacket *packet = mqtt_packet_decode(buffer); if (packet == NULL) { // invalid packet - if (num_bytes < max_receive_buffer_size) { - // maybe not long enough, try to fetch the rest - offset += num_bytes; - free(buffer); + if (buffer_free_space(buffer) > 0) { + // half packet, fetch more + buffer->position = buffer->len; + buffer->len = max_receive_buffer_size; break; } else { // no space in buffer, bail and reconnect DEBUG_LOG("Buffer overflow!"); - disconnect(handle); + platform_disconnect(handle); handle->reader_alive = false; - free(buffer); + buffer_release(buffer); return; } } else { - // hexdump(buffer->data, num_bytes, 2); - + hexdump(buffer->data, num_bytes, 2); parse_packet(handle, packet); free_MQTTPacket(packet); if (!buffer_eof(buffer)) { + buffer->position = buffer->len; + buffer->len = max_receive_buffer_size; + // Not complete recv buffer was consumed, so we have more than one packet in there size_t remaining = max_receive_buffer_size - buffer->position; - memmove(read_buffer, read_buffer + buffer->position, remaining); - offset -= buffer->position; - num_bytes -= buffer->position; - free(buffer); + memmove(buffer->data, buffer->data + buffer->position, remaining); + buffer->position = 0; + break; } else { // buffer consumed completely, read another chunk - offset = 0; - free(buffer); + buffer->position = 0; + buffer->len = max_receive_buffer_size; break; } } @@ -127,58 +113,30 @@ static void _reader(MQTTHandle *handle) { } static void _mqtt_connect(MQTTHandle *handle, MQTTEventHandler callback, void *context) { - int ret; - struct sockaddr_in servaddr; - memset(&servaddr, 0, sizeof(servaddr)); + PlatformStatusCode ret = platform_connect(handle); - handle->sock = socket(AF_INET, SOCK_STREAM, 0); - servaddr.sin_family = AF_INET; - servaddr.sin_port = htons(handle->config->port); - - char ip[40]; - if (!hostname_to_ip(handle->config->hostname, ip)) { - bool free_handle = handle->error_handler(handle, handle->config, MQTT_Error_Host_Not_Found); - if (free_handle) { - mqtt_free(handle); - } - DEBUG_LOG("Resolving hostname failed: %s", strerror(errno)); - close(handle->sock); - return; - } - ret = inet_pton(AF_INET, ip, &(servaddr.sin_addr)); - if (ret == 0) { - bool free_handle = handle->error_handler(handle, handle->config, MQTT_Error_Host_Not_Found); - if (free_handle) { - mqtt_free(handle); - } - DEBUG_LOG("Converting to servaddr failed: %s", strerror(errno)); - close(handle->sock); + if (ret == PlatformStatusError) { + DEBUG_LOG("Could not connect"); return; } - ret = connect(handle->sock, (struct sockaddr *)&servaddr, sizeof(servaddr)); - if (ret != 0) { - bool free_handle = handle->error_handler(handle, handle->config, MQTT_Error_Connection_Refused); - if (free_handle) { - mqtt_free(handle); - } - DEBUG_LOG("Connection failed: %s", strerror(errno)); - close(handle->sock); + ret = platform_run_task(handle, &handle->read_task_handle, _reader); + if (ret == PlatformStatusError) { + DEBUG_LOG("Could not start read task"); return; } - run_read_task(handle, _reader); - expect_packet(handle, PacketTypeConnAck, 0, callback, context); bool result = send_connect_packet(handle); if (result == false) { DEBUG_LOG("Sending connect packet failed, running error handler"); bool free_handle = handle->error_handler(handle, handle->config, MQTT_Error_Broker_Disconnected); + platform_disconnect(handle); if (free_handle) { + platform_cleanup_task(handle, handle->read_task_handle); mqtt_free(handle); } - close(handle->sock); } } @@ -190,7 +148,11 @@ MQTTHandle *mqtt_connect(MQTTConfig *config, MQTTEventHandler callback, void *co } MQTTHandle *handle = calloc(sizeof(struct _MQTTHandle), 1); - initialize_platform(handle); + PlatformStatusCode ret = platform_init(handle); + if (ret == PlatformStatusError) { + free(handle); + return NULL; + } if (config->port == 0) { config->port = 1883; @@ -208,8 +170,8 @@ MQTTHandle *mqtt_connect(MQTTConfig *config, MQTTEventHandler callback, void *co MQTTStatus mqtt_reconnect(MQTTHandle *handle, MQTTEventHandler callback, void *context) { if (handle->reader_alive) { DEBUG_LOG("Closing old connection"); - close(handle->sock); - join_read_task(handle); + platform_disconnect(handle); + platform_cleanup_task(handle, handle->read_task_handle); } // TODO: re-submit unacknowledged packages with QoS > 0 @@ -249,10 +211,8 @@ MQTTStatus mqtt_publish(MQTTHandle *handle, char *topic, char *payload, MQTTQosL MQTTStatus mqtt_disconnect(MQTTHandle *handle, MQTTEventHandler callback, void *callback_context) { send_disconnect_packet(handle); - if (close(handle->sock)) { - return MQTT_STATUS_ERROR; - } - join_read_task(handle); + platform_disconnect(handle); + platform_cleanup_task(handle, handle->read_task_handle); mqtt_free(handle); if (callback) { diff --git a/src/mqtt_internal.h b/src/mqtt_internal.h index 972487f..fefeb92 100644 --- a/src/mqtt_internal.h +++ b/src/mqtt_internal.h @@ -14,8 +14,8 @@ struct _MQTTHandle { MQTTErrorHandler error_handler; Subscriptions subscriptions; - int sock; bool reader_alive; + int read_task_handle; uint16_t packet_id_counter; @@ -23,4 +23,6 @@ struct _MQTTHandle { PlatformData *platform; }; +void mqtt_free(MQTTHandle *handle); + #endif /* mqtt_internal_h__included */ diff --git a/src/protocol.c b/src/protocol.c index c685c34..d6ca975 100644 --- a/src/protocol.c +++ b/src/protocol.c @@ -1,5 +1,6 @@ #include #include +#include #include "mqtt_internal.h" #include "packet.h" @@ -17,16 +18,9 @@ typedef struct { */ bool send_buffer(MQTTHandle *handle, Buffer *buffer) { - while (!buffer_eof(buffer)) { - ssize_t bytes = write(handle->sock, buffer->data + buffer->position, buffer_free_space(buffer)); - if (bytes <= 0) { - buffer_release(buffer); - return false; - } - buffer->position += bytes; - } + PlatformStatusCode ret = platform_write(handle, buffer); buffer_release(buffer); - return true; + return (ret == PlatformStatusOk); } /*