read from config

This commit is contained in:
Freya Murphy 2023-04-08 15:58:03 -04:00
parent 4cd1cced75
commit 2a75d25632
16 changed files with 717 additions and 69 deletions

1
.gitignore vendored
View file

@ -1 +1,2 @@
bin
config

View file

@ -6,6 +6,7 @@ CCFLAGS = -std=c17 -Wall -Wextra -O2
CCFLAGS += $(INCFLAGS)
LDFLAGS += $(INCFLAGS)
LDFLAGS += -lpthread
BIN = bin
APP = $(BIN)/app

476
src/io/config.c Normal file
View file

@ -0,0 +1,476 @@
#include <errno.h>
#include <stdio.h>
#include <string.h>
#include <stdlib.h>
#include "config.h"
#include "log.h"
#include "map.h"
#define MAX_LEN 1024
#define BUF(name) char name[MAX_LEN]
static int line = 0;
static bool get_line(FILE* file, const BUF(buf)) {
line++;
return fgets((char*) buf, MAX_LEN, file) != NULL;
}
static bool is_whitespace(const char* buf) {
int i = 0;
char c;
while(c = buf[i], 1) {
if (c == '\n' || c == '\0') return true;
if (c != ' ' && c != '\n') return false;
i++;
}
}
static bool get_words(char* buf, char** words, int count) {
int last = 0;
int offset = 0;
int i = 0;
for(i = 0; i < count; i++) {
char c;
while(c = buf[offset], c != ' ' && c != '\0' && c != '\n') {
offset++;
}
if (offset - last < 1) {
return false;
}
words[i] = buf + last;
buf[offset] = '\0';
offset++;
last = offset;
if (c == '\0' || c == '\n') {
break;
}
}
return i + 1 == count;
}
static bool get_int(const char* word, uint32_t* i) {
char* end;
uint32_t res = (uint32_t) strtol(word, &end, 10);
if (*end == '\0') {
*i = res;
return true;
} else {
return false;
}
}
static bool config_read_qtype(const char* qstr, RecordType* qtype) {
if (strcmp(qstr, "A") == 0) {
*qtype = A;
return true;
} else if (strcmp(qstr, "NS") == 0) {
*qtype = NS;
return true;
} else if (strcmp(qstr, "CNAME") == 0) {
*qtype = CNAME;
return true;
} else if (strcmp(qstr, "SOA") == 0) {
*qtype = SOA;
return true;
} else if (strcmp(qstr, "PTR") == 0) {
*qtype = PTR;
return true;
} else if (strcmp(qstr, "MX") == 0) {
*qtype = MX;
return true;
} else if (strcmp(qstr, "TXT") == 0) {
*qtype = TXT;
return true;
} else if (strcmp(qstr, "AAAA") == 0) {
*qtype = AAAA;
return true;
} else if (strcmp(qstr, "SRV") == 0) {
*qtype = SRV;
return true;
} else if (strcmp(qstr, "CAA") == 0) {
*qtype = CAA;
return true;
} else {
return false;
}
}
static bool config_read_class(const char* cstr, uint16_t* class) {
if (strcmp(cstr, "IN") == 0) {
*class = 1;
return true;
} else if (strcmp(cstr, "CH") == 0) {
*class = 3;
return true;
} else if (strcmp(cstr, "HS") == 0) {
*class = 4;
return true;
} else {
return false;
}
}
// Format QTYPE CLASS DOMAIN: A IN google.com
static bool config_read_question(FILE* file, Question* question) {
BUF(buf);
if (!get_line(file, buf)) {
return false;
}
if (is_whitespace(buf)) {
return false;
}
char* words[3];
if (!get_words(&buf[0], &words[0], 3)) {
WARN("Invalid question at line %d", line);
return false;
};
uint16_t class;
if (!config_read_class(words[0], &class)) {
WARN("Invalid question class at line %d", line);
return false;
}
RecordType qtype;
if (!config_read_qtype(words[1], &qtype)) {
WARN("Invalid question qtype at line %d", line);
return false;
}
size_t domain_len = strlen(words[2]);
question->cls = class;
question->qtype = qtype;
question->domain = malloc(domain_len + 1);
question->domain[0] = domain_len;
memcpy(question->domain + 1, words[2], domain_len);
return true;
}
static void copy_str(char* from, uint8_t** to) {
size_t len = strlen(from);
if (len > 255) {
len = 255;
}
uint8_t* new = malloc(len + 1);
new[0] = len;
memcpy(new + 1, from, len);
*to = new;
}
static bool config_read_a_record(char* data, ARecord* record) {
sscanf(data, "%hhu.%hhu.%hhu.%hhu",
&record->addr[0],
&record->addr[1],
&record->addr[2],
&record->addr[3]
);
return true;
}
static bool config_read_ns_record(char* data, NSRecord* record) {
copy_str(data, &record->host);
return true;
}
static bool config_read_cname_record(char* data, CNAMERecord* record) {
copy_str(data, &record->host);
return true;
}
static bool config_read_soa_record(char* data, SOARecord* record) {
char* words[7];
if (!get_words(&data[0], &words[0], 7)) {
WARN("Invalid SOA record data at line %d", line);
record->mname = NULL;
record->nname = NULL;
return false;
}
copy_str(words[0], &record->mname);
copy_str(words[1], &record->nname);
if (!get_int(words[2], &record->serial)) {
WARN("Invalid SOA record data at line %d", line);
return false;
}
if (!get_int(words[3], &record->refresh)) {
WARN("Invalid SOA record data at line %d", line);
return false;
}
if (!get_int(words[4], &record->retry)) {
WARN("Invalid SOA record data at line %d", line);
return false;
}
if (!get_int(words[5], &record->expire)) {
WARN("Invalid SOA record data at line %d", line);
return false;
}
if (!get_int(words[6], &record->minimum)) {
WARN("Invalid SOA record data at line %d", line);
return false;
}
return true;
}
static bool config_read_ptr_record(char* data, PTRRecord* record) {
copy_str(data, &record->pointer);
return true;
}
static bool config_read_mx_record(char* data, MXRecord* record) {
char* words[2];
if (!get_words(&data[0], &words[0], 2)) {
WARN("Invalid MX record data at line %d", line);
record->host = NULL;
return false;
}
copy_str(words[1], &record->host);
uint32_t priority;
if (!get_int(words[0], &priority)) {
WARN("Invalid MX record data at line %d", line);
return false;
}
record->priority = (uint16_t) priority;
return true;
}
static bool config_read_txt_record(char* data, TXTRecord* record) {
int len = strlen(data);
uint8_t count = ((uint8_t)len + 254) / 255;
record->len = count;
record->text = malloc(sizeof(uint8_t*) * count);
for (uint8_t i = 0; i < count; i++) {
uint32_t offset = count * 255;
uint32_t part_len = len - offset;
if (part_len > 255) part_len = 255;
uint8_t* part = malloc(part_len + 1);
part[0] = part_len;
memcpy(part + 1, data + offset, part_len);
record->text[i] = part;
}
return true;
}
static bool config_read_aaaa_record(char* data, AAAARecord* record) {
for(int i = 0; i < 8; i++) {
if (sscanf(data, "%02hhx%02hhx:",
&record->addr[i*2 + 0],
&record->addr[i*2 + 1]
) == EOF) {
return false;
}
}
return true;
}
static bool config_read_srv_record(char* data, SRVRecord* record) {
char* words[4];
if (!get_words(&data[0], &words[0], 4)) {
WARN("Invalid SRV record data at line %d", line);
record->target = NULL;
return false;
}
copy_str(words[3], &record->target);
uint32_t priority;
if (!get_int(words[0], &priority)) {
WARN("Invalid SRV record data at line %d", line);
return false;
}
record->priority = (uint16_t) priority;
uint32_t weight;
if (!get_int(words[1], &weight)) {
WARN("Invalid SRV record data at line %d", line);
return false;
}
record->weight = (uint16_t) weight;
uint32_t port;
if (!get_int(words[2], &port)) {
WARN("Invalid SRV record data at line %d", line);
return false;
}
record->port = (uint16_t) port;
return true;
}
static bool config_read_caa_record(char* data, CAARecord* record) {
char* words[4];
if (!get_words(&data[0], &words[0], 4)) {
WARN("Invalid SRV record data at line %d", line);
record->tag = NULL;
record->value = NULL;
return false;
}
copy_str(words[2], &record->tag);
copy_str(words[3], &record->value);
uint32_t flags;
if (!get_int(words[0], &flags)) {
WARN("Invalid SRV record data at line %d", line);
return false;
}
record->flags = (uint8_t) flags;
uint32_t length;
if (!get_int(words[1], &length)) {
WARN("Invalid SRV record data at line %d", line);
return false;
}
record->length = (uint8_t) length;
return true;
}
static bool config_read_record_data(char* data, Record* record) {
switch (record->type) {
case UNKOWN:
// This can never happend in here so uh do nothing i guess
return false;
case A:
return config_read_a_record(data, &record->data.a);
case NS:
return config_read_ns_record(data, &record->data.ns);
case CNAME:
return config_read_cname_record(data, &record->data.cname);
case SOA:
return config_read_soa_record(data, &record->data.soa);
case PTR:
return config_read_ptr_record(data, &record->data.ptr);
case MX:
return config_read_mx_record(data, &record->data.mx);
case TXT:
return config_read_txt_record(data, &record->data.txt);
case AAAA:
return config_read_aaaa_record(data, &record->data.aaaa);
case SRV:
return config_read_srv_record(data, &record->data.srv);
case CAA:
return config_read_caa_record(data, &record->data.caa);
}
return false;
}
static bool config_read_record(FILE* file, Record* record, Question* question) {
BUF(buf);
if (!get_line(file, buf)) {
return false;
}
if (is_whitespace(buf)) {
return false;
}
char* words[2];
if (!get_words(&buf[0], &words[0], 2)) {
WARN("Invalid record at line %d", line);
return false;
}
uint32_t ttl;
if (!get_int(words[0], &ttl)) {
WARN("Invalid record ttl at line %d", line);
return false;
}
record->cls = question->cls;
record->type = question->qtype;
record->len = 0;
record->ttl = ttl;
record->domain = malloc(question->domain[0] + 1);
memcpy(record->domain, question->domain, question->domain[0] + 1);
if(!config_read_record_data(words[1], record)) {
free_record(record);
return false;
}
return true;
}
static void config_push_record(Record** buf, Record record, uint16_t* capacity, uint16_t* size) {
if (size == capacity) {
*capacity *= 2;
*buf = realloc(*buf, sizeof(Record) * *capacity);
}
(*buf)[*size] = record;
(*size)++;
}
bool load_config(const char* path, RecordMap* map) {
FILE* file = fopen(path, "r");
if (file == NULL) {
ERROR("Failed to open file %s: %s", path, strerror(errno));
return false;
}
line = 0;
record_map_init(map);
while (1) {
Question* question = malloc(sizeof(Question));
if (!config_read_question(file, question)) {
free(question);
break;
}
INIT_LOG_BUFFER(log);
LOGONLY(print_question(question, log));
TRACE("Found config question: %s", log);
Packet* packet = malloc(sizeof(Packet));
memset(packet, 0, sizeof(Packet));
packet->authorities = NULL;
packet->resources = NULL;
packet->questions = malloc(sizeof(Question));
packet->questions[0] = *question;
uint16_t capacity = 1;
packet->answers = malloc(sizeof(Record));
while(1) {
Record record;
if (!config_read_record(file, &record, question)) {
break;
}
LOGONLY(print_record(&record, log));
TRACE("Found config record: %s", log);
config_push_record(&packet->answers, record, &capacity, &packet->header.answers);
}
record_map_add(map, question, packet);
}
fclose(file);
return true;
}

