diff --git a/buffer.c b/buffer.c new file mode 100644 index 0000000..28dd73b --- /dev/null +++ b/buffer.c @@ -0,0 +1,240 @@ +#include "buffer.h" + +#include +#include +#include + +struct PacketBuffer { + uint8_t* arr; + int capacity; + int index; + int size; +}; + +PacketBuffer* buffer_create(int capacity) { + PacketBuffer* buffer = malloc(sizeof(PacketBuffer)); + buffer->arr = malloc(capacity); + buffer->capacity = capacity; + buffer->size = 0; + buffer->index = 0; + return buffer; +} + +void buffer_free(PacketBuffer* buffer) { + free(buffer->arr); + free(buffer); +} + +void buffer_seek(PacketBuffer* buffer, int index) { + buffer->index = index; +} + +uint8_t buffer_read(PacketBuffer* buffer) { + if (buffer->index > buffer->size) { + return 0; + } + uint8_t data = buffer->arr[buffer->index]; + buffer->index++; + return data; +} + +uint16_t buffer_read_short(PacketBuffer* buffer) { + return + (uint16_t) buffer_read(buffer) << 8 | + (uint16_t) buffer_read(buffer); +} + +uint32_t buffer_read_int(PacketBuffer* buffer) { + return + (uint32_t) buffer_read(buffer) << 24 | + (uint32_t) buffer_read(buffer) << 16 | + (uint32_t) buffer_read(buffer) << 8 | + (uint32_t) buffer_read(buffer); +} + +uint8_t buffer_get(PacketBuffer* buffer, int index) { + if (index > buffer->size) { + return 0; + } + uint8_t data = buffer->arr[index]; + return data; +} + +uint8_t* buffer_get_range(PacketBuffer* buffer, int start, int len) { + uint8_t* arr = malloc(len); + for (int i = 0; i < len; i++) { + arr[i] = buffer_get(buffer, start + i); + } + return arr; +} + +uint16_t buffer_get_size(PacketBuffer* buffer) { + return (uint16_t) buffer->size + 1; +} + +static void write(uint8_t** buffer, uint8_t* size, uint8_t* capacity, uint8_t data) { + if (*size == *capacity) { + *capacity *= 2; + *buffer = realloc(*buffer, *capacity); + } + (*buffer)[*size] = data; + (*size)++; +} + +void buffer_read_qname(PacketBuffer* buffer, uint8_t** out) { + int index = buffer->index; + int jumped = 0; + + int max_jumps = 5; + int jumps_performed = 0; + + uint8_t length = 0; + uint8_t capacity = 8; + *out = malloc(capacity); + write(out, &length, &capacity, 0); + + while(1) { + if (jumps_performed > max_jumps) { + break; + } + + uint8_t len = buffer_get(buffer, index); + + if ((len & 0xC0) == 0xC0) { + if (jumped == 0) { + buffer_seek(buffer, index + 2); + } + + uint16_t b2 = (uint16_t) buffer_get(buffer, index + 1); + uint16_t offset = ((((uint16_t) len) ^ 0xC0) << 8) | b2; + index = (int) offset; + jumped = 1; + jumps_performed++; + continue; + } + + index++; + + if (len == 0) { + break; + } + + if (length > 1) { + write(out, &length, &capacity, '.'); + } + + uint8_t* range = buffer_get_range(buffer, index, len); + for (uint8_t i = 0; i < len; i++) { + write(out, &length, &capacity, range[i]); + } + free(range); + + index += (int) len; + } + + if (jumped == 0) { + buffer_seek(buffer, index); + } + + (*out)[0] = length - 1; +} + +void buffer_read_string(PacketBuffer* buffer, uint8_t** out) { + uint8_t len = buffer_read(buffer); + buffer_read_n(buffer, out, len); +} + +void buffer_read_n(PacketBuffer* buffer, uint8_t** out, uint8_t len) { + *out = malloc(len + 1); + *out[0] = len; + memcpy(*out + 1, buffer->arr + buffer->index, len); + buffer->index += len; +} + +void buffer_write(PacketBuffer* buffer, uint8_t data) { + if(buffer->index == buffer->capacity) { + buffer->capacity *= 2; + buffer->arr = realloc(buffer->arr, buffer->capacity); + } + if (buffer->size < buffer->index) { + buffer->size = buffer->index; + } + buffer->arr[buffer->index] = data; + buffer->index++; +} + +void buffer_write_short(PacketBuffer* buffer, uint16_t data) { + buffer_write(buffer, (uint8_t)(data >> 8)); + buffer_write(buffer, (uint8_t)(data & 0xFF)); +} + +void buffer_write_int(PacketBuffer* buffer, uint32_t data) { + buffer_write(buffer, (uint8_t)(data >> 24)); + buffer_write(buffer, (uint8_t)(data >> 16)); + buffer_write(buffer, (uint8_t)(data >> 8)); + buffer_write(buffer, (uint8_t)(data & 0xFF)); +} + +void buffer_write_qname(PacketBuffer* buffer, uint8_t* in) { + uint8_t part = 0; + uint8_t len = in[0]; + + buffer_write(buffer, 0); + + if (len == 0) { + return; + } + + for(uint8_t i = 0; i < len; i ++) { + if (in[i+1] == '.') { + buffer_set(buffer, part, buffer->index - (int)part - 1); + buffer_write(buffer, 0); + part = 0; + } else { + buffer_write(buffer, in[i+1]); + part++; + } + } + buffer_set(buffer, part, buffer->index - (int)part - 1); + buffer_write(buffer, 0); +} + +void buffer_write_string(PacketBuffer* buffer, uint8_t* in) { + buffer_write(buffer, in[0]); + buffer_write_n(buffer, in + 1, in[0]); +} + +void buffer_write_n(PacketBuffer* buffer, uint8_t* in, int len) { + if (buffer->size + len >= buffer->capacity) { + buffer->capacity *= 2; + buffer->capacity += len; + buffer->arr = realloc(buffer->arr, buffer->capacity); + } + memcpy(buffer->arr + buffer->index, in, len); + buffer->size += len; + buffer->index += len; +} + +void buffer_set(PacketBuffer* buffer, uint8_t data, int index) { + if (index > buffer->size) { + return; + } + buffer->arr[index] = data; +} + +void buffer_set_uint16_t(PacketBuffer* buffer, uint16_t data, int index) { + buffer_set(buffer, (uint8_t)(data >> 8), index); + buffer_set(buffer, (uint8_t)(data & 0xFF), index + 1); +} + +int buffer_get_index(PacketBuffer* buffer) { + return buffer->index; +} + +void buffer_step(PacketBuffer* buffer, int len) { + buffer->index += len; +} + +uint8_t* buffer_get_ptr(PacketBuffer* buffer) { + return buffer->arr; +} diff --git a/src/main.c b/src/main.c index 13dae57..b142fd5 100644 --- a/src/main.c +++ b/src/main.c @@ -1,7 +1,6 @@ #include "server/server.h" #include -#include #define DEFAULT_PORT 53 diff --git a/src/packet/buffer.c b/src/packet/buffer.c index 28dd73b..b609476 100644 --- a/src/packet/buffer.c +++ b/src/packet/buffer.c @@ -14,6 +14,7 @@ struct PacketBuffer { PacketBuffer* buffer_create(int capacity) { PacketBuffer* buffer = malloc(sizeof(PacketBuffer)); buffer->arr = malloc(capacity); + memset(buffer->arr, 0, capacity); buffer->capacity = capacity; buffer->size = 0; buffer->index = 0; @@ -30,7 +31,7 @@ void buffer_seek(PacketBuffer* buffer, int index) { } uint8_t buffer_read(PacketBuffer* buffer) { - if (buffer->index > buffer->size) { + if (buffer->index >= buffer->size) { return 0; } uint8_t data = buffer->arr[buffer->index]; @@ -53,7 +54,7 @@ uint32_t buffer_read_int(PacketBuffer* buffer) { } uint8_t buffer_get(PacketBuffer* buffer, int index) { - if (index > buffer->size) { + if (index >= buffer->size) { return 0; } uint8_t data = buffer->arr[index]; @@ -73,8 +74,12 @@ uint16_t buffer_get_size(PacketBuffer* buffer) { } static void write(uint8_t** buffer, uint8_t* size, uint8_t* capacity, uint8_t data) { - if (*size == *capacity) { - *capacity *= 2; + if (*size >= *capacity) { + if (*capacity >= 128) { + *capacity = 255; + } else { + *capacity *= 2; + } *buffer = realloc(*buffer, *capacity); } (*buffer)[*size] = data; @@ -144,7 +149,16 @@ void buffer_read_string(PacketBuffer* buffer, uint8_t** out) { buffer_read_n(buffer, out, len); } +static void buffer_expand(PacketBuffer* buffer, int capacity) { + if (buffer->capacity >= capacity) return; + + buffer->arr = realloc(buffer->arr, capacity); + memset(buffer->arr + buffer->capacity, 0, capacity - buffer->capacity); + buffer->capacity = capacity; +} + void buffer_read_n(PacketBuffer* buffer, uint8_t** out, uint8_t len) { + buffer_expand(buffer, buffer->index + len + 1); *out = malloc(len + 1); *out[0] = len; memcpy(*out + 1, buffer->arr + buffer->index, len); @@ -152,7 +166,7 @@ void buffer_read_n(PacketBuffer* buffer, uint8_t** out, uint8_t len) { } void buffer_write(PacketBuffer* buffer, uint8_t data) { - if(buffer->index == buffer->capacity) { + if(buffer->index >= buffer->capacity) { buffer->capacity *= 2; buffer->arr = realloc(buffer->arr, buffer->capacity); } @@ -205,11 +219,7 @@ void buffer_write_string(PacketBuffer* buffer, uint8_t* in) { } void buffer_write_n(PacketBuffer* buffer, uint8_t* in, int len) { - if (buffer->size + len >= buffer->capacity) { - buffer->capacity *= 2; - buffer->capacity += len; - buffer->arr = realloc(buffer->arr, buffer->capacity); - } + buffer_expand(buffer, buffer->index + len + 1); memcpy(buffer->arr + buffer->index, in, len); buffer->size += len; buffer->index += len; diff --git a/src/packet/record.c b/src/packet/record.c index 53934ed..c676bce 100644 --- a/src/packet/record.c +++ b/src/packet/record.c @@ -140,17 +140,28 @@ static void read_txt_record(PacketBuffer* buffer, Record* record) { TXTRecord data; data.len = 0; data.text = malloc(sizeof(uint8_t*) * 2); - uint8_t capacity = 2; + uint8_t total = record->len; while (1) { if (data.len >= capacity) { - capacity *= 2; + if (capacity >= 128) { + capacity = 255; + } else { + capacity *= 2; + } data.text = realloc(data.text, sizeof(uint8_t*) * capacity); } - + buffer_read_string(buffer, &data.text[data.len]); - if(data.text[data.len][0] == 0) break; + if(data.text[data.len][0] == 0) { + free(data.text[data.len]); + break; + } + data.len++; + + total -= data.text[data.len - 1][0] + 1; + if (total == 0) break; } record->data.txt = data; diff --git a/src/server/addr.c b/src/server/addr.c index c83b216..494b694 100644 --- a/src/server/addr.c +++ b/src/server/addr.c @@ -1,3 +1,4 @@ +#include #include #include #include @@ -159,8 +160,30 @@ int32_t write_udp_socket(UdpSocket* socket, void* buffer, uint16_t len, SocketAd ); } +static int get_socket_error(int fd) { + int err = 1; + socklen_t len = sizeof err; + if (-1 == getsockopt(fd, SOL_SOCKET, SO_ERROR, (char *)&err, &len)) + return 0; + if (err) + errno = err; + return err; +} + +static int close_socket(int fd) { + if (fd >= 0) { + get_socket_error(fd); // first clear any errors, which can cause close to fail + if (shutdown(fd, SHUT_RDWR) < 0) // secondly, terminate the 'reliable' delivery + if (errno != ENOTCONN && errno != EINVAL) // SGI causes EINVAL + perror("shutdown"); + if (close(fd) < 0) // finally call close() + perror("close"); + } + return 0; +} + int32_t close_udp_socket(UdpSocket* socket) { - return close(socket->sockfd); + return close_socket(socket->sockfd); } int32_t create_tcp_socket(AddrType type, TcpSocket* sock) { @@ -204,7 +227,7 @@ int32_t accept_tcp_socket(TcpSocket* socket, TcpStream* stream) { } int32_t close_tcp_socket(TcpSocket* socket) { - return close(socket->sockfd); + return close_socket(socket->sockfd); } int32_t connect_tcp_stream(SocketAddr* servaddr, TcpStream* stream) { @@ -222,7 +245,7 @@ int32_t connect_tcp_stream(SocketAddr* servaddr, TcpStream* stream) { } int32_t read_tcp_stream(TcpStream* stream, void* buffer, uint16_t len) { - return recv(stream->streamfd, buffer, len, 0); + return recv(stream->streamfd, buffer, len, MSG_WAITALL); } int32_t write_tcp_stream(TcpStream* stream, void* buffer, uint16_t len) { @@ -230,5 +253,5 @@ int32_t write_tcp_stream(TcpStream* stream, void* buffer, uint16_t len) { } int32_t close_tcp_stream(TcpStream* stream) { - return close(stream->streamfd); + return close_socket(stream->streamfd); } diff --git a/src/server/resolver.c b/src/server/resolver.c index c9e5246..67149bc 100644 --- a/src/server/resolver.c +++ b/src/server/resolver.c @@ -45,7 +45,7 @@ static bool lookup( static bool search(Question* question, Packet* result, BindingType type, const RecordMap* map) { IpAddr addr; - uint8_t ip[4] = {1, 1, 1, 1}; + uint8_t ip[4] = {9,9,9,9}; create_ip_addr(ip, &addr); uint16_t port = 53;