summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorTyler Murphy <tylerm@tylerm.dev>2023-04-08 15:58:03 -0400
committerTyler Murphy <tylerm@tylerm.dev>2023-04-08 15:58:03 -0400
commit2a75d25632eb2ac966a3c52acbcc790d32abeef3 (patch)
tree239345c0666aee232fd50024ecabc0d9fc156a1c /src
parentgoofy ahh bin folder (diff)
downloadwrapper-2a75d25632eb2ac966a3c52acbcc790d32abeef3.tar.gz
wrapper-2a75d25632eb2ac966a3c52acbcc790d32abeef3.tar.bz2
wrapper-2a75d25632eb2ac966a3c52acbcc790d32abeef3.zip
read from config
Diffstat (limited to 'src')
-rw-r--r--src/io/config.c476
-rw-r--r--src/io/config.h5
-rw-r--r--src/io/map.c104
-rw-r--r--src/io/map.h20
-rw-r--r--src/packet/packet.c10
-rw-r--r--src/packet/question.c3
-rw-r--r--src/packet/record.c2
-rw-r--r--src/packet/record.h2
-rw-r--r--src/server/addr.c8
-rw-r--r--src/server/addr.h4
-rw-r--r--src/server/binding.c6
-rw-r--r--src/server/resolver.c29
-rw-r--r--src/server/resolver.h3
-rw-r--r--src/server/server.c112
14 files changed, 715 insertions, 69 deletions
diff --git a/src/io/config.c b/src/io/config.c
new file mode 100644
index 0000000..7bd4522
--- /dev/null
+++ b/src/io/config.c
@@ -0,0 +1,476 @@
+#include <errno.h>
+#include <stdio.h>
+#include <string.h>
+#include <stdlib.h>
+
+#include "config.h"
+#include "log.h"
+#include "map.h"
+
+#define MAX_LEN 1024
+#define BUF(name) char name[MAX_LEN]
+
+static int line = 0;
+
+static bool get_line(FILE* file, const BUF(buf)) {
+ line++;
+ return fgets((char*) buf, MAX_LEN, file) != NULL;
+}
+
+static bool is_whitespace(const char* buf) {
+ int i = 0;
+ char c;
+ while(c = buf[i], 1) {
+ if (c == '\n' || c == '\0') return true;
+ if (c != ' ' && c != '\n') return false;
+ i++;
+ }
+}
+
+static bool get_words(char* buf, char** words, int count) {
+ int last = 0;
+ int offset = 0;
+ int i = 0;
+
+ for(i = 0; i < count; i++) {
+ char c;
+ while(c = buf[offset], c != ' ' && c != '\0' && c != '\n') {
+ offset++;
+ }
+
+ if (offset - last < 1) {
+ return false;
+ }
+ words[i] = buf + last;
+ buf[offset] = '\0';
+ offset++;
+ last = offset;
+
+ if (c == '\0' || c == '\n') {
+ break;
+ }
+
+ }
+ return i + 1 == count;
+}
+
+static bool get_int(const char* word, uint32_t* i) {
+ char* end;
+ uint32_t res = (uint32_t) strtol(word, &end, 10);
+
+ if (*end == '\0') {
+ *i = res;
+ return true;
+ } else {
+ return false;
+ }
+}
+
+static bool config_read_qtype(const char* qstr, RecordType* qtype) {
+ if (strcmp(qstr, "A") == 0) {
+ *qtype = A;
+ return true;
+ } else if (strcmp(qstr, "NS") == 0) {
+ *qtype = NS;
+ return true;
+ } else if (strcmp(qstr, "CNAME") == 0) {
+ *qtype = CNAME;
+ return true;
+ } else if (strcmp(qstr, "SOA") == 0) {
+ *qtype = SOA;
+ return true;
+ } else if (strcmp(qstr, "PTR") == 0) {
+ *qtype = PTR;
+ return true;
+ } else if (strcmp(qstr, "MX") == 0) {
+ *qtype = MX;
+ return true;
+ } else if (strcmp(qstr, "TXT") == 0) {
+ *qtype = TXT;
+ return true;
+ } else if (strcmp(qstr, "AAAA") == 0) {
+ *qtype = AAAA;
+ return true;
+ } else if (strcmp(qstr, "SRV") == 0) {
+ *qtype = SRV;
+ return true;
+ } else if (strcmp(qstr, "CAA") == 0) {
+ *qtype = CAA;
+ return true;
+ } else {
+ return false;
+ }
+}
+
+static bool config_read_class(const char* cstr, uint16_t* class) {
+ if (strcmp(cstr, "IN") == 0) {
+ *class = 1;
+ return true;
+ } else if (strcmp(cstr, "CH") == 0) {
+ *class = 3;
+ return true;
+ } else if (strcmp(cstr, "HS") == 0) {
+ *class = 4;
+ return true;
+ } else {
+ return false;
+ }
+}
+
+// Format QTYPE CLASS DOMAIN: A IN google.com
+static bool config_read_question(FILE* file, Question* question) {
+ BUF(buf);
+ if (!get_line(file, buf)) {
+ return false;
+ }
+
+ if (is_whitespace(buf)) {
+ return false;
+ }
+
+ char* words[3];
+ if (!get_words(&buf[0], &words[0], 3)) {
+ WARN("Invalid question at line %d", line);
+ return false;
+ };
+
+ uint16_t class;
+ if (!config_read_class(words[0], &class)) {
+ WARN("Invalid question class at line %d", line);
+ return false;
+ }
+
+ RecordType qtype;
+ if (!config_read_qtype(words[1], &qtype)) {
+ WARN("Invalid question qtype at line %d", line);
+ return false;
+ }
+
+ size_t domain_len = strlen(words[2]);
+ question->cls = class;
+ question->qtype = qtype;
+ question->domain = malloc(domain_len + 1);
+ question->domain[0] = domain_len;
+ memcpy(question->domain + 1, words[2], domain_len);
+
+ return true;
+}
+
+static void copy_str(char* from, uint8_t** to) {
+ size_t len = strlen(from);
+ if (len > 255) {
+ len = 255;
+ }
+
+ uint8_t* new = malloc(len + 1);
+ new[0] = len;
+ memcpy(new + 1, from, len);
+ *to = new;
+}
+
+static bool config_read_a_record(char* data, ARecord* record) {
+ sscanf(data, "%hhu.%hhu.%hhu.%hhu",
+ &record->addr[0],
+ &record->addr[1],
+ &record->addr[2],
+ &record->addr[3]
+ );
+ return true;
+}
+
+static bool config_read_ns_record(char* data, NSRecord* record) {
+ copy_str(data, &record->host);
+ return true;
+}
+
+static bool config_read_cname_record(char* data, CNAMERecord* record) {
+ copy_str(data, &record->host);
+ return true;
+}
+
+static bool config_read_soa_record(char* data, SOARecord* record) {
+ char* words[7];
+ if (!get_words(&data[0], &words[0], 7)) {
+ WARN("Invalid SOA record data at line %d", line);
+ record->mname = NULL;
+ record->nname = NULL;
+ return false;
+ }
+
+ copy_str(words[0], &record->mname);
+ copy_str(words[1], &record->nname);
+
+ if (!get_int(words[2], &record->serial)) {
+ WARN("Invalid SOA record data at line %d", line);
+ return false;
+ }
+
+ if (!get_int(words[3], &record->refresh)) {
+ WARN("Invalid SOA record data at line %d", line);
+ return false;
+ }
+
+ if (!get_int(words[4], &record->retry)) {
+ WARN("Invalid SOA record data at line %d", line);
+ return false;
+ }
+
+ if (!get_int(words[5], &record->expire)) {
+ WARN("Invalid SOA record data at line %d", line);
+ return false;
+ }
+
+ if (!get_int(words[6], &record->minimum)) {
+ WARN("Invalid SOA record data at line %d", line);
+ return false;
+ }
+
+ return true;
+}
+
+static bool config_read_ptr_record(char* data, PTRRecord* record) {
+ copy_str(data, &record->pointer);
+ return true;
+}
+
+static bool config_read_mx_record(char* data, MXRecord* record) {
+ char* words[2];
+ if (!get_words(&data[0], &words[0], 2)) {
+ WARN("Invalid MX record data at line %d", line);
+ record->host = NULL;
+ return false;
+ }
+
+ copy_str(words[1], &record->host);
+
+ uint32_t priority;
+ if (!get_int(words[0], &priority)) {
+ WARN("Invalid MX record data at line %d", line);
+ return false;
+ }
+ record->priority = (uint16_t) priority;
+
+ return true;
+}
+
+static bool config_read_txt_record(char* data, TXTRecord* record) {
+ int len = strlen(data);
+ uint8_t count = ((uint8_t)len + 254) / 255;
+ record->len = count;
+ record->text = malloc(sizeof(uint8_t*) * count);
+
+ for (uint8_t i = 0; i < count; i++) {
+ uint32_t offset = count * 255;
+ uint32_t part_len = len - offset;
+ if (part_len > 255) part_len = 255;
+
+ uint8_t* part = malloc(part_len + 1);
+ part[0] = part_len;
+ memcpy(part + 1, data + offset, part_len);
+
+ record->text[i] = part;
+ }
+
+ return true;
+}
+
+static bool config_read_aaaa_record(char* data, AAAARecord* record) {
+ for(int i = 0; i < 8; i++) {
+ if (sscanf(data, "%02hhx%02hhx:",
+ &record->addr[i*2 + 0],
+ &record->addr[i*2 + 1]
+ ) == EOF) {
+ return false;
+ }
+ }
+ return true;
+}
+
+static bool config_read_srv_record(char* data, SRVRecord* record) {
+ char* words[4];
+ if (!get_words(&data[0], &words[0], 4)) {
+ WARN("Invalid SRV record data at line %d", line);
+ record->target = NULL;
+ return false;
+ }
+
+ copy_str(words[3], &record->target);
+
+ uint32_t priority;
+ if (!get_int(words[0], &priority)) {
+ WARN("Invalid SRV record data at line %d", line);
+ return false;
+ }
+ record->priority = (uint16_t) priority;
+
+ uint32_t weight;
+ if (!get_int(words[1], &weight)) {
+ WARN("Invalid SRV record data at line %d", line);
+ return false;
+ }
+ record->weight = (uint16_t) weight;
+
+ uint32_t port;
+ if (!get_int(words[2], &port)) {
+ WARN("Invalid SRV record data at line %d", line);
+ return false;
+ }
+ record->port = (uint16_t) port;
+
+ return true;
+}
+
+static bool config_read_caa_record(char* data, CAARecord* record) {
+ char* words[4];
+ if (!get_words(&data[0], &words[0], 4)) {
+ WARN("Invalid SRV record data at line %d", line);
+ record->tag = NULL;
+ record->value = NULL;
+ return false;
+ }
+
+ copy_str(words[2], &record->tag);
+ copy_str(words[3], &record->value);
+
+ uint32_t flags;
+ if (!get_int(words[0], &flags)) {
+ WARN("Invalid SRV record data at line %d", line);
+ return false;
+ }
+ record->flags = (uint8_t) flags;
+
+ uint32_t length;
+ if (!get_int(words[1], &length)) {
+ WARN("Invalid SRV record data at line %d", line);
+ return false;
+ }
+ record->length = (uint8_t) length;
+
+ return true;
+}
+
+static bool config_read_record_data(char* data, Record* record) {
+ switch (record->type) {
+ case UNKOWN:
+ // This can never happend in here so uh do nothing i guess
+ return false;
+ case A:
+ return config_read_a_record(data, &record->data.a);
+ case NS:
+ return config_read_ns_record(data, &record->data.ns);
+ case CNAME:
+ return config_read_cname_record(data, &record->data.cname);
+ case SOA:
+ return config_read_soa_record(data, &record->data.soa);
+ case PTR:
+ return config_read_ptr_record(data, &record->data.ptr);
+ case MX:
+ return config_read_mx_record(data, &record->data.mx);
+ case TXT:
+ return config_read_txt_record(data, &record->data.txt);
+ case AAAA:
+ return config_read_aaaa_record(data, &record->data.aaaa);
+ case SRV:
+ return config_read_srv_record(data, &record->data.srv);
+ case CAA:
+ return config_read_caa_record(data, &record->data.caa);
+ }
+ return false;
+}
+
+static bool config_read_record(FILE* file, Record* record, Question* question) {
+ BUF(buf);
+ if (!get_line(file, buf)) {
+ return false;
+ }
+
+ if (is_whitespace(buf)) {
+ return false;
+ }
+
+ char* words[2];
+ if (!get_words(&buf[0], &words[0], 2)) {
+ WARN("Invalid record at line %d", line);
+ return false;
+ }
+
+ uint32_t ttl;
+ if (!get_int(words[0], &ttl)) {
+ WARN("Invalid record ttl at line %d", line);
+ return false;
+ }
+
+ record->cls = question->cls;
+ record->type = question->qtype;
+ record->len = 0;
+ record->ttl = ttl;
+ record->domain = malloc(question->domain[0] + 1);
+ memcpy(record->domain, question->domain, question->domain[0] + 1);
+
+ if(!config_read_record_data(words[1], record)) {
+ free_record(record);
+ return false;
+ }
+
+ return true;
+}
+
+static void config_push_record(Record** buf, Record record, uint16_t* capacity, uint16_t* size) {
+ if (size == capacity) {
+ *capacity *= 2;
+ *buf = realloc(*buf, sizeof(Record) * *capacity);
+ }
+ (*buf)[*size] = record;
+ (*size)++;
+}
+
+bool load_config(const char* path, RecordMap* map) {
+ FILE* file = fopen(path, "r");
+ if (file == NULL) {
+ ERROR("Failed to open file %s: %s", path, strerror(errno));
+ return false;
+ }
+
+ line = 0;
+ record_map_init(map);
+
+ while (1) {
+ Question* question = malloc(sizeof(Question));
+ if (!config_read_question(file, question)) {
+ free(question);
+ break;
+ }
+
+ INIT_LOG_BUFFER(log);
+ LOGONLY(print_question(question, log));
+ TRACE("Found config question: %s", log);
+
+ Packet* packet = malloc(sizeof(Packet));
+ memset(packet, 0, sizeof(Packet));
+ packet->authorities = NULL;
+ packet->resources = NULL;
+
+ packet->questions = malloc(sizeof(Question));
+ packet->questions[0] = *question;
+
+ uint16_t capacity = 1;
+ packet->answers = malloc(sizeof(Record));
+
+ while(1) {
+ Record record;
+ if (!config_read_record(file, &record, question)) {
+ break;
+ }
+
+ LOGONLY(print_record(&record, log));
+ TRACE("Found config record: %s", log);
+
+ config_push_record(&packet->answers, record, &capacity, &packet->header.answers);
+ }
+
+ record_map_add(map, question, packet);
+ }
+
+ fclose(file);
+ return true;
+}
diff --git a/src/io/config.h b/src/io/config.h
new file mode 100644
index 0000000..7da7e98
--- /dev/null
+++ b/src/io/config.h
@@ -0,0 +1,5 @@
+#pragma once
+
+#include "map.h"
+
+bool load_config(const char* path, RecordMap* map);
diff --git a/src/io/map.c b/src/io/map.c
new file mode 100644
index 0000000..cb4642e
--- /dev/null
+++ b/src/io/map.c
@@ -0,0 +1,104 @@
+#include <string.h>
+#include <stdlib.h>
+
+#include "map.h"
+
+void record_map_init(RecordMap* map) {
+ map->capacity = 0;
+ map->len = 0;
+ map->entries = NULL;
+}
+
+void record_map_free(RecordMap* map) {
+ for(uint32_t i = 0; i < map->capacity; i++) {
+ Entry* e = &map->entries[i];
+ if (e->key != NULL) {
+ free_question(e->key);
+ free(e->key);
+ free_packet(e->value);
+ free(e->value);
+ }
+ }
+ free(map->entries);
+}
+
+static size_t hash_question(const Question* question) {
+ size_t hash = 5381;
+ for(int i = 0; i < question->domain[0]; i++) {
+ uint8_t c = question->domain[i+1];
+ hash = ((hash << 5) + hash) + c;
+ }
+ hash = ((hash << 5) + hash) + (uint8_t)question->cls;
+ hash = ((hash << 5) + hash) + (uint8_t)question->qtype;
+ return hash;
+}
+
+static bool question_equals(const Question* a, const Question* b) {
+ if (a->qtype != b->qtype) return false;
+ if (a->cls != b->cls) return false;
+ if (a->domain[0] != b->domain[0]) return false;
+ return memcmp(a->domain+1, b->domain+1, a->domain[0]) == 0;
+}
+
+static Entry* record_map_find(Entry* entries, uint32_t capacity, const Question* key) {
+ uint32_t index = hash_question(key) % capacity;
+ while(true) {
+ Entry* entry = &entries[index];
+ if(entry->key == NULL) {
+ return entry;
+ } else if(question_equals(entry->key, key)) {
+ return entry;
+ }
+ index += 1;
+ index %= capacity;
+ }
+}
+
+static void record_map_grow(RecordMap* map, uint32_t capacity) {
+ Entry* entries = malloc(capacity * sizeof(Entry));
+ for(uint32_t i = 0; i < capacity; i++) {
+ entries[i].key = NULL;
+ entries[i].value = NULL;
+ }
+ map->len = 0;
+ for(uint32_t i = 0; i < map->capacity; i++) {
+ Entry* entry = &map->entries[i];
+ if(entry->key == NULL) continue;
+
+ Entry* dest = record_map_find(entries, capacity, entry->key);
+ dest->key = entry->key;
+ dest->value = entry->value;
+ map->len++;
+ }
+ free(map->entries);
+
+ map->entries = entries;
+ map->capacity = capacity;
+}
+
+bool record_map_get(const RecordMap* map, const Question* key, Packet* value) {
+ if(map->len == 0) return false;
+
+ Entry* e = record_map_find(map->entries, map->capacity, key);
+ if (e->key == NULL) return false;
+
+ *value = *(e->value);
+ return true;
+}
+
+void record_map_add(RecordMap* map, Question* key, Packet* value) {
+ if(map->len + 1 > map->capacity * 0.75) {
+ int capacity = (map->capacity == 0 ? 8 : (2 * map->capacity));
+ record_map_grow(map, capacity);
+ }
+ Entry* e = record_map_find(map->entries, map->capacity, key);
+ bool new_key = e->key == NULL;
+ if(new_key) {
+ map->len++;
+ e->key = key;
+ }
+
+ value->header.z = true;
+ e->value = value;
+}
+
diff --git a/src/io/map.h b/src/io/map.h
new file mode 100644
index 0000000..84b40fb
--- /dev/null
+++ b/src/io/map.h
@@ -0,0 +1,20 @@
+#pragma once
+
+#include "../packet/packet.h"
+
+typedef struct {
+ Question* key;
+ Packet* value;
+} Entry;
+
+typedef struct {
+ uint32_t capacity;
+ uint32_t len;
+ Entry* entries;
+} RecordMap;
+
+void record_map_init(RecordMap* map);
+void record_map_free(RecordMap* map);
+
+bool record_map_get(const RecordMap* map, const Question* key, Packet* value);
+void record_map_add(RecordMap* map, Question* key, Packet* value);
diff --git a/src/packet/packet.c b/src/packet/packet.c
index 9b1159d..1d96e38 100644
--- a/src/packet/packet.c
+++ b/src/packet/packet.c
@@ -79,10 +79,10 @@ bool get_random_a(Packet* packet, IpAddr* addr) {
for (uint16_t i = 0; i < packet->header.answers; i++) {
Record record = packet->answers[i];
if (record.type == A) {
- create_ip_addr((char*) &record.data.a.addr, addr);
+ create_ip_addr(record.data.a.addr, addr);
return true;
} else if (record.type == AAAA) {
- create_ip_addr6((char*) &record.data.aaaa.addr, addr);
+ create_ip_addr6(record.data.aaaa.addr, addr);
return true;
}
}
@@ -138,10 +138,10 @@ bool get_resolved_ns(Packet* packet, uint8_t* qname, IpAddr* addr) {
}
if (resource.type == A) {
- create_ip_addr((char*) &record.data.a.addr, addr);
+ create_ip_addr(record.data.a.addr, addr);
return true;
} else if (resource.type == AAAA) {
- create_ip_addr6((char*) &record.data.aaaa.addr, addr);
+ create_ip_addr6(record.data.aaaa.addr, addr);
return true;
}
}
@@ -168,4 +168,4 @@ bool get_unresoled_ns(Packet* packet, uint8_t* qname, Question* question) {
return true;
}
return false;
-} \ No newline at end of file
+}
diff --git a/src/packet/question.c b/src/packet/question.c
index c2807d0..5a08fd6 100644
--- a/src/packet/question.c
+++ b/src/packet/question.c
@@ -91,4 +91,5 @@ void print_question(Question* question, char* buffer) {
question->domain[0],
question->domain + 1
);
-} \ No newline at end of file
+}
+
diff --git a/src/packet/record.c b/src/packet/record.c
index 29c3bf0..e3e9077 100644
--- a/src/packet/record.c
+++ b/src/packet/record.c
@@ -537,4 +537,4 @@ void print_record(Record* record, char* buffer) {
);
break;
}
-} \ No newline at end of file
+}
diff --git a/src/packet/record.h b/src/packet/record.h
index 95bbbbe..479ce40 100644
--- a/src/packet/record.h
+++ b/src/packet/record.h
@@ -98,4 +98,4 @@ typedef struct {
void read_record(PacketBuffer* buffer, Record* record);
void write_record(PacketBuffer* buffer, Record* record);
void free_record(Record* record);
-void print_record(Record* record, char* buffer); \ No newline at end of file
+void print_record(Record* record, char* buffer);
diff --git a/src/server/addr.c b/src/server/addr.c
index 982da13..14f44c6 100644
--- a/src/server/addr.c
+++ b/src/server/addr.c
@@ -7,12 +7,12 @@
#include "addr.h"
#include "../io/log.h"
-void create_ip_addr(char* domain, IpAddr* addr) {
+void create_ip_addr(uint8_t* domain, IpAddr* addr) {
addr->type = V4;
memcpy(&addr->data.v4.s_addr, domain, 4);
}
-void create_ip_addr6(char* domain, IpAddr* addr) {
+void create_ip_addr6(uint8_t* domain, IpAddr* addr) {
addr->type = V6;
memcpy(&addr->data.v6.__in6_u.__u6_addr8, domain, 16);
}
@@ -76,7 +76,7 @@ void print_socket_addr(SocketAddr* addr, char* buffer) {
(uint8_t) ((uint32_t)addr->data.v4.sin_addr.s_addr >> 16),
(uint8_t) ((uint32_t)addr->data.v4.sin_addr.s_addr >> 8),
(uint8_t) ((uint32_t)addr->data.v4.sin_addr.s_addr),
- addr->data.v4.sin_port
+ ntohs(addr->data.v4.sin_port)
);
} else {
for(int i = 0; i < 8; i++) {
@@ -85,7 +85,7 @@ void print_socket_addr(SocketAddr* addr, char* buffer) {
addr->data.v6.sin6_addr.__in6_u.__u6_addr8[i*2 + 1]
);
}
- APPEND(buffer, ":[%hu]", addr->data.v6.sin6_port);
+ APPEND(buffer, ":[%hu]", ntohs(addr->data.v6.sin6_port));
}
}
diff --git a/src/server/addr.h b/src/server/addr.h
index 173c7fd..1210850 100644
--- a/src/server/addr.h
+++ b/src/server/addr.h
@@ -20,8 +20,8 @@ typedef struct {
} data;
} IpAddr;
-void create_ip_addr(char* domain, IpAddr* addr);
-void create_ip_addr6(char* domain, IpAddr* addr);
+void create_ip_addr(uint8_t* domain, IpAddr* addr);
+void create_ip_addr6(uint8_t* domain, IpAddr* addr);
void ip_addr_any(IpAddr* addr);
void ip_addr_any6(IpAddr* addr);
diff --git a/src/server/binding.c b/src/server/binding.c
index 47c62c6..157d1d4 100644
--- a/src/server/binding.c
+++ b/src/server/binding.c
@@ -144,10 +144,10 @@ bool read_connection(Connection* connection, Packet* packet) {
}
static bool write_udp(Connection* connection, uint8_t* buf, uint16_t len) {
- //if (len > 512) {
+ if (len > 512) {
buf[2] = buf[2] | 0x03;
- // len = 512;
- // }
+ len = 512;
+ }
return write_udp_socket(
&connection->sock.udp.udp,
buf,
diff --git a/src/server/resolver.c b/src/server/resolver.c
index e05f365..a1fa82a 100644
--- a/src/server/resolver.c
+++ b/src/server/resolver.c
@@ -43,9 +43,9 @@ static bool lookup(
return true;
}
-static bool search(Question* question, Packet* result, BindingType type) {
+static bool search(Question* question, Packet* result, BindingType type, const RecordMap* map) {
IpAddr addr;
- char ip[4] = {1, 1, 1, 1};
+ uint8_t ip[4] = {1, 1, 1, 1};
create_ip_addr(ip, &addr);
uint16_t port = 53;
@@ -53,6 +53,10 @@ static bool search(Question* question, Packet* result, BindingType type) {
create_socket_addr(port, addr, &saddr);
while(1) {
+ if (record_map_get(map, question, result)) {
+ return true;
+ }
+
if (!lookup(question, result, type, saddr)) {
return false;
}
@@ -75,7 +79,7 @@ static bool search(Question* question, Packet* result, BindingType type) {
}
Packet recurse;
- if (!search(&new_question, &recurse, type)) {
+ if (!search(&new_question, &recurse, type, map)) {
return false;
}
@@ -104,7 +108,7 @@ static void push_questions(Question* from, uint8_t from_len, Question** to, uint
memcpy(*to + to_len, from, from_len * sizeof(Question));
}
-void handle_query(Packet* request, Packet* response, BindingType type) {
+void handle_query(const Packet* request, Packet* response, BindingType type, const RecordMap* map) {
memset(response, 0, sizeof(Packet));
response->header.id = request->header.id;
response->header.opcode = request->header.opcode;
@@ -121,7 +125,7 @@ void handle_query(Packet* request, Packet* response, BindingType type) {
Packet result;
memset(&result, 0, sizeof(Packet));
result.header.id = response->header.id;
- if (!search(&request->questions[i], &result, type)) {
+ if (!search(&request->questions[i], &result, type, map)) {
response->header.response = SERVFAIL;
break;
}
@@ -158,9 +162,14 @@ void handle_query(Packet* request, Packet* response, BindingType type) {
);
response->header.resource_entries += result.header.resource_entries;
- free(result.questions);
- free(result.answers);
- free(result.authorities);
- free(result.resources);
+ if (result.header.z == false) {
+ // Not from cache
+ free(result.questions);
+ free(result.answers);
+ free(result.authorities);
+ free(result.resources);
+ } else {
+ response->header.z = true;
+ }
}
-} \ No newline at end of file
+}
diff --git a/src/server/resolver.h b/src/server/resolver.h
index 79b4825..993b515 100644
--- a/src/server/resolver.h
+++ b/src/server/resolver.h
@@ -1,6 +1,7 @@
#pragma once
#include "../packet/packet.h"
+#include "../io/map.h"
#include "binding.h"
-void handle_query(Packet* request, Packet* response, BindingType type); \ No newline at end of file
+void handle_query(const Packet* request, Packet* response, BindingType type, const RecordMap* map);
diff --git a/src/server/server.c b/src/server/server.c
index c8975ee..27cc9d9 100644
--- a/src/server/server.c
+++ b/src/server/server.c
@@ -1,25 +1,67 @@
-#define _POSIX_SOURCE
+#define _POSIX_C_SOURCE 200809L
#include <errno.h>
#include <stdio.h>
#include <stdlib.h>
#include <unistd.h>
#include <sys/wait.h>
#include <signal.h>
+#include <pthread.h>
#include "addr.h"
#include "server.h"
#include "resolver.h"
#include "../io/log.h"
+#include "../io/map.h"
+#include "../io/config.h"
-static pid_t udp, tcp;
+static pthread_t udp, tcp;
+static RecordMap map;
void server_init(uint16_t port, Server* server) {
INFO("Server port set to %hu", port);
create_binding(UDP, port, &server->udp);
create_binding(TCP, port, &server->tcp);
+ load_config("config", &map);
}
-static void server_listen(Binding* binding) {
+struct DnsRequest {
+ Connection connection;
+ Packet request;
+};
+
+static void* server_respond(void* arg) {
+ struct DnsRequest req = *(struct DnsRequest*) arg;
+
+ INFO("Recieved packet request ID %hu", req.request.header.id);
+
+ Packet response;
+ handle_query(&req.request, &response, req.connection.type, &map);
+
+ if (!write_connection(&req.connection, &response)) {
+ ERROR("Faled to respond to connection ID %hu: %s",
+ req.request.header.id,
+ strerror(errno)
+ );
+ }
+
+ free_packet(&req.request);
+ free_connection(&req.connection);
+
+ if (response.header.z == false) {
+ free_packet(&response);
+ } else {
+ // Dont free from config
+ free(response.questions);
+ free(response.answers);
+ free(response.authorities);
+ free(response.resources);
+ }
+
+ return NULL;
+}
+
+static void* server_listen(void* arg) {
+ Binding* binding = (Binding*) arg;
while(1) {
Connection connection;
@@ -34,67 +76,55 @@ static void server_listen(Binding* binding) {
free_connection(&connection);
continue;
}
+
+ struct DnsRequest req;
+ req.connection = connection;
+ req.request = request;
- if(fork() != 0) {
+ pthread_t thread;
+ if(pthread_create(&thread, NULL, &server_respond, &req)) {
+ ERROR("Failed to create thread for dns request ID %hu: %s",
+ request.header.id,
+ strerror(errno)
+ );
free_packet(&request);
free_connection(&connection);
continue;
}
-
- INFO("Recieved packet request ID %hu", request.header.id);
-
- Packet response;
- handle_query(&request, &response, connection.type);
-
- if (!write_connection(&connection, &response)) {
- ERROR("Failed to respond to connection ID %hu: %s", request.header.id, strerror(errno));
- }
-
- free_packet(&request);
- free_packet(&response);
- free_connection(&connection);
- exit(EXIT_SUCCESS);
}
+
+ return NULL;
}
static void signal_handler() {
printf("\n");
- kill(udp, SIGTERM);
- kill(tcp, SIGTERM);
+ pthread_kill(udp, SIGTERM);
+ pthread_kill(tcp, SIGTERM);
}
void server_run(Server* server) {
- if ((udp = fork()) == 0) {
+ if (!pthread_create(&udp, NULL, &server_listen, &server->udp)) {
INFO("Listening for connections on UDP");
- server_listen(&server->udp);
- exit(EXIT_SUCCESS);
+ } else {
+ ERROR("Failed to start UDP thread");
+ exit(EXIT_FAILURE);
}
- if ((tcp = fork()) == 0) {
+ if (!pthread_create(&tcp, NULL, &server_listen, &server->tcp)) {
INFO("Listening for connections on TCP");
- server_listen(&server->tcp);
- exit(EXIT_SUCCESS);
- }
-
- signal(SIGINT, signal_handler);
-
- int status;
- waitpid(udp, &status, 0);
- if (status == 0) {
- INFO("UDP process closed successfully");
} else {
- ERROR("UDP process failed with error code %d", status);
+ ERROR("Failed to start TCP thread");
+ exit(EXIT_FAILURE);
}
+ signal(SIGINT, signal_handler);
- waitpid(tcp, &status, 0);
- if (status == 0) {
- INFO("TCP process closed successfully");
- } else {
- ERROR("TCP process failed with error code %d", status);
- }
+ pthread_join(udp, NULL);
+ pthread_join(tcp, NULL);
}
void server_free(Server* server) {
free_binding(&server->udp);
free_binding(&server->tcp);
+ record_map_free(&map);
}
+