#include #include #include #include #include #undef _POSIX_C_SOURCE #include #include "record.h" #include "buffer.h" uint16_t record_to_id(RecordType type) { switch (type) { case A: return 1; case NS: return 2; case CNAME: return 5; case SOA: return 6; case PTR: return 12; case MX: return 15; case TXT: return 16; case AAAA: return 28; case SRV: return 33; case CAA: return 257; default: return 0; } } void record_from_id(uint16_t i, RecordType* type) { switch (i) { case 1: *type = A; break; case 2: *type = NS; break; case 5: *type = CNAME; break; case 6: *type = SOA; break; case 12: *type = PTR; break; case 15: *type = MX; break; case 16: *type = TXT; break; case 28: *type = AAAA; break; case 33: *type = SRV; break; case 257: *type = CAA; break; default: *type = UNKOWN; } } bool str_to_qtype(const char* qstr, RecordType* qtype) { if (strcasecmp(qstr, "A") == 0) { *qtype = A; return true; } else if (strcasecmp(qstr, "NS") == 0) { *qtype = NS; return true; } else if (strcasecmp(qstr, "CNAME") == 0) { *qtype = CNAME; return true; } else if (strcasecmp(qstr, "SOA") == 0) { *qtype = SOA; return true; } else if (strcasecmp(qstr, "PTR") == 0) { *qtype = PTR; return true; } else if (strcasecmp(qstr, "MX") == 0) { *qtype = MX; return true; } else if (strcasecmp(qstr, "TXT") == 0) { *qtype = TXT; return true; } else if (strcasecmp(qstr, "AAAA") == 0) { *qtype = AAAA; return true; } else if (strcasecmp(qstr, "SRV") == 0) { *qtype = SRV; return true; } else if (strcasecmp(qstr, "CAA") == 0) { *qtype = CAA; return true; } else { return false; } return false; } static void read_a_record(PacketBuffer* buffer, Record* record) { ARecord data; data.addr[0] = buffer_read(buffer); data.addr[1] = buffer_read(buffer); data.addr[2] = buffer_read(buffer); data.addr[3] = buffer_read(buffer); record->data.a = data; } static void read_ns_record(PacketBuffer* buffer, Record* record) { NSRecord data; buffer_read_qname(buffer, &data.host); record->data.ns = data; } static void read_cname_record(PacketBuffer* buffer, Record* record) { CNAMERecord data; buffer_read_qname(buffer, &data.host); record->data.cname = data; } static void read_soa_record(PacketBuffer* buffer, Record* record) { SOARecord data; buffer_read_qname(buffer, &data.mname); buffer_read_qname(buffer, &data.nname); data.serial = buffer_read_int(buffer); data.refresh = buffer_read_int(buffer); data.retry = buffer_read_int(buffer); data.expire = buffer_read_int(buffer); data.minimum = buffer_read_int(buffer); record->data.soa = data; } static void read_ptr_record(PacketBuffer* buffer, Record* record) { PTRRecord data; buffer_read_qname(buffer, &data.pointer); record->data.ptr = data; } static void read_mx_record(PacketBuffer* buffer, Record* record) { MXRecord data; data.priority = buffer_read_short(buffer); buffer_read_qname(buffer, &data.host); record->data.mx = data; } static void read_txt_record(PacketBuffer* buffer, Record* record) { TXTRecord data; data.len = 0; data.text = malloc(sizeof(uint8_t*) * 2); uint8_t capacity = 2; uint8_t total = record->len; while (1) { if (data.len >= capacity) { if (capacity >= 128) { capacity = 255; } else { capacity *= 2; } data.text = realloc(data.text, sizeof(uint8_t*) * capacity); } buffer_read_string(buffer, &data.text[data.len]); if(data.text[data.len][0] == 0) { free(data.text[data.len]); break; } data.len++; total -= data.text[data.len - 1][0] + 1; if (total == 0) break; } record->data.txt = data; } static void read_aaaa_record(PacketBuffer* buffer, Record* record) { AAAARecord data; for (int i = 0; i < 16; i++) { data.addr[i] = buffer_read(buffer); } record->data.aaaa = data; } static void read_srv_record(PacketBuffer* buffer, Record* record) { SRVRecord data; data.priority = buffer_read_short(buffer); data.weight = buffer_read_short(buffer); data.port = buffer_read_short(buffer); buffer_read_qname(buffer, &data.target); record->data.srv = data; } static void read_caa_record(PacketBuffer* buffer, Record* record, int header_pos) { CAARecord data; data.flags = buffer_read(buffer); data.length = buffer_read(buffer); buffer_read_n(buffer, &data.tag, data.length); int value_len = ((int)record->len) + header_pos - buffer_get_index(buffer); buffer_read_n(buffer, &data.value, (uint8_t)value_len); record->data.caa = data; } bool read_record(PacketBuffer* buffer, Record* record) { buffer_read_qname(buffer, &record->domain); uint16_t qtype_num = buffer_read_short(buffer); record_from_id(qtype_num, &record->type); record->cls = buffer_read_short(buffer); record->ttl = buffer_read_int(buffer); record->len = buffer_read_short(buffer); int header_pos = buffer_get_index(buffer); switch (record->type) { case A: read_a_record(buffer, record); break; case NS: read_ns_record(buffer, record); break; case CNAME: read_cname_record(buffer, record); break; case SOA: read_soa_record(buffer, record); break; case PTR: read_ptr_record(buffer, record); break; case MX: read_mx_record(buffer, record); break; case TXT: read_txt_record(buffer, record); break; case AAAA: read_aaaa_record(buffer, record); break; case SRV: read_srv_record(buffer, record); break; case CAA: read_caa_record(buffer, record, header_pos); break; default: buffer_step(buffer, record->len); free(record->domain); return false; } return true; } static void write_a_record(PacketBuffer* buffer, ARecord* data) { buffer_write_short(buffer, 4); buffer_write(buffer, data->addr[0]); buffer_write(buffer, data->addr[1]); buffer_write(buffer, data->addr[2]); buffer_write(buffer, data->addr[3]); } static void write_ns_record(PacketBuffer* buffer, NSRecord* data) { int pos = buffer_get_index(buffer); buffer_write_short(buffer, 0); buffer_write_qname(buffer, data->host); int size = buffer_get_index(buffer) - pos - 2; buffer_set_uint16_t(buffer, (uint16_t)size, pos); } static void write_cname_record(PacketBuffer* buffer, CNAMERecord* data) { int pos = buffer_get_index(buffer); buffer_write_short(buffer, 0); buffer_write_qname(buffer, data->host); int size = buffer_get_index(buffer) - pos - 2; buffer_set_uint16_t(buffer, (uint16_t)size, pos); } static void write_soa_record(PacketBuffer* buffer, SOARecord* data) { int pos = buffer_get_index(buffer); buffer_write_short(buffer, 0); buffer_write_qname(buffer, data->mname); buffer_write_qname(buffer, data->nname); buffer_write_int(buffer, data->serial); buffer_write_int(buffer, data->refresh); buffer_write_int(buffer, data->retry); buffer_write_int(buffer, data->expire); buffer_write_int(buffer, data->minimum); int size = buffer_get_index(buffer) - pos - 2; buffer_set_uint16_t(buffer, (uint16_t)size, pos); } static void write_ptr_record(PacketBuffer* buffer, PTRRecord* data) { int pos = buffer_get_index(buffer); buffer_write_short(buffer, 0); buffer_write_qname(buffer, data->pointer); int size = buffer_get_index(buffer) - pos - 2; buffer_set_uint16_t(buffer, (uint16_t)size, pos); } static void write_mx_record(PacketBuffer* buffer, MXRecord* data) { int pos = buffer_get_index(buffer); buffer_write_short(buffer, 0); buffer_write_short(buffer, data->priority); buffer_write_qname(buffer, data->host); int size = buffer_get_index(buffer) - pos - 2; buffer_set_uint16_t(buffer, (uint16_t)size, pos); } static void write_txt_record(PacketBuffer* buffer, TXTRecord* data) { int pos = buffer_get_index(buffer); buffer_write_short(buffer, 0); if(data->len == 0) { return; } for(uint8_t i = 0; i < data->len; i++) { buffer_write_string(buffer, data->text[i]); } int size = buffer_get_index(buffer) - pos - 2; buffer_set_uint16_t(buffer, (uint16_t)size, pos); } static void write_aaaa_record(PacketBuffer* buffer, AAAARecord* data) { buffer_write_short(buffer, 16); for (int i = 0; i < 16; i++) { buffer_write(buffer, data->addr[i]); } } static void write_srv_record(PacketBuffer* buffer, SRVRecord* data) { int pos = buffer_get_index(buffer); buffer_write_short(buffer, 0); buffer_write_short(buffer, data->priority); buffer_write_short(buffer, data->weight); buffer_write_short(buffer, data->port); buffer_write_qname(buffer, data->target); int size = buffer_get_index(buffer) - pos - 2; buffer_set_uint16_t(buffer, (uint16_t)size, pos); } static void write_caa_record(PacketBuffer* buffer, CAARecord* data) { int pos = buffer_get_index(buffer); buffer_write_short(buffer, 0); buffer_write(buffer, data->flags); buffer_write(buffer, data->length); buffer_write_n(buffer, data->tag + 1, data->tag[0]); buffer_write_n(buffer, data->value + 1, data->value[0]); int size = buffer_get_index(buffer) - pos - 2; buffer_set_uint16_t(buffer, (uint16_t)size, pos); } static void write_record_header(PacketBuffer* buffer, Record* record) { buffer_write_qname(buffer, record->domain); uint16_t id = record_to_id(record->type); buffer_write_short(buffer, id); buffer_write_short(buffer, record->cls); buffer_write_int(buffer, record->ttl); } void write_record(PacketBuffer* buffer, Record* record) { switch(record->type) { case A: write_record_header(buffer, record); write_a_record(buffer, &record->data.a); break; case NS: write_record_header(buffer, record); write_ns_record(buffer, &record->data.ns); break; case CNAME: write_record_header(buffer, record); write_cname_record(buffer, &record->data.cname); break; case SOA: write_record_header(buffer, record); write_soa_record(buffer, &record->data.soa); break; case PTR: write_record_header(buffer, record); write_ptr_record(buffer, &record->data.ptr); break; case MX: write_record_header(buffer, record); write_mx_record(buffer, &record->data.mx); break; case TXT: write_record_header(buffer, record); write_txt_record(buffer, &record->data.txt); break; case AAAA: write_record_header(buffer, record); write_aaaa_record(buffer, &record->data.aaaa); break; case SRV: write_record_header(buffer, record); write_srv_record(buffer, &record->data.srv); break; case CAA: write_record_header(buffer, record); write_caa_record(buffer, &record->data.caa); break; default: break; } } void free_record(Record* record) { free(record->domain); switch (record->type) { case NS: free(record->data.ns.host); break; case CNAME: free(record->data.cname.host); break; case SOA: free(record->data.soa.mname); free(record->data.soa.nname); break; case PTR: free(record->data.ptr.pointer); break; case MX: free(record->data.mx.host); break; case TXT: for (uint8_t i = 0; i < record->data.txt.len; i++) { free(record->data.txt.text[i]); } free(record->data.txt.text); break; case SRV: free(record->data.srv.target); break; case CAA: free(record->data.caa.value); free(record->data.caa.tag); break; default: break; } } static const char* class_to_str(Record* record) { switch(record->cls) { case 1: return "IN"; case 3: return "CH"; case 4: return "HS"; default: return "??"; } } static const char* qtype_to_str(Record* record) { switch(record->type) { case A: return "A"; case NS: return "NS"; case CNAME: return "CNAME"; case SOA: return "SOA"; case PTR: return "PTR"; case MX: return "MX"; case TXT: return "TXT"; case AAAA: return "AAAA"; case SRV: return "SRV"; case CAA: return "CAA"; default: return "UNKOWN"; } } static void print_record_data(Record* record) { switch(record->type) { case A: printf("%hhu.%hhu.%hhu.%hhu", record->data.a.addr[0], record->data.a.addr[1], record->data.a.addr[2], record->data.a.addr[3] ); break; case NS: printf("%.*s", record->data.ns.host[0], record->data.ns.host + 1 ); break; case CNAME: printf("%.*s", record->data.cname.host[0], record->data.cname.host + 1 ); break; case SOA: printf("%.*s %.*s %u %u %u %u %u", record->data.soa.mname[0], record->data.soa.mname + 1, record->data.soa.nname[0], record->data.soa.nname + 1, record->data.soa.serial, record->data.soa.refresh, record->data.soa.retry, record->data.soa.expire, record->data.soa.minimum ); break; case PTR: printf("%.*s", record->data.ptr.pointer[0], record->data.ptr.pointer + 1 ); break; case MX: printf("%.*s %hu", record->data.mx.host[0], record->data.mx.host + 1, record->data.mx.priority ); break; case TXT: for(uint8_t i = 0; i < record->data.txt.len; i++) { printf("%.*s", record->data.txt.text[i][0], record->data.txt.text[i] + 1 ); } break; case AAAA: for(int i = 0; i < 8; i++) { printf("%02hhx%02hhx:", record->data.a.addr[i*2 + 0], record->data.a.addr[i*2 + 1] ); } printf(":"); break; case SRV: printf("SRV (%hu %hu %hu %.*s", record->data.srv.priority, record->data.srv.weight, record->data.srv.port, record->data.srv.target[0], record->data.srv.target + 1 ); break; case CAA: printf("%hhu %.*s %.*s", record->data.caa.flags, record->data.caa.tag[0], record->data.caa.tag + 1, record->data.caa.value[0], record->data.caa.value + 1 ); break; default: break; } } void print_record(Record* record) { printf("%.*s.\t%s %s\t", record->domain[0], record->domain + 1, class_to_str(record), qtype_to_str(record) ); print_record_data(record); printf("\n"); }