5
src/io/config.h Normal file
View file

@ -0,0 +1,5 @@
#pragma once
#include "map.h"
bool load_config(const char* path, RecordMap* map);

104
src/io/map.c Normal file
View file

@ -0,0 +1,104 @@
#include <string.h>
#include <stdlib.h>
#include "map.h"
void record_map_init(RecordMap* map) {
map->capacity = 0;
map->len = 0;
map->entries = NULL;
}
void record_map_free(RecordMap* map) {
for(uint32_t i = 0; i < map->capacity; i++) {
Entry* e = &map->entries[i];
if (e->key != NULL) {
free_question(e->key);
free(e->key);
free_packet(e->value);
free(e->value);
}
}
free(map->entries);
}
static size_t hash_question(const Question* question) {
size_t hash = 5381;
for(int i = 0; i < question->domain[0]; i++) {
uint8_t c = question->domain[i+1];
hash = ((hash << 5) + hash) + c;
}
hash = ((hash << 5) + hash) + (uint8_t)question->cls;
hash = ((hash << 5) + hash) + (uint8_t)question->qtype;
return hash;
}
static bool question_equals(const Question* a, const Question* b) {
if (a->qtype != b->qtype) return false;
if (a->cls != b->cls) return false;
if (a->domain[0] != b->domain[0]) return false;
return memcmp(a->domain+1, b->domain+1, a->domain[0]) == 0;
}
static Entry* record_map_find(Entry* entries, uint32_t capacity, const Question* key) {
uint32_t index = hash_question(key) % capacity;
while(true) {
Entry* entry = &entries[index];
if(entry->key == NULL) {
return entry;
} else if(question_equals(entry->key, key)) {
return entry;
}
index += 1;
index %= capacity;
}
}
static void record_map_grow(RecordMap* map, uint32_t capacity) {
Entry* entries = malloc(capacity * sizeof(Entry));
for(uint32_t i = 0; i < capacity; i++) {
entries[i].key = NULL;
entries[i].value = NULL;
}
map->len = 0;
for(uint32_t i = 0; i < map->capacity; i++) {
Entry* entry = &map->entries[i];
if(entry->key == NULL) continue;
Entry* dest = record_map_find(entries, capacity, entry->key);
dest->key = entry->key;
dest->value = entry->value;
map->len++;
}
free(map->entries);
map->entries = entries;
map->capacity = capacity;
}
bool record_map_get(const RecordMap* map, const Question* key, Packet* value) {
if(map->len == 0) return false;
Entry* e = record_map_find(map->entries, map->capacity, key);
if (e->key == NULL) return false;
*value = *(e->value);
return true;
}
void record_map_add(RecordMap* map, Question* key, Packet* value) {
if(map->len + 1 > map->capacity * 0.75) {
int capacity = (map->capacity == 0 ? 8 : (2 * map->capacity));
record_map_grow(map, capacity);
}
Entry* e = record_map_find(map->entries, map->capacity, key);
bool new_key = e->key == NULL;
if(new_key) {
map->len++;
e->key = key;
}
value->header.z = true;
e->value = value;
}

