diff options
author | Tyler Murphy <tylerm@tylerm.dev> | 2023-04-18 11:01:56 -0400 |
---|---|---|
committer | Tyler Murphy <tylerm@tylerm.dev> | 2023-04-18 11:01:56 -0400 |
commit | 5eef60aa1a6ac9e174708f19a154ce9660589b6d (patch) | |
tree | 649e238dc905b7071efc9e4a06b61bcb286ee8a0 /src | |
parent | fix makefile (diff) | |
download | wrapper-main.tar.gz wrapper-main.tar.bz2 wrapper-main.zip |
Diffstat (limited to 'src')
-rw-r--r-- | src/main.c | 1 | ||||
-rw-r--r-- | src/packet/buffer.c | 30 | ||||
-rw-r--r-- | src/packet/record.c | 19 | ||||
-rw-r--r-- | src/server/addr.c | 31 | ||||
-rw-r--r-- | src/server/resolver.c | 2 |
5 files changed, 63 insertions, 20 deletions
@@ -1,7 +1,6 @@ #include "server/server.h" #include <stdlib.h> -#include <sys/select.h> #define DEFAULT_PORT 53 diff --git a/src/packet/buffer.c b/src/packet/buffer.c index 28dd73b..b609476 100644 --- a/src/packet/buffer.c +++ b/src/packet/buffer.c @@ -14,6 +14,7 @@ struct PacketBuffer { PacketBuffer* buffer_create(int capacity) { PacketBuffer* buffer = malloc(sizeof(PacketBuffer)); buffer->arr = malloc(capacity); + memset(buffer->arr, 0, capacity); buffer->capacity = capacity; buffer->size = 0; buffer->index = 0; @@ -30,7 +31,7 @@ void buffer_seek(PacketBuffer* buffer, int index) { } uint8_t buffer_read(PacketBuffer* buffer) { - if (buffer->index > buffer->size) { + if (buffer->index >= buffer->size) { return 0; } uint8_t data = buffer->arr[buffer->index]; @@ -53,7 +54,7 @@ uint32_t buffer_read_int(PacketBuffer* buffer) { } uint8_t buffer_get(PacketBuffer* buffer, int index) { - if (index > buffer->size) { + if (index >= buffer->size) { return 0; } uint8_t data = buffer->arr[index]; @@ -73,8 +74,12 @@ uint16_t buffer_get_size(PacketBuffer* buffer) { } static void write(uint8_t** buffer, uint8_t* size, uint8_t* capacity, uint8_t data) { - if (*size == *capacity) { - *capacity *= 2; + if (*size >= *capacity) { + if (*capacity >= 128) { + *capacity = 255; + } else { + *capacity *= 2; + } *buffer = realloc(*buffer, *capacity); } (*buffer)[*size] = data; @@ -144,7 +149,16 @@ void buffer_read_string(PacketBuffer* buffer, uint8_t** out) { buffer_read_n(buffer, out, len); } +static void buffer_expand(PacketBuffer* buffer, int capacity) { + if (buffer->capacity >= capacity) return; + + buffer->arr = realloc(buffer->arr, capacity); + memset(buffer->arr + buffer->capacity, 0, capacity - buffer->capacity); + buffer->capacity = capacity; +} + void buffer_read_n(PacketBuffer* buffer, uint8_t** out, uint8_t len) { + buffer_expand(buffer, buffer->index + len + 1); *out = malloc(len + 1); *out[0] = len; memcpy(*out + 1, buffer->arr + buffer->index, len); @@ -152,7 +166,7 @@ void buffer_read_n(PacketBuffer* buffer, uint8_t** out, uint8_t len) { } void buffer_write(PacketBuffer* buffer, uint8_t data) { - if(buffer->index == buffer->capacity) { + if(buffer->index >= buffer->capacity) { buffer->capacity *= 2; buffer->arr = realloc(buffer->arr, buffer->capacity); } @@ -205,11 +219,7 @@ void buffer_write_string(PacketBuffer* buffer, uint8_t* in) { } void buffer_write_n(PacketBuffer* buffer, uint8_t* in, int len) { - if (buffer->size + len >= buffer->capacity) { - buffer->capacity *= 2; - buffer->capacity += len; - buffer->arr = realloc(buffer->arr, buffer->capacity); - } + buffer_expand(buffer, buffer->index + len + 1); memcpy(buffer->arr + buffer->index, in, len); buffer->size += len; buffer->index += len; diff --git a/src/packet/record.c b/src/packet/record.c index 53934ed..c676bce 100644 --- a/src/packet/record.c +++ b/src/packet/record.c @@ -140,17 +140,28 @@ static void read_txt_record(PacketBuffer* buffer, Record* record) { TXTRecord data; data.len = 0; data.text = malloc(sizeof(uint8_t*) * 2); - uint8_t capacity = 2; + uint8_t total = record->len; while (1) { if (data.len >= capacity) { - capacity *= 2; + if (capacity >= 128) { + capacity = 255; + } else { + capacity *= 2; + } data.text = realloc(data.text, sizeof(uint8_t*) * capacity); } - + buffer_read_string(buffer, &data.text[data.len]); - if(data.text[data.len][0] == 0) break; + if(data.text[data.len][0] == 0) { + free(data.text[data.len]); + break; + } + data.len++; + + total -= data.text[data.len - 1][0] + 1; + if (total == 0) break; } record->data.txt = data; diff --git a/src/server/addr.c b/src/server/addr.c index c83b216..494b694 100644 --- a/src/server/addr.c +++ b/src/server/addr.c @@ -1,3 +1,4 @@ +#include <errno.h> #include <netinet/in.h> #include <stdio.h> #include <stdlib.h> @@ -159,8 +160,30 @@ int32_t write_udp_socket(UdpSocket* socket, void* buffer, uint16_t len, SocketAd ); } +static int get_socket_error(int fd) { + int err = 1; + socklen_t len = sizeof err; + if (-1 == getsockopt(fd, SOL_SOCKET, SO_ERROR, (char *)&err, &len)) + return 0; + if (err) + errno = err; + return err; +} + +static int close_socket(int fd) { + if (fd >= 0) { + get_socket_error(fd); // first clear any errors, which can cause close to fail + if (shutdown(fd, SHUT_RDWR) < 0) // secondly, terminate the 'reliable' delivery + if (errno != ENOTCONN && errno != EINVAL) // SGI causes EINVAL + perror("shutdown"); + if (close(fd) < 0) // finally call close() + perror("close"); + } + return 0; +} + int32_t close_udp_socket(UdpSocket* socket) { - return close(socket->sockfd); + return close_socket(socket->sockfd); } int32_t create_tcp_socket(AddrType type, TcpSocket* sock) { @@ -204,7 +227,7 @@ int32_t accept_tcp_socket(TcpSocket* socket, TcpStream* stream) { } int32_t close_tcp_socket(TcpSocket* socket) { - return close(socket->sockfd); + return close_socket(socket->sockfd); } int32_t connect_tcp_stream(SocketAddr* servaddr, TcpStream* stream) { @@ -222,7 +245,7 @@ int32_t connect_tcp_stream(SocketAddr* servaddr, TcpStream* stream) { } int32_t read_tcp_stream(TcpStream* stream, void* buffer, uint16_t len) { - return recv(stream->streamfd, buffer, len, 0); + return recv(stream->streamfd, buffer, len, MSG_WAITALL); } int32_t write_tcp_stream(TcpStream* stream, void* buffer, uint16_t len) { @@ -230,5 +253,5 @@ int32_t write_tcp_stream(TcpStream* stream, void* buffer, uint16_t len) { } int32_t close_tcp_stream(TcpStream* stream) { - return close(stream->streamfd); + return close_socket(stream->streamfd); } diff --git a/src/server/resolver.c b/src/server/resolver.c index c9e5246..67149bc 100644 --- a/src/server/resolver.c +++ b/src/server/resolver.c @@ -45,7 +45,7 @@ static bool lookup( static bool search(Question* question, Packet* result, BindingType type, const RecordMap* map) { IpAddr addr; - uint8_t ip[4] = {1, 1, 1, 1}; + uint8_t ip[4] = {9,9,9,9}; create_ip_addr(ip, &addr); uint16_t port = 53; |