diff --git a/.gitignore b/.gitignore index ba077a4..409f624 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,2 @@ bin +config diff --git a/Makefile b/Makefile index 6c2a640..9a6217d 100644 --- a/Makefile +++ b/Makefile @@ -6,6 +6,7 @@ CCFLAGS = -std=c17 -Wall -Wextra -O2 CCFLAGS += $(INCFLAGS) LDFLAGS += $(INCFLAGS) +LDFLAGS += -lpthread BIN = bin APP = $(BIN)/app diff --git a/src/io/config.c b/src/io/config.c new file mode 100644 index 0000000..7bd4522 --- /dev/null +++ b/src/io/config.c @@ -0,0 +1,476 @@ +#include +#include +#include +#include + +#include "config.h" +#include "log.h" +#include "map.h" + +#define MAX_LEN 1024 +#define BUF(name) char name[MAX_LEN] + +static int line = 0; + +static bool get_line(FILE* file, const BUF(buf)) { + line++; + return fgets((char*) buf, MAX_LEN, file) != NULL; +} + +static bool is_whitespace(const char* buf) { + int i = 0; + char c; + while(c = buf[i], 1) { + if (c == '\n' || c == '\0') return true; + if (c != ' ' && c != '\n') return false; + i++; + } +} + +static bool get_words(char* buf, char** words, int count) { + int last = 0; + int offset = 0; + int i = 0; + + for(i = 0; i < count; i++) { + char c; + while(c = buf[offset], c != ' ' && c != '\0' && c != '\n') { + offset++; + } + + if (offset - last < 1) { + return false; + } + words[i] = buf + last; + buf[offset] = '\0'; + offset++; + last = offset; + + if (c == '\0' || c == '\n') { + break; + } + + } + return i + 1 == count; +} + +static bool get_int(const char* word, uint32_t* i) { + char* end; + uint32_t res = (uint32_t) strtol(word, &end, 10); + + if (*end == '\0') { + *i = res; + return true; + } else { + return false; + } +} + +static bool config_read_qtype(const char* qstr, RecordType* qtype) { + if (strcmp(qstr, "A") == 0) { + *qtype = A; + return true; + } else if (strcmp(qstr, "NS") == 0) { + *qtype = NS; + return true; + } else if (strcmp(qstr, "CNAME") == 0) { + *qtype = CNAME; + return true; + } else if (strcmp(qstr, "SOA") == 0) { + *qtype = SOA; + return true; + } else if (strcmp(qstr, "PTR") == 0) { + *qtype = PTR; + return true; + } else if (strcmp(qstr, "MX") == 0) { + *qtype = MX; + return true; + } else if (strcmp(qstr, "TXT") == 0) { + *qtype = TXT; + return true; + } else if (strcmp(qstr, "AAAA") == 0) { + *qtype = AAAA; + return true; + } else if (strcmp(qstr, "SRV") == 0) { + *qtype = SRV; + return true; + } else if (strcmp(qstr, "CAA") == 0) { + *qtype = CAA; + return true; + } else { + return false; + } +} + +static bool config_read_class(const char* cstr, uint16_t* class) { + if (strcmp(cstr, "IN") == 0) { + *class = 1; + return true; + } else if (strcmp(cstr, "CH") == 0) { + *class = 3; + return true; + } else if (strcmp(cstr, "HS") == 0) { + *class = 4; + return true; + } else { + return false; + } +} + +// Format QTYPE CLASS DOMAIN: A IN google.com +static bool config_read_question(FILE* file, Question* question) { + BUF(buf); + if (!get_line(file, buf)) { + return false; + } + + if (is_whitespace(buf)) { + return false; + } + + char* words[3]; + if (!get_words(&buf[0], &words[0], 3)) { + WARN("Invalid question at line %d", line); + return false; + }; + + uint16_t class; + if (!config_read_class(words[0], &class)) { + WARN("Invalid question class at line %d", line); + return false; + } + + RecordType qtype; + if (!config_read_qtype(words[1], &qtype)) { + WARN("Invalid question qtype at line %d", line); + return false; + } + + size_t domain_len = strlen(words[2]); + question->cls = class; + question->qtype = qtype; + question->domain = malloc(domain_len + 1); + question->domain[0] = domain_len; + memcpy(question->domain + 1, words[2], domain_len); + + return true; +} + +static void copy_str(char* from, uint8_t** to) { + size_t len = strlen(from); + if (len > 255) { + len = 255; + } + + uint8_t* new = malloc(len + 1); + new[0] = len; + memcpy(new + 1, from, len); + *to = new; +} + +static bool config_read_a_record(char* data, ARecord* record) { + sscanf(data, "%hhu.%hhu.%hhu.%hhu", + &record->addr[0], + &record->addr[1], + &record->addr[2], + &record->addr[3] + ); + return true; +} + +static bool config_read_ns_record(char* data, NSRecord* record) { + copy_str(data, &record->host); + return true; +} + +static bool config_read_cname_record(char* data, CNAMERecord* record) { + copy_str(data, &record->host); + return true; +} + +static bool config_read_soa_record(char* data, SOARecord* record) { + char* words[7]; + if (!get_words(&data[0], &words[0], 7)) { + WARN("Invalid SOA record data at line %d", line); + record->mname = NULL; + record->nname = NULL; + return false; + } + + copy_str(words[0], &record->mname); + copy_str(words[1], &record->nname); + + if (!get_int(words[2], &record->serial)) { + WARN("Invalid SOA record data at line %d", line); + return false; + } + + if (!get_int(words[3], &record->refresh)) { + WARN("Invalid SOA record data at line %d", line); + return false; + } + + if (!get_int(words[4], &record->retry)) { + WARN("Invalid SOA record data at line %d", line); + return false; + } + + if (!get_int(words[5], &record->expire)) { + WARN("Invalid SOA record data at line %d", line); + return false; + } + + if (!get_int(words[6], &record->minimum)) { + WARN("Invalid SOA record data at line %d", line); + return false; + } + + return true; +} + +static bool config_read_ptr_record(char* data, PTRRecord* record) { + copy_str(data, &record->pointer); + return true; +} + +static bool config_read_mx_record(char* data, MXRecord* record) { + char* words[2]; + if (!get_words(&data[0], &words[0], 2)) { + WARN("Invalid MX record data at line %d", line); + record->host = NULL; + return false; + } + + copy_str(words[1], &record->host); + + uint32_t priority; + if (!get_int(words[0], &priority)) { + WARN("Invalid MX record data at line %d", line); + return false; + } + record->priority = (uint16_t) priority; + + return true; +} + +static bool config_read_txt_record(char* data, TXTRecord* record) { + int len = strlen(data); + uint8_t count = ((uint8_t)len + 254) / 255; + record->len = count; + record->text = malloc(sizeof(uint8_t*) * count); + + for (uint8_t i = 0; i < count; i++) { + uint32_t offset = count * 255; + uint32_t part_len = len - offset; + if (part_len > 255) part_len = 255; + + uint8_t* part = malloc(part_len + 1); + part[0] = part_len; + memcpy(part + 1, data + offset, part_len); + + record->text[i] = part; + } + + return true; +} + +static bool config_read_aaaa_record(char* data, AAAARecord* record) { + for(int i = 0; i < 8; i++) { + if (sscanf(data, "%02hhx%02hhx:", + &record->addr[i*2 + 0], + &record->addr[i*2 + 1] + ) == EOF) { + return false; + } + } + return true; +} + +static bool config_read_srv_record(char* data, SRVRecord* record) { + char* words[4]; + if (!get_words(&data[0], &words[0], 4)) { + WARN("Invalid SRV record data at line %d", line); + record->target = NULL; + return false; + } + + copy_str(words[3], &record->target); + + uint32_t priority; + if (!get_int(words[0], &priority)) { + WARN("Invalid SRV record data at line %d", line); + return false; + } + record->priority = (uint16_t) priority; + + uint32_t weight; + if (!get_int(words[1], &weight)) { + WARN("Invalid SRV record data at line %d", line); + return false; + } + record->weight = (uint16_t) weight; + + uint32_t port; + if (!get_int(words[2], &port)) { + WARN("Invalid SRV record data at line %d", line); + return false; + } + record->port = (uint16_t) port; + + return true; +} + +static bool config_read_caa_record(char* data, CAARecord* record) { + char* words[4]; + if (!get_words(&data[0], &words[0], 4)) { + WARN("Invalid SRV record data at line %d", line); + record->tag = NULL; + record->value = NULL; + return false; + } + + copy_str(words[2], &record->tag); + copy_str(words[3], &record->value); + + uint32_t flags; + if (!get_int(words[0], &flags)) { + WARN("Invalid SRV record data at line %d", line); + return false; + } + record->flags = (uint8_t) flags; + + uint32_t length; + if (!get_int(words[1], &length)) { + WARN("Invalid SRV record data at line %d", line); + return false; + } + record->length = (uint8_t) length; + + return true; +} + +static bool config_read_record_data(char* data, Record* record) { + switch (record->type) { + case UNKOWN: + // This can never happend in here so uh do nothing i guess + return false; + case A: + return config_read_a_record(data, &record->data.a); + case NS: + return config_read_ns_record(data, &record->data.ns); + case CNAME: + return config_read_cname_record(data, &record->data.cname); + case SOA: + return config_read_soa_record(data, &record->data.soa); + case PTR: + return config_read_ptr_record(data, &record->data.ptr); + case MX: + return config_read_mx_record(data, &record->data.mx); + case TXT: + return config_read_txt_record(data, &record->data.txt); + case AAAA: + return config_read_aaaa_record(data, &record->data.aaaa); + case SRV: + return config_read_srv_record(data, &record->data.srv); + case CAA: + return config_read_caa_record(data, &record->data.caa); + } + return false; +} + +static bool config_read_record(FILE* file, Record* record, Question* question) { + BUF(buf); + if (!get_line(file, buf)) { + return false; + } + + if (is_whitespace(buf)) { + return false; + } + + char* words[2]; + if (!get_words(&buf[0], &words[0], 2)) { + WARN("Invalid record at line %d", line); + return false; + } + + uint32_t ttl; + if (!get_int(words[0], &ttl)) { + WARN("Invalid record ttl at line %d", line); + return false; + } + + record->cls = question->cls; + record->type = question->qtype; + record->len = 0; + record->ttl = ttl; + record->domain = malloc(question->domain[0] + 1); + memcpy(record->domain, question->domain, question->domain[0] + 1); + + if(!config_read_record_data(words[1], record)) { + free_record(record); + return false; + } + + return true; +} + +static void config_push_record(Record** buf, Record record, uint16_t* capacity, uint16_t* size) { + if (size == capacity) { + *capacity *= 2; + *buf = realloc(*buf, sizeof(Record) * *capacity); + } + (*buf)[*size] = record; + (*size)++; +} + +bool load_config(const char* path, RecordMap* map) { + FILE* file = fopen(path, "r"); + if (file == NULL) { + ERROR("Failed to open file %s: %s", path, strerror(errno)); + return false; + } + + line = 0; + record_map_init(map); + + while (1) { + Question* question = malloc(sizeof(Question)); + if (!config_read_question(file, question)) { + free(question); + break; + } + + INIT_LOG_BUFFER(log); + LOGONLY(print_question(question, log)); + TRACE("Found config question: %s", log); + + Packet* packet = malloc(sizeof(Packet)); + memset(packet, 0, sizeof(Packet)); + packet->authorities = NULL; + packet->resources = NULL; + + packet->questions = malloc(sizeof(Question)); + packet->questions[0] = *question; + + uint16_t capacity = 1; + packet->answers = malloc(sizeof(Record)); + + while(1) { + Record record; + if (!config_read_record(file, &record, question)) { + break; + } + + LOGONLY(print_record(&record, log)); + TRACE("Found config record: %s", log); + + config_push_record(&packet->answers, record, &capacity, &packet->header.answers); + } + + record_map_add(map, question, packet); + } + + fclose(file); + return true; +} diff --git a/src/io/config.h b/src/io/config.h new file mode 100644 index 0000000..7da7e98 --- /dev/null +++ b/src/io/config.h @@ -0,0 +1,5 @@ +#pragma once + +#include "map.h" + +bool load_config(const char* path, RecordMap* map); diff --git a/src/io/map.c b/src/io/map.c new file mode 100644 index 0000000..cb4642e --- /dev/null +++ b/src/io/map.c @@ -0,0 +1,104 @@ +#include +#include + +#include "map.h" + +void record_map_init(RecordMap* map) { + map->capacity = 0; + map->len = 0; + map->entries = NULL; +} + +void record_map_free(RecordMap* map) { + for(uint32_t i = 0; i < map->capacity; i++) { + Entry* e = &map->entries[i]; + if (e->key != NULL) { + free_question(e->key); + free(e->key); + free_packet(e->value); + free(e->value); + } + } + free(map->entries); +} + +static size_t hash_question(const Question* question) { + size_t hash = 5381; + for(int i = 0; i < question->domain[0]; i++) { + uint8_t c = question->domain[i+1]; + hash = ((hash << 5) + hash) + c; + } + hash = ((hash << 5) + hash) + (uint8_t)question->cls; + hash = ((hash << 5) + hash) + (uint8_t)question->qtype; + return hash; +} + +static bool question_equals(const Question* a, const Question* b) { + if (a->qtype != b->qtype) return false; + if (a->cls != b->cls) return false; + if (a->domain[0] != b->domain[0]) return false; + return memcmp(a->domain+1, b->domain+1, a->domain[0]) == 0; +} + +static Entry* record_map_find(Entry* entries, uint32_t capacity, const Question* key) { + uint32_t index = hash_question(key) % capacity; + while(true) { + Entry* entry = &entries[index]; + if(entry->key == NULL) { + return entry; + } else if(question_equals(entry->key, key)) { + return entry; + } + index += 1; + index %= capacity; + } +} + +static void record_map_grow(RecordMap* map, uint32_t capacity) { + Entry* entries = malloc(capacity * sizeof(Entry)); + for(uint32_t i = 0; i < capacity; i++) { + entries[i].key = NULL; + entries[i].value = NULL; + } + map->len = 0; + for(uint32_t i = 0; i < map->capacity; i++) { + Entry* entry = &map->entries[i]; + if(entry->key == NULL) continue; + + Entry* dest = record_map_find(entries, capacity, entry->key); + dest->key = entry->key; + dest->value = entry->value; + map->len++; + } + free(map->entries); + + map->entries = entries; + map->capacity = capacity; +} + +bool record_map_get(const RecordMap* map, const Question* key, Packet* value) { + if(map->len == 0) return false; + + Entry* e = record_map_find(map->entries, map->capacity, key); + if (e->key == NULL) return false; + + *value = *(e->value); + return true; +} + +void record_map_add(RecordMap* map, Question* key, Packet* value) { + if(map->len + 1 > map->capacity * 0.75) { + int capacity = (map->capacity == 0 ? 8 : (2 * map->capacity)); + record_map_grow(map, capacity); + } + Entry* e = record_map_find(map->entries, map->capacity, key); + bool new_key = e->key == NULL; + if(new_key) { + map->len++; + e->key = key; + } + + value->header.z = true; + e->value = value; +} + diff --git a/src/io/map.h b/src/io/map.h new file mode 100644 index 0000000..84b40fb --- /dev/null +++ b/src/io/map.h @@ -0,0 +1,20 @@ +#pragma once + +#include "../packet/packet.h" + +typedef struct { + Question* key; + Packet* value; +} Entry; + +typedef struct { + uint32_t capacity; + uint32_t len; + Entry* entries; +} RecordMap; + +void record_map_init(RecordMap* map); +void record_map_free(RecordMap* map); + +bool record_map_get(const RecordMap* map, const Question* key, Packet* value); +void record_map_add(RecordMap* map, Question* key, Packet* value); diff --git a/src/packet/packet.c b/src/packet/packet.c index 9b1159d..1d96e38 100644 --- a/src/packet/packet.c +++ b/src/packet/packet.c @@ -79,10 +79,10 @@ bool get_random_a(Packet* packet, IpAddr* addr) { for (uint16_t i = 0; i < packet->header.answers; i++) { Record record = packet->answers[i]; if (record.type == A) { - create_ip_addr((char*) &record.data.a.addr, addr); + create_ip_addr(record.data.a.addr, addr); return true; } else if (record.type == AAAA) { - create_ip_addr6((char*) &record.data.aaaa.addr, addr); + create_ip_addr6(record.data.aaaa.addr, addr); return true; } } @@ -138,10 +138,10 @@ bool get_resolved_ns(Packet* packet, uint8_t* qname, IpAddr* addr) { } if (resource.type == A) { - create_ip_addr((char*) &record.data.a.addr, addr); + create_ip_addr(record.data.a.addr, addr); return true; } else if (resource.type == AAAA) { - create_ip_addr6((char*) &record.data.aaaa.addr, addr); + create_ip_addr6(record.data.aaaa.addr, addr); return true; } } @@ -168,4 +168,4 @@ bool get_unresoled_ns(Packet* packet, uint8_t* qname, Question* question) { return true; } return false; -} \ No newline at end of file +} diff --git a/src/packet/question.c b/src/packet/question.c index c2807d0..5a08fd6 100644 --- a/src/packet/question.c +++ b/src/packet/question.c @@ -91,4 +91,5 @@ void print_question(Question* question, char* buffer) { question->domain[0], question->domain + 1 ); -} \ No newline at end of file +} + diff --git a/src/packet/record.c b/src/packet/record.c index 29c3bf0..e3e9077 100644 --- a/src/packet/record.c +++ b/src/packet/record.c @@ -537,4 +537,4 @@ void print_record(Record* record, char* buffer) { ); break; } -} \ No newline at end of file +} diff --git a/src/packet/record.h b/src/packet/record.h index 95bbbbe..479ce40 100644 --- a/src/packet/record.h +++ b/src/packet/record.h @@ -98,4 +98,4 @@ typedef struct { void read_record(PacketBuffer* buffer, Record* record); void write_record(PacketBuffer* buffer, Record* record); void free_record(Record* record); -void print_record(Record* record, char* buffer); \ No newline at end of file +void print_record(Record* record, char* buffer); diff --git a/src/server/addr.c b/src/server/addr.c index 982da13..14f44c6 100644 --- a/src/server/addr.c +++ b/src/server/addr.c @@ -7,12 +7,12 @@ #include "addr.h" #include "../io/log.h" -void create_ip_addr(char* domain, IpAddr* addr) { +void create_ip_addr(uint8_t* domain, IpAddr* addr) { addr->type = V4; memcpy(&addr->data.v4.s_addr, domain, 4); } -void create_ip_addr6(char* domain, IpAddr* addr) { +void create_ip_addr6(uint8_t* domain, IpAddr* addr) { addr->type = V6; memcpy(&addr->data.v6.__in6_u.__u6_addr8, domain, 16); } @@ -76,7 +76,7 @@ void print_socket_addr(SocketAddr* addr, char* buffer) { (uint8_t) ((uint32_t)addr->data.v4.sin_addr.s_addr >> 16), (uint8_t) ((uint32_t)addr->data.v4.sin_addr.s_addr >> 8), (uint8_t) ((uint32_t)addr->data.v4.sin_addr.s_addr), - addr->data.v4.sin_port + ntohs(addr->data.v4.sin_port) ); } else { for(int i = 0; i < 8; i++) { @@ -85,7 +85,7 @@ void print_socket_addr(SocketAddr* addr, char* buffer) { addr->data.v6.sin6_addr.__in6_u.__u6_addr8[i*2 + 1] ); } - APPEND(buffer, ":[%hu]", addr->data.v6.sin6_port); + APPEND(buffer, ":[%hu]", ntohs(addr->data.v6.sin6_port)); } } diff --git a/src/server/addr.h b/src/server/addr.h index 173c7fd..1210850 100644 --- a/src/server/addr.h +++ b/src/server/addr.h @@ -20,8 +20,8 @@ typedef struct { } data; } IpAddr; -void create_ip_addr(char* domain, IpAddr* addr); -void create_ip_addr6(char* domain, IpAddr* addr); +void create_ip_addr(uint8_t* domain, IpAddr* addr); +void create_ip_addr6(uint8_t* domain, IpAddr* addr); void ip_addr_any(IpAddr* addr); void ip_addr_any6(IpAddr* addr); diff --git a/src/server/binding.c b/src/server/binding.c index 47c62c6..157d1d4 100644 --- a/src/server/binding.c +++ b/src/server/binding.c @@ -144,10 +144,10 @@ bool read_connection(Connection* connection, Packet* packet) { } static bool write_udp(Connection* connection, uint8_t* buf, uint16_t len) { - //if (len > 512) { + if (len > 512) { buf[2] = buf[2] | 0x03; - // len = 512; - // } + len = 512; + } return write_udp_socket( &connection->sock.udp.udp, buf, diff --git a/src/server/resolver.c b/src/server/resolver.c index e05f365..a1fa82a 100644 --- a/src/server/resolver.c +++ b/src/server/resolver.c @@ -43,9 +43,9 @@ static bool lookup( return true; } -static bool search(Question* question, Packet* result, BindingType type) { +static bool search(Question* question, Packet* result, BindingType type, const RecordMap* map) { IpAddr addr; - char ip[4] = {1, 1, 1, 1}; + uint8_t ip[4] = {1, 1, 1, 1}; create_ip_addr(ip, &addr); uint16_t port = 53; @@ -53,6 +53,10 @@ static bool search(Question* question, Packet* result, BindingType type) { create_socket_addr(port, addr, &saddr); while(1) { + if (record_map_get(map, question, result)) { + return true; + } + if (!lookup(question, result, type, saddr)) { return false; } @@ -75,7 +79,7 @@ static bool search(Question* question, Packet* result, BindingType type) { } Packet recurse; - if (!search(&new_question, &recurse, type)) { + if (!search(&new_question, &recurse, type, map)) { return false; } @@ -104,7 +108,7 @@ static void push_questions(Question* from, uint8_t from_len, Question** to, uint memcpy(*to + to_len, from, from_len * sizeof(Question)); } -void handle_query(Packet* request, Packet* response, BindingType type) { +void handle_query(const Packet* request, Packet* response, BindingType type, const RecordMap* map) { memset(response, 0, sizeof(Packet)); response->header.id = request->header.id; response->header.opcode = request->header.opcode; @@ -121,7 +125,7 @@ void handle_query(Packet* request, Packet* response, BindingType type) { Packet result; memset(&result, 0, sizeof(Packet)); result.header.id = response->header.id; - if (!search(&request->questions[i], &result, type)) { + if (!search(&request->questions[i], &result, type, map)) { response->header.response = SERVFAIL; break; } @@ -158,9 +162,14 @@ void handle_query(Packet* request, Packet* response, BindingType type) { ); response->header.resource_entries += result.header.resource_entries; - free(result.questions); - free(result.answers); - free(result.authorities); - free(result.resources); + if (result.header.z == false) { + // Not from cache + free(result.questions); + free(result.answers); + free(result.authorities); + free(result.resources); + } else { + response->header.z = true; + } } -} \ No newline at end of file +} diff --git a/src/server/resolver.h b/src/server/resolver.h index 79b4825..993b515 100644 --- a/src/server/resolver.h +++ b/src/server/resolver.h @@ -1,6 +1,7 @@ #pragma once #include "../packet/packet.h" +#include "../io/map.h" #include "binding.h" -void handle_query(Packet* request, Packet* response, BindingType type); \ No newline at end of file +void handle_query(const Packet* request, Packet* response, BindingType type, const RecordMap* map); diff --git a/src/server/server.c b/src/server/server.c index c8975ee..27cc9d9 100644 --- a/src/server/server.c +++ b/src/server/server.c @@ -1,25 +1,67 @@ -#define _POSIX_SOURCE +#define _POSIX_C_SOURCE 200809L #include #include #include #include #include #include +#include #include "addr.h" #include "server.h" #include "resolver.h" #include "../io/log.h" +#include "../io/map.h" +#include "../io/config.h" -static pid_t udp, tcp; +static pthread_t udp, tcp; +static RecordMap map; void server_init(uint16_t port, Server* server) { INFO("Server port set to %hu", port); create_binding(UDP, port, &server->udp); create_binding(TCP, port, &server->tcp); + load_config("config", &map); } -static void server_listen(Binding* binding) { +struct DnsRequest { + Connection connection; + Packet request; +}; + +static void* server_respond(void* arg) { + struct DnsRequest req = *(struct DnsRequest*) arg; + + INFO("Recieved packet request ID %hu", req.request.header.id); + + Packet response; + handle_query(&req.request, &response, req.connection.type, &map); + + if (!write_connection(&req.connection, &response)) { + ERROR("Faled to respond to connection ID %hu: %s", + req.request.header.id, + strerror(errno) + ); + } + + free_packet(&req.request); + free_connection(&req.connection); + + if (response.header.z == false) { + free_packet(&response); + } else { + // Dont free from config + free(response.questions); + free(response.answers); + free(response.authorities); + free(response.resources); + } + + return NULL; +} + +static void* server_listen(void* arg) { + Binding* binding = (Binding*) arg; while(1) { Connection connection; @@ -34,67 +76,55 @@ static void server_listen(Binding* binding) { free_connection(&connection); continue; } + + struct DnsRequest req; + req.connection = connection; + req.request = request; - if(fork() != 0) { + pthread_t thread; + if(pthread_create(&thread, NULL, &server_respond, &req)) { + ERROR("Failed to create thread for dns request ID %hu: %s", + request.header.id, + strerror(errno) + ); free_packet(&request); free_connection(&connection); continue; } - - INFO("Recieved packet request ID %hu", request.header.id); - - Packet response; - handle_query(&request, &response, connection.type); - - if (!write_connection(&connection, &response)) { - ERROR("Failed to respond to connection ID %hu: %s", request.header.id, strerror(errno)); - } - - free_packet(&request); - free_packet(&response); - free_connection(&connection); - exit(EXIT_SUCCESS); } + + return NULL; } static void signal_handler() { printf("\n"); - kill(udp, SIGTERM); - kill(tcp, SIGTERM); + pthread_kill(udp, SIGTERM); + pthread_kill(tcp, SIGTERM); } void server_run(Server* server) { - if ((udp = fork()) == 0) { + if (!pthread_create(&udp, NULL, &server_listen, &server->udp)) { INFO("Listening for connections on UDP"); - server_listen(&server->udp); - exit(EXIT_SUCCESS); + } else { + ERROR("Failed to start UDP thread"); + exit(EXIT_FAILURE); } - if ((tcp = fork()) == 0) { + if (!pthread_create(&tcp, NULL, &server_listen, &server->tcp)) { INFO("Listening for connections on TCP"); - server_listen(&server->tcp); - exit(EXIT_SUCCESS); + } else { + ERROR("Failed to start TCP thread"); + exit(EXIT_FAILURE); } - signal(SIGINT, signal_handler); - int status; - waitpid(udp, &status, 0); - if (status == 0) { - INFO("UDP process closed successfully"); - } else { - ERROR("UDP process failed with error code %d", status); - } - - waitpid(tcp, &status, 0); - if (status == 0) { - INFO("TCP process closed successfully"); - } else { - ERROR("TCP process failed with error code %d", status); - } + pthread_join(udp, NULL); + pthread_join(tcp, NULL); } void server_free(Server* server) { free_binding(&server->udp); free_binding(&server->tcp); + record_map_free(&map); } +