summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorTyler Murphy <tylerm@tylerm.dev>2023-04-18 11:01:56 -0400
committerTyler Murphy <tylerm@tylerm.dev>2023-04-18 11:01:56 -0400
commit5eef60aa1a6ac9e174708f19a154ce9660589b6d (patch)
tree649e238dc905b7071efc9e4a06b61bcb286ee8a0 /src
parentfix makefile (diff)
downloadwrapper-main.tar.gz
wrapper-main.tar.bz2
wrapper-main.zip
Diffstat (limited to '')
-rw-r--r--src/main.c1
-rw-r--r--src/packet/buffer.c30
-rw-r--r--src/packet/record.c19
-rw-r--r--src/server/addr.c31
-rw-r--r--src/server/resolver.c2
5 files changed, 63 insertions, 20 deletions
diff --git a/src/main.c b/src/main.c
index 13dae57..b142fd5 100644
--- a/src/main.c
+++ b/src/main.c
@@ -1,7 +1,6 @@
#include "server/server.h"
#include <stdlib.h>
-#include <sys/select.h>
#define DEFAULT_PORT 53
diff --git a/src/packet/buffer.c b/src/packet/buffer.c
index 28dd73b..b609476 100644
--- a/src/packet/buffer.c
+++ b/src/packet/buffer.c
@@ -14,6 +14,7 @@ struct PacketBuffer {
PacketBuffer* buffer_create(int capacity) {
PacketBuffer* buffer = malloc(sizeof(PacketBuffer));
buffer->arr = malloc(capacity);
+ memset(buffer->arr, 0, capacity);
buffer->capacity = capacity;
buffer->size = 0;
buffer->index = 0;
@@ -30,7 +31,7 @@ void buffer_seek(PacketBuffer* buffer, int index) {
}
uint8_t buffer_read(PacketBuffer* buffer) {
- if (buffer->index > buffer->size) {
+ if (buffer->index >= buffer->size) {
return 0;
}
uint8_t data = buffer->arr[buffer->index];
@@ -53,7 +54,7 @@ uint32_t buffer_read_int(PacketBuffer* buffer) {
}
uint8_t buffer_get(PacketBuffer* buffer, int index) {
- if (index > buffer->size) {
+ if (index >= buffer->size) {
return 0;
}
uint8_t data = buffer->arr[index];
@@ -73,8 +74,12 @@ uint16_t buffer_get_size(PacketBuffer* buffer) {
}
static void write(uint8_t** buffer, uint8_t* size, uint8_t* capacity, uint8_t data) {
- if (*size == *capacity) {
- *capacity *= 2;
+ if (*size >= *capacity) {
+ if (*capacity >= 128) {
+ *capacity = 255;
+ } else {
+ *capacity *= 2;
+ }
*buffer = realloc(*buffer, *capacity);
}
(*buffer)[*size] = data;
@@ -144,7 +149,16 @@ void buffer_read_string(PacketBuffer* buffer, uint8_t** out) {
buffer_read_n(buffer, out, len);
}
+static void buffer_expand(PacketBuffer* buffer, int capacity) {
+ if (buffer->capacity >= capacity) return;
+
+ buffer->arr = realloc(buffer->arr, capacity);
+ memset(buffer->arr + buffer->capacity, 0, capacity - buffer->capacity);
+ buffer->capacity = capacity;
+}
+
void buffer_read_n(PacketBuffer* buffer, uint8_t** out, uint8_t len) {
+ buffer_expand(buffer, buffer->index + len + 1);
*out = malloc(len + 1);
*out[0] = len;
memcpy(*out + 1, buffer->arr + buffer->index, len);
@@ -152,7 +166,7 @@ void buffer_read_n(PacketBuffer* buffer, uint8_t** out, uint8_t len) {
}
void buffer_write(PacketBuffer* buffer, uint8_t data) {
- if(buffer->index == buffer->capacity) {
+ if(buffer->index >= buffer->capacity) {
buffer->capacity *= 2;
buffer->arr = realloc(buffer->arr, buffer->capacity);
}
@@ -205,11 +219,7 @@ void buffer_write_string(PacketBuffer* buffer, uint8_t* in) {
}
void buffer_write_n(PacketBuffer* buffer, uint8_t* in, int len) {
- if (buffer->size + len >= buffer->capacity) {
- buffer->capacity *= 2;
- buffer->capacity += len;
- buffer->arr = realloc(buffer->arr, buffer->capacity);
- }
+ buffer_expand(buffer, buffer->index + len + 1);
memcpy(buffer->arr + buffer->index, in, len);
buffer->size += len;
buffer->index += len;
diff --git a/src/packet/record.c b/src/packet/record.c
index 53934ed..c676bce 100644
--- a/src/packet/record.c
+++ b/src/packet/record.c
@@ -140,17 +140,28 @@ static void read_txt_record(PacketBuffer* buffer, Record* record) {
TXTRecord data;
data.len = 0;
data.text = malloc(sizeof(uint8_t*) * 2);
-
uint8_t capacity = 2;
+ uint8_t total = record->len;
while (1) {
if (data.len >= capacity) {
- capacity *= 2;
+ if (capacity >= 128) {
+ capacity = 255;
+ } else {
+ capacity *= 2;
+ }
data.text = realloc(data.text, sizeof(uint8_t*) * capacity);
}
-
+
buffer_read_string(buffer, &data.text[data.len]);
- if(data.text[data.len][0] == 0) break;
+ if(data.text[data.len][0] == 0) {
+ free(data.text[data.len]);
+ break;
+ }
+
data.len++;
+
+ total -= data.text[data.len - 1][0] + 1;
+ if (total == 0) break;
}
record->data.txt = data;
diff --git a/src/server/addr.c b/src/server/addr.c
index c83b216..494b694 100644
--- a/src/server/addr.c
+++ b/src/server/addr.c
@@ -1,3 +1,4 @@
+#include <errno.h>
#include <netinet/in.h>
#include <stdio.h>
#include <stdlib.h>
@@ -159,8 +160,30 @@ int32_t write_udp_socket(UdpSocket* socket, void* buffer, uint16_t len, SocketAd
);
}
+static int get_socket_error(int fd) {
+ int err = 1;
+ socklen_t len = sizeof err;
+ if (-1 == getsockopt(fd, SOL_SOCKET, SO_ERROR, (char *)&err, &len))
+ return 0;
+ if (err)
+ errno = err;
+ return err;
+}
+
+static int close_socket(int fd) {
+ if (fd >= 0) {
+ get_socket_error(fd); // first clear any errors, which can cause close to fail
+ if (shutdown(fd, SHUT_RDWR) < 0) // secondly, terminate the 'reliable' delivery
+ if (errno != ENOTCONN && errno != EINVAL) // SGI causes EINVAL
+ perror("shutdown");
+ if (close(fd) < 0) // finally call close()
+ perror("close");
+ }
+ return 0;
+}
+
int32_t close_udp_socket(UdpSocket* socket) {
- return close(socket->sockfd);
+ return close_socket(socket->sockfd);
}
int32_t create_tcp_socket(AddrType type, TcpSocket* sock) {
@@ -204,7 +227,7 @@ int32_t accept_tcp_socket(TcpSocket* socket, TcpStream* stream) {
}
int32_t close_tcp_socket(TcpSocket* socket) {
- return close(socket->sockfd);
+ return close_socket(socket->sockfd);
}
int32_t connect_tcp_stream(SocketAddr* servaddr, TcpStream* stream) {
@@ -222,7 +245,7 @@ int32_t connect_tcp_stream(SocketAddr* servaddr, TcpStream* stream) {
}
int32_t read_tcp_stream(TcpStream* stream, void* buffer, uint16_t len) {
- return recv(stream->streamfd, buffer, len, 0);
+ return recv(stream->streamfd, buffer, len, MSG_WAITALL);
}
int32_t write_tcp_stream(TcpStream* stream, void* buffer, uint16_t len) {
@@ -230,5 +253,5 @@ int32_t write_tcp_stream(TcpStream* stream, void* buffer, uint16_t len) {
}
int32_t close_tcp_stream(TcpStream* stream) {
- return close(stream->streamfd);
+ return close_socket(stream->streamfd);
}
diff --git a/src/server/resolver.c b/src/server/resolver.c
index c9e5246..67149bc 100644
--- a/src/server/resolver.c
+++ b/src/server/resolver.c
@@ -45,7 +45,7 @@ static bool lookup(
static bool search(Question* question, Packet* result, BindingType type, const RecordMap* map) {
IpAddr addr;
- uint8_t ip[4] = {1, 1, 1, 1};
+ uint8_t ip[4] = {9,9,9,9};
create_ip_addr(ip, &addr);
uint16_t port = 53;