20
src/io/map.h Normal file
View file

@ -0,0 +1,20 @@
#pragma once
#include "../packet/packet.h"
typedef struct {
Question* key;
Packet* value;
} Entry;
typedef struct {
uint32_t capacity;
uint32_t len;
Entry* entries;
} RecordMap;
void record_map_init(RecordMap* map);
void record_map_free(RecordMap* map);
bool record_map_get(const RecordMap* map, const Question* key, Packet* value);
void record_map_add(RecordMap* map, Question* key, Packet* value);

View file

@ -79,10 +79,10 @@ bool get_random_a(Packet* packet, IpAddr* addr) {
for (uint16_t i = 0; i < packet->header.answers; i++) {
Record record = packet->answers[i];
if (record.type == A) {
create_ip_addr((char*) &record.data.a.addr, addr);
create_ip_addr(record.data.a.addr, addr);
return true;
} else if (record.type == AAAA) {
create_ip_addr6((char*) &record.data.aaaa.addr, addr);
create_ip_addr6(record.data.aaaa.addr, addr);
return true;
}
}
@ -138,10 +138,10 @@ bool get_resolved_ns(Packet* packet, uint8_t* qname, IpAddr* addr) {
}
if (resource.type == A) {
create_ip_addr((char*) &record.data.a.addr, addr);
create_ip_addr(record.data.a.addr, addr);
return true;
} else if (resource.type == AAAA) {
create_ip_addr6((char*) &record.data.aaaa.addr, addr);
create_ip_addr6(record.data.aaaa.addr, addr);
return true;
}
}

