summaryrefslogtreecommitdiff
path: root/src/io
diff options
context:
space:
mode:
authorTyler Murphy <tylerm@tylerm.dev>2023-04-08 15:58:03 -0400
committerTyler Murphy <tylerm@tylerm.dev>2023-04-08 15:58:03 -0400
commit2a75d25632eb2ac966a3c52acbcc790d32abeef3 (patch)
tree239345c0666aee232fd50024ecabc0d9fc156a1c /src/io
parentgoofy ahh bin folder (diff)
downloadwrapper-2a75d25632eb2ac966a3c52acbcc790d32abeef3.tar.gz
wrapper-2a75d25632eb2ac966a3c52acbcc790d32abeef3.tar.bz2
wrapper-2a75d25632eb2ac966a3c52acbcc790d32abeef3.zip
read from config
Diffstat (limited to 'src/io')
-rw-r--r--src/io/config.c476
-rw-r--r--src/io/config.h5
-rw-r--r--src/io/map.c104
-rw-r--r--src/io/map.h20
4 files changed, 605 insertions, 0 deletions
diff --git a/src/io/config.c b/src/io/config.c
new file mode 100644
index 0000000..7bd4522
--- /dev/null
+++ b/src/io/config.c
@@ -0,0 +1,476 @@
+#include <errno.h>
+#include <stdio.h>
+#include <string.h>
+#include <stdlib.h>
+
+#include "config.h"
+#include "log.h"
+#include "map.h"
+
+#define MAX_LEN 1024
+#define BUF(name) char name[MAX_LEN]
+
+static int line = 0;
+
+static bool get_line(FILE* file, const BUF(buf)) {
+ line++;
+ return fgets((char*) buf, MAX_LEN, file) != NULL;
+}
+
+static bool is_whitespace(const char* buf) {
+ int i = 0;
+ char c;
+ while(c = buf[i], 1) {
+ if (c == '\n' || c == '\0') return true;
+ if (c != ' ' && c != '\n') return false;
+ i++;
+ }
+}
+
+static bool get_words(char* buf, char** words, int count) {
+ int last = 0;
+ int offset = 0;
+ int i = 0;
+
+ for(i = 0; i < count; i++) {
+ char c;
+ while(c = buf[offset], c != ' ' && c != '\0' && c != '\n') {
+ offset++;
+ }
+
+ if (offset - last < 1) {
+ return false;
+ }
+ words[i] = buf + last;
+ buf[offset] = '\0';
+ offset++;
+ last = offset;
+
+ if (c == '\0' || c == '\n') {
+ break;
+ }
+
+ }
+ return i + 1 == count;
+}
+
+static bool get_int(const char* word, uint32_t* i) {
+ char* end;
+ uint32_t res = (uint32_t) strtol(word, &end, 10);
+
+ if (*end == '\0') {
+ *i = res;
+ return true;
+ } else {
+ return false;
+ }
+}
+
+static bool config_read_qtype(const char* qstr, RecordType* qtype) {
+ if (strcmp(qstr, "A") == 0) {
+ *qtype = A;
+ return true;
+ } else if (strcmp(qstr, "NS") == 0) {
+ *qtype = NS;
+ return true;
+ } else if (strcmp(qstr, "CNAME") == 0) {
+ *qtype = CNAME;
+ return true;
+ } else if (strcmp(qstr, "SOA") == 0) {
+ *qtype = SOA;
+ return true;
+ } else if (strcmp(qstr, "PTR") == 0) {
+ *qtype = PTR;
+ return true;
+ } else if (strcmp(qstr, "MX") == 0) {
+ *qtype = MX;
+ return true;
+ } else if (strcmp(qstr, "TXT") == 0) {
+ *qtype = TXT;
+ return true;
+ } else if (strcmp(qstr, "AAAA") == 0) {
+ *qtype = AAAA;
+ return true;
+ } else if (strcmp(qstr, "SRV") == 0) {
+ *qtype = SRV;
+ return true;
+ } else if (strcmp(qstr, "CAA") == 0) {
+ *qtype = CAA;
+ return true;
+ } else {
+ return false;
+ }
+}
+
+static bool config_read_class(const char* cstr, uint16_t* class) {
+ if (strcmp(cstr, "IN") == 0) {
+ *class = 1;
+ return true;
+ } else if (strcmp(cstr, "CH") == 0) {
+ *class = 3;
+ return true;
+ } else if (strcmp(cstr, "HS") == 0) {
+ *class = 4;
+ return true;
+ } else {
+ return false;
+ }
+}
+
+// Format QTYPE CLASS DOMAIN: A IN google.com
+static bool config_read_question(FILE* file, Question* question) {
+ BUF(buf);
+ if (!get_line(file, buf)) {
+ return false;
+ }
+
+ if (is_whitespace(buf)) {
+ return false;
+ }
+
+ char* words[3];
+ if (!get_words(&buf[0], &words[0], 3)) {
+ WARN("Invalid question at line %d", line);
+ return false;
+ };
+
+ uint16_t class;
+ if (!config_read_class(words[0], &class)) {
+ WARN("Invalid question class at line %d", line);
+ return false;
+ }
+
+ RecordType qtype;
+ if (!config_read_qtype(words[1], &qtype)) {
+ WARN("Invalid question qtype at line %d", line);
+ return false;
+ }
+
+ size_t domain_len = strlen(words[2]);
+ question->cls = class;
+ question->qtype = qtype;
+ question->domain = malloc(domain_len + 1);
+ question->domain[0] = domain_len;
+ memcpy(question->domain + 1, words[2], domain_len);
+
+ return true;
+}
+
+static void copy_str(char* from, uint8_t** to) {
+ size_t len = strlen(from);
+ if (len > 255) {
+ len = 255;
+ }
+
+ uint8_t* new = malloc(len + 1);
+ new[0] = len;
+ memcpy(new + 1, from, len);
+ *to = new;
+}
+
+static bool config_read_a_record(char* data, ARecord* record) {
+ sscanf(data, "%hhu.%hhu.%hhu.%hhu",
+ &record->addr[0],
+ &record->addr[1],
+ &record->addr[2],
+ &record->addr[3]
+ );
+ return true;
+}
+
+static bool config_read_ns_record(char* data, NSRecord* record) {
+ copy_str(data, &record->host);
+ return true;
+}
+
+static bool config_read_cname_record(char* data, CNAMERecord* record) {
+ copy_str(data, &record->host);
+ return true;
+}
+
+static bool config_read_soa_record(char* data, SOARecord* record) {
+ char* words[7];
+ if (!get_words(&data[0], &words[0], 7)) {
+ WARN("Invalid SOA record data at line %d", line);
+ record->mname = NULL;
+ record->nname = NULL;
+ return false;
+ }
+
+ copy_str(words[0], &record->mname);
+ copy_str(words[1], &record->nname);
+
+ if (!get_int(words[2], &record->serial)) {
+ WARN("Invalid SOA record data at line %d", line);
+ return false;
+ }
+
+ if (!get_int(words[3], &record->refresh)) {
+ WARN("Invalid SOA record data at line %d", line);
+ return false;
+ }
+
+ if (!get_int(words[4], &record->retry)) {
+ WARN("Invalid SOA record data at line %d", line);
+ return false;
+ }
+
+ if (!get_int(words[5], &record->expire)) {
+ WARN("Invalid SOA record data at line %d", line);
+ return false;
+ }
+
+ if (!get_int(words[6], &record->minimum)) {
+ WARN("Invalid SOA record data at line %d", line);
+ return false;
+ }
+
+ return true;
+}
+
+static bool config_read_ptr_record(char* data, PTRRecord* record) {
+ copy_str(data, &record->pointer);
+ return true;
+}
+
+static bool config_read_mx_record(char* data, MXRecord* record) {
+ char* words[2];
+ if (!get_words(&data[0], &words[0], 2)) {
+ WARN("Invalid MX record data at line %d", line);
+ record->host = NULL;
+ return false;
+ }
+
+ copy_str(words[1], &record->host);
+
+ uint32_t priority;
+ if (!get_int(words[0], &priority)) {
+ WARN("Invalid MX record data at line %d", line);
+ return false;
+ }
+ record->priority = (uint16_t) priority;
+
+ return true;
+}
+
+static bool config_read_txt_record(char* data, TXTRecord* record) {
+ int len = strlen(data);
+ uint8_t count = ((uint8_t)len + 254) / 255;
+ record->len = count;
+ record->text = malloc(sizeof(uint8_t*) * count);
+
+ for (uint8_t i = 0; i < count; i++) {
+ uint32_t offset = count * 255;
+ uint32_t part_len = len - offset;
+ if (part_len > 255) part_len = 255;
+
+ uint8_t* part = malloc(part_len + 1);
+ part[0] = part_len;
+ memcpy(part + 1, data + offset, part_len);
+
+ record->text[i] = part;
+ }
+
+ return true;
+}
+
+static bool config_read_aaaa_record(char* data, AAAARecord* record) {
+ for(int i = 0; i < 8; i++) {
+ if (sscanf(data, "%02hhx%02hhx:",
+ &record->addr[i*2 + 0],
+ &record->addr[i*2 + 1]
+ ) == EOF) {
+ return false;
+ }
+ }
+ return true;
+}
+
+static bool config_read_srv_record(char* data, SRVRecord* record) {
+ char* words[4];
+ if (!get_words(&data[0], &words[0], 4)) {
+ WARN("Invalid SRV record data at line %d", line);
+ record->target = NULL;
+ return false;
+ }
+
+ copy_str(words[3], &record->target);
+
+ uint32_t priority;
+ if (!get_int(words[0], &priority)) {
+ WARN("Invalid SRV record data at line %d", line);
+ return false;
+ }
+ record->priority = (uint16_t) priority;
+
+ uint32_t weight;
+ if (!get_int(words[1], &weight)) {
+ WARN("Invalid SRV record data at line %d", line);
+ return false;
+ }
+ record->weight = (uint16_t) weight;
+
+ uint32_t port;
+ if (!get_int(words[2], &port)) {
+ WARN("Invalid SRV record data at line %d", line);
+ return false;
+ }
+ record->port = (uint16_t) port;
+
+ return true;
+}
+
+static bool config_read_caa_record(char* data, CAARecord* record) {
+ char* words[4];
+ if (!get_words(&data[0], &words[0], 4)) {
+ WARN("Invalid SRV record data at line %d", line);
+ record->tag = NULL;
+ record->value = NULL;
+ return false;
+ }
+
+ copy_str(words[2], &record->tag);
+ copy_str(words[3], &record->value);
+
+ uint32_t flags;
+ if (!get_int(words[0], &flags)) {
+ WARN("Invalid SRV record data at line %d", line);
+ return false;
+ }
+ record->flags = (uint8_t) flags;
+
+ uint32_t length;
+ if (!get_int(words[1], &length)) {
+ WARN("Invalid SRV record data at line %d", line);
+ return false;
+ }
+ record->length = (uint8_t) length;
+
+ return true;
+}
+
+static bool config_read_record_data(char* data, Record* record) {
+ switch (record->type) {
+ case UNKOWN:
+ // This can never happend in here so uh do nothing i guess
+ return false;
+ case A:
+ return config_read_a_record(data, &record->data.a);
+ case NS:
+ return config_read_ns_record(data, &record->data.ns);
+ case CNAME:
+ return config_read_cname_record(data, &record->data.cname);
+ case SOA:
+ return config_read_soa_record(data, &record->data.soa);
+ case PTR:
+ return config_read_ptr_record(data, &record->data.ptr);
+ case MX:
+ return config_read_mx_record(data, &record->data.mx);
+ case TXT:
+ return config_read_txt_record(data, &record->data.txt);
+ case AAAA:
+ return config_read_aaaa_record(data, &record->data.aaaa);
+ case SRV:
+ return config_read_srv_record(data, &record->data.srv);
+ case CAA:
+ return config_read_caa_record(data, &record->data.caa);
+ }
+ return false;
+}
+
+static bool config_read_record(FILE* file, Record* record, Question* question) {
+ BUF(buf);
+ if (!get_line(file, buf)) {
+ return false;
+ }
+
+ if (is_whitespace(buf)) {
+ return false;
+ }
+
+ char* words[2];
+ if (!get_words(&buf[0], &words[0], 2)) {
+ WARN("Invalid record at line %d", line);
+ return false;
+ }
+
+ uint32_t ttl;
+ if (!get_int(words[0], &ttl)) {
+ WARN("Invalid record ttl at line %d", line);
+ return false;
+ }
+
+ record->cls = question->cls;
+ record->type = question->qtype;
+ record->len = 0;
+ record->ttl = ttl;
+ record->domain = malloc(question->domain[0] + 1);
+ memcpy(record->domain, question->domain, question->domain[0] + 1);
+
+ if(!config_read_record_data(words[1], record)) {
+ free_record(record);
+ return false;
+ }
+
+ return true;
+}
+
+static void config_push_record(Record** buf, Record record, uint16_t* capacity, uint16_t* size) {
+ if (size == capacity) {
+ *capacity *= 2;
+ *buf = realloc(*buf, sizeof(Record) * *capacity);
+ }
+ (*buf)[*size] = record;
+ (*size)++;
+}
+
+bool load_config(const char* path, RecordMap* map) {
+ FILE* file = fopen(path, "r");
+ if (file == NULL) {
+ ERROR("Failed to open file %s: %s", path, strerror(errno));
+ return false;
+ }
+
+ line = 0;
+ record_map_init(map);
+
+ while (1) {
+ Question* question = malloc(sizeof(Question));
+ if (!config_read_question(file, question)) {
+ free(question);
+ break;
+ }
+
+ INIT_LOG_BUFFER(log);
+ LOGONLY(print_question(question, log));
+ TRACE("Found config question: %s", log);
+
+ Packet* packet = malloc(sizeof(Packet));
+ memset(packet, 0, sizeof(Packet));
+ packet->authorities = NULL;
+ packet->resources = NULL;
+
+ packet->questions = malloc(sizeof(Question));
+ packet->questions[0] = *question;
+
+ uint16_t capacity = 1;
+ packet->answers = malloc(sizeof(Record));
+
+ while(1) {
+ Record record;
+ if (!config_read_record(file, &record, question)) {
+ break;
+ }
+
+ LOGONLY(print_record(&record, log));
+ TRACE("Found config record: %s", log);
+
+ config_push_record(&packet->answers, record, &capacity, &packet->header.answers);
+ }
+
+ record_map_add(map, question, packet);
+ }
+
+ fclose(file);
+ return true;
+}
diff --git a/src/io/config.h b/src/io/config.h
new file mode 100644
index 0000000..7da7e98
--- /dev/null
+++ b/src/io/config.h
@@ -0,0 +1,5 @@
+#pragma once
+
+#include "map.h"
+
+bool load_config(const char* path, RecordMap* map);
diff --git a/src/io/map.c b/src/io/map.c
new file mode 100644
index 0000000..cb4642e
--- /dev/null
+++ b/src/io/map.c
@@ -0,0 +1,104 @@
+#include <string.h>
+#include <stdlib.h>
+
+#include "map.h"
+
+void record_map_init(RecordMap* map) {
+ map->capacity = 0;
+ map->len = 0;
+ map->entries = NULL;
+}
+
+void record_map_free(RecordMap* map) {
+ for(uint32_t i = 0; i < map->capacity; i++) {
+ Entry* e = &map->entries[i];
+ if (e->key != NULL) {
+ free_question(e->key);
+ free(e->key);
+ free_packet(e->value);
+ free(e->value);
+ }
+ }
+ free(map->entries);
+}
+
+static size_t hash_question(const Question* question) {
+ size_t hash = 5381;
+ for(int i = 0; i < question->domain[0]; i++) {
+ uint8_t c = question->domain[i+1];
+ hash = ((hash << 5) + hash) + c;
+ }
+ hash = ((hash << 5) + hash) + (uint8_t)question->cls;
+ hash = ((hash << 5) + hash) + (uint8_t)question->qtype;
+ return hash;
+}
+
+static bool question_equals(const Question* a, const Question* b) {
+ if (a->qtype != b->qtype) return false;
+ if (a->cls != b->cls) return false;
+ if (a->domain[0] != b->domain[0]) return false;
+ return memcmp(a->domain+1, b->domain+1, a->domain[0]) == 0;
+}
+
+static Entry* record_map_find(Entry* entries, uint32_t capacity, const Question* key) {
+ uint32_t index = hash_question(key) % capacity;
+ while(true) {
+ Entry* entry = &entries[index];
+ if(entry->key == NULL) {
+ return entry;
+ } else if(question_equals(entry->key, key)) {
+ return entry;
+ }
+ index += 1;
+ index %= capacity;
+ }
+}
+
+static void record_map_grow(RecordMap* map, uint32_t capacity) {
+ Entry* entries = malloc(capacity * sizeof(Entry));
+ for(uint32_t i = 0; i < capacity; i++) {
+ entries[i].key = NULL;
+ entries[i].value = NULL;
+ }
+ map->len = 0;
+ for(uint32_t i = 0; i < map->capacity; i++) {
+ Entry* entry = &map->entries[i];
+ if(entry->key == NULL) continue;
+
+ Entry* dest = record_map_find(entries, capacity, entry->key);
+ dest->key = entry->key;
+ dest->value = entry->value;
+ map->len++;
+ }
+ free(map->entries);
+
+ map->entries = entries;
+ map->capacity = capacity;
+}
+
+bool record_map_get(const RecordMap* map, const Question* key, Packet* value) {
+ if(map->len == 0) return false;
+
+ Entry* e = record_map_find(map->entries, map->capacity, key);
+ if (e->key == NULL) return false;
+
+ *value = *(e->value);
+ return true;
+}
+
+void record_map_add(RecordMap* map, Question* key, Packet* value) {
+ if(map->len + 1 > map->capacity * 0.75) {
+ int capacity = (map->capacity == 0 ? 8 : (2 * map->capacity));
+ record_map_grow(map, capacity);
+ }
+ Entry* e = record_map_find(map->entries, map->capacity, key);
+ bool new_key = e->key == NULL;
+ if(new_key) {
+ map->len++;
+ e->key = key;
+ }
+
+ value->header.z = true;
+ e->value = value;
+}
+
diff --git a/src/io/map.h b/src/io/map.h
new file mode 100644
index 0000000..84b40fb
--- /dev/null
+++ b/src/io/map.h
@@ -0,0 +1,20 @@
+#pragma once
+
+#include "../packet/packet.h"
+
+typedef struct {
+ Question* key;
+ Packet* value;
+} Entry;
+
+typedef struct {
+ uint32_t capacity;
+ uint32_t len;
+ Entry* entries;
+} RecordMap;
+
+void record_map_init(RecordMap* map);
+void record_map_free(RecordMap* map);
+
+bool record_map_get(const RecordMap* map, const Question* key, Packet* value);
+void record_map_add(RecordMap* map, Question* key, Packet* value);