read from config
This commit is contained in:
parent
4cd1cced75
commit
2a75d25632
16 changed files with 717 additions and 69 deletions
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -1 +1,2 @@
|
|||
bin
|
||||
config
|
||||
|
|
1
Makefile
1
Makefile
|
@ -6,6 +6,7 @@ CCFLAGS = -std=c17 -Wall -Wextra -O2
|
|||
CCFLAGS += $(INCFLAGS)
|
||||
|
||||
LDFLAGS += $(INCFLAGS)
|
||||
LDFLAGS += -lpthread
|
||||
|
||||
BIN = bin
|
||||
APP = $(BIN)/app
|
||||
|
|
476
src/io/config.c
Normal file
476
src/io/config.c
Normal file
|
@ -0,0 +1,476 @@
|
|||
#include <errno.h>
|
||||
#include <stdio.h>
|
||||
#include <string.h>
|
||||
#include <stdlib.h>
|
||||
|
||||
#include "config.h"
|
||||
#include "log.h"
|
||||
#include "map.h"
|
||||
|
||||
#define MAX_LEN 1024
|
||||
#define BUF(name) char name[MAX_LEN]
|
||||
|
||||
static int line = 0;
|
||||
|
||||
static bool get_line(FILE* file, const BUF(buf)) {
|
||||
line++;
|
||||
return fgets((char*) buf, MAX_LEN, file) != NULL;
|
||||
}
|
||||
|
||||
static bool is_whitespace(const char* buf) {
|
||||
int i = 0;
|
||||
char c;
|
||||
while(c = buf[i], 1) {
|
||||
if (c == '\n' || c == '\0') return true;
|
||||
if (c != ' ' && c != '\n') return false;
|
||||
i++;
|
||||
}
|
||||
}
|
||||
|
||||
static bool get_words(char* buf, char** words, int count) {
|
||||
int last = 0;
|
||||
int offset = 0;
|
||||
int i = 0;
|
||||
|
||||
for(i = 0; i < count; i++) {
|
||||
char c;
|
||||
while(c = buf[offset], c != ' ' && c != '\0' && c != '\n') {
|
||||
offset++;
|
||||
}
|
||||
|
||||
if (offset - last < 1) {
|
||||
return false;
|
||||
}
|
||||
words[i] = buf + last;
|
||||
buf[offset] = '\0';
|
||||
offset++;
|
||||
last = offset;
|
||||
|
||||
if (c == '\0' || c == '\n') {
|
||||
break;
|
||||
}
|
||||
|
||||
}
|
||||
return i + 1 == count;
|
||||
}
|
||||
|
||||
static bool get_int(const char* word, uint32_t* i) {
|
||||
char* end;
|
||||
uint32_t res = (uint32_t) strtol(word, &end, 10);
|
||||
|
||||
if (*end == '\0') {
|
||||
*i = res;
|
||||
return true;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
static bool config_read_qtype(const char* qstr, RecordType* qtype) {
|
||||
if (strcmp(qstr, "A") == 0) {
|
||||
*qtype = A;
|
||||
return true;
|
||||
} else if (strcmp(qstr, "NS") == 0) {
|
||||
*qtype = NS;
|
||||
return true;
|
||||
} else if (strcmp(qstr, "CNAME") == 0) {
|
||||
*qtype = CNAME;
|
||||
return true;
|
||||
} else if (strcmp(qstr, "SOA") == 0) {
|
||||
*qtype = SOA;
|
||||
return true;
|
||||
} else if (strcmp(qstr, "PTR") == 0) {
|
||||
*qtype = PTR;
|
||||
return true;
|
||||
} else if (strcmp(qstr, "MX") == 0) {
|
||||
*qtype = MX;
|
||||
return true;
|
||||
} else if (strcmp(qstr, "TXT") == 0) {
|
||||
*qtype = TXT;
|
||||
return true;
|
||||
} else if (strcmp(qstr, "AAAA") == 0) {
|
||||
*qtype = AAAA;
|
||||
return true;
|
||||
} else if (strcmp(qstr, "SRV") == 0) {
|
||||
*qtype = SRV;
|
||||
return true;
|
||||
} else if (strcmp(qstr, "CAA") == 0) {
|
||||
*qtype = CAA;
|
||||
return true;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
static bool config_read_class(const char* cstr, uint16_t* class) {
|
||||
if (strcmp(cstr, "IN") == 0) {
|
||||
*class = 1;
|
||||
return true;
|
||||
} else if (strcmp(cstr, "CH") == 0) {
|
||||
*class = 3;
|
||||
return true;
|
||||
} else if (strcmp(cstr, "HS") == 0) {
|
||||
*class = 4;
|
||||
return true;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// Format QTYPE CLASS DOMAIN: A IN google.com
|
||||
static bool config_read_question(FILE* file, Question* question) {
|
||||
BUF(buf);
|
||||
if (!get_line(file, buf)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (is_whitespace(buf)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
char* words[3];
|
||||
if (!get_words(&buf[0], &words[0], 3)) {
|
||||
WARN("Invalid question at line %d", line);
|
||||
return false;
|
||||
};
|
||||
|
||||
uint16_t class;
|
||||
if (!config_read_class(words[0], &class)) {
|
||||
WARN("Invalid question class at line %d", line);
|
||||
return false;
|
||||
}
|
||||
|
||||
RecordType qtype;
|
||||
if (!config_read_qtype(words[1], &qtype)) {
|
||||
WARN("Invalid question qtype at line %d", line);
|
||||
return false;
|
||||
}
|
||||
|
||||
size_t domain_len = strlen(words[2]);
|
||||
question->cls = class;
|
||||
question->qtype = qtype;
|
||||
question->domain = malloc(domain_len + 1);
|
||||
question->domain[0] = domain_len;
|
||||
memcpy(question->domain + 1, words[2], domain_len);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static void copy_str(char* from, uint8_t** to) {
|
||||
size_t len = strlen(from);
|
||||
if (len > 255) {
|
||||
len = 255;
|
||||
}
|
||||
|
||||
uint8_t* new = malloc(len + 1);
|
||||
new[0] = len;
|
||||
memcpy(new + 1, from, len);
|
||||
*to = new;
|
||||
}
|
||||
|
||||
static bool config_read_a_record(char* data, ARecord* record) {
|
||||
sscanf(data, "%hhu.%hhu.%hhu.%hhu",
|
||||
&record->addr[0],
|
||||
&record->addr[1],
|
||||
&record->addr[2],
|
||||
&record->addr[3]
|
||||
);
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool config_read_ns_record(char* data, NSRecord* record) {
|
||||
copy_str(data, &record->host);
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool config_read_cname_record(char* data, CNAMERecord* record) {
|
||||
copy_str(data, &record->host);
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool config_read_soa_record(char* data, SOARecord* record) {
|
||||
char* words[7];
|
||||
if (!get_words(&data[0], &words[0], 7)) {
|
||||
WARN("Invalid SOA record data at line %d", line);
|
||||
record->mname = NULL;
|
||||
record->nname = NULL;
|
||||
return false;
|
||||
}
|
||||
|
||||
copy_str(words[0], &record->mname);
|
||||
copy_str(words[1], &record->nname);
|
||||
|
||||
if (!get_int(words[2], &record->serial)) {
|
||||
WARN("Invalid SOA record data at line %d", line);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!get_int(words[3], &record->refresh)) {
|
||||
WARN("Invalid SOA record data at line %d", line);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!get_int(words[4], &record->retry)) {
|
||||
WARN("Invalid SOA record data at line %d", line);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!get_int(words[5], &record->expire)) {
|
||||
WARN("Invalid SOA record data at line %d", line);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!get_int(words[6], &record->minimum)) {
|
||||
WARN("Invalid SOA record data at line %d", line);
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool config_read_ptr_record(char* data, PTRRecord* record) {
|
||||
copy_str(data, &record->pointer);
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool config_read_mx_record(char* data, MXRecord* record) {
|
||||
char* words[2];
|
||||
if (!get_words(&data[0], &words[0], 2)) {
|
||||
WARN("Invalid MX record data at line %d", line);
|
||||
record->host = NULL;
|
||||
return false;
|
||||
}
|
||||
|
||||
copy_str(words[1], &record->host);
|
||||
|
||||
uint32_t priority;
|
||||
if (!get_int(words[0], &priority)) {
|
||||
WARN("Invalid MX record data at line %d", line);
|
||||
return false;
|
||||
}
|
||||
record->priority = (uint16_t) priority;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool config_read_txt_record(char* data, TXTRecord* record) {
|
||||
int len = strlen(data);
|
||||
uint8_t count = ((uint8_t)len + 254) / 255;
|
||||
record->len = count;
|
||||
record->text = malloc(sizeof(uint8_t*) * count);
|
||||
|
||||
for (uint8_t i = 0; i < count; i++) {
|
||||
uint32_t offset = count * 255;
|
||||
uint32_t part_len = len - offset;
|
||||
if (part_len > 255) part_len = 255;
|
||||
|
||||
uint8_t* part = malloc(part_len + 1);
|
||||
part[0] = part_len;
|
||||
memcpy(part + 1, data + offset, part_len);
|
||||
|
||||
record->text[i] = part;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool config_read_aaaa_record(char* data, AAAARecord* record) {
|
||||
for(int i = 0; i < 8; i++) {
|
||||
if (sscanf(data, "%02hhx%02hhx:",
|
||||
&record->addr[i*2 + 0],
|
||||
&record->addr[i*2 + 1]
|
||||
) == EOF) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool config_read_srv_record(char* data, SRVRecord* record) {
|
||||
char* words[4];
|
||||
if (!get_words(&data[0], &words[0], 4)) {
|
||||
WARN("Invalid SRV record data at line %d", line);
|
||||
record->target = NULL;
|
||||
return false;
|
||||
}
|
||||
|
||||
copy_str(words[3], &record->target);
|
||||
|
||||
uint32_t priority;
|
||||
if (!get_int(words[0], &priority)) {
|
||||
WARN("Invalid SRV record data at line %d", line);
|
||||
return false;
|
||||
}
|
||||
record->priority = (uint16_t) priority;
|
||||
|
||||
uint32_t weight;
|
||||
if (!get_int(words[1], &weight)) {
|
||||
WARN("Invalid SRV record data at line %d", line);
|
||||
return false;
|
||||
}
|
||||
record->weight = (uint16_t) weight;
|
||||
|
||||
uint32_t port;
|
||||
if (!get_int(words[2], &port)) {
|
||||
WARN("Invalid SRV record data at line %d", line);
|
||||
return false;
|
||||
}
|
||||
record->port = (uint16_t) port;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool config_read_caa_record(char* data, CAARecord* record) {
|
||||
char* words[4];
|
||||
if (!get_words(&data[0], &words[0], 4)) {
|
||||
WARN("Invalid SRV record data at line %d", line);
|
||||
record->tag = NULL;
|
||||
record->value = NULL;
|
||||
return false;
|
||||
}
|
||||
|
||||
copy_str(words[2], &record->tag);
|
||||
copy_str(words[3], &record->value);
|
||||
|
||||
uint32_t flags;
|
||||
if (!get_int(words[0], &flags)) {
|
||||
WARN("Invalid SRV record data at line %d", line);
|
||||
return false;
|
||||
}
|
||||
record->flags = (uint8_t) flags;
|
||||
|
||||
uint32_t length;
|
||||
if (!get_int(words[1], &length)) {
|
||||
WARN("Invalid SRV record data at line %d", line);
|
||||
return false;
|
||||
}
|
||||
record->length = (uint8_t) length;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool config_read_record_data(char* data, Record* record) {
|
||||
switch (record->type) {
|
||||
case UNKOWN:
|
||||
// This can never happend in here so uh do nothing i guess
|
||||
return false;
|
||||
case A:
|
||||
return config_read_a_record(data, &record->data.a);
|
||||
case NS:
|
||||
return config_read_ns_record(data, &record->data.ns);
|
||||
case CNAME:
|
||||
return config_read_cname_record(data, &record->data.cname);
|
||||
case SOA:
|
||||
return config_read_soa_record(data, &record->data.soa);
|
||||
case PTR:
|
||||
return config_read_ptr_record(data, &record->data.ptr);
|
||||
case MX:
|
||||
return config_read_mx_record(data, &record->data.mx);
|
||||
case TXT:
|
||||
return config_read_txt_record(data, &record->data.txt);
|
||||
case AAAA:
|
||||
return config_read_aaaa_record(data, &record->data.aaaa);
|
||||
case SRV:
|
||||
return config_read_srv_record(data, &record->data.srv);
|
||||
case CAA:
|
||||
return config_read_caa_record(data, &record->data.caa);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
static bool config_read_record(FILE* file, Record* record, Question* question) {
|
||||
BUF(buf);
|
||||
if (!get_line(file, buf)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (is_whitespace(buf)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
char* words[2];
|
||||
if (!get_words(&buf[0], &words[0], 2)) {
|
||||
WARN("Invalid record at line %d", line);
|
||||
return false;
|
||||
}
|
||||
|
||||
uint32_t ttl;
|
||||
if (!get_int(words[0], &ttl)) {
|
||||
WARN("Invalid record ttl at line %d", line);
|
||||
return false;
|
||||
}
|
||||
|
||||
record->cls = question->cls;
|
||||
record->type = question->qtype;
|
||||
record->len = 0;
|
||||
record->ttl = ttl;
|
||||
record->domain = malloc(question->domain[0] + 1);
|
||||
memcpy(record->domain, question->domain, question->domain[0] + 1);
|
||||
|
||||
if(!config_read_record_data(words[1], record)) {
|
||||
free_record(record);
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static void config_push_record(Record** buf, Record record, uint16_t* capacity, uint16_t* size) {
|
||||
if (size == capacity) {
|
||||
*capacity *= 2;
|
||||
*buf = realloc(*buf, sizeof(Record) * *capacity);
|
||||
}
|
||||
(*buf)[*size] = record;
|
||||
(*size)++;
|
||||
}
|
||||
|
||||
bool load_config(const char* path, RecordMap* map) {
|
||||
FILE* file = fopen(path, "r");
|
||||
if (file == NULL) {
|
||||
ERROR("Failed to open file %s: %s", path, strerror(errno));
|
||||
return false;
|
||||
}
|
||||
|
||||
line = 0;
|
||||
record_map_init(map);
|
||||
|
||||
while (1) {
|
||||
Question* question = malloc(sizeof(Question));
|
||||
if (!config_read_question(file, question)) {
|
||||
free(question);
|
||||
break;
|
||||
}
|
||||
|
||||
INIT_LOG_BUFFER(log);
|
||||
LOGONLY(print_question(question, log));
|
||||
TRACE("Found config question: %s", log);
|
||||
|
||||
Packet* packet = malloc(sizeof(Packet));
|
||||
memset(packet, 0, sizeof(Packet));
|
||||
packet->authorities = NULL;
|
||||
packet->resources = NULL;
|
||||
|
||||
packet->questions = malloc(sizeof(Question));
|
||||
packet->questions[0] = *question;
|
||||
|
||||
uint16_t capacity = 1;
|
||||
packet->answers = malloc(sizeof(Record));
|
||||
|
||||
while(1) {
|
||||
Record record;
|
||||
if (!config_read_record(file, &record, question)) {
|
||||
break;
|
||||
}
|
||||
|
||||
LOGONLY(print_record(&record, log));
|
||||
TRACE("Found config record: %s", log);
|
||||
|
||||
config_push_record(&packet->answers, record, &capacity, &packet->header.answers);
|
||||
}
|
||||
|
||||
record_map_add(map, question, packet);
|
||||
}
|
||||
|
||||
fclose(file);
|
||||
return true;
|
||||
}
|
5
src/io/config.h
Normal file
5
src/io/config.h
Normal file
|
@ -0,0 +1,5 @@
|
|||
#pragma once
|
||||
|
||||
#include "map.h"
|
||||
|
||||
bool load_config(const char* path, RecordMap* map);
|
104
src/io/map.c
Normal file
104
src/io/map.c
Normal file
|
@ -0,0 +1,104 @@
|
|||
#include <string.h>
|
||||
#include <stdlib.h>
|
||||
|
||||
#include "map.h"
|
||||
|
||||
void record_map_init(RecordMap* map) {
|
||||
map->capacity = 0;
|
||||
map->len = 0;
|
||||
map->entries = NULL;
|
||||
}
|
||||
|
||||
void record_map_free(RecordMap* map) {
|
||||
for(uint32_t i = 0; i < map->capacity; i++) {
|
||||
Entry* e = &map->entries[i];
|
||||
if (e->key != NULL) {
|
||||
free_question(e->key);
|
||||
free(e->key);
|
||||
free_packet(e->value);
|
||||
free(e->value);
|
||||
}
|
||||
}
|
||||
free(map->entries);
|
||||
}
|
||||
|
||||
static size_t hash_question(const Question* question) {
|
||||
size_t hash = 5381;
|
||||
for(int i = 0; i < question->domain[0]; i++) {
|
||||
uint8_t c = question->domain[i+1];
|
||||
hash = ((hash << 5) + hash) + c;
|
||||
}
|
||||
hash = ((hash << 5) + hash) + (uint8_t)question->cls;
|
||||
hash = ((hash << 5) + hash) + (uint8_t)question->qtype;
|
||||
return hash;
|
||||
}
|
||||
|
||||
static bool question_equals(const Question* a, const Question* b) {
|
||||
if (a->qtype != b->qtype) return false;
|
||||
if (a->cls != b->cls) return false;
|
||||
if (a->domain[0] != b->domain[0]) return false;
|
||||
return memcmp(a->domain+1, b->domain+1, a->domain[0]) == 0;
|
||||
}
|
||||
|
||||
static Entry* record_map_find(Entry* entries, uint32_t capacity, const Question* key) {
|
||||
uint32_t index = hash_question(key) % capacity;
|
||||
while(true) {
|
||||
Entry* entry = &entries[index];
|
||||
if(entry->key == NULL) {
|
||||
return entry;
|
||||
} else if(question_equals(entry->key, key)) {
|
||||
return entry;
|
||||
}
|
||||
index += 1;
|
||||
index %= capacity;
|
||||
}
|
||||
}
|
||||
|
||||
static void record_map_grow(RecordMap* map, uint32_t capacity) {
|
||||
Entry* entries = malloc(capacity * sizeof(Entry));
|
||||
for(uint32_t i = 0; i < capacity; i++) {
|
||||
entries[i].key = NULL;
|
||||
entries[i].value = NULL;
|
||||
}
|
||||
map->len = 0;
|
||||
for(uint32_t i = 0; i < map->capacity; i++) {
|
||||
Entry* entry = &map->entries[i];
|
||||
if(entry->key == NULL) continue;
|
||||
|
||||
Entry* dest = record_map_find(entries, capacity, entry->key);
|
||||
dest->key = entry->key;
|
||||
dest->value = entry->value;
|
||||
map->len++;
|
||||
}
|
||||
free(map->entries);
|
||||
|
||||
map->entries = entries;
|
||||
map->capacity = capacity;
|
||||
}
|
||||
|
||||
bool record_map_get(const RecordMap* map, const Question* key, Packet* value) {
|
||||
if(map->len == 0) return false;
|
||||
|
||||
Entry* e = record_map_find(map->entries, map->capacity, key);
|
||||
if (e->key == NULL) return false;
|
||||
|
||||
*value = *(e->value);
|
||||
return true;
|
||||
}
|
||||
|
||||
void record_map_add(RecordMap* map, Question* key, Packet* value) {
|
||||
if(map->len + 1 > map->capacity * 0.75) {
|
||||
int capacity = (map->capacity == 0 ? 8 : (2 * map->capacity));
|
||||
record_map_grow(map, capacity);
|
||||
}
|
||||
Entry* e = record_map_find(map->entries, map->capacity, key);
|
||||
bool new_key = e->key == NULL;
|
||||
if(new_key) {
|
||||
map->len++;
|
||||
e->key = key;
|
||||
}
|
||||
|
||||
value->header.z = true;
|
||||
e->value = value;
|
||||
}
|
||||
|
20
src/io/map.h
Normal file
20
src/io/map.h
Normal file
|
@ -0,0 +1,20 @@
|
|||
#pragma once
|
||||
|
||||
#include "../packet/packet.h"
|
||||
|
||||
typedef struct {
|
||||
Question* key;
|
||||
Packet* value;
|
||||
} Entry;
|
||||
|
||||
typedef struct {
|
||||
uint32_t capacity;
|
||||
uint32_t len;
|
||||
Entry* entries;
|
||||
} RecordMap;
|
||||
|
||||
void record_map_init(RecordMap* map);
|
||||
void record_map_free(RecordMap* map);
|
||||
|
||||
bool record_map_get(const RecordMap* map, const Question* key, Packet* value);
|
||||
void record_map_add(RecordMap* map, Question* key, Packet* value);
|
|
@ -79,10 +79,10 @@ bool get_random_a(Packet* packet, IpAddr* addr) {
|
|||
for (uint16_t i = 0; i < packet->header.answers; i++) {
|
||||
Record record = packet->answers[i];
|
||||
if (record.type == A) {
|
||||
create_ip_addr((char*) &record.data.a.addr, addr);
|
||||
create_ip_addr(record.data.a.addr, addr);
|
||||
return true;
|
||||
} else if (record.type == AAAA) {
|
||||
create_ip_addr6((char*) &record.data.aaaa.addr, addr);
|
||||
create_ip_addr6(record.data.aaaa.addr, addr);
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
@ -138,10 +138,10 @@ bool get_resolved_ns(Packet* packet, uint8_t* qname, IpAddr* addr) {
|
|||
}
|
||||
|
||||
if (resource.type == A) {
|
||||
create_ip_addr((char*) &record.data.a.addr, addr);
|
||||
create_ip_addr(record.data.a.addr, addr);
|
||||
return true;
|
||||
} else if (resource.type == AAAA) {
|
||||
create_ip_addr6((char*) &record.data.aaaa.addr, addr);
|
||||
create_ip_addr6(record.data.aaaa.addr, addr);
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
@ -168,4 +168,4 @@ bool get_unresoled_ns(Packet* packet, uint8_t* qname, Question* question) {
|
|||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -91,4 +91,5 @@ void print_question(Question* question, char* buffer) {
|
|||
question->domain[0],
|
||||
question->domain + 1
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -537,4 +537,4 @@ void print_record(Record* record, char* buffer) {
|
|||
);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -98,4 +98,4 @@ typedef struct {
|
|||
void read_record(PacketBuffer* buffer, Record* record);
|
||||
void write_record(PacketBuffer* buffer, Record* record);
|
||||
void free_record(Record* record);
|
||||
void print_record(Record* record, char* buffer);
|
||||
void print_record(Record* record, char* buffer);
|
||||
|
|
|
@ -7,12 +7,12 @@
|
|||
#include "addr.h"
|
||||
#include "../io/log.h"
|
||||
|
||||
void create_ip_addr(char* domain, IpAddr* addr) {
|
||||
void create_ip_addr(uint8_t* domain, IpAddr* addr) {
|
||||
addr->type = V4;
|
||||
memcpy(&addr->data.v4.s_addr, domain, 4);
|
||||
}
|
||||
|
||||
void create_ip_addr6(char* domain, IpAddr* addr) {
|
||||
void create_ip_addr6(uint8_t* domain, IpAddr* addr) {
|
||||
addr->type = V6;
|
||||
memcpy(&addr->data.v6.__in6_u.__u6_addr8, domain, 16);
|
||||
}
|
||||
|
@ -76,7 +76,7 @@ void print_socket_addr(SocketAddr* addr, char* buffer) {
|
|||
(uint8_t) ((uint32_t)addr->data.v4.sin_addr.s_addr >> 16),
|
||||
(uint8_t) ((uint32_t)addr->data.v4.sin_addr.s_addr >> 8),
|
||||
(uint8_t) ((uint32_t)addr->data.v4.sin_addr.s_addr),
|
||||
addr->data.v4.sin_port
|
||||
ntohs(addr->data.v4.sin_port)
|
||||
);
|
||||
} else {
|
||||
for(int i = 0; i < 8; i++) {
|
||||
|
@ -85,7 +85,7 @@ void print_socket_addr(SocketAddr* addr, char* buffer) {
|
|||
addr->data.v6.sin6_addr.__in6_u.__u6_addr8[i*2 + 1]
|
||||
);
|
||||
}
|
||||
APPEND(buffer, ":[%hu]", addr->data.v6.sin6_port);
|
||||
APPEND(buffer, ":[%hu]", ntohs(addr->data.v6.sin6_port));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -20,8 +20,8 @@ typedef struct {
|
|||
} data;
|
||||
} IpAddr;
|
||||
|
||||
void create_ip_addr(char* domain, IpAddr* addr);
|
||||
void create_ip_addr6(char* domain, IpAddr* addr);
|
||||
void create_ip_addr(uint8_t* domain, IpAddr* addr);
|
||||
void create_ip_addr6(uint8_t* domain, IpAddr* addr);
|
||||
void ip_addr_any(IpAddr* addr);
|
||||
void ip_addr_any6(IpAddr* addr);
|
||||
|
||||
|
|
|
@ -144,10 +144,10 @@ bool read_connection(Connection* connection, Packet* packet) {
|
|||
}
|
||||
|
||||
static bool write_udp(Connection* connection, uint8_t* buf, uint16_t len) {
|
||||
//if (len > 512) {
|
||||
if (len > 512) {
|
||||
buf[2] = buf[2] | 0x03;
|
||||
// len = 512;
|
||||
// }
|
||||
len = 512;
|
||||
}
|
||||
return write_udp_socket(
|
||||
&connection->sock.udp.udp,
|
||||
buf,
|
||||
|
|
|
@ -43,9 +43,9 @@ static bool lookup(
|
|||
return true;
|
||||
}
|
||||
|
||||
static bool search(Question* question, Packet* result, BindingType type) {
|
||||
static bool search(Question* question, Packet* result, BindingType type, const RecordMap* map) {
|
||||
IpAddr addr;
|
||||
char ip[4] = {1, 1, 1, 1};
|
||||
uint8_t ip[4] = {1, 1, 1, 1};
|
||||
create_ip_addr(ip, &addr);
|
||||
|
||||
uint16_t port = 53;
|
||||
|
@ -53,6 +53,10 @@ static bool search(Question* question, Packet* result, BindingType type) {
|
|||
create_socket_addr(port, addr, &saddr);
|
||||
|
||||
while(1) {
|
||||
if (record_map_get(map, question, result)) {
|
||||
return true;
|
||||
}
|
||||
|
||||
if (!lookup(question, result, type, saddr)) {
|
||||
return false;
|
||||
}
|
||||
|
@ -75,7 +79,7 @@ static bool search(Question* question, Packet* result, BindingType type) {
|
|||
}
|
||||
|
||||
Packet recurse;
|
||||
if (!search(&new_question, &recurse, type)) {
|
||||
if (!search(&new_question, &recurse, type, map)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
@ -104,7 +108,7 @@ static void push_questions(Question* from, uint8_t from_len, Question** to, uint
|
|||
memcpy(*to + to_len, from, from_len * sizeof(Question));
|
||||
}
|
||||
|
||||
void handle_query(Packet* request, Packet* response, BindingType type) {
|
||||
void handle_query(const Packet* request, Packet* response, BindingType type, const RecordMap* map) {
|
||||
memset(response, 0, sizeof(Packet));
|
||||
response->header.id = request->header.id;
|
||||
response->header.opcode = request->header.opcode;
|
||||
|
@ -121,7 +125,7 @@ void handle_query(Packet* request, Packet* response, BindingType type) {
|
|||
Packet result;
|
||||
memset(&result, 0, sizeof(Packet));
|
||||
result.header.id = response->header.id;
|
||||
if (!search(&request->questions[i], &result, type)) {
|
||||
if (!search(&request->questions[i], &result, type, map)) {
|
||||
response->header.response = SERVFAIL;
|
||||
break;
|
||||
}
|
||||
|
@ -158,9 +162,14 @@ void handle_query(Packet* request, Packet* response, BindingType type) {
|
|||
);
|
||||
response->header.resource_entries += result.header.resource_entries;
|
||||
|
||||
free(result.questions);
|
||||
free(result.answers);
|
||||
free(result.authorities);
|
||||
free(result.resources);
|
||||
if (result.header.z == false) {
|
||||
// Not from cache
|
||||
free(result.questions);
|
||||
free(result.answers);
|
||||
free(result.authorities);
|
||||
free(result.resources);
|
||||
} else {
|
||||
response->header.z = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
#pragma once
|
||||
|
||||
#include "../packet/packet.h"
|
||||
#include "../io/map.h"
|
||||
#include "binding.h"
|
||||
|
||||
void handle_query(Packet* request, Packet* response, BindingType type);
|
||||
void handle_query(const Packet* request, Packet* response, BindingType type, const RecordMap* map);
|
||||
|
|
|
@ -1,25 +1,67 @@
|
|||
#define _POSIX_SOURCE
|
||||
#define _POSIX_C_SOURCE 200809L
|
||||
#include <errno.h>
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <unistd.h>
|
||||
#include <sys/wait.h>
|
||||
#include <signal.h>
|
||||
#include <pthread.h>
|
||||
|
||||
#include "addr.h"
|
||||
#include "server.h"
|
||||
#include "resolver.h"
|
||||
#include "../io/log.h"
|
||||
#include "../io/map.h"
|
||||
#include "../io/config.h"
|
||||
|
||||
static pid_t udp, tcp;
|
||||
static pthread_t udp, tcp;
|
||||
static RecordMap map;
|
||||
|
||||
void server_init(uint16_t port, Server* server) {
|
||||
INFO("Server port set to %hu", port);
|
||||
create_binding(UDP, port, &server->udp);
|
||||
create_binding(TCP, port, &server->tcp);
|
||||
load_config("config", &map);
|
||||
}
|
||||
|
||||
static void server_listen(Binding* binding) {
|
||||
struct DnsRequest {
|
||||
Connection connection;
|
||||
Packet request;
|
||||
};
|
||||
|
||||
static void* server_respond(void* arg) {
|
||||
struct DnsRequest req = *(struct DnsRequest*) arg;
|
||||
|
||||
INFO("Recieved packet request ID %hu", req.request.header.id);
|
||||
|
||||
Packet response;
|
||||
handle_query(&req.request, &response, req.connection.type, &map);
|
||||
|
||||
if (!write_connection(&req.connection, &response)) {
|
||||
ERROR("Faled to respond to connection ID %hu: %s",
|
||||
req.request.header.id,
|
||||
strerror(errno)
|
||||
);
|
||||
}
|
||||
|
||||
free_packet(&req.request);
|
||||
free_connection(&req.connection);
|
||||
|
||||
if (response.header.z == false) {
|
||||
free_packet(&response);
|
||||
} else {
|
||||
// Dont free from config
|
||||
free(response.questions);
|
||||
free(response.answers);
|
||||
free(response.authorities);
|
||||
free(response.resources);
|
||||
}
|
||||
|
||||
return NULL;
|
||||
}
|
||||
|
||||
static void* server_listen(void* arg) {
|
||||
Binding* binding = (Binding*) arg;
|
||||
while(1) {
|
||||
|
||||
Connection connection;
|
||||
|
@ -34,67 +76,55 @@ static void server_listen(Binding* binding) {
|
|||
free_connection(&connection);
|
||||
continue;
|
||||
}
|
||||
|
||||
struct DnsRequest req;
|
||||
req.connection = connection;
|
||||
req.request = request;
|
||||
|
||||
if(fork() != 0) {
|
||||
pthread_t thread;
|
||||
if(pthread_create(&thread, NULL, &server_respond, &req)) {
|
||||
ERROR("Failed to create thread for dns request ID %hu: %s",
|
||||
request.header.id,
|
||||
strerror(errno)
|
||||
);
|
||||
free_packet(&request);
|
||||
free_connection(&connection);
|
||||
continue;
|
||||
}
|
||||
|
||||
INFO("Recieved packet request ID %hu", request.header.id);
|
||||
|
||||
Packet response;
|
||||
handle_query(&request, &response, connection.type);
|
||||
|
||||
if (!write_connection(&connection, &response)) {
|
||||
ERROR("Failed to respond to connection ID %hu: %s", request.header.id, strerror(errno));
|
||||
}
|
||||
|
||||
free_packet(&request);
|
||||
free_packet(&response);
|
||||
free_connection(&connection);
|
||||
exit(EXIT_SUCCESS);
|
||||
}
|
||||
|
||||
return NULL;
|
||||
}
|
||||
|
||||
static void signal_handler() {
|
||||
printf("\n");
|
||||
kill(udp, SIGTERM);
|
||||
kill(tcp, SIGTERM);
|
||||
pthread_kill(udp, SIGTERM);
|
||||
pthread_kill(tcp, SIGTERM);
|
||||
}
|
||||
|
||||
void server_run(Server* server) {
|
||||
if ((udp = fork()) == 0) {
|
||||
if (!pthread_create(&udp, NULL, &server_listen, &server->udp)) {
|
||||
INFO("Listening for connections on UDP");
|
||||
server_listen(&server->udp);
|
||||
exit(EXIT_SUCCESS);
|
||||
} else {
|
||||
ERROR("Failed to start UDP thread");
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
|
||||
if ((tcp = fork()) == 0) {
|
||||
if (!pthread_create(&tcp, NULL, &server_listen, &server->tcp)) {
|
||||
INFO("Listening for connections on TCP");
|
||||
server_listen(&server->tcp);
|
||||
exit(EXIT_SUCCESS);
|
||||
} else {
|
||||
ERROR("Failed to start TCP thread");
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
|
||||
signal(SIGINT, signal_handler);
|
||||
|
||||
int status;
|
||||
waitpid(udp, &status, 0);
|
||||
if (status == 0) {
|
||||
INFO("UDP process closed successfully");
|
||||
} else {
|
||||
ERROR("UDP process failed with error code %d", status);
|
||||
}
|
||||
|
||||
waitpid(tcp, &status, 0);
|
||||
if (status == 0) {
|
||||
INFO("TCP process closed successfully");
|
||||
} else {
|
||||
ERROR("TCP process failed with error code %d", status);
|
||||
}
|
||||
pthread_join(udp, NULL);
|
||||
pthread_join(tcp, NULL);
|
||||
}
|
||||
|
||||
void server_free(Server* server) {
|
||||
free_binding(&server->udp);
|
||||
free_binding(&server->tcp);
|
||||
record_map_free(&map);
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in a new issue