View file

@ -92,3 +92,4 @@ void print_question(Question* question, char* buffer) {
question->domain + 1
);
}

View file

@ -7,12 +7,12 @@
#include "addr.h"
#include "../io/log.h"
void create_ip_addr(char* domain, IpAddr* addr) {
void create_ip_addr(uint8_t* domain, IpAddr* addr) {
addr->type = V4;
memcpy(&addr->data.v4.s_addr, domain, 4);
}
void create_ip_addr6(char* domain, IpAddr* addr) {
void create_ip_addr6(uint8_t* domain, IpAddr* addr) {
addr->type = V6;
memcpy(&addr->data.v6.__in6_u.__u6_addr8, domain, 16);
}
@ -76,7 +76,7 @@ void print_socket_addr(SocketAddr* addr, char* buffer) {
(uint8_t) ((uint32_t)addr->data.v4.sin_addr.s_addr >> 16),
(uint8_t) ((uint32_t)addr->data.v4.sin_addr.s_addr >> 8),
(uint8_t) ((uint32_t)addr->data.v4.sin_addr.s_addr),
addr->data.v4.sin_port
ntohs(addr->data.v4.sin_port)
);
} else {
for(int i = 0; i < 8; i++) {
@ -85,7 +85,7 @@ void print_socket_addr(SocketAddr* addr, char* buffer) {
addr->data.v6.sin6_addr.__in6_u.__u6_addr8[i*2 + 1]
);
}
APPEND(buffer, ":[%hu]", addr->data.v6.sin6_port);
APPEND(buffer, ":[%hu]", ntohs(addr->data.v6.sin6_port));
}
}

View file

@ -20,8 +20,8 @@ typedef struct {
} data;
} IpAddr;
void create_ip_addr(char* domain, IpAddr* addr);
void create_ip_addr6(char* domain, IpAddr* addr);
void create_ip_addr(uint8_t* domain, IpAddr* addr);
void create_ip_addr6(uint8_t* domain, IpAddr* addr);
void ip_addr_any(IpAddr* addr);
void ip_addr_any6(IpAddr* addr);

View file

@ -144,10 +144,10 @@ bool read_connection(Connection* connection, Packet* packet) {
}
static bool write_udp(Connection* connection, uint8_t* buf, uint16_t len) {
//if (len > 512) {
if (len > 512) {
buf[2] = buf[2] | 0x03;
// len = 512;
// }
len = 512;
}
return write_udp_socket(
&connection->sock.udp.udp,
buf,

View file

@ -43,9 +43,9 @@ static bool lookup(
return true;
}
static bool search(Question* question, Packet* result, BindingType type) {
static bool search(Question* question, Packet* result, BindingType type, const RecordMap* map) {
IpAddr addr;
char ip[4] = {1, 1, 1, 1};
uint8_t ip[4] = {1, 1, 1, 1};
create_ip_addr(ip, &addr);
uint16_t port = 53;
@ -53,6 +53,10 @@ static bool search(Question* question, Packet* result, BindingType type) {
create_socket_addr(port, addr, &saddr);
while(1) {
if (record_map_get(map, question, result)) {
return true;
}
if (!lookup(question, result, type, saddr)) {
return false;
}
@ -75,7 +79,7 @@ static bool search(Question* question, Packet* result, BindingType type) {
}
Packet recurse;
if (!search(&new_question, &recurse, type)) {
if (!search(&new_question, &recurse, type, map)) {
return false;
}
@ -104,7 +108,7 @@ static void push_questions(Question* from, uint8_t from_len, Question** to, uint
memcpy(*to + to_len, from, from_len * sizeof(Question));
}
void handle_query(Packet* request, Packet* response, BindingType type) {
void handle_query(const Packet* request, Packet* response, BindingType type, const RecordMap* map) {
memset(response, 0, sizeof(Packet));
response->header.id = request->header.id;
response->header.opcode = request->header.opcode;
@ -121,7 +125,7 @@ void handle_query(Packet* request, Packet* response, BindingType type) {
Packet result;
memset(&result, 0, sizeof(Packet));
result.header.id = response->header.id;
if (!search(&request->questions[i], &result, type)) {
if (!search(&request->questions[i], &result, type, map)) {
response->header.response = SERVFAIL;
break;
}
@ -158,9 +162,14 @@ void handle_query(Packet* request, Packet* response, BindingType type) {
);
response->header.resource_entries += result.header.resource_entries;
if (result.header.z == false) {
// Not from cache
free(result.questions);
free(result.answers);
free(result.authorities);
free(result.resources);
} else {
response->header.z = true;
}
}
}

View file

@ -1,6 +1,7 @@
#pragma once
#include "../packet/packet.h"
#include "../io/map.h"
#include "binding.h"
void handle_query(Packet* request, Packet* response, BindingType type);
void handle_query(const Packet* request, Packet* response, BindingType type, const RecordMap* map);

View file

@ -1,25 +1,67 @@
#define _POSIX_SOURCE
#define _POSIX_C_SOURCE 200809L
#include <errno.h>
#include <stdio.h>
#include <stdlib.h>
#include <unistd.h>
#include <sys/wait.h>
#include <signal.h>
#include <pthread.h>
#include "addr.h"
#include "server.h"
#include "resolver.h"
#include "../io/log.h"
#include "../io/map.h"
#include "../io/config.h"
static pid_t udp, tcp;
static pthread_t udp, tcp;
static RecordMap map;
void server_init(uint16_t port, Server* server) {
INFO("Server port set to %hu", port);
create_binding(UDP, port, &server->udp);
create_binding(TCP, port, &server->tcp);
load_config("config", &map);
}
static void server_listen(Binding* binding) {
struct DnsRequest {
Connection connection;
Packet request;
};
static void* server_respond(void* arg) {
struct DnsRequest req = *(struct DnsRequest*) arg;
INFO("Recieved packet request ID %hu", req.request.header.id);
Packet response;
handle_query(&req.request, &response, req.connection.type, &map);
if (!write_connection(&req.connection, &response)) {
ERROR("Faled to respond to connection ID %hu: %s",
req.request.header.id,
strerror(errno)
);
}
free_packet(&req.request);
free_connection(&req.connection);
if (response.header.z == false) {
free_packet(&response);
} else {
// Dont free from config
free(response.questions);
free(response.answers);
free(response.authorities);
free(response.resources);
}
return NULL;
}
static void* server_listen(void* arg) {
Binding* binding = (Binding*) arg;
while(1) {
Connection connection;
@ -35,66 +77,54 @@ static void server_listen(Binding* binding) {
continue;
}
if(fork() != 0) {
struct DnsRequest req;
req.connection = connection;
req.request = request;
pthread_t thread;
if(pthread_create(&thread, NULL, &server_respond, &req)) {
ERROR("Failed to create thread for dns request ID %hu: %s",
request.header.id,
strerror(errno)
);
free_packet(&request);
free_connection(&connection);
continue;
}
INFO("Recieved packet request ID %hu", request.header.id);
Packet response;
handle_query(&request, &response, connection.type);
if (!write_connection(&connection, &response)) {
ERROR("Failed to respond to connection ID %hu: %s", request.header.id, strerror(errno));
}
free_packet(&request);
free_packet(&response);
free_connection(&connection);
exit(EXIT_SUCCESS);
}
return NULL;
}
static void signal_handler() {
printf("\n");
kill(udp, SIGTERM);
kill(tcp, SIGTERM);
pthread_kill(udp, SIGTERM);
pthread_kill(tcp, SIGTERM);
}
void server_run(Server* server) {
if ((udp = fork()) == 0) {
if (!pthread_create(&udp, NULL, &server_listen, &server->udp)) {
INFO("Listening for connections on UDP");
server_listen(&server->udp);
exit(EXIT_SUCCESS);
} else {
ERROR("Failed to start UDP thread");
exit(EXIT_FAILURE);
}
if ((tcp = fork()) == 0) {
if (!pthread_create(&tcp, NULL, &server_listen, &server->tcp)) {
INFO("Listening for connections on TCP");
server_listen(&server->tcp);
exit(EXIT_SUCCESS);
} else {
ERROR("Failed to start TCP thread");
exit(EXIT_FAILURE);
}
signal(SIGINT, signal_handler);
int status;
waitpid(udp, &status, 0);
if (status == 0) {
INFO("UDP process closed successfully");
} else {
ERROR("UDP process failed with error code %d", status);
}
waitpid(tcp, &status, 0);
if (status == 0) {
INFO("TCP process closed successfully");
} else {
ERROR("TCP process failed with error code %d", status);
}
pthread_join(udp, NULL);
pthread_join(tcp, NULL);
}
void server_free(Server* server) {
free_binding(&server->udp);
free_binding(&server->tcp);
record_map_free(&map);
}