Compare commits
1 commit
Author | SHA1 | Date | |
---|---|---|---|
f46d5307fc |
70 changed files with 5583 additions and 3678 deletions
5
.gitignore
vendored
5
.gitignore
vendored
|
@ -1,3 +1,2 @@
|
||||||
bin
|
**/target
|
||||||
config
|
.env
|
||||||
bee
|
|
2447
Cargo.lock
generated
Normal file
2447
Cargo.lock
generated
Normal file
File diff suppressed because it is too large
Load diff
37
Cargo.toml
Normal file
37
Cargo.toml
Normal file
|
@ -0,0 +1,37 @@
|
||||||
|
[package]
|
||||||
|
name = "wrapper"
|
||||||
|
version = "0.1.0"
|
||||||
|
edition = "2021"
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
# Blazingly fast runtime
|
||||||
|
tokio = { version = "1", features = ["full"] }
|
||||||
|
tracing-subscriber = "0.3.16"
|
||||||
|
tracing = "0.1.37"
|
||||||
|
|
||||||
|
# Allow recursion inside tokio async
|
||||||
|
async-recursion = "1"
|
||||||
|
|
||||||
|
# DNS Caching Layer
|
||||||
|
moka = { version = "0.10.0", features = ["future"] }
|
||||||
|
|
||||||
|
# Mongodb
|
||||||
|
mongodb = { version = "2.4", features = ["tokio-sync"] }
|
||||||
|
futures = "0.3.26"
|
||||||
|
|
||||||
|
# Convert values to json for Mongodb
|
||||||
|
serde = { version = "1.0", features = ["derive"] }
|
||||||
|
|
||||||
|
# Reading env vars from .env
|
||||||
|
dotenv = "0.15.0"
|
||||||
|
|
||||||
|
# For the meme records
|
||||||
|
rand = "0.8.5"
|
||||||
|
|
||||||
|
# For the http web frontend
|
||||||
|
axum = "0.6.4"
|
||||||
|
tower-http = { version = "0.4.0", features = ["fs"] }
|
||||||
|
tower-cookies = "0.9.0"
|
||||||
|
tower = "0.4.13"
|
||||||
|
bytes = "1.4.0"
|
||||||
|
serde_json = "1"
|
51
Makefile
51
Makefile
|
@ -1,51 +0,0 @@
|
||||||
CC = gcc
|
|
||||||
|
|
||||||
INCFLAGS = -Isrc
|
|
||||||
|
|
||||||
CCFLAGS = -std=gnu99 -Wall -Wextra -pedantic -O2
|
|
||||||
CCFLAGS += $(INCFLAGS)
|
|
||||||
|
|
||||||
LDFLAGS += $(INCFLAGS)
|
|
||||||
LDFLAGS += -lpthread
|
|
||||||
|
|
||||||
BIN = bin
|
|
||||||
APP = $(BIN)/app
|
|
||||||
SRC = $(shell find src -name "*.c")
|
|
||||||
OBJ = $(SRC:%.c=$(BIN)/%.o)
|
|
||||||
|
|
||||||
.PHONY: dirs run clean build install uninstall install-openrc uninstall-openrc
|
|
||||||
|
|
||||||
EOF: clean build
|
|
||||||
|
|
||||||
dirs:
|
|
||||||
mkdir -p ./$(BIN)
|
|
||||||
mkdir -p ./$(BIN)/src
|
|
||||||
mkdir -p ./$(BIN)/src/io
|
|
||||||
mkdir -p ./$(BIN)/src/packet
|
|
||||||
mkdir -p ./$(BIN)/src/server
|
|
||||||
|
|
||||||
run: build
|
|
||||||
$(APP)
|
|
||||||
|
|
||||||
build: dirs ${OBJ}
|
|
||||||
${CC} -o $(APP) $(filter %.o,$^) $(LDFLAGS)
|
|
||||||
|
|
||||||
$(BIN)/%.o: %.c
|
|
||||||
$(CC) -o $@ -c $< $(CCFLAGS)
|
|
||||||
|
|
||||||
clean:
|
|
||||||
rm -rf $(APP)
|
|
||||||
rm -rf $(BIN)
|
|
||||||
|
|
||||||
install:
|
|
||||||
cp $(APP) /usr/local/bin/wrapper
|
|
||||||
|
|
||||||
uninstall:
|
|
||||||
rm /usr/local/bin/wrapper
|
|
||||||
|
|
||||||
install-openrc: install
|
|
||||||
cp service/openrc /etc/init.d/wrapper
|
|
||||||
chmod +x /etc/init.d/wrapper
|
|
||||||
|
|
||||||
uninstall-openrc: uninstall
|
|
||||||
rm /etc/init.d/wrapper
|
|
240
buffer.c
240
buffer.c
|
@ -1,240 +0,0 @@
|
||||||
#include "buffer.h"
|
|
||||||
|
|
||||||
#include <stdio.h>
|
|
||||||
#include <stdlib.h>
|
|
||||||
#include <string.h>
|
|
||||||
|
|
||||||
struct PacketBuffer {
|
|
||||||
uint8_t* arr;
|
|
||||||
int capacity;
|
|
||||||
int index;
|
|
||||||
int size;
|
|
||||||
};
|
|
||||||
|
|
||||||
PacketBuffer* buffer_create(int capacity) {
|
|
||||||
PacketBuffer* buffer = malloc(sizeof(PacketBuffer));
|
|
||||||
buffer->arr = malloc(capacity);
|
|
||||||
buffer->capacity = capacity;
|
|
||||||
buffer->size = 0;
|
|
||||||
buffer->index = 0;
|
|
||||||
return buffer;
|
|
||||||
}
|
|
||||||
|
|
||||||
void buffer_free(PacketBuffer* buffer) {
|
|
||||||
free(buffer->arr);
|
|
||||||
free(buffer);
|
|
||||||
}
|
|
||||||
|
|
||||||
void buffer_seek(PacketBuffer* buffer, int index) {
|
|
||||||
buffer->index = index;
|
|
||||||
}
|
|
||||||
|
|
||||||
uint8_t buffer_read(PacketBuffer* buffer) {
|
|
||||||
if (buffer->index > buffer->size) {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
uint8_t data = buffer->arr[buffer->index];
|
|
||||||
buffer->index++;
|
|
||||||
return data;
|
|
||||||
}
|
|
||||||
|
|
||||||
uint16_t buffer_read_short(PacketBuffer* buffer) {
|
|
||||||
return
|
|
||||||
(uint16_t) buffer_read(buffer) << 8 |
|
|
||||||
(uint16_t) buffer_read(buffer);
|
|
||||||
}
|
|
||||||
|
|
||||||
uint32_t buffer_read_int(PacketBuffer* buffer) {
|
|
||||||
return
|
|
||||||
(uint32_t) buffer_read(buffer) << 24 |
|
|
||||||
(uint32_t) buffer_read(buffer) << 16 |
|
|
||||||
(uint32_t) buffer_read(buffer) << 8 |
|
|
||||||
(uint32_t) buffer_read(buffer);
|
|
||||||
}
|
|
||||||
|
|
||||||
uint8_t buffer_get(PacketBuffer* buffer, int index) {
|
|
||||||
if (index > buffer->size) {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
uint8_t data = buffer->arr[index];
|
|
||||||
return data;
|
|
||||||
}
|
|
||||||
|
|
||||||
uint8_t* buffer_get_range(PacketBuffer* buffer, int start, int len) {
|
|
||||||
uint8_t* arr = malloc(len);
|
|
||||||
for (int i = 0; i < len; i++) {
|
|
||||||
arr[i] = buffer_get(buffer, start + i);
|
|
||||||
}
|
|
||||||
return arr;
|
|
||||||
}
|
|
||||||
|
|
||||||
uint16_t buffer_get_size(PacketBuffer* buffer) {
|
|
||||||
return (uint16_t) buffer->size + 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
static void write(uint8_t** buffer, uint8_t* size, uint8_t* capacity, uint8_t data) {
|
|
||||||
if (*size == *capacity) {
|
|
||||||
*capacity *= 2;
|
|
||||||
*buffer = realloc(*buffer, *capacity);
|
|
||||||
}
|
|
||||||
(*buffer)[*size] = data;
|
|
||||||
(*size)++;
|
|
||||||
}
|
|
||||||
|
|
||||||
void buffer_read_qname(PacketBuffer* buffer, uint8_t** out) {
|
|
||||||
int index = buffer->index;
|
|
||||||
int jumped = 0;
|
|
||||||
|
|
||||||
int max_jumps = 5;
|
|
||||||
int jumps_performed = 0;
|
|
||||||
|
|
||||||
uint8_t length = 0;
|
|
||||||
uint8_t capacity = 8;
|
|
||||||
*out = malloc(capacity);
|
|
||||||
write(out, &length, &capacity, 0);
|
|
||||||
|
|
||||||
while(1) {
|
|
||||||
if (jumps_performed > max_jumps) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
uint8_t len = buffer_get(buffer, index);
|
|
||||||
|
|
||||||
if ((len & 0xC0) == 0xC0) {
|
|
||||||
if (jumped == 0) {
|
|
||||||
buffer_seek(buffer, index + 2);
|
|
||||||
}
|
|
||||||
|
|
||||||
uint16_t b2 = (uint16_t) buffer_get(buffer, index + 1);
|
|
||||||
uint16_t offset = ((((uint16_t) len) ^ 0xC0) << 8) | b2;
|
|
||||||
index = (int) offset;
|
|
||||||
jumped = 1;
|
|
||||||
jumps_performed++;
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
index++;
|
|
||||||
|
|
||||||
if (len == 0) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (length > 1) {
|
|
||||||
write(out, &length, &capacity, '.');
|
|
||||||
}
|
|
||||||
|
|
||||||
uint8_t* range = buffer_get_range(buffer, index, len);
|
|
||||||
for (uint8_t i = 0; i < len; i++) {
|
|
||||||
write(out, &length, &capacity, range[i]);
|
|
||||||
}
|
|
||||||
free(range);
|
|
||||||
|
|
||||||
index += (int) len;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (jumped == 0) {
|
|
||||||
buffer_seek(buffer, index);
|
|
||||||
}
|
|
||||||
|
|
||||||
(*out)[0] = length - 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
void buffer_read_string(PacketBuffer* buffer, uint8_t** out) {
|
|
||||||
uint8_t len = buffer_read(buffer);
|
|
||||||
buffer_read_n(buffer, out, len);
|
|
||||||
}
|
|
||||||
|
|
||||||
void buffer_read_n(PacketBuffer* buffer, uint8_t** out, uint8_t len) {
|
|
||||||
*out = malloc(len + 1);
|
|
||||||
*out[0] = len;
|
|
||||||
memcpy(*out + 1, buffer->arr + buffer->index, len);
|
|
||||||
buffer->index += len;
|
|
||||||
}
|
|
||||||
|
|
||||||
void buffer_write(PacketBuffer* buffer, uint8_t data) {
|
|
||||||
if(buffer->index == buffer->capacity) {
|
|
||||||
buffer->capacity *= 2;
|
|
||||||
buffer->arr = realloc(buffer->arr, buffer->capacity);
|
|
||||||
}
|
|
||||||
if (buffer->size < buffer->index) {
|
|
||||||
buffer->size = buffer->index;
|
|
||||||
}
|
|
||||||
buffer->arr[buffer->index] = data;
|
|
||||||
buffer->index++;
|
|
||||||
}
|
|
||||||
|
|
||||||
void buffer_write_short(PacketBuffer* buffer, uint16_t data) {
|
|
||||||
buffer_write(buffer, (uint8_t)(data >> 8));
|
|
||||||
buffer_write(buffer, (uint8_t)(data & 0xFF));
|
|
||||||
}
|
|
||||||
|
|
||||||
void buffer_write_int(PacketBuffer* buffer, uint32_t data) {
|
|
||||||
buffer_write(buffer, (uint8_t)(data >> 24));
|
|
||||||
buffer_write(buffer, (uint8_t)(data >> 16));
|
|
||||||
buffer_write(buffer, (uint8_t)(data >> 8));
|
|
||||||
buffer_write(buffer, (uint8_t)(data & 0xFF));
|
|
||||||
}
|
|
||||||
|
|
||||||
void buffer_write_qname(PacketBuffer* buffer, uint8_t* in) {
|
|
||||||
uint8_t part = 0;
|
|
||||||
uint8_t len = in[0];
|
|
||||||
|
|
||||||
buffer_write(buffer, 0);
|
|
||||||
|
|
||||||
if (len == 0) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
for(uint8_t i = 0; i < len; i ++) {
|
|
||||||
if (in[i+1] == '.') {
|
|
||||||
buffer_set(buffer, part, buffer->index - (int)part - 1);
|
|
||||||
buffer_write(buffer, 0);
|
|
||||||
part = 0;
|
|
||||||
} else {
|
|
||||||
buffer_write(buffer, in[i+1]);
|
|
||||||
part++;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
buffer_set(buffer, part, buffer->index - (int)part - 1);
|
|
||||||
buffer_write(buffer, 0);
|
|
||||||
}
|
|
||||||
|
|
||||||
void buffer_write_string(PacketBuffer* buffer, uint8_t* in) {
|
|
||||||
buffer_write(buffer, in[0]);
|
|
||||||
buffer_write_n(buffer, in + 1, in[0]);
|
|
||||||
}
|
|
||||||
|
|
||||||
void buffer_write_n(PacketBuffer* buffer, uint8_t* in, int len) {
|
|
||||||
if (buffer->size + len >= buffer->capacity) {
|
|
||||||
buffer->capacity *= 2;
|
|
||||||
buffer->capacity += len;
|
|
||||||
buffer->arr = realloc(buffer->arr, buffer->capacity);
|
|
||||||
}
|
|
||||||
memcpy(buffer->arr + buffer->index, in, len);
|
|
||||||
buffer->size += len;
|
|
||||||
buffer->index += len;
|
|
||||||
}
|
|
||||||
|
|
||||||
void buffer_set(PacketBuffer* buffer, uint8_t data, int index) {
|
|
||||||
if (index > buffer->size) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
buffer->arr[index] = data;
|
|
||||||
}
|
|
||||||
|
|
||||||
void buffer_set_uint16_t(PacketBuffer* buffer, uint16_t data, int index) {
|
|
||||||
buffer_set(buffer, (uint8_t)(data >> 8), index);
|
|
||||||
buffer_set(buffer, (uint8_t)(data & 0xFF), index + 1);
|
|
||||||
}
|
|
||||||
|
|
||||||
int buffer_get_index(PacketBuffer* buffer) {
|
|
||||||
return buffer->index;
|
|
||||||
}
|
|
||||||
|
|
||||||
void buffer_step(PacketBuffer* buffer, int len) {
|
|
||||||
buffer->index += len;
|
|
||||||
}
|
|
||||||
|
|
||||||
uint8_t* buffer_get_ptr(PacketBuffer* buffer) {
|
|
||||||
return buffer->arr;
|
|
||||||
}
|
|
40
public/css/home.css
Normal file
40
public/css/home.css
Normal file
|
@ -0,0 +1,40 @@
|
||||||
|
span {
|
||||||
|
margin-top: 5rem;
|
||||||
|
margin-bottom: 1rem;
|
||||||
|
width: 45rem;
|
||||||
|
font-size: 2em;
|
||||||
|
}
|
||||||
|
|
||||||
|
#new {
|
||||||
|
display: flex;
|
||||||
|
justify-content: center;
|
||||||
|
width: 100%;
|
||||||
|
padding-top: 2rem;
|
||||||
|
padding-bottom: 1rem;
|
||||||
|
border-bottom: solid 1px var(--gray);
|
||||||
|
}
|
||||||
|
|
||||||
|
#new input, .block {
|
||||||
|
border-radius: 1rem 0 0 1rem;
|
||||||
|
width: 40rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
.block {
|
||||||
|
width: 33em;
|
||||||
|
}
|
||||||
|
|
||||||
|
#new button {
|
||||||
|
border-radius: 0 1rem 1rem 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
.domain {
|
||||||
|
margin-top: 2rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
.domain .delete {
|
||||||
|
border-radius: 0 1rem 1rem 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
.domain .edit {
|
||||||
|
border-radius: 0;
|
||||||
|
}
|
18
public/css/login.css
Normal file
18
public/css/login.css
Normal file
|
@ -0,0 +1,18 @@
|
||||||
|
#login {
|
||||||
|
margin-top: 20em;
|
||||||
|
}
|
||||||
|
|
||||||
|
#logo {
|
||||||
|
font-size: 6em;
|
||||||
|
font-weight: 750;
|
||||||
|
font-family: bold;
|
||||||
|
margin-bottom: 2rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
form {
|
||||||
|
width: 30rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
form input {
|
||||||
|
width: 100%;
|
||||||
|
}
|
119
public/css/main.css
Normal file
119
public/css/main.css
Normal file
|
@ -0,0 +1,119 @@
|
||||||
|
:root {
|
||||||
|
--dark: #222428;
|
||||||
|
--dark-alternate: #2b2e36;
|
||||||
|
--header: #1e1e22;
|
||||||
|
|
||||||
|
--accent: #8849f5;
|
||||||
|
--accent-alternate: #6829d5;
|
||||||
|
--gray: #2f2f3f;
|
||||||
|
--main: #ffffff;
|
||||||
|
--main-alternate: #cccccc;
|
||||||
|
}
|
||||||
|
|
||||||
|
* {
|
||||||
|
padding: 0;
|
||||||
|
margin: 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
@font-face {
|
||||||
|
font-family: main;
|
||||||
|
src: url("../fonts/helvetica.ttf") format("truetype");
|
||||||
|
font-display: swap;
|
||||||
|
}
|
||||||
|
|
||||||
|
@font-face {
|
||||||
|
font-family: bold;
|
||||||
|
src: url("../fonts/overpass-bold.otf") format("opentype");
|
||||||
|
font-display: swap;
|
||||||
|
}
|
||||||
|
|
||||||
|
@font-face {
|
||||||
|
font-family: bold-italic;
|
||||||
|
src: url("../fonts/overpass-bold-italic.otf") format("opentype");
|
||||||
|
font-display: swap;
|
||||||
|
}
|
||||||
|
|
||||||
|
html {
|
||||||
|
background-color: var(--dark);
|
||||||
|
font-family: main;
|
||||||
|
color: var(--main);
|
||||||
|
width: 100%;
|
||||||
|
height: 100%;
|
||||||
|
}
|
||||||
|
|
||||||
|
body {
|
||||||
|
width: 100%;
|
||||||
|
height: 100%;
|
||||||
|
display: flex;
|
||||||
|
flex-direction: column;
|
||||||
|
align-items: center;
|
||||||
|
}
|
||||||
|
|
||||||
|
.accent {
|
||||||
|
color: var(--accent);
|
||||||
|
}
|
||||||
|
|
||||||
|
.fill {
|
||||||
|
width: 100%;
|
||||||
|
height: 100%;
|
||||||
|
display: flex;
|
||||||
|
flex-direction: column;
|
||||||
|
align-items: center;
|
||||||
|
}
|
||||||
|
|
||||||
|
input, button, .block {
|
||||||
|
all: unset;
|
||||||
|
display: inline-block;
|
||||||
|
font: main;
|
||||||
|
background-color: var(--dark-alternate);
|
||||||
|
font-size: 1rem;
|
||||||
|
padding: 1rem;
|
||||||
|
border-radius: 1rem;
|
||||||
|
margin-bottom: 20px;
|
||||||
|
}
|
||||||
|
|
||||||
|
button {
|
||||||
|
background-color: var(--accent);
|
||||||
|
width: 5em;
|
||||||
|
text-align: center;
|
||||||
|
}
|
||||||
|
|
||||||
|
button:hover {
|
||||||
|
cursor: pointer;
|
||||||
|
background-color: var(--accent-alternate);
|
||||||
|
}
|
||||||
|
|
||||||
|
.delete {
|
||||||
|
background-color: #f54842;
|
||||||
|
}
|
||||||
|
|
||||||
|
.delete:hover {
|
||||||
|
cursor: pointer;
|
||||||
|
background-color: #d52822;
|
||||||
|
}
|
||||||
|
|
||||||
|
form {
|
||||||
|
display: flex;
|
||||||
|
flex-direction: column;
|
||||||
|
}
|
||||||
|
|
||||||
|
#header {
|
||||||
|
width: calc(100% - 4rem);
|
||||||
|
background-color: var(--header);
|
||||||
|
border-bottom: solid 1px var(--gray);
|
||||||
|
padding: 1rem;
|
||||||
|
padding-left: 3rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
#logo {
|
||||||
|
font-size: 2em;
|
||||||
|
font-weight: 500;
|
||||||
|
font-family: bold;
|
||||||
|
}
|
||||||
|
|
||||||
|
#title {
|
||||||
|
font-size: 2em;
|
||||||
|
font-weight: 300;
|
||||||
|
font-family: sans-serif;
|
||||||
|
padding-left: 1em;
|
||||||
|
}
|
67
public/css/record.css
Normal file
67
public/css/record.css
Normal file
|
@ -0,0 +1,67 @@
|
||||||
|
#buttons {
|
||||||
|
margin-top: 2rem;
|
||||||
|
width: 50rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
#buttons button {
|
||||||
|
margin: 0;
|
||||||
|
margin-right: 2rem;
|
||||||
|
border-radius: 10px;
|
||||||
|
width: auto;
|
||||||
|
padding: .75rem 1rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
.record {
|
||||||
|
width: 50rem;
|
||||||
|
background-color: var(--header);
|
||||||
|
padding: 1rem;
|
||||||
|
margin-top: 2rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
.header {
|
||||||
|
display: flex;
|
||||||
|
align-items: center;
|
||||||
|
margin-bottom: 1rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
.header span {
|
||||||
|
font-family: bold;
|
||||||
|
}
|
||||||
|
|
||||||
|
.header button {
|
||||||
|
margin: 0;
|
||||||
|
margin-left: 2rem;
|
||||||
|
padding: .5rem 1rem;
|
||||||
|
width: auto;
|
||||||
|
border-radius: 5px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.type {
|
||||||
|
margin-right: 1rem;
|
||||||
|
background-color: var(--accent);
|
||||||
|
padding: .25rem .5rem;
|
||||||
|
border-radius: 5px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.domain {
|
||||||
|
color: var(--main-alternate);
|
||||||
|
flex-grow: 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
.properties {
|
||||||
|
display: flex;
|
||||||
|
flex-direction: column;
|
||||||
|
}
|
||||||
|
|
||||||
|
.poperty {
|
||||||
|
display: flex;
|
||||||
|
flex-direction: row;
|
||||||
|
border-bottom: solid 1px var(--gray);
|
||||||
|
margin-top: 1rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
.key {
|
||||||
|
font-family: bold;
|
||||||
|
width: 5rem;
|
||||||
|
}
|
||||||
|
|
21
public/domain.html
Normal file
21
public/domain.html
Normal file
|
@ -0,0 +1,21 @@
|
||||||
|
<!DOCTYPE html>
|
||||||
|
<html lang="en">
|
||||||
|
<head>
|
||||||
|
<meta charset="UTF-8">
|
||||||
|
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||||
|
<title>Wrapper - Records</title>
|
||||||
|
|
||||||
|
<meta name="author" content="Tyler Murphy">
|
||||||
|
<meta name="description" content="wrapper records">
|
||||||
|
|
||||||
|
<meta property="og:title" content="wrapper">
|
||||||
|
<meta property="og:description" content="wrapper records">
|
||||||
|
|
||||||
|
<link rel="stylesheet" href="/css/main.css">
|
||||||
|
<link rel="stylesheet" href="/css/record.css">
|
||||||
|
|
||||||
|
<script type="module" src="/js/domain.js"></script>
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
</body>
|
||||||
|
</html>
|
BIN
public/fonts/helvetica-bold.ttf
Normal file
BIN
public/fonts/helvetica-bold.ttf
Normal file
Binary file not shown.
BIN
public/fonts/helvetica.ttf
Normal file
BIN
public/fonts/helvetica.ttf
Normal file
Binary file not shown.
BIN
public/fonts/overpass-bold-italic.otf
Normal file
BIN
public/fonts/overpass-bold-italic.otf
Normal file
Binary file not shown.
BIN
public/fonts/overpass-bold.otf
Normal file
BIN
public/fonts/overpass-bold.otf
Normal file
Binary file not shown.
21
public/home.html
Normal file
21
public/home.html
Normal file
|
@ -0,0 +1,21 @@
|
||||||
|
<!DOCTYPE html>
|
||||||
|
<html lang="en">
|
||||||
|
<head>
|
||||||
|
<meta charset="UTF-8">
|
||||||
|
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||||
|
<title>Wrapper - Domains</title>
|
||||||
|
|
||||||
|
<meta name="author" content="Tyler Murphy">
|
||||||
|
<meta name="description" content="wrapper domains">
|
||||||
|
|
||||||
|
<meta property="og:title" content="wrapper">
|
||||||
|
<meta property="og:description" content="wrapper domains">
|
||||||
|
|
||||||
|
<link rel="stylesheet" href="/css/main.css">
|
||||||
|
<link rel="stylesheet" href="/css/home.css">
|
||||||
|
|
||||||
|
<script type="module" src="/js/home.js"></script>
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
</body>
|
||||||
|
</html>
|
51
public/js/api.js
Normal file
51
public/js/api.js
Normal file
|
@ -0,0 +1,51 @@
|
||||||
|
const endpoint = '/api'
|
||||||
|
|
||||||
|
const request = async (url, method, body) => {
|
||||||
|
|
||||||
|
let response;
|
||||||
|
|
||||||
|
if (method == 'GET') {
|
||||||
|
response = await fetch(endpoint + url, {
|
||||||
|
method,
|
||||||
|
headers: {
|
||||||
|
'Content-Type': 'application/json'
|
||||||
|
}
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
response = await fetch(endpoint + url, {
|
||||||
|
method,
|
||||||
|
body: JSON.stringify(body),
|
||||||
|
headers: {
|
||||||
|
'Content-Type': 'application/json'
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
if (response.status == 401) {
|
||||||
|
location.href = '/login'
|
||||||
|
}
|
||||||
|
const contentType = response.headers.get("content-type");
|
||||||
|
if (contentType && contentType.indexOf("application/json") !== -1) {
|
||||||
|
const json = await response.json()
|
||||||
|
return { status: response.status, msg: json.msg, json }
|
||||||
|
} else {
|
||||||
|
const msg = await response.text();
|
||||||
|
return { status: response.status, msg }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
export const login = async (user, pass) => {
|
||||||
|
return await request('/login', 'POST', {user, pass})
|
||||||
|
}
|
||||||
|
|
||||||
|
export const domains = async () => {
|
||||||
|
return await request('/domains', 'GET')
|
||||||
|
}
|
||||||
|
|
||||||
|
export const del_domain = async (domain) => {
|
||||||
|
return await request('/domains', 'DELETE', {domain})
|
||||||
|
}
|
||||||
|
|
||||||
|
export const records = async (domain) => {
|
||||||
|
return await request(`/records?domain=${domain}`, 'GET')
|
||||||
|
}
|
12
public/js/components.js
Normal file
12
public/js/components.js
Normal file
|
@ -0,0 +1,12 @@
|
||||||
|
import { div, parse, span } from './main.js';
|
||||||
|
|
||||||
|
export function header(title) {
|
||||||
|
return div({id: 'header'},
|
||||||
|
span({id: 'logo', class: 'accent'},
|
||||||
|
parse("Wrapper")
|
||||||
|
),
|
||||||
|
span({id: 'title'},
|
||||||
|
parse(title)
|
||||||
|
),
|
||||||
|
)
|
||||||
|
}
|
95
public/js/domain.js
Normal file
95
public/js/domain.js
Normal file
|
@ -0,0 +1,95 @@
|
||||||
|
import { del_domain, domains, records } from './api.js'
|
||||||
|
import { header } from './components.js'
|
||||||
|
import { body, parse, div, input, button, span, is_domain } from './main.js';
|
||||||
|
|
||||||
|
function render(domain, records) {
|
||||||
|
|
||||||
|
let divs = []
|
||||||
|
for (const record of records) {
|
||||||
|
divs.push(gen_record(record))
|
||||||
|
}
|
||||||
|
|
||||||
|
document.body.replaceWith(
|
||||||
|
body({},
|
||||||
|
header(domain),
|
||||||
|
div({id: 'buttons'},
|
||||||
|
button({onclick: (event) => {
|
||||||
|
location.href = '/home'
|
||||||
|
}}, parse("Home")),
|
||||||
|
button({}, parse("New Record")),
|
||||||
|
),
|
||||||
|
...divs
|
||||||
|
)
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
function gen_record(record) {
|
||||||
|
let domain = record.domain
|
||||||
|
let prefix = record.prefix
|
||||||
|
|
||||||
|
if (prefix.length > 0) {
|
||||||
|
prefix = prefix + '.'
|
||||||
|
}
|
||||||
|
|
||||||
|
let type = Object.keys(record.record)[0]
|
||||||
|
let data = record.record[type]
|
||||||
|
|
||||||
|
let divs = []
|
||||||
|
for (const key in data) {
|
||||||
|
let disp_key;
|
||||||
|
if (key == 'ttl') {
|
||||||
|
disp_key = 'TTL'
|
||||||
|
} else {
|
||||||
|
disp_key = upper(key)
|
||||||
|
}
|
||||||
|
divs.push(
|
||||||
|
div({class: 'poperty'},
|
||||||
|
div({class: 'key'}, parse(disp_key)),
|
||||||
|
div({class: 'value'}, parse(data[key])),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
return div({class: 'record'},
|
||||||
|
div({class: 'header'},
|
||||||
|
span({class: 'type'}, parse(type)),
|
||||||
|
span({class: 'prefix'}, parse(prefix)),
|
||||||
|
span({class: 'domain'}, parse(domain)),
|
||||||
|
button({}, parse("Edit")),
|
||||||
|
button({class: 'delete'}, parse("Delete"))
|
||||||
|
),
|
||||||
|
div({class: 'properties'},
|
||||||
|
...divs
|
||||||
|
)
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
function upper(string) {
|
||||||
|
return string.charAt(0).toUpperCase() + string.slice(1);
|
||||||
|
}
|
||||||
|
|
||||||
|
async function init() {
|
||||||
|
|
||||||
|
const params = new Proxy(new URLSearchParams(window.location.search), {
|
||||||
|
get: (searchParams, prop) => searchParams.get(prop),
|
||||||
|
});
|
||||||
|
|
||||||
|
let domain = params.domain;
|
||||||
|
|
||||||
|
if (!is_domain(domain)) {
|
||||||
|
location.href = '/home'
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
let res = await records(domain);
|
||||||
|
|
||||||
|
if (res.status !== 200) {
|
||||||
|
alert(res.msg)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
render(domain, res.json)
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
init()
|
91
public/js/home.js
Normal file
91
public/js/home.js
Normal file
|
@ -0,0 +1,91 @@
|
||||||
|
import { del_domain, domains } from './api.js'
|
||||||
|
import { header } from './components.js'
|
||||||
|
import { body, parse, div, input, button, span, is_domain } from './main.js';
|
||||||
|
|
||||||
|
function render(domains) {
|
||||||
|
document.body.replaceWith(
|
||||||
|
body({},
|
||||||
|
header('domains'),
|
||||||
|
div({id: 'new'},
|
||||||
|
input({
|
||||||
|
type: 'text',
|
||||||
|
name: 'domain',
|
||||||
|
id: 'domain',
|
||||||
|
placeholder: 'Type domain name to create new records',
|
||||||
|
autocomplete: "off",
|
||||||
|
}),
|
||||||
|
button({onclick: () => {
|
||||||
|
let domain = document.getElementById('domain').value
|
||||||
|
|
||||||
|
if (!is_domain(domain)) {
|
||||||
|
alert("Invalid domain")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
location.href = '/domain?domain='+domain
|
||||||
|
}},
|
||||||
|
parse("Create")
|
||||||
|
)
|
||||||
|
),
|
||||||
|
...domain(domains)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
function domain(domains) {
|
||||||
|
let divs = []
|
||||||
|
for (const domain of domains) {
|
||||||
|
divs.push(
|
||||||
|
div({class: 'domain'},
|
||||||
|
div({class: 'block'},
|
||||||
|
parse(domain)
|
||||||
|
),
|
||||||
|
button({class: 'edit', onclick: (event) => {
|
||||||
|
console.log(event.target.parentElement)
|
||||||
|
let domain = event
|
||||||
|
.target
|
||||||
|
.parentElement
|
||||||
|
.getElementsByClassName('block')[0]
|
||||||
|
.innerText
|
||||||
|
|
||||||
|
if (!is_domain(domain)) {
|
||||||
|
alert("Invalid domain")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
location.href = '/domain?domain='+domain
|
||||||
|
}},
|
||||||
|
parse("Edit")
|
||||||
|
),
|
||||||
|
button({class: 'delete', onclick: async () => {
|
||||||
|
let res = await del_domain(domain)
|
||||||
|
|
||||||
|
if (res.status != 204) {
|
||||||
|
alert(res.msg)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
location.reload()
|
||||||
|
}},
|
||||||
|
parse("Delete")
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
}
|
||||||
|
return divs
|
||||||
|
}
|
||||||
|
|
||||||
|
async function init() {
|
||||||
|
|
||||||
|
let res = await domains();
|
||||||
|
|
||||||
|
if (res.status !== 200) {
|
||||||
|
alert(res.msg)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
render(res.json)
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
init()
|
44
public/js/login.js
Normal file
44
public/js/login.js
Normal file
|
@ -0,0 +1,44 @@
|
||||||
|
import { body, div, form, input, p, parse, span} from './main.js'
|
||||||
|
import { login } from './api.js'
|
||||||
|
|
||||||
|
function render() {
|
||||||
|
document.body.replaceWith(
|
||||||
|
body({},
|
||||||
|
div({id: 'login', class: 'fill'},
|
||||||
|
span({id: 'logo'},
|
||||||
|
span({class: 'accent'}, parse('Wrapper'))
|
||||||
|
),
|
||||||
|
form({autocomplete: "off"},
|
||||||
|
input({
|
||||||
|
type: 'text',
|
||||||
|
name: 'user',
|
||||||
|
id: 'user',
|
||||||
|
placeholder: 'Username',
|
||||||
|
autofocus: 1
|
||||||
|
}),
|
||||||
|
input({
|
||||||
|
type: 'password',
|
||||||
|
name: 'pass',
|
||||||
|
id: 'pass',
|
||||||
|
placeholder: 'Password',
|
||||||
|
onkeydown: async (event) => {
|
||||||
|
if (event.key == 'Enter') {
|
||||||
|
event.preventDefault()
|
||||||
|
let user = document.getElementById('user').value
|
||||||
|
let pass = document.getElementById('pass').value
|
||||||
|
|
||||||
|
let res = await login(user, pass)
|
||||||
|
|
||||||
|
if (res.status === 200) {
|
||||||
|
location.href = '/home'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
render()
|
136
public/js/main.js
Normal file
136
public/js/main.js
Normal file
|
@ -0,0 +1,136 @@
|
||||||
|
function createElement(name, attrs, ...children) {
|
||||||
|
const el = document.createElement(name);
|
||||||
|
|
||||||
|
for (const attr in attrs) {
|
||||||
|
if(attr.startsWith("on")) {
|
||||||
|
el[attr] = attrs[attr];
|
||||||
|
} else {
|
||||||
|
el.setAttribute(attr, attrs[attr])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (const child of children) {
|
||||||
|
if (child == null) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
el.appendChild(child)
|
||||||
|
}
|
||||||
|
|
||||||
|
return el
|
||||||
|
}
|
||||||
|
|
||||||
|
export function createElementNS(name, attrs, ...children) {
|
||||||
|
var svgns = "http://www.w3.org/2000/svg";
|
||||||
|
var el = document.createElementNS(svgns, name);
|
||||||
|
|
||||||
|
for (const attr in attrs) {
|
||||||
|
if(attr.startsWith("on")) {
|
||||||
|
el[attr] = attrs[attr];
|
||||||
|
} else {
|
||||||
|
el.setAttribute(attr, attrs[attr])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (const child of children) {
|
||||||
|
if (child == null) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
el.appendChild(child)
|
||||||
|
}
|
||||||
|
|
||||||
|
return el
|
||||||
|
}
|
||||||
|
|
||||||
|
export function p(attrs, ...children) {
|
||||||
|
return createElement("p", attrs, ...children)
|
||||||
|
}
|
||||||
|
|
||||||
|
export function span(attrs, ...children) {
|
||||||
|
return createElement("span", attrs, ...children)
|
||||||
|
}
|
||||||
|
|
||||||
|
export function div(attrs, ...children) {
|
||||||
|
return createElement("div", attrs, ...children)
|
||||||
|
}
|
||||||
|
|
||||||
|
export function a(attrs, ...children) {
|
||||||
|
return createElement("a", attrs, ...children)
|
||||||
|
}
|
||||||
|
|
||||||
|
export function i(attrs, ...children) {
|
||||||
|
return createElement("i", attrs, ...children)
|
||||||
|
}
|
||||||
|
|
||||||
|
export function form(attrs, ...children) {
|
||||||
|
return createElement("form", attrs, ...children)
|
||||||
|
}
|
||||||
|
|
||||||
|
export function img(alt, attrs, ...children) {
|
||||||
|
attrs['onerror'] = (event) => event.target.remove()
|
||||||
|
attrs['alt'] = alt
|
||||||
|
return createElement("img", attrs, ...children)
|
||||||
|
}
|
||||||
|
|
||||||
|
export function input(attrs, ...children) {
|
||||||
|
return createElement("input", attrs, ...children)
|
||||||
|
}
|
||||||
|
|
||||||
|
export function button(attrs, ...children) {
|
||||||
|
return createElement("button", attrs, ...children)
|
||||||
|
}
|
||||||
|
|
||||||
|
export function path(attrs, ...children) {
|
||||||
|
return createElementNS("path", attrs, ...children)
|
||||||
|
}
|
||||||
|
|
||||||
|
export function svg(attrs, ...children) {
|
||||||
|
return createElementNS("svg", attrs, ...children)
|
||||||
|
}
|
||||||
|
|
||||||
|
export function body(attrs, ...children) {
|
||||||
|
return createElement("body", attrs, ...children)
|
||||||
|
}
|
||||||
|
|
||||||
|
export function textarea(attrs, ...children) {
|
||||||
|
return createElement("textarea", attrs, ...children)
|
||||||
|
}
|
||||||
|
|
||||||
|
export function parse(input) {
|
||||||
|
const pattern = /^[ a-zA-Z0-9!@#$%^&*()_+\-=\[\]{};':"\\|,.<>\/?]*$/;
|
||||||
|
|
||||||
|
input = input + '';
|
||||||
|
|
||||||
|
if (!pattern.test(input)) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
const sanitized = input.replace(/</g, '<').replace(/>/g, '>');
|
||||||
|
return document.createRange().createContextualFragment(sanitized);
|
||||||
|
}
|
||||||
|
|
||||||
|
export function is_domain(domain) {
|
||||||
|
domain = domain.toLowerCase()
|
||||||
|
|
||||||
|
const pattern = /^[a-z0-9_\-.]*$/;
|
||||||
|
if (!pattern.test(domain)) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
let parts = domain.split('.').reverse()
|
||||||
|
for (const part of parts) {
|
||||||
|
if (part.length < 1) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (parts.length < 2 || parts[0].length < 2) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
const tld_pattern = /^[a-z]*$/;
|
||||||
|
if (!tld_pattern.test(parts[0])) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
21
public/login.html
Normal file
21
public/login.html
Normal file
|
@ -0,0 +1,21 @@
|
||||||
|
<!DOCTYPE html>
|
||||||
|
<html lang="en">
|
||||||
|
<head>
|
||||||
|
<meta charset="UTF-8">
|
||||||
|
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||||
|
<title>Wrapper - Login</title>
|
||||||
|
|
||||||
|
<meta name="author" content="Tyler Murphy">
|
||||||
|
<meta name="description" content="wrapper dns login">
|
||||||
|
|
||||||
|
<meta property="og:title" content="wrapper">
|
||||||
|
<meta property="og:description" content="wrapper dns login">
|
||||||
|
|
||||||
|
<link rel="stylesheet" href="/css/main.css">
|
||||||
|
<link rel="stylesheet" href="/css/login.css">
|
||||||
|
|
||||||
|
<script type="module" src="/js/login.js"></script>
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
</body>
|
||||||
|
</html>
|
9
public/robots.txt
Normal file
9
public/robots.txt
Normal file
|
@ -0,0 +1,9 @@
|
||||||
|
User-agent: Googlebot
|
||||||
|
Disallow: /api
|
||||||
|
|
||||||
|
User-agent: Googlebot
|
||||||
|
User-agent: AdsBot-Google
|
||||||
|
Disallow: /api
|
||||||
|
|
||||||
|
User-agent: *
|
||||||
|
Disallow: /api
|
89
readme.md
89
readme.md
|
@ -1,89 +0,0 @@
|
||||||
|
|
||||||
# Wrapper
|
|
||||||
|
|
||||||
A simple and lightweight dns server written in C
|
|
||||||
|
|
||||||
## How to
|
|
||||||
|
|
||||||
Wrapper by default runs on port 53 udp and tcp, which is the default port for DNS. If you wish to change this variable, set the `PORT` environment variable is set to a different port number.
|
|
||||||
|
|
||||||
To set custom records, wrapper reads configuration from `/etc/wrapper.conf`, and if that doesn't exist, `./config`.
|
|
||||||
|
|
||||||
The config file format is a question on its own line, then followed by records on their own line. To separate questions/records from each other, place a extra empty new line between them. For example...
|
|
||||||
|
|
||||||
```
|
|
||||||
IN A google.com
|
|
||||||
IN A 300 12.34.56.78
|
|
||||||
|
|
||||||
IN TXT joe
|
|
||||||
IN TXT 60 biden
|
|
||||||
```
|
|
||||||
|
|
||||||
### Question
|
|
||||||
|
|
||||||
Now to break this down piece by piece, the question is made up of a class, record type, and domain. Domain is self explanitory, but for the other two:
|
|
||||||
|
|
||||||
#### Class
|
|
||||||
|
|
||||||
The valid classes are `IN` (Internet), `CH` (Chaosnet), and `HS` (Hesiod). For most purposes your going to be using IN.
|
|
||||||
|
|
||||||
#### Record type
|
|
||||||
|
|
||||||
The current supported record types are `A`, `NS`, `CNAME`, `SOA`, `PTR`, `MX`, `TXT`, `AAAA`, `SRV`, and `CAA`.
|
|
||||||
|
|
||||||
### Answer
|
|
||||||
|
|
||||||
Answers are very similar to questions, they have a class, record type, but also have a Time to Live (TTL), and its record data. TTL is just the amount of seconds other DNS servers should cache the value, but for record data, it's formatted as such:
|
|
||||||
|
|
||||||
#### Record Data
|
|
||||||
|
|
||||||
- `A` ipv4 (0.0.0.0)
|
|
||||||
- `NS`, `CNAME`, `PTR` domain (google.com)
|
|
||||||
- `SOA` mname, nname serial refresh retry expire minimum (ns1.google.com. dns-admin.google.com. 523655285 900 900 1800 60)
|
|
||||||
- `MX`: priority domain (10 smtp.google.com)
|
|
||||||
- `TXT`: text (Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua)
|
|
||||||
- `AAAA` ipv6 (2607:f8b0:4006:080c:0000:0000:0000:200e::)
|
|
||||||
- `SRV` priority weight port domain (10 10 10 example.com)
|
|
||||||
- `CAA` flags tag value (0 issue "pki.goog")
|
|
||||||
|
|
||||||
Wrapper also has a few joke/meme records for fun. Note these should never acutaly be used for general use, but they are funny.
|
|
||||||
|
|
||||||
- `AR`, `AAAAR`
|
|
||||||
- `CMD` command (neofetch)
|
|
||||||
|
|
||||||
`AR` and `AAAAR` have no record data since they generate ipv4's and ipv6's respectively, and turn into `A` and `AAAA` upon response to sender
|
|
||||||
|
|
||||||
`CMD` runs the command supplied on the host system and returns the std output as a `TXT` record
|
|
||||||
|
|
||||||
## License
|
|
||||||
|
|
||||||
This project is Licensed under the [GPLv3](https://www.gnu.org/licenses/gpl-3.0.en.html)
|
|
||||||
|
|
||||||
## Compilation
|
|
||||||
|
|
||||||
Wrapper only runs on Linux systems that are Posix 1995 compliant
|
|
||||||
|
|
||||||
Make sure to have `gcc` and `make` installed, and then run
|
|
||||||
|
|
||||||
```shell
|
|
||||||
$ make # compiles the program
|
|
||||||
$ sudo make install # installs the binary
|
|
||||||
```
|
|
||||||
|
|
||||||
If you are running openrc, there is a premade service file so you can run
|
|
||||||
|
|
||||||
```shell
|
|
||||||
$ sudo make install-openrc # installs the binary and service file
|
|
||||||
```
|
|
||||||
|
|
||||||
If you wish to remove the program, you can run
|
|
||||||
|
|
||||||
```shell
|
|
||||||
$ sudo make uninstall # removes the binary
|
|
||||||
```
|
|
||||||
|
|
||||||
Or again if your running openrc
|
|
||||||
|
|
||||||
```shell
|
|
||||||
$ sudo make uninstall-openrc # removes the binary and service file
|
|
||||||
```
|
|
|
@ -1,13 +0,0 @@
|
||||||
#!/sbin/openrc-run
|
|
||||||
name="WrapperDNS"
|
|
||||||
description="Wrapper dns server"
|
|
||||||
command=/usr/local/bin/wrapper
|
|
||||||
supervisor=supervise-daemon
|
|
||||||
|
|
||||||
depend() {
|
|
||||||
need net
|
|
||||||
}
|
|
||||||
|
|
||||||
#start_pre() {
|
|
||||||
# export PORT=53
|
|
||||||
#}
|
|
57
src/config.rs
Normal file
57
src/config.rs
Normal file
|
@ -0,0 +1,57 @@
|
||||||
|
use std::{env, net::IpAddr, str::FromStr, fmt::Display};
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct Config {
|
||||||
|
pub dns_fallback: IpAddr,
|
||||||
|
pub dns_port: u16,
|
||||||
|
pub dns_cache_size: u64,
|
||||||
|
|
||||||
|
pub db_host: String,
|
||||||
|
pub db_port: u16,
|
||||||
|
pub db_user: String,
|
||||||
|
pub db_pass: String,
|
||||||
|
|
||||||
|
pub web_user: String,
|
||||||
|
pub web_pass: String,
|
||||||
|
pub web_port: u16,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Config {
|
||||||
|
pub fn new() -> Self {
|
||||||
|
let dns_port = Self::get_var::<u16>("WRAPPER_DNS_PORT", 53);
|
||||||
|
let dns_fallback = Self::get_var::<IpAddr>("WRAPPER_FALLBACK_DNS", [9, 9, 9, 9].into());
|
||||||
|
let dns_cache_size = Self::get_var::<u64>("WRAPPER_CACHE_SIZE", 1000);
|
||||||
|
|
||||||
|
let db_host = Self::get_var::<String>("WRAPPER_DB_HOST", String::from("localhost"));
|
||||||
|
let db_port = Self::get_var::<u16>("WRAPPER_DB_PORT", 27017);
|
||||||
|
let db_user = Self::get_var::<String>("WRAPPER_DB_USER", String::from("root"));
|
||||||
|
let db_pass = Self::get_var::<String>("WRAPPER_DB_PASS", String::from(""));
|
||||||
|
|
||||||
|
let web_user = Self::get_var::<String>("WRAPPER_WEB_USER", String::from("admin"));
|
||||||
|
let web_pass = Self::get_var::<String>("WRAPPER_WEB_PASS", String::from("wrapper"));
|
||||||
|
let web_port = Self::get_var::<u16>("WRAPPER_WEB_PORT", 80);
|
||||||
|
|
||||||
|
Self {
|
||||||
|
dns_fallback,
|
||||||
|
dns_port,
|
||||||
|
dns_cache_size,
|
||||||
|
|
||||||
|
db_host,
|
||||||
|
db_port,
|
||||||
|
db_user,
|
||||||
|
db_pass,
|
||||||
|
|
||||||
|
web_user,
|
||||||
|
web_pass,
|
||||||
|
web_port,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_var<T>(name: &str, default: T) -> T
|
||||||
|
where
|
||||||
|
T: FromStr + Display,
|
||||||
|
{
|
||||||
|
let env = env::var(name).unwrap_or(format!("{default}"));
|
||||||
|
env.parse::<T>().unwrap_or(default)
|
||||||
|
}
|
||||||
|
}
|
146
src/database/mod.rs
Normal file
146
src/database/mod.rs
Normal file
|
@ -0,0 +1,146 @@
|
||||||
|
use futures::TryStreamExt;
|
||||||
|
use mongodb::{
|
||||||
|
bson::doc,
|
||||||
|
options::{ClientOptions, Credential, ServerAddress},
|
||||||
|
Client,
|
||||||
|
};
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use tracing::info;
|
||||||
|
|
||||||
|
use crate::{
|
||||||
|
config::Config,
|
||||||
|
dns::packet::{query::QueryType, record::DnsRecord},
|
||||||
|
};
|
||||||
|
|
||||||
|
use crate::Result;
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct Database {
|
||||||
|
client: Client,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
|
pub struct StoredRecord {
|
||||||
|
record: DnsRecord,
|
||||||
|
domain: String,
|
||||||
|
prefix: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl StoredRecord {
|
||||||
|
fn get_domain_parts(domain: &str) -> (String, String) {
|
||||||
|
let parts: Vec<&str> = domain.split(".").collect();
|
||||||
|
let len = parts.len();
|
||||||
|
if len == 1 {
|
||||||
|
(String::new(), String::from(parts[0]))
|
||||||
|
} else if len == 2 {
|
||||||
|
(String::new(), String::from(parts.join(".")))
|
||||||
|
} else {
|
||||||
|
(
|
||||||
|
String::from(parts[0..len - 2].join(".")),
|
||||||
|
String::from(parts[len - 2..len].join(".")),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<DnsRecord> for StoredRecord {
|
||||||
|
fn from(record: DnsRecord) -> Self {
|
||||||
|
let (prefix, domain) = Self::get_domain_parts(&record.get_domain());
|
||||||
|
Self {
|
||||||
|
record,
|
||||||
|
domain,
|
||||||
|
prefix,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Into<DnsRecord> for StoredRecord {
|
||||||
|
fn into(self) -> DnsRecord {
|
||||||
|
self.record
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Database {
|
||||||
|
pub async fn new(config: Config) -> Result<Self> {
|
||||||
|
let options = ClientOptions::builder()
|
||||||
|
.hosts(vec![ServerAddress::Tcp {
|
||||||
|
host: config.db_host,
|
||||||
|
port: Some(config.db_port),
|
||||||
|
}])
|
||||||
|
.credential(
|
||||||
|
Credential::builder()
|
||||||
|
.username(config.db_user)
|
||||||
|
.password(config.db_pass)
|
||||||
|
.build(),
|
||||||
|
)
|
||||||
|
.max_pool_size(100)
|
||||||
|
.app_name(String::from("wrapper"))
|
||||||
|
.build();
|
||||||
|
|
||||||
|
let client = Client::with_options(options)?;
|
||||||
|
|
||||||
|
client
|
||||||
|
.database("wrapper")
|
||||||
|
.run_command(doc! {"ping": 1}, None)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
info!("Connection to mongodb successfully");
|
||||||
|
|
||||||
|
Ok(Database { client })
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn get_records(&self, domain: &str, qtype: QueryType) -> Result<Vec<DnsRecord>> {
|
||||||
|
let (prefix, domain) = StoredRecord::get_domain_parts(domain);
|
||||||
|
Ok(self
|
||||||
|
.get_domain(&domain)
|
||||||
|
.await?
|
||||||
|
.into_iter()
|
||||||
|
.filter(|r| r.prefix == prefix)
|
||||||
|
.filter(|r| {
|
||||||
|
let rqtype = r.record.get_qtype();
|
||||||
|
if qtype == QueryType::A {
|
||||||
|
return rqtype == QueryType::A || rqtype == QueryType::AR;
|
||||||
|
} else if qtype == QueryType::AAAA {
|
||||||
|
return rqtype == QueryType::AAAA || rqtype == QueryType::AAAAR;
|
||||||
|
} else {
|
||||||
|
r.record.get_qtype() == qtype
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.map(|r| r.into())
|
||||||
|
.collect())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn get_domain(&self, domain: &str) -> Result<Vec<StoredRecord>> {
|
||||||
|
let db = self.client.database("wrapper");
|
||||||
|
let col = db.collection::<StoredRecord>(domain);
|
||||||
|
|
||||||
|
let filter = doc! { "domain": domain };
|
||||||
|
let mut cursor = col.find(filter, None).await?;
|
||||||
|
|
||||||
|
let mut records = Vec::new();
|
||||||
|
while let Some(record) = cursor.try_next().await? {
|
||||||
|
records.push(record);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(records)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn add_record(&self, record: DnsRecord) -> Result<()> {
|
||||||
|
let record = StoredRecord::from(record);
|
||||||
|
let db = self.client.database("wrapper");
|
||||||
|
let col = db.collection::<StoredRecord>(&record.domain);
|
||||||
|
col.insert_one(record, None).await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn get_domains(&self) -> Result<Vec<String>> {
|
||||||
|
let db = self.client.database("wrapper");
|
||||||
|
Ok(db.list_collection_names(None).await?)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn delete_domain(&self, domain: String) -> Result<()> {
|
||||||
|
let db = self.client.database("wrapper");
|
||||||
|
let col = db.collection::<StoredRecord>(&domain);
|
||||||
|
Ok(col.drop(None).await?)
|
||||||
|
}
|
||||||
|
}
|
144
src/dns/binding.rs
Normal file
144
src/dns/binding.rs
Normal file
|
@ -0,0 +1,144 @@
|
||||||
|
use std::{
|
||||||
|
net::{IpAddr, SocketAddr},
|
||||||
|
sync::Arc,
|
||||||
|
};
|
||||||
|
|
||||||
|
use super::packet::{buffer::PacketBuffer, Packet};
|
||||||
|
use crate::Result;
|
||||||
|
use tokio::{
|
||||||
|
io::{AsyncReadExt, AsyncWriteExt},
|
||||||
|
net::{TcpListener, TcpStream, UdpSocket},
|
||||||
|
};
|
||||||
|
use tracing::trace;
|
||||||
|
|
||||||
|
pub enum Binding {
|
||||||
|
UDP(Arc<UdpSocket>),
|
||||||
|
TCP(TcpListener),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Binding {
|
||||||
|
pub async fn udp(addr: SocketAddr) -> Result<Self> {
|
||||||
|
let socket = UdpSocket::bind(addr).await?;
|
||||||
|
Ok(Self::UDP(Arc::new(socket)))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn tcp(addr: SocketAddr) -> Result<Self> {
|
||||||
|
let socket = TcpListener::bind(addr).await?;
|
||||||
|
Ok(Self::TCP(socket))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn name(&self) -> &str {
|
||||||
|
match self {
|
||||||
|
Binding::UDP(_) => "UDP",
|
||||||
|
Binding::TCP(_) => "TCP",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn connect(&mut self) -> Result<Connection> {
|
||||||
|
match self {
|
||||||
|
Self::UDP(socket) => {
|
||||||
|
let mut buf = [0; 512];
|
||||||
|
let (_, addr) = socket.recv_from(&mut buf).await?;
|
||||||
|
Ok(Connection::UDP(socket.clone(), addr, buf))
|
||||||
|
}
|
||||||
|
Self::TCP(socket) => {
|
||||||
|
let (stream, _) = socket.accept().await?;
|
||||||
|
Ok(Connection::TCP(stream))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub enum Connection {
|
||||||
|
UDP(Arc<UdpSocket>, SocketAddr, [u8; 512]),
|
||||||
|
TCP(TcpStream),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Connection {
|
||||||
|
pub async fn read_packet(&mut self) -> Result<Packet> {
|
||||||
|
let data = self.read().await?;
|
||||||
|
let mut packet_buffer = PacketBuffer::new(data);
|
||||||
|
|
||||||
|
let packet = Packet::from_buffer(&mut packet_buffer)?;
|
||||||
|
Ok(packet)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn write_packet(self, mut packet: Packet) -> Result<()> {
|
||||||
|
let mut packet_buffer = PacketBuffer::new(Vec::new());
|
||||||
|
packet.write(&mut packet_buffer)?;
|
||||||
|
|
||||||
|
self.write(packet_buffer.buf).await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn request_packet(&self, mut packet: Packet, dest: (IpAddr, u16)) -> Result<Packet> {
|
||||||
|
let mut packet_buffer = PacketBuffer::new(Vec::new());
|
||||||
|
packet.write(&mut packet_buffer)?;
|
||||||
|
|
||||||
|
let data = self.request(packet_buffer.buf, dest).await?;
|
||||||
|
let mut packet_buffer = PacketBuffer::new(data);
|
||||||
|
|
||||||
|
let packet = Packet::from_buffer(&mut packet_buffer)?;
|
||||||
|
Ok(packet)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn read(&mut self) -> Result<Vec<u8>> {
|
||||||
|
trace!("Reading DNS packet");
|
||||||
|
match self {
|
||||||
|
Self::UDP(_, _, src) => Ok(Vec::from(*src)),
|
||||||
|
Self::TCP(stream) => {
|
||||||
|
let size = stream.read_u16().await?;
|
||||||
|
let mut buf = Vec::with_capacity(size as usize);
|
||||||
|
stream.read_buf(&mut buf).await?;
|
||||||
|
Ok(buf)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn write(self, mut buf: Vec<u8>) -> Result<()> {
|
||||||
|
trace!("Returning DNS packet");
|
||||||
|
match self {
|
||||||
|
Self::UDP(socket, addr, _) => {
|
||||||
|
if buf.len() > 512 {
|
||||||
|
buf[2] = buf[2] | 0x03;
|
||||||
|
socket.send_to(&buf[0..512], addr).await?;
|
||||||
|
} else {
|
||||||
|
socket.send_to(&buf, addr).await?;
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
Self::TCP(mut stream) => {
|
||||||
|
stream.write_u16(buf.len() as u16).await?;
|
||||||
|
stream.write(&buf[0..buf.len()]).await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn request(&self, buf: Vec<u8>, dest: (IpAddr, u16)) -> Result<Vec<u8>> {
|
||||||
|
match self {
|
||||||
|
Self::UDP(_socket, _addr, _src) => {
|
||||||
|
let local_addr = "[::]:0".parse::<SocketAddr>()?;
|
||||||
|
let socket = UdpSocket::bind(local_addr).await?;
|
||||||
|
socket.send_to(&buf, dest).await?;
|
||||||
|
|
||||||
|
let mut buf = [0; 512];
|
||||||
|
socket.recv_from(&mut buf).await?;
|
||||||
|
|
||||||
|
Ok(Vec::from(buf))
|
||||||
|
}
|
||||||
|
Self::TCP(_stream) => {
|
||||||
|
let mut stream = TcpStream::connect(dest).await?;
|
||||||
|
stream.write_u16((buf.len()) as u16).await?;
|
||||||
|
stream.write_all(&buf[0..buf.len()]).await?;
|
||||||
|
|
||||||
|
stream.readable().await?;
|
||||||
|
let size = stream.read_u16().await?;
|
||||||
|
let mut buf = Vec::with_capacity(size as usize);
|
||||||
|
stream.read_buf(&mut buf).await?;
|
||||||
|
|
||||||
|
Ok(buf)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
4
src/dns/mod.rs
Normal file
4
src/dns/mod.rs
Normal file
|
@ -0,0 +1,4 @@
|
||||||
|
mod binding;
|
||||||
|
pub mod packet;
|
||||||
|
mod resolver;
|
||||||
|
pub mod server;
|
228
src/dns/packet/buffer.rs
Normal file
228
src/dns/packet/buffer.rs
Normal file
|
@ -0,0 +1,228 @@
|
||||||
|
use crate::Result;
|
||||||
|
|
||||||
|
pub struct PacketBuffer {
|
||||||
|
pub buf: Vec<u8>,
|
||||||
|
pub pos: usize,
|
||||||
|
pub size: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl PacketBuffer {
|
||||||
|
pub fn new(buf: Vec<u8>) -> Self {
|
||||||
|
Self {
|
||||||
|
size: buf.len(),
|
||||||
|
buf,
|
||||||
|
pos: 0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn pos(&self) -> usize {
|
||||||
|
self.pos
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn step(&mut self, steps: usize) -> Result<()> {
|
||||||
|
self.pos += steps;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn seek(&mut self, pos: usize) -> Result<()> {
|
||||||
|
self.pos = pos;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn read(&mut self) -> Result<u8> {
|
||||||
|
if self.pos >= self.size {
|
||||||
|
return Err("Tried to read past end of buffer".into());
|
||||||
|
}
|
||||||
|
let res = self.buf[self.pos];
|
||||||
|
self.pos += 1;
|
||||||
|
Ok(res)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get(&mut self, pos: usize) -> Result<u8> {
|
||||||
|
if pos >= self.size {
|
||||||
|
return Err("Tried to read past end of buffer".into());
|
||||||
|
}
|
||||||
|
Ok(self.buf[pos])
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_range(&mut self, start: usize, len: usize) -> Result<&[u8]> {
|
||||||
|
if start + len >= self.size {
|
||||||
|
return Err("Tried to read past end of buffer".into());
|
||||||
|
}
|
||||||
|
Ok(&self.buf[start..start + len])
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn read_u16(&mut self) -> Result<u16> {
|
||||||
|
let res = ((self.read()? as u16) << 8) | (self.read()? as u16);
|
||||||
|
|
||||||
|
Ok(res)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn read_u32(&mut self) -> Result<u32> {
|
||||||
|
let res = ((self.read()? as u32) << 24)
|
||||||
|
| ((self.read()? as u32) << 16)
|
||||||
|
| ((self.read()? as u32) << 8)
|
||||||
|
| (self.read()? as u32);
|
||||||
|
|
||||||
|
Ok(res)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn read_qname(&mut self, outstr: &mut String) -> Result<()> {
|
||||||
|
let mut pos = self.pos();
|
||||||
|
let mut jumped = false;
|
||||||
|
|
||||||
|
let mut delim = "";
|
||||||
|
let max_jumps = 5;
|
||||||
|
let mut jumps_performed = 0;
|
||||||
|
loop {
|
||||||
|
// Dns Packets are untrusted data, so we need to be paranoid. Someone
|
||||||
|
// can craft a packet with a cycle in the jump instructions. This guards
|
||||||
|
// against such packets.
|
||||||
|
if jumps_performed > max_jumps {
|
||||||
|
return Err(format!("Limit of {max_jumps} jumps exceeded").into());
|
||||||
|
}
|
||||||
|
|
||||||
|
let len = self.get(pos)?;
|
||||||
|
|
||||||
|
if (len & 0xC0) == 0xC0 {
|
||||||
|
if !jumped {
|
||||||
|
self.seek(pos + 2)?;
|
||||||
|
}
|
||||||
|
|
||||||
|
let b2 = self.get(pos + 1)? as u16;
|
||||||
|
let offset = (((len as u16) ^ 0xC0) << 8) | b2;
|
||||||
|
pos = offset as usize;
|
||||||
|
jumped = true;
|
||||||
|
jumps_performed += 1;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
pos += 1;
|
||||||
|
|
||||||
|
if len == 0 {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
outstr.push_str(delim);
|
||||||
|
|
||||||
|
let str_buffer = self.get_range(pos, len as usize)?;
|
||||||
|
outstr.push_str(&String::from_utf8_lossy(str_buffer).to_lowercase());
|
||||||
|
|
||||||
|
delim = ".";
|
||||||
|
|
||||||
|
pos += len as usize;
|
||||||
|
}
|
||||||
|
|
||||||
|
if !jumped {
|
||||||
|
self.seek(pos)?;
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn read_string(&mut self, outstr: &mut String) -> Result<()> {
|
||||||
|
let len = self.read()?;
|
||||||
|
|
||||||
|
self.read_string_n(outstr, len)?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn read_string_n(&mut self, outstr: &mut String, len: u8) -> Result<()> {
|
||||||
|
let mut pos = self.pos;
|
||||||
|
|
||||||
|
let str_buffer = self.get_range(pos, len as usize)?;
|
||||||
|
|
||||||
|
let mut i = 0;
|
||||||
|
for b in str_buffer {
|
||||||
|
let c = *b as char;
|
||||||
|
if c == '\0' {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
outstr.push(c);
|
||||||
|
i += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
pos += i;
|
||||||
|
self.seek(pos)?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn write(&mut self, val: u8) -> Result<()> {
|
||||||
|
if self.size < self.pos {
|
||||||
|
self.size = self.pos;
|
||||||
|
}
|
||||||
|
|
||||||
|
if self.buf.len() <= self.size {
|
||||||
|
self.buf.resize(self.size + 1, 0x00);
|
||||||
|
}
|
||||||
|
|
||||||
|
self.buf[self.pos] = val;
|
||||||
|
self.pos += 1;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn write_u8(&mut self, val: u8) -> Result<()> {
|
||||||
|
self.write(val)?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn write_u16(&mut self, val: u16) -> Result<()> {
|
||||||
|
self.write((val >> 8) as u8)?;
|
||||||
|
self.write((val & 0xFF) as u8)?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn write_u32(&mut self, val: u32) -> Result<()> {
|
||||||
|
self.write(((val >> 24) & 0xFF) as u8)?;
|
||||||
|
self.write(((val >> 16) & 0xFF) as u8)?;
|
||||||
|
self.write(((val >> 8) & 0xFF) as u8)?;
|
||||||
|
self.write((val & 0xFF) as u8)?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn write_qname(&mut self, qname: &str) -> Result<()> {
|
||||||
|
for label in qname.split('.') {
|
||||||
|
let len = label.len();
|
||||||
|
|
||||||
|
self.write_u8(len as u8)?;
|
||||||
|
for b in label.as_bytes() {
|
||||||
|
self.write_u8(*b)?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !qname.is_empty() {
|
||||||
|
self.write_u8(0)?;
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn write_string(&mut self, text: &str) -> Result<()> {
|
||||||
|
for b in text.as_bytes() {
|
||||||
|
self.write_u8(*b)?;
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn set(&mut self, pos: usize, val: u8) -> Result<()> {
|
||||||
|
self.buf[pos] = val;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn set_u16(&mut self, pos: usize, val: u16) -> Result<()> {
|
||||||
|
self.set(pos, (val >> 8) as u8)?;
|
||||||
|
self.set(pos + 1, (val & 0xFF) as u8)?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
102
src/dns/packet/header.rs
Normal file
102
src/dns/packet/header.rs
Normal file
|
@ -0,0 +1,102 @@
|
||||||
|
use super::{buffer::PacketBuffer, result::ResultCode};
|
||||||
|
use crate::Result;
|
||||||
|
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
|
pub struct DnsHeader {
|
||||||
|
pub id: u16, // 16 bits
|
||||||
|
|
||||||
|
pub recursion_desired: bool, // 1 bit
|
||||||
|
pub truncated_message: bool, // 1 bit
|
||||||
|
pub authoritative_answer: bool, // 1 bit
|
||||||
|
pub opcode: u8, // 4 bits
|
||||||
|
pub response: bool, // 1 bit
|
||||||
|
|
||||||
|
pub rescode: ResultCode, // 4 bits
|
||||||
|
pub checking_disabled: bool, // 1 bit
|
||||||
|
pub authed_data: bool, // 1 bit
|
||||||
|
pub z: bool, // 1 bit
|
||||||
|
pub recursion_available: bool, // 1 bit
|
||||||
|
|
||||||
|
pub questions: u16, // 16 bits
|
||||||
|
pub answers: u16, // 16 bits
|
||||||
|
pub authoritative_entries: u16, // 16 bits
|
||||||
|
pub resource_entries: u16, // 16 bits
|
||||||
|
}
|
||||||
|
|
||||||
|
impl DnsHeader {
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
id: 0,
|
||||||
|
|
||||||
|
recursion_desired: false,
|
||||||
|
truncated_message: false,
|
||||||
|
authoritative_answer: false,
|
||||||
|
opcode: 0,
|
||||||
|
response: false,
|
||||||
|
|
||||||
|
rescode: ResultCode::NOERROR,
|
||||||
|
checking_disabled: false,
|
||||||
|
authed_data: false,
|
||||||
|
z: false,
|
||||||
|
recursion_available: false,
|
||||||
|
|
||||||
|
questions: 0,
|
||||||
|
answers: 0,
|
||||||
|
authoritative_entries: 0,
|
||||||
|
resource_entries: 0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn read(&mut self, buffer: &mut PacketBuffer) -> Result<()> {
|
||||||
|
self.id = buffer.read_u16()?;
|
||||||
|
let flags = buffer.read_u16()?;
|
||||||
|
let a = (flags >> 8) as u8;
|
||||||
|
let b = (flags & 0xFF) as u8;
|
||||||
|
self.recursion_desired = (a & (1 << 0)) > 0;
|
||||||
|
self.truncated_message = (a & (1 << 1)) > 0;
|
||||||
|
self.authoritative_answer = (a & (1 << 2)) > 0;
|
||||||
|
self.opcode = (a >> 3) & 0x0F;
|
||||||
|
self.response = (a & (1 << 7)) > 0;
|
||||||
|
|
||||||
|
self.rescode = ResultCode::from_num(b & 0x0F);
|
||||||
|
self.checking_disabled = (b & (1 << 4)) > 0;
|
||||||
|
self.authed_data = (b & (1 << 5)) > 0;
|
||||||
|
self.z = (b & (1 << 6)) > 0;
|
||||||
|
self.recursion_available = (b & (1 << 7)) > 0;
|
||||||
|
|
||||||
|
self.questions = buffer.read_u16()?;
|
||||||
|
self.answers = buffer.read_u16()?;
|
||||||
|
self.authoritative_entries = buffer.read_u16()?;
|
||||||
|
self.resource_entries = buffer.read_u16()?;
|
||||||
|
|
||||||
|
// Return the constant header size
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn write(&self, buffer: &mut PacketBuffer) -> Result<()> {
|
||||||
|
buffer.write_u16(self.id)?;
|
||||||
|
|
||||||
|
buffer.write_u8(
|
||||||
|
(self.recursion_desired as u8)
|
||||||
|
| ((self.truncated_message as u8) << 1)
|
||||||
|
| ((self.authoritative_answer as u8) << 2)
|
||||||
|
| (self.opcode << 3)
|
||||||
|
| ((self.response as u8) << 7),
|
||||||
|
)?;
|
||||||
|
|
||||||
|
buffer.write_u8(
|
||||||
|
(self.rescode as u8)
|
||||||
|
| ((self.checking_disabled as u8) << 4)
|
||||||
|
| ((self.authed_data as u8) << 5)
|
||||||
|
| ((self.z as u8) << 6)
|
||||||
|
| ((self.recursion_available as u8) << 7),
|
||||||
|
)?;
|
||||||
|
|
||||||
|
buffer.write_u16(self.questions)?;
|
||||||
|
buffer.write_u16(self.answers)?;
|
||||||
|
buffer.write_u16(self.authoritative_entries)?;
|
||||||
|
buffer.write_u16(self.resource_entries)?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
128
src/dns/packet/mod.rs
Normal file
128
src/dns/packet/mod.rs
Normal file
|
@ -0,0 +1,128 @@
|
||||||
|
use std::net::IpAddr;
|
||||||
|
|
||||||
|
use self::{
|
||||||
|
buffer::PacketBuffer, header::DnsHeader, query::QueryType, question::DnsQuestion,
|
||||||
|
record::DnsRecord,
|
||||||
|
};
|
||||||
|
use crate::Result;
|
||||||
|
|
||||||
|
pub mod buffer;
|
||||||
|
pub mod header;
|
||||||
|
pub mod query;
|
||||||
|
pub mod question;
|
||||||
|
pub mod record;
|
||||||
|
pub mod result;
|
||||||
|
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
|
pub struct Packet {
|
||||||
|
pub header: DnsHeader,
|
||||||
|
pub questions: Vec<DnsQuestion>,
|
||||||
|
pub answers: Vec<DnsRecord>,
|
||||||
|
pub authorities: Vec<DnsRecord>,
|
||||||
|
pub resources: Vec<DnsRecord>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Packet {
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
header: DnsHeader::new(),
|
||||||
|
questions: Vec::new(),
|
||||||
|
answers: Vec::new(),
|
||||||
|
authorities: Vec::new(),
|
||||||
|
resources: Vec::new(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn from_buffer(buffer: &mut PacketBuffer) -> Result<Self> {
|
||||||
|
let mut result = Self::new();
|
||||||
|
result.header.read(buffer)?;
|
||||||
|
|
||||||
|
for _ in 0..result.header.questions {
|
||||||
|
let mut question = DnsQuestion::new("".to_string(), QueryType::UNKNOWN(0));
|
||||||
|
question.read(buffer)?;
|
||||||
|
result.questions.push(question);
|
||||||
|
}
|
||||||
|
|
||||||
|
for _ in 0..result.header.answers {
|
||||||
|
let rec = DnsRecord::read(buffer)?;
|
||||||
|
result.answers.push(rec);
|
||||||
|
}
|
||||||
|
for _ in 0..result.header.authoritative_entries {
|
||||||
|
let rec = DnsRecord::read(buffer)?;
|
||||||
|
result.authorities.push(rec);
|
||||||
|
}
|
||||||
|
for _ in 0..result.header.resource_entries {
|
||||||
|
let rec = DnsRecord::read(buffer)?;
|
||||||
|
result.resources.push(rec);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(result)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn write(&mut self, buffer: &mut PacketBuffer) -> Result<()> {
|
||||||
|
self.header.questions = self.questions.len() as u16;
|
||||||
|
self.header.answers = self.answers.len() as u16;
|
||||||
|
self.header.authoritative_entries = self.authorities.len() as u16;
|
||||||
|
self.header.resource_entries = self.resources.len() as u16;
|
||||||
|
|
||||||
|
self.header.write(buffer)?;
|
||||||
|
|
||||||
|
for question in &self.questions {
|
||||||
|
question.write(buffer)?;
|
||||||
|
}
|
||||||
|
for rec in &self.answers {
|
||||||
|
rec.write(buffer)?;
|
||||||
|
}
|
||||||
|
for rec in &self.authorities {
|
||||||
|
rec.write(buffer)?;
|
||||||
|
}
|
||||||
|
for rec in &self.resources {
|
||||||
|
rec.write(buffer)?;
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_random_a(&self) -> Option<IpAddr> {
|
||||||
|
self.answers
|
||||||
|
.iter()
|
||||||
|
.filter_map(|record| match record {
|
||||||
|
DnsRecord::A { addr, .. } => Some(IpAddr::V4(*addr)),
|
||||||
|
DnsRecord::AAAA { addr, .. } => Some(IpAddr::V6(*addr)),
|
||||||
|
_ => None,
|
||||||
|
})
|
||||||
|
.next()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_ns<'a>(&'a self, qname: &'a str) -> impl Iterator<Item = (&'a str, &'a str)> {
|
||||||
|
self.authorities
|
||||||
|
.iter()
|
||||||
|
.filter_map(|record| match record {
|
||||||
|
DnsRecord::NS { domain, host, .. } => Some((domain.as_str(), host.as_str())),
|
||||||
|
_ => None,
|
||||||
|
})
|
||||||
|
.filter(move |(domain, _)| qname.ends_with(*domain))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_resolved_ns(&self, qname: &str) -> Option<IpAddr> {
|
||||||
|
self.get_ns(qname)
|
||||||
|
.flat_map(|(_, host)| {
|
||||||
|
self.resources
|
||||||
|
.iter()
|
||||||
|
.filter_map(move |record| match record {
|
||||||
|
DnsRecord::A { domain, addr, .. } if domain == host => {
|
||||||
|
Some(IpAddr::V4(*addr))
|
||||||
|
}
|
||||||
|
DnsRecord::AAAA { domain, addr, .. } if domain == host => {
|
||||||
|
Some(IpAddr::V6(*addr))
|
||||||
|
}
|
||||||
|
_ => None,
|
||||||
|
})
|
||||||
|
})
|
||||||
|
.next()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_unresolved_ns<'a>(&'a self, qname: &'a str) -> Option<&'a str> {
|
||||||
|
self.get_ns(qname).map(|(_, host)| host).next()
|
||||||
|
}
|
||||||
|
}
|
78
src/dns/packet/query.rs
Normal file
78
src/dns/packet/query.rs
Normal file
|
@ -0,0 +1,78 @@
|
||||||
|
#[derive(PartialEq, Eq, Debug, Clone, Hash, Copy)]
|
||||||
|
pub enum QueryType {
|
||||||
|
UNKNOWN(u16),
|
||||||
|
A, // 1
|
||||||
|
NS, // 2
|
||||||
|
CNAME, // 5
|
||||||
|
SOA, // 6
|
||||||
|
PTR, // 12
|
||||||
|
MX, // 15
|
||||||
|
TXT, // 16
|
||||||
|
AAAA, // 28
|
||||||
|
SRV, // 33
|
||||||
|
OPT, // 41
|
||||||
|
CAA, // 257
|
||||||
|
AR, // 1000
|
||||||
|
AAAAR, // 1001
|
||||||
|
}
|
||||||
|
|
||||||
|
impl QueryType {
|
||||||
|
pub fn to_num(&self) -> u16 {
|
||||||
|
match *self {
|
||||||
|
Self::UNKNOWN(x) => x,
|
||||||
|
Self::A => 1,
|
||||||
|
Self::NS => 2,
|
||||||
|
Self::CNAME => 5,
|
||||||
|
Self::SOA => 6,
|
||||||
|
Self::PTR => 12,
|
||||||
|
Self::MX => 15,
|
||||||
|
Self::TXT => 16,
|
||||||
|
Self::AAAA => 28,
|
||||||
|
Self::SRV => 33,
|
||||||
|
Self::OPT => 41,
|
||||||
|
Self::CAA => 257,
|
||||||
|
Self::AR => 1000,
|
||||||
|
Self::AAAAR => 1001,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn from_num(num: u16) -> Self {
|
||||||
|
match num {
|
||||||
|
1 => Self::A,
|
||||||
|
2 => Self::NS,
|
||||||
|
5 => Self::CNAME,
|
||||||
|
6 => Self::SOA,
|
||||||
|
12 => Self::PTR,
|
||||||
|
15 => Self::MX,
|
||||||
|
16 => Self::TXT,
|
||||||
|
28 => Self::AAAA,
|
||||||
|
33 => Self::SRV,
|
||||||
|
41 => Self::OPT,
|
||||||
|
257 => Self::CAA,
|
||||||
|
1000 => Self::AR,
|
||||||
|
1001 => Self::AAAAR,
|
||||||
|
_ => Self::UNKNOWN(num),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn allowed_actions(&self) -> (bool, bool) {
|
||||||
|
// 0. duplicates allowed
|
||||||
|
// 1. allowed to be created by database
|
||||||
|
match self {
|
||||||
|
QueryType::UNKNOWN(_) => (false, false),
|
||||||
|
QueryType::A => (true, true),
|
||||||
|
QueryType::NS => (false, true),
|
||||||
|
QueryType::CNAME => (false, true),
|
||||||
|
QueryType::SOA => (false, false),
|
||||||
|
QueryType::PTR => (false, true),
|
||||||
|
QueryType::MX => (false, true),
|
||||||
|
QueryType::TXT => (true, true),
|
||||||
|
QueryType::AAAA => (true, true),
|
||||||
|
QueryType::SRV => (false, true),
|
||||||
|
QueryType::OPT => (false, false),
|
||||||
|
QueryType::CAA => (false, true),
|
||||||
|
QueryType::AR => (false, true),
|
||||||
|
QueryType::AAAAR => (false, true),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
31
src/dns/packet/question.rs
Normal file
31
src/dns/packet/question.rs
Normal file
|
@ -0,0 +1,31 @@
|
||||||
|
use super::{buffer::PacketBuffer, query::QueryType, Result};
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
||||||
|
pub struct DnsQuestion {
|
||||||
|
pub name: String,
|
||||||
|
pub qtype: QueryType,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl DnsQuestion {
|
||||||
|
pub fn new(name: String, qtype: QueryType) -> Self {
|
||||||
|
Self { name, qtype }
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn read(&mut self, buffer: &mut PacketBuffer) -> Result<()> {
|
||||||
|
buffer.read_qname(&mut self.name)?;
|
||||||
|
self.qtype = QueryType::from_num(buffer.read_u16()?); // qtype
|
||||||
|
let _ = buffer.read_u16()?; // class
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn write(&self, buffer: &mut PacketBuffer) -> Result<()> {
|
||||||
|
buffer.write_qname(&self.name)?;
|
||||||
|
|
||||||
|
let typenum = self.qtype.to_num();
|
||||||
|
buffer.write_u16(typenum)?;
|
||||||
|
buffer.write_u16(1)?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
544
src/dns/packet/record.rs
Normal file
544
src/dns/packet/record.rs
Normal file
|
@ -0,0 +1,544 @@
|
||||||
|
use std::net::{Ipv4Addr, Ipv6Addr};
|
||||||
|
|
||||||
|
use rand::RngCore;
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use tracing::{trace, warn};
|
||||||
|
|
||||||
|
use super::{buffer::PacketBuffer, query::QueryType, Result};
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)]
|
||||||
|
pub enum DnsRecord {
|
||||||
|
UNKNOWN {
|
||||||
|
domain: String,
|
||||||
|
qtype: u16,
|
||||||
|
data_len: u16,
|
||||||
|
ttl: u32,
|
||||||
|
}, // 0
|
||||||
|
A {
|
||||||
|
domain: String,
|
||||||
|
addr: Ipv4Addr,
|
||||||
|
ttl: u32,
|
||||||
|
}, // 1
|
||||||
|
NS {
|
||||||
|
domain: String,
|
||||||
|
host: String,
|
||||||
|
ttl: u32,
|
||||||
|
}, // 2
|
||||||
|
CNAME {
|
||||||
|
domain: String,
|
||||||
|
host: String,
|
||||||
|
ttl: u32,
|
||||||
|
}, // 5
|
||||||
|
SOA {
|
||||||
|
domain: String,
|
||||||
|
mname: String,
|
||||||
|
nname: String,
|
||||||
|
serial: u32,
|
||||||
|
refresh: u32,
|
||||||
|
retry: u32,
|
||||||
|
expire: u32,
|
||||||
|
minimum: u32,
|
||||||
|
ttl: u32,
|
||||||
|
}, // 6
|
||||||
|
PTR {
|
||||||
|
domain: String,
|
||||||
|
pointer: String,
|
||||||
|
ttl: u32,
|
||||||
|
}, // 12
|
||||||
|
MX {
|
||||||
|
domain: String,
|
||||||
|
priority: u16,
|
||||||
|
host: String,
|
||||||
|
ttl: u32,
|
||||||
|
}, // 15
|
||||||
|
TXT {
|
||||||
|
domain: String,
|
||||||
|
text: Vec<String>,
|
||||||
|
ttl: u32,
|
||||||
|
}, //16
|
||||||
|
AAAA {
|
||||||
|
domain: String,
|
||||||
|
addr: Ipv6Addr,
|
||||||
|
ttl: u32,
|
||||||
|
}, // 28
|
||||||
|
SRV {
|
||||||
|
domain: String,
|
||||||
|
priority: u16,
|
||||||
|
weight: u16,
|
||||||
|
port: u16,
|
||||||
|
target: String,
|
||||||
|
ttl: u32,
|
||||||
|
}, // 33
|
||||||
|
CAA {
|
||||||
|
domain: String,
|
||||||
|
flags: u8,
|
||||||
|
length: u8,
|
||||||
|
tag: String,
|
||||||
|
value: String,
|
||||||
|
ttl: u32,
|
||||||
|
}, // 257
|
||||||
|
AR {
|
||||||
|
domain: String,
|
||||||
|
ttl: u32,
|
||||||
|
},
|
||||||
|
AAAAR {
|
||||||
|
domain: String,
|
||||||
|
ttl: u32,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
impl DnsRecord {
|
||||||
|
pub fn read(buffer: &mut PacketBuffer) -> Result<Self> {
|
||||||
|
let mut domain = String::new();
|
||||||
|
buffer.read_qname(&mut domain)?;
|
||||||
|
|
||||||
|
let qtype_num = buffer.read_u16()?;
|
||||||
|
let qtype = QueryType::from_num(qtype_num);
|
||||||
|
let _ = buffer.read_u16()?;
|
||||||
|
let ttl = buffer.read_u32()?;
|
||||||
|
let data_len = buffer.read_u16()?;
|
||||||
|
|
||||||
|
trace!("Reading DNS Record TYPE: {:?}", qtype);
|
||||||
|
|
||||||
|
let header_pos = buffer.pos();
|
||||||
|
|
||||||
|
match qtype {
|
||||||
|
QueryType::A => {
|
||||||
|
let raw_addr = buffer.read_u32()?;
|
||||||
|
let addr = Ipv4Addr::new(
|
||||||
|
((raw_addr >> 24) & 0xFF) as u8,
|
||||||
|
((raw_addr >> 16) & 0xFF) as u8,
|
||||||
|
((raw_addr >> 8) & 0xFF) as u8,
|
||||||
|
(raw_addr & 0xFF) as u8,
|
||||||
|
);
|
||||||
|
|
||||||
|
Ok(Self::A { domain, addr, ttl })
|
||||||
|
}
|
||||||
|
QueryType::AAAA => {
|
||||||
|
let raw_addr1 = buffer.read_u32()?;
|
||||||
|
let raw_addr2 = buffer.read_u32()?;
|
||||||
|
let raw_addr3 = buffer.read_u32()?;
|
||||||
|
let raw_addr4 = buffer.read_u32()?;
|
||||||
|
let addr = Ipv6Addr::new(
|
||||||
|
((raw_addr1 >> 16) & 0xFFFF) as u16,
|
||||||
|
(raw_addr1 & 0xFFFF) as u16,
|
||||||
|
((raw_addr2 >> 16) & 0xFFFF) as u16,
|
||||||
|
(raw_addr2 & 0xFFFF) as u16,
|
||||||
|
((raw_addr3 >> 16) & 0xFFFF) as u16,
|
||||||
|
(raw_addr3 & 0xFFFF) as u16,
|
||||||
|
((raw_addr4 >> 16) & 0xFFFF) as u16,
|
||||||
|
(raw_addr4 & 0xFFFF) as u16,
|
||||||
|
);
|
||||||
|
|
||||||
|
Ok(Self::AAAA { domain, addr, ttl })
|
||||||
|
}
|
||||||
|
QueryType::NS => {
|
||||||
|
let mut ns = String::new();
|
||||||
|
buffer.read_qname(&mut ns)?;
|
||||||
|
|
||||||
|
Ok(Self::NS {
|
||||||
|
domain,
|
||||||
|
host: ns,
|
||||||
|
ttl,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
QueryType::CNAME => {
|
||||||
|
let mut cname = String::new();
|
||||||
|
buffer.read_qname(&mut cname)?;
|
||||||
|
|
||||||
|
Ok(Self::CNAME {
|
||||||
|
domain,
|
||||||
|
host: cname,
|
||||||
|
ttl,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
QueryType::SOA => {
|
||||||
|
let mut mname = String::new();
|
||||||
|
buffer.read_qname(&mut mname)?;
|
||||||
|
|
||||||
|
let mut nname = String::new();
|
||||||
|
buffer.read_qname(&mut nname)?;
|
||||||
|
|
||||||
|
let serial = buffer.read_u32()?;
|
||||||
|
let refresh = buffer.read_u32()?;
|
||||||
|
let retry = buffer.read_u32()?;
|
||||||
|
let expire = buffer.read_u32()?;
|
||||||
|
let minimum = buffer.read_u32()?;
|
||||||
|
|
||||||
|
Ok(Self::SOA {
|
||||||
|
domain,
|
||||||
|
mname,
|
||||||
|
nname,
|
||||||
|
serial,
|
||||||
|
refresh,
|
||||||
|
retry,
|
||||||
|
expire,
|
||||||
|
minimum,
|
||||||
|
ttl,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
QueryType::PTR => {
|
||||||
|
let mut pointer = String::new();
|
||||||
|
buffer.read_qname(&mut pointer)?;
|
||||||
|
|
||||||
|
Ok(Self::PTR {
|
||||||
|
domain,
|
||||||
|
pointer,
|
||||||
|
ttl,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
QueryType::MX => {
|
||||||
|
let priority = buffer.read_u16()?;
|
||||||
|
let mut mx = String::new();
|
||||||
|
buffer.read_qname(&mut mx)?;
|
||||||
|
|
||||||
|
Ok(Self::MX {
|
||||||
|
domain,
|
||||||
|
priority,
|
||||||
|
host: mx,
|
||||||
|
ttl,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
QueryType::TXT => {
|
||||||
|
let mut text = Vec::new();
|
||||||
|
|
||||||
|
loop {
|
||||||
|
let mut s = String::new();
|
||||||
|
buffer.read_string(&mut s)?;
|
||||||
|
|
||||||
|
if s.len() == 0 {
|
||||||
|
break;
|
||||||
|
} else {
|
||||||
|
text.push(s);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(Self::TXT { domain, text, ttl })
|
||||||
|
}
|
||||||
|
QueryType::SRV => {
|
||||||
|
let priority = buffer.read_u16()?;
|
||||||
|
let weight = buffer.read_u16()?;
|
||||||
|
let port = buffer.read_u16()?;
|
||||||
|
|
||||||
|
let mut target = String::new();
|
||||||
|
buffer.read_qname(&mut target)?;
|
||||||
|
|
||||||
|
Ok(Self::SRV {
|
||||||
|
domain,
|
||||||
|
priority,
|
||||||
|
weight,
|
||||||
|
port,
|
||||||
|
target,
|
||||||
|
ttl,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
QueryType::CAA => {
|
||||||
|
let flags = buffer.read()?;
|
||||||
|
let length = buffer.read()?;
|
||||||
|
|
||||||
|
let mut tag = String::new();
|
||||||
|
buffer.read_string_n(&mut tag, length)?;
|
||||||
|
|
||||||
|
let value_len = (data_len as usize) + header_pos - buffer.pos;
|
||||||
|
let mut value = String::new();
|
||||||
|
buffer.read_string_n(&mut value, value_len as u8)?;
|
||||||
|
|
||||||
|
Ok(Self::CAA {
|
||||||
|
domain,
|
||||||
|
flags,
|
||||||
|
length,
|
||||||
|
tag,
|
||||||
|
value,
|
||||||
|
ttl,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
QueryType::UNKNOWN(_) | _ => {
|
||||||
|
buffer.step(data_len as usize)?;
|
||||||
|
|
||||||
|
Ok(Self::UNKNOWN {
|
||||||
|
domain,
|
||||||
|
qtype: qtype_num,
|
||||||
|
data_len,
|
||||||
|
ttl,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn write(&self, buffer: &mut PacketBuffer) -> Result<usize> {
|
||||||
|
let start_pos = buffer.pos();
|
||||||
|
|
||||||
|
trace!("Writing DNS Record {:?}", self);
|
||||||
|
|
||||||
|
match *self {
|
||||||
|
Self::A {
|
||||||
|
ref domain,
|
||||||
|
ref addr,
|
||||||
|
ttl,
|
||||||
|
} => {
|
||||||
|
buffer.write_qname(domain)?;
|
||||||
|
buffer.write_u16(QueryType::A.to_num())?;
|
||||||
|
buffer.write_u16(1)?;
|
||||||
|
buffer.write_u32(ttl)?;
|
||||||
|
buffer.write_u16(4)?;
|
||||||
|
|
||||||
|
let octets = addr.octets();
|
||||||
|
buffer.write_u8(octets[0])?;
|
||||||
|
buffer.write_u8(octets[1])?;
|
||||||
|
buffer.write_u8(octets[2])?;
|
||||||
|
buffer.write_u8(octets[3])?;
|
||||||
|
}
|
||||||
|
Self::NS {
|
||||||
|
ref domain,
|
||||||
|
ref host,
|
||||||
|
ttl,
|
||||||
|
} => {
|
||||||
|
buffer.write_qname(domain)?;
|
||||||
|
buffer.write_u16(QueryType::NS.to_num())?;
|
||||||
|
buffer.write_u16(1)?;
|
||||||
|
buffer.write_u32(ttl)?;
|
||||||
|
|
||||||
|
let pos = buffer.pos();
|
||||||
|
buffer.write_u16(0)?;
|
||||||
|
|
||||||
|
buffer.write_qname(host)?;
|
||||||
|
|
||||||
|
let size = buffer.pos() - (pos + 2);
|
||||||
|
buffer.set_u16(pos, size as u16)?;
|
||||||
|
}
|
||||||
|
Self::CNAME {
|
||||||
|
ref domain,
|
||||||
|
ref host,
|
||||||
|
ttl,
|
||||||
|
} => {
|
||||||
|
buffer.write_qname(domain)?;
|
||||||
|
buffer.write_u16(QueryType::CNAME.to_num())?;
|
||||||
|
buffer.write_u16(1)?;
|
||||||
|
buffer.write_u32(ttl)?;
|
||||||
|
|
||||||
|
let pos = buffer.pos();
|
||||||
|
buffer.write_u16(0)?;
|
||||||
|
|
||||||
|
buffer.write_qname(host)?;
|
||||||
|
|
||||||
|
let size = buffer.pos() - (pos + 2);
|
||||||
|
buffer.set_u16(pos, size as u16)?;
|
||||||
|
}
|
||||||
|
Self::SOA {
|
||||||
|
ref domain,
|
||||||
|
ref mname,
|
||||||
|
ref nname,
|
||||||
|
serial,
|
||||||
|
refresh,
|
||||||
|
retry,
|
||||||
|
expire,
|
||||||
|
minimum,
|
||||||
|
ttl,
|
||||||
|
} => {
|
||||||
|
buffer.write_qname(domain)?;
|
||||||
|
buffer.write_u16(QueryType::SOA.to_num())?;
|
||||||
|
buffer.write_u16(1)?;
|
||||||
|
buffer.write_u32(ttl)?;
|
||||||
|
|
||||||
|
let pos = buffer.pos();
|
||||||
|
buffer.write_u16(0)?;
|
||||||
|
|
||||||
|
buffer.write_qname(mname)?;
|
||||||
|
buffer.write_qname(nname)?;
|
||||||
|
buffer.write_u32(serial)?;
|
||||||
|
buffer.write_u32(refresh)?;
|
||||||
|
buffer.write_u32(retry)?;
|
||||||
|
buffer.write_u32(expire)?;
|
||||||
|
buffer.write_u32(minimum)?;
|
||||||
|
|
||||||
|
let size = buffer.pos() - (pos + 2);
|
||||||
|
buffer.set_u16(pos, size as u16)?;
|
||||||
|
}
|
||||||
|
Self::PTR {
|
||||||
|
ref domain,
|
||||||
|
ref pointer,
|
||||||
|
ttl,
|
||||||
|
} => {
|
||||||
|
buffer.write_qname(domain)?;
|
||||||
|
buffer.write_u16(QueryType::NS.to_num())?;
|
||||||
|
buffer.write_u16(1)?;
|
||||||
|
buffer.write_u32(ttl)?;
|
||||||
|
|
||||||
|
let pos = buffer.pos();
|
||||||
|
buffer.write_u16(0)?;
|
||||||
|
|
||||||
|
buffer.write_qname(&pointer)?;
|
||||||
|
|
||||||
|
let size = buffer.pos() - (pos + 2);
|
||||||
|
buffer.set_u16(pos, size as u16)?;
|
||||||
|
}
|
||||||
|
Self::MX {
|
||||||
|
ref domain,
|
||||||
|
priority,
|
||||||
|
ref host,
|
||||||
|
ttl,
|
||||||
|
} => {
|
||||||
|
buffer.write_qname(domain)?;
|
||||||
|
buffer.write_u16(QueryType::MX.to_num())?;
|
||||||
|
buffer.write_u16(1)?;
|
||||||
|
buffer.write_u32(ttl)?;
|
||||||
|
|
||||||
|
let pos = buffer.pos();
|
||||||
|
buffer.write_u16(0)?;
|
||||||
|
|
||||||
|
buffer.write_u16(priority)?;
|
||||||
|
buffer.write_qname(host)?;
|
||||||
|
|
||||||
|
let size = buffer.pos() - (pos + 2);
|
||||||
|
buffer.set_u16(pos, size as u16)?;
|
||||||
|
}
|
||||||
|
Self::TXT {
|
||||||
|
ref domain,
|
||||||
|
ref text,
|
||||||
|
ttl,
|
||||||
|
} => {
|
||||||
|
buffer.write_qname(&domain)?;
|
||||||
|
buffer.write_u16(QueryType::TXT.to_num())?;
|
||||||
|
buffer.write_u16(1)?;
|
||||||
|
buffer.write_u32(ttl)?;
|
||||||
|
|
||||||
|
let pos = buffer.pos();
|
||||||
|
buffer.write_u16(0)?;
|
||||||
|
|
||||||
|
if text.is_empty() {
|
||||||
|
return Ok(buffer.pos() - start_pos);
|
||||||
|
}
|
||||||
|
|
||||||
|
for s in text {
|
||||||
|
buffer.write_u8(s.len() as u8)?;
|
||||||
|
buffer.write_string(&s)?;
|
||||||
|
}
|
||||||
|
let size = buffer.pos() - (pos + 2);
|
||||||
|
buffer.set_u16(pos, size as u16)?;
|
||||||
|
}
|
||||||
|
Self::AAAA {
|
||||||
|
ref domain,
|
||||||
|
ref addr,
|
||||||
|
ttl,
|
||||||
|
} => {
|
||||||
|
buffer.write_qname(domain)?;
|
||||||
|
buffer.write_u16(QueryType::AAAA.to_num())?;
|
||||||
|
buffer.write_u16(1)?;
|
||||||
|
buffer.write_u32(ttl)?;
|
||||||
|
buffer.write_u16(16)?;
|
||||||
|
|
||||||
|
for octet in &addr.segments() {
|
||||||
|
buffer.write_u16(*octet)?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Self::SRV {
|
||||||
|
ref domain,
|
||||||
|
priority,
|
||||||
|
weight,
|
||||||
|
port,
|
||||||
|
ref target,
|
||||||
|
ttl,
|
||||||
|
} => {
|
||||||
|
buffer.write_qname(domain)?;
|
||||||
|
buffer.write_u16(QueryType::SRV.to_num())?;
|
||||||
|
buffer.write_u16(1)?;
|
||||||
|
buffer.write_u32(ttl)?;
|
||||||
|
|
||||||
|
let pos = buffer.pos();
|
||||||
|
buffer.write_u16(0)?;
|
||||||
|
|
||||||
|
buffer.write_u16(priority)?;
|
||||||
|
buffer.write_u16(weight)?;
|
||||||
|
buffer.write_u16(port)?;
|
||||||
|
buffer.write_qname(target)?;
|
||||||
|
|
||||||
|
let size = buffer.pos() - (pos + 2);
|
||||||
|
buffer.set_u16(pos, size as u16)?;
|
||||||
|
}
|
||||||
|
Self::CAA {
|
||||||
|
ref domain,
|
||||||
|
flags,
|
||||||
|
length,
|
||||||
|
ref tag,
|
||||||
|
ref value,
|
||||||
|
ttl,
|
||||||
|
} => {
|
||||||
|
buffer.write_qname(domain)?;
|
||||||
|
buffer.write_u16(QueryType::CAA.to_num())?;
|
||||||
|
buffer.write_u16(1)?;
|
||||||
|
buffer.write_u32(ttl)?;
|
||||||
|
|
||||||
|
let pos = buffer.pos();
|
||||||
|
buffer.write_u16(0)?;
|
||||||
|
|
||||||
|
buffer.write_u8(flags)?;
|
||||||
|
buffer.write_u8(length)?;
|
||||||
|
buffer.write_string(tag)?;
|
||||||
|
buffer.write_string(value)?;
|
||||||
|
|
||||||
|
let size = buffer.pos() - (pos + 2);
|
||||||
|
buffer.set_u16(pos, size as u16)?;
|
||||||
|
}
|
||||||
|
Self::AR { ref domain, ttl } => {
|
||||||
|
buffer.write_qname(domain)?;
|
||||||
|
buffer.write_u16(QueryType::A.to_num())?;
|
||||||
|
buffer.write_u16(1)?;
|
||||||
|
buffer.write_u32(ttl)?;
|
||||||
|
buffer.write_u16(4)?;
|
||||||
|
|
||||||
|
let mut rand = rand::thread_rng();
|
||||||
|
buffer.write_u32(rand.next_u32())?;
|
||||||
|
}
|
||||||
|
Self::AAAAR { ref domain, ttl } => {
|
||||||
|
buffer.write_qname(domain)?;
|
||||||
|
buffer.write_u16(QueryType::A.to_num())?;
|
||||||
|
buffer.write_u16(1)?;
|
||||||
|
buffer.write_u32(ttl)?;
|
||||||
|
buffer.write_u16(4)?;
|
||||||
|
|
||||||
|
let mut rand = rand::thread_rng();
|
||||||
|
buffer.write_u32(rand.next_u32())?;
|
||||||
|
buffer.write_u32(rand.next_u32())?;
|
||||||
|
buffer.write_u32(rand.next_u32())?;
|
||||||
|
buffer.write_u32(rand.next_u32())?;
|
||||||
|
}
|
||||||
|
Self::UNKNOWN { .. } => {
|
||||||
|
warn!("Skipping record: {self:?}");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(buffer.pos() - start_pos)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_domain(&self) -> String {
|
||||||
|
self.get_shared_domain().0
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_qtype(&self) -> QueryType {
|
||||||
|
self.get_shared_domain().1
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_ttl(&self) -> u32 {
|
||||||
|
self.get_shared_domain().2
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_shared_domain(&self) -> (String, QueryType, u32) {
|
||||||
|
match self {
|
||||||
|
DnsRecord::UNKNOWN {
|
||||||
|
domain, ttl, qtype, ..
|
||||||
|
} => (domain.clone(), QueryType::UNKNOWN(*qtype), *ttl),
|
||||||
|
DnsRecord::AAAA { domain, ttl, .. } => (domain.clone(), QueryType::AAAA, *ttl),
|
||||||
|
DnsRecord::A { domain, ttl, .. } => (domain.clone(), QueryType::A, *ttl),
|
||||||
|
DnsRecord::NS { domain, ttl, .. } => (domain.clone(), QueryType::NS, *ttl),
|
||||||
|
DnsRecord::CNAME { domain, ttl, .. } => (domain.clone(), QueryType::CNAME, *ttl),
|
||||||
|
DnsRecord::SOA { domain, ttl, .. } => (domain.clone(), QueryType::SOA, *ttl),
|
||||||
|
DnsRecord::PTR { domain, ttl, .. } => (domain.clone(), QueryType::PTR, *ttl),
|
||||||
|
DnsRecord::MX { domain, ttl, .. } => (domain.clone(), QueryType::MX, *ttl),
|
||||||
|
DnsRecord::TXT { domain, ttl, .. } => (domain.clone(), QueryType::TXT, *ttl),
|
||||||
|
DnsRecord::SRV { domain, ttl, .. } => (domain.clone(), QueryType::SRV, *ttl),
|
||||||
|
DnsRecord::CAA { domain, ttl, .. } => (domain.clone(), QueryType::CAA, *ttl),
|
||||||
|
DnsRecord::AR { domain, ttl, .. } => (domain.clone(), QueryType::AR, *ttl),
|
||||||
|
DnsRecord::AAAAR { domain, ttl, .. } => (domain.clone(), QueryType::AAAAR, *ttl),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
22
src/dns/packet/result.rs
Normal file
22
src/dns/packet/result.rs
Normal file
|
@ -0,0 +1,22 @@
|
||||||
|
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
|
||||||
|
pub enum ResultCode {
|
||||||
|
NOERROR = 0,
|
||||||
|
FORMERR = 1,
|
||||||
|
SERVFAIL = 2,
|
||||||
|
NXDOMAIN = 3,
|
||||||
|
NOTIMP = 4,
|
||||||
|
REFUSED = 5,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ResultCode {
|
||||||
|
pub fn from_num(num: u8) -> Self {
|
||||||
|
match num {
|
||||||
|
1 => Self::FORMERR,
|
||||||
|
2 => Self::SERVFAIL,
|
||||||
|
3 => Self::NXDOMAIN,
|
||||||
|
4 => Self::NOTIMP,
|
||||||
|
5 => Self::REFUSED,
|
||||||
|
0 | _ => Self::NOERROR,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
230
src/dns/resolver.rs
Normal file
230
src/dns/resolver.rs
Normal file
|
@ -0,0 +1,230 @@
|
||||||
|
use super::binding::Connection;
|
||||||
|
use super::packet::{query::QueryType, question::DnsQuestion, result::ResultCode, Packet};
|
||||||
|
use crate::Result;
|
||||||
|
use crate::{config::Config, database::Database, get_time};
|
||||||
|
use async_recursion::async_recursion;
|
||||||
|
use moka::future::Cache;
|
||||||
|
use std::{net::IpAddr, sync::Arc, time::Duration};
|
||||||
|
use tracing::{error, trace};
|
||||||
|
|
||||||
|
pub struct Resolver {
|
||||||
|
request_id: u16,
|
||||||
|
connection: Connection,
|
||||||
|
config: Arc<Config>,
|
||||||
|
database: Arc<Database>,
|
||||||
|
cache: Cache<DnsQuestion, (Packet, u64)>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Resolver {
|
||||||
|
pub fn new(
|
||||||
|
request_id: u16,
|
||||||
|
connection: Connection,
|
||||||
|
config: Arc<Config>,
|
||||||
|
database: Arc<Database>,
|
||||||
|
cache: Cache<DnsQuestion, (Packet, u64)>,
|
||||||
|
) -> Self {
|
||||||
|
Self {
|
||||||
|
request_id,
|
||||||
|
connection,
|
||||||
|
config,
|
||||||
|
database,
|
||||||
|
cache,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn lookup_database(&self, question: &DnsQuestion) -> Option<Packet> {
|
||||||
|
let records = match self
|
||||||
|
.database
|
||||||
|
.get_records(&question.name, question.qtype)
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
Ok(record) => record,
|
||||||
|
Err(err) => {
|
||||||
|
error!("{err}");
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
if records.is_empty() {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut packet = Packet::new();
|
||||||
|
|
||||||
|
packet.header.id = self.request_id;
|
||||||
|
packet.header.questions = 1;
|
||||||
|
packet.header.answers = records.len() as u16;
|
||||||
|
packet.header.recursion_desired = true;
|
||||||
|
packet
|
||||||
|
.questions
|
||||||
|
.push(DnsQuestion::new(question.name.to_string(), question.qtype));
|
||||||
|
|
||||||
|
for record in records {
|
||||||
|
packet.answers.push(record);
|
||||||
|
}
|
||||||
|
|
||||||
|
trace!(
|
||||||
|
"Found stored value for {:?} {}",
|
||||||
|
question.qtype,
|
||||||
|
question.name
|
||||||
|
);
|
||||||
|
|
||||||
|
Some(packet)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn lookup_cache(&self, question: &DnsQuestion) -> Option<Packet> {
|
||||||
|
let Some((packet, date)) = self.cache.get(&question) else {
|
||||||
|
return None
|
||||||
|
};
|
||||||
|
|
||||||
|
let now = get_time();
|
||||||
|
let diff = Duration::from_millis(now - date).as_secs() as u32;
|
||||||
|
|
||||||
|
for answer in &packet.answers {
|
||||||
|
let ttl = answer.get_ttl();
|
||||||
|
if diff > ttl {
|
||||||
|
self.cache.invalidate(&question).await;
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
trace!(
|
||||||
|
"Found cached value for {:?} {}",
|
||||||
|
question.qtype,
|
||||||
|
question.name
|
||||||
|
);
|
||||||
|
|
||||||
|
Some(packet)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn lookup_fallback(&self, question: &DnsQuestion, server: (IpAddr, u16)) -> Packet {
|
||||||
|
let mut packet = Packet::new();
|
||||||
|
|
||||||
|
packet.header.id = self.request_id;
|
||||||
|
packet.header.questions = 1;
|
||||||
|
packet.header.recursion_desired = true;
|
||||||
|
packet
|
||||||
|
.questions
|
||||||
|
.push(DnsQuestion::new(question.name.to_string(), question.qtype));
|
||||||
|
|
||||||
|
let packet = match self.connection.request_packet(packet, server).await {
|
||||||
|
Ok(packet) => packet,
|
||||||
|
Err(e) => {
|
||||||
|
error!("Failed to complete nameserver request: {e}");
|
||||||
|
let mut packet = Packet::new();
|
||||||
|
packet.header.rescode = ResultCode::SERVFAIL;
|
||||||
|
packet
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
packet
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn lookup(&self, question: &DnsQuestion, server: (IpAddr, u16)) -> Packet {
|
||||||
|
if let Some(packet) = self.lookup_cache(question).await {
|
||||||
|
return packet;
|
||||||
|
};
|
||||||
|
|
||||||
|
if let Some(packet) = self.lookup_database(question).await {
|
||||||
|
return packet;
|
||||||
|
};
|
||||||
|
|
||||||
|
trace!(
|
||||||
|
"Attempting lookup of {:?} {} with ns {}",
|
||||||
|
question.qtype,
|
||||||
|
question.name,
|
||||||
|
server.0
|
||||||
|
);
|
||||||
|
|
||||||
|
self.lookup_fallback(question, server).await
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_recursion]
|
||||||
|
async fn recursive_lookup(&mut self, qname: &str, qtype: QueryType) -> Packet {
|
||||||
|
let question = DnsQuestion::new(qname.to_string(), qtype);
|
||||||
|
let mut ns = self.config.dns_fallback.clone();
|
||||||
|
|
||||||
|
loop {
|
||||||
|
let ns_copy = ns;
|
||||||
|
|
||||||
|
let server = (ns_copy, 53);
|
||||||
|
let response = self.lookup(&question, server).await;
|
||||||
|
|
||||||
|
if !response.answers.is_empty() && response.header.rescode == ResultCode::NOERROR {
|
||||||
|
self.cache
|
||||||
|
.insert(question, (response.clone(), get_time()))
|
||||||
|
.await;
|
||||||
|
return response;
|
||||||
|
}
|
||||||
|
|
||||||
|
if response.header.rescode == ResultCode::NXDOMAIN {
|
||||||
|
self.cache
|
||||||
|
.insert(question, (response.clone(), get_time()))
|
||||||
|
.await;
|
||||||
|
return response;
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(new_ns) = response.get_resolved_ns(qname) {
|
||||||
|
ns = new_ns;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
let new_ns_name = match response.get_unresolved_ns(qname) {
|
||||||
|
Some(x) => x,
|
||||||
|
None => {
|
||||||
|
self.cache
|
||||||
|
.insert(question, (response.clone(), get_time()))
|
||||||
|
.await;
|
||||||
|
return response;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let recursive_response = self.recursive_lookup(new_ns_name, QueryType::A).await;
|
||||||
|
|
||||||
|
if let Some(new_ns) = recursive_response.get_random_a() {
|
||||||
|
ns = new_ns;
|
||||||
|
} else {
|
||||||
|
self.cache
|
||||||
|
.insert(question, (response.clone(), get_time()))
|
||||||
|
.await;
|
||||||
|
return response;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn handle_query(mut self) -> Result<()> {
|
||||||
|
let mut request = self.connection.read_packet().await?;
|
||||||
|
|
||||||
|
let mut packet = Packet::new();
|
||||||
|
packet.header.id = request.header.id;
|
||||||
|
packet.header.recursion_desired = true;
|
||||||
|
packet.header.recursion_available = true;
|
||||||
|
packet.header.response = true;
|
||||||
|
|
||||||
|
if let Some(question) = request.questions.pop() {
|
||||||
|
trace!("Received query: {question:?}");
|
||||||
|
|
||||||
|
let result = self.recursive_lookup(&question.name, question.qtype).await;
|
||||||
|
packet.questions.push(question.clone());
|
||||||
|
packet.header.rescode = result.header.rescode;
|
||||||
|
|
||||||
|
for rec in result.answers {
|
||||||
|
trace!("Answer: {rec:?}");
|
||||||
|
packet.answers.push(rec);
|
||||||
|
}
|
||||||
|
for rec in result.authorities {
|
||||||
|
trace!("Authority: {rec:?}");
|
||||||
|
packet.authorities.push(rec);
|
||||||
|
}
|
||||||
|
for rec in result.resources {
|
||||||
|
trace!("Resource: {rec:?}");
|
||||||
|
packet.resources.push(rec);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
packet.header.rescode = ResultCode::FORMERR;
|
||||||
|
}
|
||||||
|
|
||||||
|
self.connection.write_packet(packet).await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
85
src/dns/server.rs
Normal file
85
src/dns/server.rs
Normal file
|
@ -0,0 +1,85 @@
|
||||||
|
use super::{
|
||||||
|
binding::Binding,
|
||||||
|
packet::{question::DnsQuestion, Packet},
|
||||||
|
resolver::Resolver,
|
||||||
|
};
|
||||||
|
use crate::{config::Config, database::Database, Result};
|
||||||
|
use moka::future::Cache;
|
||||||
|
use std::{net::SocketAddr, sync::Arc, time::Duration};
|
||||||
|
use tokio::task::JoinHandle;
|
||||||
|
use tracing::{error, info};
|
||||||
|
|
||||||
|
pub struct DnsServer {
|
||||||
|
addr: SocketAddr,
|
||||||
|
config: Arc<Config>,
|
||||||
|
database: Arc<Database>,
|
||||||
|
cache: Cache<DnsQuestion, (Packet, u64)>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl DnsServer {
|
||||||
|
pub async fn new(config: Config, database: Database) -> Result<Self> {
|
||||||
|
let addr = format!("[::]:{}", config.dns_port).parse::<SocketAddr>()?;
|
||||||
|
let cache = Cache::builder()
|
||||||
|
.time_to_live(Duration::from_secs(60 * 60))
|
||||||
|
.max_capacity(config.dns_cache_size)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
info!("Created DNS cache with size of {}", config.dns_cache_size);
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
addr,
|
||||||
|
config: Arc::new(config),
|
||||||
|
database: Arc::new(database),
|
||||||
|
cache,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn run(&self) -> Result<(JoinHandle<()>, JoinHandle<()>)> {
|
||||||
|
let tcp = Binding::tcp(self.addr).await?;
|
||||||
|
let tcp_handle = self.listen(tcp);
|
||||||
|
|
||||||
|
let udp = Binding::udp(self.addr).await?;
|
||||||
|
let udp_handle = self.listen(udp);
|
||||||
|
|
||||||
|
info!(
|
||||||
|
"Fallback DNS Server is set to: {:?}",
|
||||||
|
self.config.dns_fallback
|
||||||
|
);
|
||||||
|
info!(
|
||||||
|
"Listening for TCP and UDP traffic on [::]:{}",
|
||||||
|
self.config.dns_port
|
||||||
|
);
|
||||||
|
|
||||||
|
Ok((udp_handle, tcp_handle))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn listen(&self, mut binding: Binding) -> JoinHandle<()> {
|
||||||
|
let config = self.config.clone();
|
||||||
|
let database = self.database.clone();
|
||||||
|
let cache = self.cache.clone();
|
||||||
|
tokio::spawn(async move {
|
||||||
|
let mut id = 0;
|
||||||
|
loop {
|
||||||
|
let Ok(connection) = binding.connect().await else { continue };
|
||||||
|
info!("Received request on {}", binding.name());
|
||||||
|
|
||||||
|
let resolver = Resolver::new(
|
||||||
|
id,
|
||||||
|
connection,
|
||||||
|
config.clone(),
|
||||||
|
database.clone(),
|
||||||
|
cache.clone(),
|
||||||
|
);
|
||||||
|
|
||||||
|
let name = binding.name().to_string();
|
||||||
|
tokio::spawn(async move {
|
||||||
|
if let Err(err) = resolver.handle_query().await {
|
||||||
|
error!("{} request {} failed: {:?}", name, id, err);
|
||||||
|
};
|
||||||
|
});
|
||||||
|
|
||||||
|
id += 1;
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
525
src/io/config.c
525
src/io/config.c
|
@ -1,525 +0,0 @@
|
||||||
#include <errno.h>
|
|
||||||
#include <stdio.h>
|
|
||||||
#include <string.h>
|
|
||||||
#include <stdlib.h>
|
|
||||||
|
|
||||||
#include "config.h"
|
|
||||||
#include "log.h"
|
|
||||||
#include "map.h"
|
|
||||||
|
|
||||||
#define MAX_LEN 1024
|
|
||||||
#define BUF(name) char name[MAX_LEN]
|
|
||||||
|
|
||||||
static int line = 0;
|
|
||||||
|
|
||||||
static bool get_line(FILE* file, const BUF(buf)) {
|
|
||||||
line++;
|
|
||||||
return fgets((char*) buf, MAX_LEN, file) != NULL;
|
|
||||||
}
|
|
||||||
|
|
||||||
static bool is_whitespace(const char* buf) {
|
|
||||||
int i = 0;
|
|
||||||
char c;
|
|
||||||
while(c = buf[i], 1) {
|
|
||||||
if (c == '\n' || c == '\0') return true;
|
|
||||||
if (c != ' ' && c != '\n') return false;
|
|
||||||
i++;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
static bool get_words(char* buf, char** words, int count) {
|
|
||||||
int last = 0;
|
|
||||||
int offset = 0;
|
|
||||||
int i = 0;
|
|
||||||
|
|
||||||
while(1) {
|
|
||||||
char c;
|
|
||||||
while(1) {
|
|
||||||
if (offset == MAX_LEN) return false;
|
|
||||||
c = buf[offset];
|
|
||||||
|
|
||||||
if (c == '\0' || c == '\n') {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (c == ' ' && i + 1 != count) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
offset++;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (offset - last < 1) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
words[i] = buf + last;
|
|
||||||
buf[offset] = '\0';
|
|
||||||
offset++;
|
|
||||||
last = offset;
|
|
||||||
|
|
||||||
if (c == '\0' || c == '\n') {
|
|
||||||
break;
|
|
||||||
} else if (i + 1 == count) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
i++;
|
|
||||||
}
|
|
||||||
return i + 1 == count;
|
|
||||||
}
|
|
||||||
|
|
||||||
static bool get_int(const char* word, uint32_t* i) {
|
|
||||||
char* end;
|
|
||||||
uint32_t res = (uint32_t) strtol(word, &end, 10);
|
|
||||||
|
|
||||||
if (*end == '\0') {
|
|
||||||
*i = res;
|
|
||||||
return true;
|
|
||||||
} else {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
static bool config_read_qtype(const char* qstr, RecordType* qtype) {
|
|
||||||
if (strcmp(qstr, "A") == 0) {
|
|
||||||
*qtype = A;
|
|
||||||
return true;
|
|
||||||
} else if (strcmp(qstr, "NS") == 0) {
|
|
||||||
*qtype = NS;
|
|
||||||
return true;
|
|
||||||
} else if (strcmp(qstr, "CNAME") == 0) {
|
|
||||||
*qtype = CNAME;
|
|
||||||
return true;
|
|
||||||
} else if (strcmp(qstr, "SOA") == 0) {
|
|
||||||
*qtype = SOA;
|
|
||||||
return true;
|
|
||||||
} else if (strcmp(qstr, "PTR") == 0) {
|
|
||||||
*qtype = PTR;
|
|
||||||
return true;
|
|
||||||
} else if (strcmp(qstr, "MX") == 0) {
|
|
||||||
*qtype = MX;
|
|
||||||
return true;
|
|
||||||
} else if (strcmp(qstr, "TXT") == 0) {
|
|
||||||
*qtype = TXT;
|
|
||||||
return true;
|
|
||||||
} else if (strcmp(qstr, "AAAA") == 0) {
|
|
||||||
*qtype = AAAA;
|
|
||||||
return true;
|
|
||||||
} else if (strcmp(qstr, "SRV") == 0) {
|
|
||||||
*qtype = SRV;
|
|
||||||
return true;
|
|
||||||
} else if (strcmp(qstr, "CAA") == 0) {
|
|
||||||
*qtype = CAA;
|
|
||||||
return true;
|
|
||||||
} else if (strcmp(qstr, "CMD") == 0) {
|
|
||||||
*qtype = CMD;
|
|
||||||
return true;
|
|
||||||
} else if(strcmp(qstr, "AR") == 0) {
|
|
||||||
*qtype = AR;
|
|
||||||
return true;
|
|
||||||
} else if(strcmp(qstr, "AAAAR") == 0) {
|
|
||||||
*qtype = AAAAR;
|
|
||||||
return true;
|
|
||||||
} else {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
static bool config_read_class(const char* cstr, uint16_t* class) {
|
|
||||||
if (strcmp(cstr, "IN") == 0) {
|
|
||||||
*class = 1;
|
|
||||||
return true;
|
|
||||||
} else if (strcmp(cstr, "CH") == 0) {
|
|
||||||
*class = 3;
|
|
||||||
return true;
|
|
||||||
} else if (strcmp(cstr, "HS") == 0) {
|
|
||||||
*class = 4;
|
|
||||||
return true;
|
|
||||||
} else {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Format QTYPE CLASS DOMAIN: A IN google.com
|
|
||||||
static bool config_read_question(FILE* file, Question* question) {
|
|
||||||
BUF(buf);
|
|
||||||
if (!get_line(file, buf)) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (is_whitespace(buf)) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
char* words[3];
|
|
||||||
if (!get_words(&buf[0], &words[0], 3)) {
|
|
||||||
WARN("Invalid question at line %d", line);
|
|
||||||
return false;
|
|
||||||
};
|
|
||||||
|
|
||||||
uint16_t class;
|
|
||||||
if (!config_read_class(words[0], &class)) {
|
|
||||||
WARN("Invalid question class at line %d", line);
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
RecordType qtype;
|
|
||||||
if (!config_read_qtype(words[1], &qtype)) {
|
|
||||||
WARN("Invalid question qtype at line %d", line);
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
size_t domain_len = strlen(words[2]);
|
|
||||||
question->cls = class;
|
|
||||||
question->qtype = qtype;
|
|
||||||
question->domain = malloc(domain_len + 1);
|
|
||||||
question->domain[0] = domain_len;
|
|
||||||
memcpy(question->domain + 1, words[2], domain_len);
|
|
||||||
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
static void copy_str(char* from, uint8_t** to) {
|
|
||||||
size_t len = strlen(from);
|
|
||||||
if (len > 255) {
|
|
||||||
len = 255;
|
|
||||||
}
|
|
||||||
|
|
||||||
uint8_t* new = malloc(len + 1);
|
|
||||||
new[0] = len;
|
|
||||||
memcpy(new + 1, from, len);
|
|
||||||
*to = new;
|
|
||||||
}
|
|
||||||
|
|
||||||
static bool config_read_a_record(char* data, ARecord* record) {
|
|
||||||
sscanf(data, "%hhu.%hhu.%hhu.%hhu",
|
|
||||||
&record->addr[0],
|
|
||||||
&record->addr[1],
|
|
||||||
&record->addr[2],
|
|
||||||
&record->addr[3]
|
|
||||||
);
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
static bool config_read_ns_record(char* data, NSRecord* record) {
|
|
||||||
copy_str(data, &record->host);
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
static bool config_read_cname_record(char* data, CNAMERecord* record) {
|
|
||||||
copy_str(data, &record->host);
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
static bool config_read_soa_record(char* data, SOARecord* record) {
|
|
||||||
char* words[7];
|
|
||||||
if (!get_words(&data[0], &words[0], 7)) {
|
|
||||||
WARN("Invalid SOA record data at line %d", line);
|
|
||||||
record->mname = NULL;
|
|
||||||
record->nname = NULL;
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
copy_str(words[0], &record->mname);
|
|
||||||
copy_str(words[1], &record->nname);
|
|
||||||
|
|
||||||
if (!get_int(words[2], &record->serial)) {
|
|
||||||
WARN("Invalid SOA record data at line %d", line);
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!get_int(words[3], &record->refresh)) {
|
|
||||||
WARN("Invalid SOA record data at line %d", line);
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!get_int(words[4], &record->retry)) {
|
|
||||||
WARN("Invalid SOA record data at line %d", line);
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!get_int(words[5], &record->expire)) {
|
|
||||||
WARN("Invalid SOA record data at line %d", line);
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!get_int(words[6], &record->minimum)) {
|
|
||||||
WARN("Invalid SOA record data at line %d", line);
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
static bool config_read_ptr_record(char* data, PTRRecord* record) {
|
|
||||||
copy_str(data, &record->pointer);
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
static bool config_read_mx_record(char* data, MXRecord* record) {
|
|
||||||
char* words[2];
|
|
||||||
if (!get_words(&data[0], &words[0], 2)) {
|
|
||||||
WARN("Invalid MX record data at line %d", line);
|
|
||||||
record->host = NULL;
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
copy_str(words[1], &record->host);
|
|
||||||
|
|
||||||
uint32_t priority;
|
|
||||||
if (!get_int(words[0], &priority)) {
|
|
||||||
WARN("Invalid MX record data at line %d", line);
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
record->priority = (uint16_t) priority;
|
|
||||||
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
static bool config_read_txt_record(char* data, TXTRecord* record) {
|
|
||||||
int len = strlen(data);
|
|
||||||
uint8_t count = ((uint8_t)len + 254) / 255;
|
|
||||||
record->len = count;
|
|
||||||
record->text = malloc(sizeof(uint8_t*) * count);
|
|
||||||
|
|
||||||
for (uint8_t i = 0; i < count; i++) {
|
|
||||||
uint32_t offset = count * 255;
|
|
||||||
uint32_t part_len = len - offset;
|
|
||||||
if (part_len > 255) part_len = 255;
|
|
||||||
|
|
||||||
uint8_t* part = malloc(part_len + 1);
|
|
||||||
part[0] = part_len;
|
|
||||||
memcpy(part + 1, data + offset, part_len);
|
|
||||||
|
|
||||||
record->text[i] = part;
|
|
||||||
}
|
|
||||||
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
static bool config_read_aaaa_record(char* data, AAAARecord* record) {
|
|
||||||
for(int i = 0; i < 8; i++) {
|
|
||||||
if (sscanf(data, "%02hhx%02hhx:",
|
|
||||||
&record->addr[i*2 + 0],
|
|
||||||
&record->addr[i*2 + 1]
|
|
||||||
) == EOF) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
static bool config_read_srv_record(char* data, SRVRecord* record) {
|
|
||||||
char* words[4];
|
|
||||||
if (!get_words(&data[0], &words[0], 4)) {
|
|
||||||
WARN("Invalid SRV record data at line %d", line);
|
|
||||||
record->target = NULL;
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
copy_str(words[3], &record->target);
|
|
||||||
|
|
||||||
uint32_t priority;
|
|
||||||
if (!get_int(words[0], &priority)) {
|
|
||||||
WARN("Invalid SRV record data at line %d", line);
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
record->priority = (uint16_t) priority;
|
|
||||||
|
|
||||||
uint32_t weight;
|
|
||||||
if (!get_int(words[1], &weight)) {
|
|
||||||
WARN("Invalid SRV record data at line %d", line);
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
record->weight = (uint16_t) weight;
|
|
||||||
|
|
||||||
uint32_t port;
|
|
||||||
if (!get_int(words[2], &port)) {
|
|
||||||
WARN("Invalid SRV record data at line %d", line);
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
record->port = (uint16_t) port;
|
|
||||||
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
static bool config_read_caa_record(char* data, CAARecord* record) {
|
|
||||||
char* words[4];
|
|
||||||
if (!get_words(&data[0], &words[0], 4)) {
|
|
||||||
WARN("Invalid SRV record data at line %d", line);
|
|
||||||
record->tag = NULL;
|
|
||||||
record->value = NULL;
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
copy_str(words[2], &record->tag);
|
|
||||||
copy_str(words[3], &record->value);
|
|
||||||
|
|
||||||
uint32_t flags;
|
|
||||||
if (!get_int(words[0], &flags)) {
|
|
||||||
WARN("Invalid SRV record data at line %d", line);
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
record->flags = (uint8_t) flags;
|
|
||||||
|
|
||||||
uint32_t length;
|
|
||||||
if (!get_int(words[1], &length)) {
|
|
||||||
WARN("Invalid SRV record data at line %d", line);
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
record->length = (uint8_t) length;
|
|
||||||
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
static bool config_read_cmd_record(char* data, CMDRecord* record) {
|
|
||||||
int len = strlen(data);
|
|
||||||
record->command = malloc(len);
|
|
||||||
memcpy(record->command, data, len);
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
static bool config_read_record_data(char* data, Record* record) {
|
|
||||||
switch (record->type) {
|
|
||||||
case UNKOWN:
|
|
||||||
// This can never happend in here so uh do nothing i guess
|
|
||||||
return false;
|
|
||||||
case A:
|
|
||||||
return config_read_a_record(data, &record->data.a);
|
|
||||||
case NS:
|
|
||||||
return config_read_ns_record(data, &record->data.ns);
|
|
||||||
case CNAME:
|
|
||||||
return config_read_cname_record(data, &record->data.cname);
|
|
||||||
case SOA:
|
|
||||||
return config_read_soa_record(data, &record->data.soa);
|
|
||||||
case PTR:
|
|
||||||
return config_read_ptr_record(data, &record->data.ptr);
|
|
||||||
case MX:
|
|
||||||
return config_read_mx_record(data, &record->data.mx);
|
|
||||||
case TXT:
|
|
||||||
return config_read_txt_record(data, &record->data.txt);
|
|
||||||
case AAAA:
|
|
||||||
return config_read_aaaa_record(data, &record->data.aaaa);
|
|
||||||
case SRV:
|
|
||||||
return config_read_srv_record(data, &record->data.srv);
|
|
||||||
case CAA:
|
|
||||||
return config_read_caa_record(data, &record->data.caa);
|
|
||||||
case CMD:
|
|
||||||
return config_read_cmd_record(data, &record->data.cmd);
|
|
||||||
case AR:
|
|
||||||
case AAAAR:
|
|
||||||
memset(&record->data, 0, sizeof(record->data));
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
static bool config_read_record(FILE* file, Record* record, Question* question) {
|
|
||||||
BUF(buf);
|
|
||||||
if (!get_line(file, buf)) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (is_whitespace(buf)) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
char* words[4];
|
|
||||||
if (!get_words(&buf[0], &words[0], 4)) {
|
|
||||||
WARN("Invalid record at line %d", line);
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
uint16_t class;
|
|
||||||
if (!config_read_class(words[0], &class)) {
|
|
||||||
WARN("Invalid question class at line %d", line);
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
RecordType qtype;
|
|
||||||
if (!config_read_qtype(words[1], &qtype)) {
|
|
||||||
WARN("Invalid question qtype at line %d", line);
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
uint32_t ttl;
|
|
||||||
if (!get_int(words[2], &ttl)) {
|
|
||||||
WARN("Invalid record ttl at line %d", line);
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
record->cls = class;
|
|
||||||
record->type = qtype;
|
|
||||||
record->len = 0;
|
|
||||||
record->ttl = ttl;
|
|
||||||
record->domain = malloc(question->domain[0] + 1);
|
|
||||||
memcpy(record->domain, question->domain, question->domain[0] + 1);
|
|
||||||
|
|
||||||
if(!config_read_record_data(words[3], record)) {
|
|
||||||
free_record(record);
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
static void config_push_record(Record** buf, Record record, uint16_t* capacity, uint16_t* size) {
|
|
||||||
if (size == capacity) {
|
|
||||||
*capacity *= 2;
|
|
||||||
*buf = realloc(*buf, sizeof(Record) * *capacity);
|
|
||||||
}
|
|
||||||
(*buf)[*size] = record;
|
|
||||||
(*size)++;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool load_config(const char* path, RecordMap* map) {
|
|
||||||
FILE* file = fopen(path, "r");
|
|
||||||
if (file == NULL) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
INFO("Using config file at path: %s", path);
|
|
||||||
|
|
||||||
line = 0;
|
|
||||||
record_map_init(map);
|
|
||||||
|
|
||||||
while (1) {
|
|
||||||
Question* question = malloc(sizeof(Question));
|
|
||||||
if (!config_read_question(file, question)) {
|
|
||||||
free(question);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
INIT_LOG_BUFFER(log);
|
|
||||||
LOGONLY(print_question(question, log));
|
|
||||||
TRACE("Found config question: %s", log);
|
|
||||||
|
|
||||||
Packet* packet = malloc(sizeof(Packet));
|
|
||||||
memset(packet, 0, sizeof(Packet));
|
|
||||||
packet->authorities = NULL;
|
|
||||||
packet->resources = NULL;
|
|
||||||
|
|
||||||
packet->questions = malloc(sizeof(Question));
|
|
||||||
packet->questions[0] = *question;
|
|
||||||
|
|
||||||
uint16_t capacity = 1;
|
|
||||||
packet->answers = malloc(sizeof(Record));
|
|
||||||
|
|
||||||
while(1) {
|
|
||||||
Record record;
|
|
||||||
if (!config_read_record(file, &record, question)) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
LOGONLY(print_record(&record, log));
|
|
||||||
TRACE("Found config record: %s", log);
|
|
||||||
|
|
||||||
config_push_record(&packet->answers, record, &capacity, &packet->header.answers);
|
|
||||||
}
|
|
||||||
|
|
||||||
record_map_add(map, question, packet);
|
|
||||||
}
|
|
||||||
|
|
||||||
fclose(file);
|
|
||||||
return true;
|
|
||||||
}
|
|
|
@ -1,5 +0,0 @@
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include "map.h"
|
|
||||||
|
|
||||||
bool load_config(const char* path, RecordMap* map);
|
|
49
src/io/log.c
49
src/io/log.c
|
@ -1,49 +0,0 @@
|
||||||
#include <stdarg.h>
|
|
||||||
#include <stdio.h>
|
|
||||||
#include <stdlib.h>
|
|
||||||
#include <string.h>
|
|
||||||
#include <time.h>
|
|
||||||
|
|
||||||
#include "log.h"
|
|
||||||
|
|
||||||
#ifdef LOG
|
|
||||||
|
|
||||||
void logmsg(LogLevel level, const char* msg, ...) {
|
|
||||||
|
|
||||||
INIT_LOG_BOUNDS
|
|
||||||
INIT_LOG_BUFFER(buffer)
|
|
||||||
|
|
||||||
time_t now = time(NULL);
|
|
||||||
struct tm *tm = localtime(&now);
|
|
||||||
APPEND(buffer, "\x1b[97m%02d:%02d:%02d ", tm->tm_hour, tm->tm_min, tm->tm_sec);
|
|
||||||
|
|
||||||
switch (level) {
|
|
||||||
case DEBUG:
|
|
||||||
APPEND(buffer, "\x1b[%dm%s\x1b[97m ", 95, "DEBUG");
|
|
||||||
break;
|
|
||||||
case TRACE:
|
|
||||||
APPEND(buffer, "\x1b[%dm%s\x1b[97m ", 96, "TRACE");
|
|
||||||
break;
|
|
||||||
case INFO:
|
|
||||||
APPEND(buffer, "\x1b[%dm%s\x1b[97m ", 92, "INFO");
|
|
||||||
break;
|
|
||||||
case WARN:
|
|
||||||
APPEND(buffer, "\x1b[%dm%s\x1b[97m ", 93, "WARN");
|
|
||||||
break;
|
|
||||||
case ERROR:
|
|
||||||
APPEND(buffer, "\x1b[%dm%s\x1b[97m ", 91, "ERROR");
|
|
||||||
break;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
va_list valist;
|
|
||||||
va_start(valist, msg);
|
|
||||||
t += vsnprintf(buffer + t, BUF_LENGTH - t, msg, valist);
|
|
||||||
va_end(valist);
|
|
||||||
|
|
||||||
APPEND(buffer, "\n");
|
|
||||||
|
|
||||||
fwrite(&buffer, t, 1, stdout);
|
|
||||||
}
|
|
||||||
|
|
||||||
#endif
|
|
45
src/io/log.h
45
src/io/log.h
|
@ -1,45 +0,0 @@
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include <stdint.h>
|
|
||||||
|
|
||||||
#define LOG
|
|
||||||
#ifdef LOG
|
|
||||||
|
|
||||||
typedef enum {
|
|
||||||
DEBUG,
|
|
||||||
TRACE,
|
|
||||||
INFO,
|
|
||||||
WARN,
|
|
||||||
ERROR,
|
|
||||||
} LogLevel;
|
|
||||||
|
|
||||||
#define BUF_LENGTH 256
|
|
||||||
#define INIT_LOG_BUFFER(name) char name[BUF_LENGTH];
|
|
||||||
#define INIT_LOG_BOUNDS int t = 0;
|
|
||||||
#define APPEND(buffer, msg, ...) t += snprintf(buffer + t, BUF_LENGTH - t, msg, ##__VA_ARGS__);
|
|
||||||
#define LOGONLY(code) code
|
|
||||||
|
|
||||||
void logmsg(LogLevel level, const char* msg, ...)
|
|
||||||
__attribute__ ((__format__(printf, 2, 3)));
|
|
||||||
|
|
||||||
#define DEBUG(msg, ...) logmsg(DEBUG, msg, ##__VA_ARGS__)
|
|
||||||
#define TRACE(msg, ...) logmsg(TRACE, msg, ##__VA_ARGS__)
|
|
||||||
#define INFO(msg, ...) logmsg(INFO, msg, ##__VA_ARGS__)
|
|
||||||
#define WARN(msg, ...) logmsg(WARN, msg, ##__VA_ARGS__)
|
|
||||||
#define ERROR(msg, ...) logmsg(ERROR, msg, ##__VA_ARGS__)
|
|
||||||
|
|
||||||
#else
|
|
||||||
|
|
||||||
#define BUF_LENGTH
|
|
||||||
#define INIT_LOG_BUFFER(name)
|
|
||||||
#define INIT_LOG_BOUNDS
|
|
||||||
#define APPEND(buffer, msg, ...)
|
|
||||||
#define LOGONLY(code)
|
|
||||||
|
|
||||||
#define DEBUG(msg, ...)
|
|
||||||
#define TRACE(msg, ...)
|
|
||||||
#define INFO(msg, ...)
|
|
||||||
#define WARN(msg, ...)
|
|
||||||
#define ERROR(msg, ...)
|
|
||||||
|
|
||||||
#endif
|
|
104
src/io/map.c
104
src/io/map.c
|
@ -1,104 +0,0 @@
|
||||||
#include <string.h>
|
|
||||||
#include <stdlib.h>
|
|
||||||
|
|
||||||
#include "map.h"
|
|
||||||
|
|
||||||
void record_map_init(RecordMap* map) {
|
|
||||||
map->capacity = 0;
|
|
||||||
map->len = 0;
|
|
||||||
map->entries = NULL;
|
|
||||||
}
|
|
||||||
|
|
||||||
void record_map_free(RecordMap* map) {
|
|
||||||
for(uint32_t i = 0; i < map->capacity; i++) {
|
|
||||||
Entry* e = &map->entries[i];
|
|
||||||
if (e->key != NULL) {
|
|
||||||
free_question(e->key);
|
|
||||||
free(e->key);
|
|
||||||
free_packet(e->value);
|
|
||||||
free(e->value);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
free(map->entries);
|
|
||||||
}
|
|
||||||
|
|
||||||
static size_t hash_question(const Question* question) {
|
|
||||||
size_t hash = 5381;
|
|
||||||
for(int i = 0; i < question->domain[0]; i++) {
|
|
||||||
uint8_t c = question->domain[i+1];
|
|
||||||
hash = ((hash << 5) + hash) + c;
|
|
||||||
}
|
|
||||||
hash = ((hash << 5) + hash) + (uint8_t)question->cls;
|
|
||||||
hash = ((hash << 5) + hash) + (uint8_t)question->qtype;
|
|
||||||
return hash;
|
|
||||||
}
|
|
||||||
|
|
||||||
static bool question_equals(const Question* a, const Question* b) {
|
|
||||||
if (a->qtype != b->qtype) return false;
|
|
||||||
if (a->cls != b->cls) return false;
|
|
||||||
if (a->domain[0] != b->domain[0]) return false;
|
|
||||||
return memcmp(a->domain+1, b->domain+1, a->domain[0]) == 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
static Entry* record_map_find(Entry* entries, uint32_t capacity, const Question* key) {
|
|
||||||
uint32_t index = hash_question(key) % capacity;
|
|
||||||
while(true) {
|
|
||||||
Entry* entry = &entries[index];
|
|
||||||
if(entry->key == NULL) {
|
|
||||||
return entry;
|
|
||||||
} else if(question_equals(entry->key, key)) {
|
|
||||||
return entry;
|
|
||||||
}
|
|
||||||
index += 1;
|
|
||||||
index %= capacity;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
static void record_map_grow(RecordMap* map, uint32_t capacity) {
|
|
||||||
Entry* entries = malloc(capacity * sizeof(Entry));
|
|
||||||
for(uint32_t i = 0; i < capacity; i++) {
|
|
||||||
entries[i].key = NULL;
|
|
||||||
entries[i].value = NULL;
|
|
||||||
}
|
|
||||||
map->len = 0;
|
|
||||||
for(uint32_t i = 0; i < map->capacity; i++) {
|
|
||||||
Entry* entry = &map->entries[i];
|
|
||||||
if(entry->key == NULL) continue;
|
|
||||||
|
|
||||||
Entry* dest = record_map_find(entries, capacity, entry->key);
|
|
||||||
dest->key = entry->key;
|
|
||||||
dest->value = entry->value;
|
|
||||||
map->len++;
|
|
||||||
}
|
|
||||||
free(map->entries);
|
|
||||||
|
|
||||||
map->entries = entries;
|
|
||||||
map->capacity = capacity;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool record_map_get(const RecordMap* map, const Question* key, Packet* value) {
|
|
||||||
if(map->len == 0) return false;
|
|
||||||
|
|
||||||
Entry* e = record_map_find(map->entries, map->capacity, key);
|
|
||||||
if (e->key == NULL) return false;
|
|
||||||
|
|
||||||
*value = *(e->value);
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
void record_map_add(RecordMap* map, Question* key, Packet* value) {
|
|
||||||
if(map->len + 1 > map->capacity * 0.75) {
|
|
||||||
int capacity = (map->capacity == 0 ? 8 : (2 * map->capacity));
|
|
||||||
record_map_grow(map, capacity);
|
|
||||||
}
|
|
||||||
Entry* e = record_map_find(map->entries, map->capacity, key);
|
|
||||||
bool new_key = e->key == NULL;
|
|
||||||
if(new_key) {
|
|
||||||
map->len++;
|
|
||||||
e->key = key;
|
|
||||||
}
|
|
||||||
|
|
||||||
value->header.z = true;
|
|
||||||
e->value = value;
|
|
||||||
}
|
|
||||||
|
|
20
src/io/map.h
20
src/io/map.h
|
@ -1,20 +0,0 @@
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include "../packet/packet.h"
|
|
||||||
|
|
||||||
typedef struct {
|
|
||||||
Question* key;
|
|
||||||
Packet* value;
|
|
||||||
} Entry;
|
|
||||||
|
|
||||||
typedef struct {
|
|
||||||
uint32_t capacity;
|
|
||||||
uint32_t len;
|
|
||||||
Entry* entries;
|
|
||||||
} RecordMap;
|
|
||||||
|
|
||||||
void record_map_init(RecordMap* map);
|
|
||||||
void record_map_free(RecordMap* map);
|
|
||||||
|
|
||||||
bool record_map_get(const RecordMap* map, const Question* key, Packet* value);
|
|
||||||
void record_map_add(RecordMap* map, Question* key, Packet* value);
|
|
31
src/main.c
31
src/main.c
|
@ -1,31 +0,0 @@
|
||||||
#include "server/server.h"
|
|
||||||
|
|
||||||
#include <stdlib.h>
|
|
||||||
|
|
||||||
#define DEFAULT_PORT 53
|
|
||||||
|
|
||||||
static uint16_t get_port(const char* port_str) {
|
|
||||||
if (port_str == NULL) {
|
|
||||||
return DEFAULT_PORT;
|
|
||||||
}
|
|
||||||
|
|
||||||
uint16_t port;
|
|
||||||
if ((port = strtoul(port_str, NULL, 10)) == 0) {
|
|
||||||
return DEFAULT_PORT;
|
|
||||||
}
|
|
||||||
|
|
||||||
return port;
|
|
||||||
}
|
|
||||||
|
|
||||||
int main(void) {
|
|
||||||
|
|
||||||
const char* port_str = getenv("PORT");
|
|
||||||
uint16_t port = get_port(port_str);
|
|
||||||
|
|
||||||
Server server;
|
|
||||||
server_init(port, &server);
|
|
||||||
server_run(&server);
|
|
||||||
server_free(&server);
|
|
||||||
|
|
||||||
return EXIT_SUCCESS;
|
|
||||||
}
|
|
64
src/main.rs
Normal file
64
src/main.rs
Normal file
|
@ -0,0 +1,64 @@
|
||||||
|
use std::time::{SystemTime, UNIX_EPOCH};
|
||||||
|
|
||||||
|
use config::Config;
|
||||||
|
|
||||||
|
use database::Database;
|
||||||
|
use dotenv::dotenv;
|
||||||
|
use dns::server::DnsServer;
|
||||||
|
use tracing::{error, metadata::LevelFilter};
|
||||||
|
use tracing_subscriber::{
|
||||||
|
filter::filter_fn, prelude::__tracing_subscriber_SubscriberExt, util::SubscriberInitExt, Layer,
|
||||||
|
};
|
||||||
|
use web::WebServer;
|
||||||
|
|
||||||
|
mod config;
|
||||||
|
mod database;
|
||||||
|
mod dns;
|
||||||
|
mod web;
|
||||||
|
|
||||||
|
type Error = Box<dyn std::error::Error>;
|
||||||
|
pub type Result<T> = std::result::Result<T, Error>;
|
||||||
|
|
||||||
|
#[tokio::main]
|
||||||
|
async fn main() {
|
||||||
|
if let Err(err) = run().await {
|
||||||
|
error!("{err}")
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn run() -> Result<()> {
|
||||||
|
dotenv().ok();
|
||||||
|
|
||||||
|
tracing_subscriber::registry()
|
||||||
|
.with(
|
||||||
|
tracing_subscriber::fmt::layer()
|
||||||
|
.with_filter(LevelFilter::TRACE)
|
||||||
|
.with_filter(filter_fn(|metadata| {
|
||||||
|
metadata.target().starts_with("wrapper")
|
||||||
|
})),
|
||||||
|
)
|
||||||
|
.init();
|
||||||
|
|
||||||
|
let config = Config::new();
|
||||||
|
let database = Database::new(config.clone()).await?;
|
||||||
|
|
||||||
|
let dns_server = DnsServer::new(config.clone(), database.clone()).await?;
|
||||||
|
let (udp, tcp) = dns_server.run().await?;
|
||||||
|
|
||||||
|
let web_server = WebServer::new(config, database).await?;
|
||||||
|
let web = web_server.run().await?;
|
||||||
|
|
||||||
|
tokio::join!(udp).0?;
|
||||||
|
tokio::join!(tcp).0?;
|
||||||
|
tokio::join!(web).0?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_time() -> u64 {
|
||||||
|
let start = SystemTime::now();
|
||||||
|
let since_the_epoch = start
|
||||||
|
.duration_since(UNIX_EPOCH)
|
||||||
|
.expect("Time went backwards");
|
||||||
|
since_the_epoch.as_millis() as u64
|
||||||
|
}
|
|
@ -1,250 +0,0 @@
|
||||||
#include "buffer.h"
|
|
||||||
|
|
||||||
#include <stdio.h>
|
|
||||||
#include <stdlib.h>
|
|
||||||
#include <string.h>
|
|
||||||
|
|
||||||
struct PacketBuffer {
|
|
||||||
uint8_t* arr;
|
|
||||||
int capacity;
|
|
||||||
int index;
|
|
||||||
int size;
|
|
||||||
};
|
|
||||||
|
|
||||||
PacketBuffer* buffer_create(int capacity) {
|
|
||||||
PacketBuffer* buffer = malloc(sizeof(PacketBuffer));
|
|
||||||
buffer->arr = malloc(capacity);
|
|
||||||
memset(buffer->arr, 0, capacity);
|
|
||||||
buffer->capacity = capacity;
|
|
||||||
buffer->size = 0;
|
|
||||||
buffer->index = 0;
|
|
||||||
return buffer;
|
|
||||||
}
|
|
||||||
|
|
||||||
void buffer_free(PacketBuffer* buffer) {
|
|
||||||
free(buffer->arr);
|
|
||||||
free(buffer);
|
|
||||||
}
|
|
||||||
|
|
||||||
void buffer_seek(PacketBuffer* buffer, int index) {
|
|
||||||
buffer->index = index;
|
|
||||||
}
|
|
||||||
|
|
||||||
uint8_t buffer_read(PacketBuffer* buffer) {
|
|
||||||
if (buffer->index >= buffer->size) {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
uint8_t data = buffer->arr[buffer->index];
|
|
||||||
buffer->index++;
|
|
||||||
return data;
|
|
||||||
}
|
|
||||||
|
|
||||||
uint16_t buffer_read_short(PacketBuffer* buffer) {
|
|
||||||
return
|
|
||||||
(uint16_t) buffer_read(buffer) << 8 |
|
|
||||||
(uint16_t) buffer_read(buffer);
|
|
||||||
}
|
|
||||||
|
|
||||||
uint32_t buffer_read_int(PacketBuffer* buffer) {
|
|
||||||
return
|
|
||||||
(uint32_t) buffer_read(buffer) << 24 |
|
|
||||||
(uint32_t) buffer_read(buffer) << 16 |
|
|
||||||
(uint32_t) buffer_read(buffer) << 8 |
|
|
||||||
(uint32_t) buffer_read(buffer);
|
|
||||||
}
|
|
||||||
|
|
||||||
uint8_t buffer_get(PacketBuffer* buffer, int index) {
|
|
||||||
if (index >= buffer->size) {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
uint8_t data = buffer->arr[index];
|
|
||||||
return data;
|
|
||||||
}
|
|
||||||
|
|
||||||
uint8_t* buffer_get_range(PacketBuffer* buffer, int start, int len) {
|
|
||||||
uint8_t* arr = malloc(len);
|
|
||||||
for (int i = 0; i < len; i++) {
|
|
||||||
arr[i] = buffer_get(buffer, start + i);
|
|
||||||
}
|
|
||||||
return arr;
|
|
||||||
}
|
|
||||||
|
|
||||||
uint16_t buffer_get_size(PacketBuffer* buffer) {
|
|
||||||
return (uint16_t) buffer->size + 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
static void write(uint8_t** buffer, uint8_t* size, uint8_t* capacity, uint8_t data) {
|
|
||||||
if (*size >= *capacity) {
|
|
||||||
if (*capacity >= 128) {
|
|
||||||
*capacity = 255;
|
|
||||||
} else {
|
|
||||||
*capacity *= 2;
|
|
||||||
}
|
|
||||||
*buffer = realloc(*buffer, *capacity);
|
|
||||||
}
|
|
||||||
(*buffer)[*size] = data;
|
|
||||||
(*size)++;
|
|
||||||
}
|
|
||||||
|
|
||||||
void buffer_read_qname(PacketBuffer* buffer, uint8_t** out) {
|
|
||||||
int index = buffer->index;
|
|
||||||
int jumped = 0;
|
|
||||||
|
|
||||||
int max_jumps = 5;
|
|
||||||
int jumps_performed = 0;
|
|
||||||
|
|
||||||
uint8_t length = 0;
|
|
||||||
uint8_t capacity = 8;
|
|
||||||
*out = malloc(capacity);
|
|
||||||
write(out, &length, &capacity, 0);
|
|
||||||
|
|
||||||
while(1) {
|
|
||||||
if (jumps_performed > max_jumps) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
uint8_t len = buffer_get(buffer, index);
|
|
||||||
|
|
||||||
if ((len & 0xC0) == 0xC0) {
|
|
||||||
if (jumped == 0) {
|
|
||||||
buffer_seek(buffer, index + 2);
|
|
||||||
}
|
|
||||||
|
|
||||||
uint16_t b2 = (uint16_t) buffer_get(buffer, index + 1);
|
|
||||||
uint16_t offset = ((((uint16_t) len) ^ 0xC0) << 8) | b2;
|
|
||||||
index = (int) offset;
|
|
||||||
jumped = 1;
|
|
||||||
jumps_performed++;
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
index++;
|
|
||||||
|
|
||||||
if (len == 0) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (length > 1) {
|
|
||||||
write(out, &length, &capacity, '.');
|
|
||||||
}
|
|
||||||
|
|
||||||
uint8_t* range = buffer_get_range(buffer, index, len);
|
|
||||||
for (uint8_t i = 0; i < len; i++) {
|
|
||||||
write(out, &length, &capacity, range[i]);
|
|
||||||
}
|
|
||||||
free(range);
|
|
||||||
|
|
||||||
index += (int) len;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (jumped == 0) {
|
|
||||||
buffer_seek(buffer, index);
|
|
||||||
}
|
|
||||||
|
|
||||||
(*out)[0] = length - 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
void buffer_read_string(PacketBuffer* buffer, uint8_t** out) {
|
|
||||||
uint8_t len = buffer_read(buffer);
|
|
||||||
buffer_read_n(buffer, out, len);
|
|
||||||
}
|
|
||||||
|
|
||||||
static void buffer_expand(PacketBuffer* buffer, int capacity) {
|
|
||||||
if (buffer->capacity >= capacity) return;
|
|
||||||
|
|
||||||
buffer->arr = realloc(buffer->arr, capacity);
|
|
||||||
memset(buffer->arr + buffer->capacity, 0, capacity - buffer->capacity);
|
|
||||||
buffer->capacity = capacity;
|
|
||||||
}
|
|
||||||
|
|
||||||
void buffer_read_n(PacketBuffer* buffer, uint8_t** out, uint8_t len) {
|
|
||||||
buffer_expand(buffer, buffer->index + len + 1);
|
|
||||||
*out = malloc(len + 1);
|
|
||||||
*out[0] = len;
|
|
||||||
memcpy(*out + 1, buffer->arr + buffer->index, len);
|
|
||||||
buffer->index += len;
|
|
||||||
}
|
|
||||||
|
|
||||||
void buffer_write(PacketBuffer* buffer, uint8_t data) {
|
|
||||||
if(buffer->index >= buffer->capacity) {
|
|
||||||
buffer->capacity *= 2;
|
|
||||||
buffer->arr = realloc(buffer->arr, buffer->capacity);
|
|
||||||
}
|
|
||||||
if (buffer->size < buffer->index) {
|
|
||||||
buffer->size = buffer->index;
|
|
||||||
}
|
|
||||||
buffer->arr[buffer->index] = data;
|
|
||||||
buffer->index++;
|
|
||||||
}
|
|
||||||
|
|
||||||
void buffer_write_short(PacketBuffer* buffer, uint16_t data) {
|
|
||||||
buffer_write(buffer, (uint8_t)(data >> 8));
|
|
||||||
buffer_write(buffer, (uint8_t)(data & 0xFF));
|
|
||||||
}
|
|
||||||
|
|
||||||
void buffer_write_int(PacketBuffer* buffer, uint32_t data) {
|
|
||||||
buffer_write(buffer, (uint8_t)(data >> 24));
|
|
||||||
buffer_write(buffer, (uint8_t)(data >> 16));
|
|
||||||
buffer_write(buffer, (uint8_t)(data >> 8));
|
|
||||||
buffer_write(buffer, (uint8_t)(data & 0xFF));
|
|
||||||
}
|
|
||||||
|
|
||||||
void buffer_write_qname(PacketBuffer* buffer, uint8_t* in) {
|
|
||||||
uint8_t part = 0;
|
|
||||||
uint8_t len = in[0];
|
|
||||||
|
|
||||||
buffer_write(buffer, 0);
|
|
||||||
|
|
||||||
if (len == 0) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
for(uint8_t i = 0; i < len; i ++) {
|
|
||||||
if (in[i+1] == '.') {
|
|
||||||
buffer_set(buffer, part, buffer->index - (int)part - 1);
|
|
||||||
buffer_write(buffer, 0);
|
|
||||||
part = 0;
|
|
||||||
} else {
|
|
||||||
buffer_write(buffer, in[i+1]);
|
|
||||||
part++;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
buffer_set(buffer, part, buffer->index - (int)part - 1);
|
|
||||||
buffer_write(buffer, 0);
|
|
||||||
}
|
|
||||||
|
|
||||||
void buffer_write_string(PacketBuffer* buffer, uint8_t* in) {
|
|
||||||
buffer_write(buffer, in[0]);
|
|
||||||
buffer_write_n(buffer, in + 1, in[0]);
|
|
||||||
}
|
|
||||||
|
|
||||||
void buffer_write_n(PacketBuffer* buffer, uint8_t* in, int len) {
|
|
||||||
buffer_expand(buffer, buffer->index + len + 1);
|
|
||||||
memcpy(buffer->arr + buffer->index, in, len);
|
|
||||||
buffer->size += len;
|
|
||||||
buffer->index += len;
|
|
||||||
}
|
|
||||||
|
|
||||||
void buffer_set(PacketBuffer* buffer, uint8_t data, int index) {
|
|
||||||
if (index > buffer->size) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
buffer->arr[index] = data;
|
|
||||||
}
|
|
||||||
|
|
||||||
void buffer_set_uint16_t(PacketBuffer* buffer, uint16_t data, int index) {
|
|
||||||
buffer_set(buffer, (uint8_t)(data >> 8), index);
|
|
||||||
buffer_set(buffer, (uint8_t)(data & 0xFF), index + 1);
|
|
||||||
}
|
|
||||||
|
|
||||||
int buffer_get_index(PacketBuffer* buffer) {
|
|
||||||
return buffer->index;
|
|
||||||
}
|
|
||||||
|
|
||||||
void buffer_step(PacketBuffer* buffer, int len) {
|
|
||||||
buffer->index += len;
|
|
||||||
}
|
|
||||||
|
|
||||||
uint8_t* buffer_get_ptr(PacketBuffer* buffer) {
|
|
||||||
return buffer->arr;
|
|
||||||
}
|
|
|
@ -1,51 +0,0 @@
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include <stdint.h>
|
|
||||||
|
|
||||||
typedef struct PacketBuffer PacketBuffer;
|
|
||||||
|
|
||||||
PacketBuffer* buffer_create(int capacity);
|
|
||||||
|
|
||||||
void buffer_free(PacketBuffer* buffer);
|
|
||||||
|
|
||||||
void buffer_seek(PacketBuffer* buffer, int index);
|
|
||||||
|
|
||||||
uint8_t buffer_read(PacketBuffer* buffer);
|
|
||||||
|
|
||||||
uint16_t buffer_read_short(PacketBuffer* buffer);
|
|
||||||
|
|
||||||
uint32_t buffer_read_int(PacketBuffer* buffer);
|
|
||||||
|
|
||||||
uint8_t buffer_get(PacketBuffer* buffer, int index);
|
|
||||||
|
|
||||||
uint8_t* buffer_get_range(PacketBuffer* buffer, int start, int len);
|
|
||||||
|
|
||||||
uint16_t buffer_get_size(PacketBuffer* buffer);
|
|
||||||
|
|
||||||
void buffer_read_qname(PacketBuffer* buffer, uint8_t** out);
|
|
||||||
|
|
||||||
void buffer_read_string(PacketBuffer* buffer, uint8_t** out);
|
|
||||||
|
|
||||||
void buffer_read_n(PacketBuffer* buffer, uint8_t** out, uint8_t len);
|
|
||||||
|
|
||||||
void buffer_write(PacketBuffer* buffer, uint8_t data);
|
|
||||||
|
|
||||||
void buffer_write_short(PacketBuffer* buffer, uint16_t data);
|
|
||||||
|
|
||||||
void buffer_write_int(PacketBuffer* buffer, uint32_t data);
|
|
||||||
|
|
||||||
void buffer_write_qname(PacketBuffer* buffer, uint8_t* in);
|
|
||||||
|
|
||||||
void buffer_write_string(PacketBuffer* buffer, uint8_t* in);
|
|
||||||
|
|
||||||
void buffer_write_n(PacketBuffer* buffer, uint8_t* in, int len);
|
|
||||||
|
|
||||||
void buffer_set(PacketBuffer* buffer, uint8_t data, int index);
|
|
||||||
|
|
||||||
void buffer_set_uint16_t(PacketBuffer* buffer, uint16_t data, int index);
|
|
||||||
|
|
||||||
int buffer_get_index(PacketBuffer* buffer);
|
|
||||||
|
|
||||||
void buffer_step(PacketBuffer* buffer, int len);
|
|
||||||
|
|
||||||
uint8_t* buffer_get_ptr(PacketBuffer* buffer);
|
|
|
@ -1,95 +0,0 @@
|
||||||
#include "header.h"
|
|
||||||
#include "buffer.h"
|
|
||||||
|
|
||||||
#include <stdio.h>
|
|
||||||
#include <stdlib.h>
|
|
||||||
#include <string.h>
|
|
||||||
|
|
||||||
uint8_t rescode_to_id(ResultCode code) {
|
|
||||||
switch(code) {
|
|
||||||
case NOERROR:
|
|
||||||
return 0;
|
|
||||||
case FORMERR:
|
|
||||||
return 1;
|
|
||||||
case SERVFAIL:
|
|
||||||
return 2;
|
|
||||||
case NXDOMAIN:
|
|
||||||
return 3;
|
|
||||||
case NOTIMP:
|
|
||||||
return 4;
|
|
||||||
case REFUSED:
|
|
||||||
return 5;
|
|
||||||
default:
|
|
||||||
return 2;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
ResultCode rescode_from_id(uint8_t id) {
|
|
||||||
switch(id) {
|
|
||||||
case 0:
|
|
||||||
return NOERROR;
|
|
||||||
case 1:
|
|
||||||
return FORMERR;
|
|
||||||
case 2:
|
|
||||||
return SERVFAIL;
|
|
||||||
case 3:
|
|
||||||
return NXDOMAIN;
|
|
||||||
case 4:
|
|
||||||
return NOTIMP;
|
|
||||||
case 5:
|
|
||||||
return REFUSED;
|
|
||||||
default:
|
|
||||||
return FORMERR;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#define MAX(var, max) var = var > max ? max : var;
|
|
||||||
|
|
||||||
void read_header(PacketBuffer* buffer, Header* header) {
|
|
||||||
// memset(header, 0, sizeof(Header));
|
|
||||||
header->id = buffer_read_short(buffer);
|
|
||||||
|
|
||||||
uint8_t a = buffer_read(buffer);
|
|
||||||
header->recursion_desired = (a & (1 << 0)) > 0;
|
|
||||||
header->truncated_message = (a & (1 << 1)) > 0;
|
|
||||||
header->authorative_answer = (a & (1 << 2)) > 0;
|
|
||||||
header->opcode = (a >> 3) & 0x0F;
|
|
||||||
header->response = (a & (1 << 7)) > 0;
|
|
||||||
|
|
||||||
uint8_t b = buffer_read(buffer);
|
|
||||||
header->rescode = rescode_from_id(b & 0x0F);
|
|
||||||
header->checking_disabled = (b & (1 << 4)) > 0;
|
|
||||||
header->authed_data = (b& (1 << 4)) > 0;
|
|
||||||
header->z = (b & (1 << 6)) > 0;
|
|
||||||
header->recursion_available = (b & (1 << 7)) > 0;
|
|
||||||
|
|
||||||
header->questions = buffer_read_short(buffer);
|
|
||||||
header->answers = buffer_read_short(buffer);
|
|
||||||
header->authoritative_entries = buffer_read_short(buffer);
|
|
||||||
header->resource_entries = buffer_read_short(buffer);
|
|
||||||
}
|
|
||||||
|
|
||||||
void write_header(PacketBuffer* buffer, Header* header) {
|
|
||||||
buffer_write_short(buffer, header->id);
|
|
||||||
|
|
||||||
buffer_write(buffer,
|
|
||||||
((uint8_t) header->recursion_desired) |
|
|
||||||
((uint8_t) header->truncated_message << 1) |
|
|
||||||
((uint8_t) header->authorative_answer << 2) |
|
|
||||||
(header->opcode << 3) |
|
|
||||||
((uint8_t) header->response << 7)
|
|
||||||
);
|
|
||||||
|
|
||||||
buffer_write(buffer,
|
|
||||||
(rescode_to_id(header->rescode)) |
|
|
||||||
((uint8_t) header->checking_disabled << 4) |
|
|
||||||
((uint8_t) header->authed_data << 5) |
|
|
||||||
((uint8_t) header->z << 6) |
|
|
||||||
((uint8_t) header->recursion_available << 7)
|
|
||||||
);
|
|
||||||
|
|
||||||
buffer_write_short(buffer, header->questions);
|
|
||||||
buffer_write_short(buffer, header->answers);
|
|
||||||
buffer_write_short(buffer, header->authoritative_entries);
|
|
||||||
buffer_write_short(buffer, header->resource_entries);
|
|
||||||
}
|
|
|
@ -1,41 +0,0 @@
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include "buffer.h"
|
|
||||||
|
|
||||||
#include <stdbool.h>
|
|
||||||
|
|
||||||
typedef enum {
|
|
||||||
NOERROR, // 0
|
|
||||||
FORMERR, // 1
|
|
||||||
SERVFAIL, // 2
|
|
||||||
NXDOMAIN, // 3,
|
|
||||||
NOTIMP, // 4
|
|
||||||
REFUSED, // 5
|
|
||||||
} ResultCode;
|
|
||||||
|
|
||||||
uint8_t rescode_to_id(ResultCode code);
|
|
||||||
ResultCode rescode_from_id(uint8_t id);
|
|
||||||
|
|
||||||
typedef struct {
|
|
||||||
uint16_t id;
|
|
||||||
|
|
||||||
bool recursion_desired; // 1 bit
|
|
||||||
bool truncated_message; // 1 bit
|
|
||||||
bool authorative_answer; // 1 bit
|
|
||||||
uint8_t opcode; // 4 bits
|
|
||||||
bool response; // 1 bit
|
|
||||||
|
|
||||||
ResultCode rescode; // 4 bits
|
|
||||||
bool checking_disabled; // 1 bit
|
|
||||||
bool authed_data; // 1 bit
|
|
||||||
bool z; // 1 bit
|
|
||||||
bool recursion_available; // 1 bit
|
|
||||||
|
|
||||||
uint16_t questions; // 16 bits
|
|
||||||
uint16_t answers; // 16 bits
|
|
||||||
uint16_t authoritative_entries; // 16 bits
|
|
||||||
uint16_t resource_entries; // 16 bits
|
|
||||||
} Header;
|
|
||||||
|
|
||||||
void read_header(PacketBuffer* buffer, Header* header);
|
|
||||||
void write_header(PacketBuffer* buffer, Header* header);
|
|
|
@ -1,185 +0,0 @@
|
||||||
#include <stdbool.h>
|
|
||||||
#include <stdio.h>
|
|
||||||
#include <string.h>
|
|
||||||
#include <stdlib.h>
|
|
||||||
|
|
||||||
#include "packet.h"
|
|
||||||
#include "buffer.h"
|
|
||||||
#include "header.h"
|
|
||||||
#include "question.h"
|
|
||||||
#include "record.h"
|
|
||||||
#include "../io/log.h"
|
|
||||||
|
|
||||||
void read_packet(PacketBuffer* buffer, Packet* packet) {
|
|
||||||
read_header(buffer, &packet->header);
|
|
||||||
|
|
||||||
packet->questions = malloc(sizeof(Question) * packet->header.questions);
|
|
||||||
for(uint16_t i = 0; i < packet->header.questions; i++) {
|
|
||||||
if (!read_question(buffer, &packet->questions[i])) {
|
|
||||||
i--;
|
|
||||||
packet->header.questions--;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
packet->answers = malloc(sizeof(Record) * packet->header.answers);
|
|
||||||
for(uint16_t i = 0; i < packet->header.answers; i++) {
|
|
||||||
if (!read_record(buffer, &packet->answers[i])) {
|
|
||||||
i--;
|
|
||||||
packet->header.answers--;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
packet->authorities = malloc(sizeof(Record) * packet->header.authoritative_entries);
|
|
||||||
for(uint16_t i = 0; i < packet->header.authoritative_entries; i++) {
|
|
||||||
if (!read_record(buffer, &packet->authorities[i])) {
|
|
||||||
i--;
|
|
||||||
packet->header.authoritative_entries--;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
packet->resources = malloc(sizeof(Record) * packet->header.resource_entries);
|
|
||||||
for(uint16_t i = 0; i < packet->header.resource_entries; i++) {
|
|
||||||
if (!read_record(buffer, &packet->resources[i])) {
|
|
||||||
i--;
|
|
||||||
packet->header.resource_entries--;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void write_packet(PacketBuffer* buffer, Packet* packet) {
|
|
||||||
write_header(buffer, &packet->header);
|
|
||||||
|
|
||||||
for(uint16_t i = 0; i < packet->header.questions; i++) {
|
|
||||||
write_question(buffer, &packet->questions[i]);
|
|
||||||
}
|
|
||||||
|
|
||||||
for(uint16_t i = 0; i < packet->header.answers; i++) {
|
|
||||||
write_record(buffer, &packet->answers[i]);
|
|
||||||
}
|
|
||||||
|
|
||||||
for(uint16_t i = 0; i < packet->header.authoritative_entries; i++) {
|
|
||||||
write_record(buffer, &packet->authorities[i]);
|
|
||||||
}
|
|
||||||
|
|
||||||
for(uint16_t i = 0; i < packet->header.resource_entries; i++) {
|
|
||||||
write_record(buffer, &packet->resources[i]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void free_packet(Packet* packet) {
|
|
||||||
|
|
||||||
for(uint16_t i = 0; i < packet->header.questions; i++) {
|
|
||||||
free_question(&packet->questions[i]);
|
|
||||||
}
|
|
||||||
free(packet->questions);
|
|
||||||
|
|
||||||
for(uint16_t i = 0; i < packet->header.answers; i++) {
|
|
||||||
free_record(&packet->answers[i]);
|
|
||||||
}
|
|
||||||
free(packet->answers);
|
|
||||||
|
|
||||||
for(uint16_t i = 0; i < packet->header.authoritative_entries; i++) {
|
|
||||||
free_record(&packet->authorities[i]);
|
|
||||||
}
|
|
||||||
free(packet->authorities);
|
|
||||||
|
|
||||||
for(uint16_t i = 0; i < packet->header.resource_entries; i++) {
|
|
||||||
free_record(&packet->resources[i]);
|
|
||||||
}
|
|
||||||
free(packet->resources);
|
|
||||||
}
|
|
||||||
|
|
||||||
bool get_random_a(Packet* packet, IpAddr* addr) {
|
|
||||||
for (uint16_t i = 0; i < packet->header.answers; i++) {
|
|
||||||
Record record = packet->answers[i];
|
|
||||||
if (record.type == A) {
|
|
||||||
create_ip_addr(record.data.a.addr, addr);
|
|
||||||
return true;
|
|
||||||
} else if (record.type == AAAA) {
|
|
||||||
create_ip_addr6(record.data.aaaa.addr, addr);
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
static bool ends_with(uint8_t* full, uint8_t* end) {
|
|
||||||
uint8_t check = end[0];
|
|
||||||
uint8_t len = full[0];
|
|
||||||
|
|
||||||
if (check > len) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
for(uint8_t i = 0; i < check; i++) {
|
|
||||||
if (end[check - 1 - i] != full[len - 1 - i]) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
static bool equals(uint8_t* a, uint8_t* b) {
|
|
||||||
if (a[0] != b[0]) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
for(uint8_t i = 1; i < a[0] + 1; i++) {
|
|
||||||
if(a[i] != b[i]) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool get_resolved_ns(Packet* packet, uint8_t* qname, IpAddr* addr) {
|
|
||||||
for (uint16_t i = 0; i < packet->header.authoritative_entries; i++) {
|
|
||||||
Record record = packet->authorities[i];
|
|
||||||
if (record.type != NS) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
if(!ends_with(qname, record.domain)) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
for (uint16_t i = 0; i < packet->header.resource_entries; i++) {
|
|
||||||
Record resource = packet->resources[i];
|
|
||||||
if (!equals(record.data.ns.host, resource.domain)) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (resource.type == A) {
|
|
||||||
create_ip_addr(record.data.a.addr, addr);
|
|
||||||
return true;
|
|
||||||
} else if (resource.type == AAAA) {
|
|
||||||
create_ip_addr6(record.data.aaaa.addr, addr);
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool get_unresoled_ns(Packet* packet, uint8_t* qname, Question* question) {
|
|
||||||
for (uint16_t i = 0; i < packet->header.authoritative_entries; i++) {
|
|
||||||
Record record = packet->authorities[i];
|
|
||||||
if (record.type != NS) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
if(!ends_with(qname, record.domain)) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
uint8_t* host = record.data.ns.host;
|
|
||||||
|
|
||||||
question->qtype = NS;
|
|
||||||
question->domain = malloc(host[0] + 1);
|
|
||||||
memcpy(question->domain, host, host[0] + 1);
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
return false;
|
|
||||||
}
|
|
|
@ -1,25 +0,0 @@
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include "buffer.h"
|
|
||||||
#include "question.h"
|
|
||||||
#include "header.h"
|
|
||||||
#include "record.h"
|
|
||||||
#include "../server/addr.h"
|
|
||||||
|
|
||||||
#include <stdbool.h>
|
|
||||||
|
|
||||||
typedef struct {
|
|
||||||
Header header;
|
|
||||||
Question* questions;
|
|
||||||
Record* answers;
|
|
||||||
Record* authorities;
|
|
||||||
Record* resources;
|
|
||||||
} Packet;
|
|
||||||
|
|
||||||
void read_packet(PacketBuffer* buffer, Packet* packet);
|
|
||||||
void write_packet(PacketBuffer* buffer, Packet* packet);
|
|
||||||
void free_packet(Packet* packet);
|
|
||||||
|
|
||||||
bool get_random_a(Packet* packet, IpAddr* addr);
|
|
||||||
bool get_resolved_ns(Packet* packet, uint8_t* qname, IpAddr* addr);
|
|
||||||
bool get_unresoled_ns(Packet* packet, uint8_t* qname, Question* question);
|
|
|
@ -1,110 +0,0 @@
|
||||||
#include <stdio.h>
|
|
||||||
#include <stdlib.h>
|
|
||||||
#include <stdbool.h>
|
|
||||||
|
|
||||||
#include "question.h"
|
|
||||||
#include "buffer.h"
|
|
||||||
#include "record.h"
|
|
||||||
#include "../io/log.h"
|
|
||||||
|
|
||||||
bool read_question(PacketBuffer* buffer, Question* question) {
|
|
||||||
buffer_read_qname(buffer, &question->domain);
|
|
||||||
|
|
||||||
uint16_t qtype_num = buffer_read_short(buffer);
|
|
||||||
record_from_id(qtype_num, &question->qtype);
|
|
||||||
question->cls = buffer_read_short(buffer);
|
|
||||||
|
|
||||||
if (question->qtype == UNKOWN) {
|
|
||||||
free(question->domain);
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
INIT_LOG_BUFFER(log)
|
|
||||||
LOGONLY(print_question(question, log);)
|
|
||||||
TRACE("Reading question: %s", log);
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
void write_question(PacketBuffer* buffer, Question* question) {
|
|
||||||
buffer_write_qname(buffer, question->domain);
|
|
||||||
|
|
||||||
uint16_t id = record_to_id(question->qtype);
|
|
||||||
buffer_write_short(buffer, id);
|
|
||||||
|
|
||||||
buffer_write_short(buffer, question->cls);
|
|
||||||
|
|
||||||
INIT_LOG_BUFFER(log)
|
|
||||||
LOGONLY(print_question(question, log);)
|
|
||||||
TRACE("Writing question: %s", log);
|
|
||||||
}
|
|
||||||
|
|
||||||
void free_question(Question* question) {
|
|
||||||
free(question->domain);
|
|
||||||
}
|
|
||||||
|
|
||||||
void print_question(Question* question, char* buffer) {
|
|
||||||
INIT_LOG_BOUNDS
|
|
||||||
switch (question->cls) {
|
|
||||||
case 1:
|
|
||||||
APPEND(buffer, "IN ");;
|
|
||||||
break;
|
|
||||||
case 3:
|
|
||||||
APPEND(buffer, "CH ");
|
|
||||||
break;
|
|
||||||
case 4:
|
|
||||||
APPEND(buffer, "HS ");
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
APPEND(buffer, "?? ");
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
switch(question->qtype) {
|
|
||||||
case UNKOWN:
|
|
||||||
APPEND(buffer, "UNKOWN ");
|
|
||||||
break;
|
|
||||||
case A:
|
|
||||||
APPEND(buffer, "A ");
|
|
||||||
break;
|
|
||||||
case NS:
|
|
||||||
APPEND(buffer, "NS ");
|
|
||||||
break;
|
|
||||||
case CNAME:
|
|
||||||
APPEND(buffer, "CNAME ");
|
|
||||||
break;
|
|
||||||
case SOA:
|
|
||||||
APPEND(buffer, "SOA ");
|
|
||||||
break;
|
|
||||||
case PTR:
|
|
||||||
APPEND(buffer, "PTR ");
|
|
||||||
break;
|
|
||||||
case MX:
|
|
||||||
APPEND(buffer, "MX ");
|
|
||||||
break;
|
|
||||||
case TXT:
|
|
||||||
APPEND(buffer, "TXT ");
|
|
||||||
break;
|
|
||||||
case AAAA:
|
|
||||||
APPEND(buffer, "AAAA ");
|
|
||||||
break;
|
|
||||||
case SRV:
|
|
||||||
APPEND(buffer, "SRV ");
|
|
||||||
break;
|
|
||||||
case CAA:
|
|
||||||
APPEND(buffer, "CAA ");
|
|
||||||
break;
|
|
||||||
case CMD:
|
|
||||||
APPEND(buffer, "CMD ");
|
|
||||||
break;
|
|
||||||
case AR:
|
|
||||||
APPEND(buffer, "A ");
|
|
||||||
break;
|
|
||||||
case AAAAR:
|
|
||||||
APPEND(buffer, "AAAAR ");
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
APPEND(buffer, "%.*s",
|
|
||||||
question->domain[0],
|
|
||||||
question->domain + 1
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
|
@ -1,15 +0,0 @@
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include "buffer.h"
|
|
||||||
#include "record.h"
|
|
||||||
|
|
||||||
typedef struct {
|
|
||||||
uint8_t* domain;
|
|
||||||
RecordType qtype;
|
|
||||||
uint16_t cls;
|
|
||||||
} Question;
|
|
||||||
|
|
||||||
bool read_question(PacketBuffer* buffer, Question* question);
|
|
||||||
void write_question(PacketBuffer* buffer, Question* question);
|
|
||||||
void free_question(Question* question);
|
|
||||||
void print_question(Question* question, char* buffer);
|
|
|
@ -1,670 +0,0 @@
|
||||||
#include <stdbool.h>
|
|
||||||
#include <stdio.h>
|
|
||||||
#include <stdlib.h>
|
|
||||||
#include <string.h>
|
|
||||||
#include <time.h>
|
|
||||||
|
|
||||||
#undef _POSIX_C_SOURCE
|
|
||||||
#include <stdio.h>
|
|
||||||
|
|
||||||
#include "record.h"
|
|
||||||
#include "buffer.h"
|
|
||||||
#include "../io/log.h"
|
|
||||||
|
|
||||||
uint16_t record_to_id(RecordType type) {
|
|
||||||
switch (type) {
|
|
||||||
case A:
|
|
||||||
return 1;
|
|
||||||
case NS:
|
|
||||||
return 2;
|
|
||||||
case CNAME:
|
|
||||||
return 5;
|
|
||||||
case SOA:
|
|
||||||
return 6;
|
|
||||||
case PTR:
|
|
||||||
return 12;
|
|
||||||
case MX:
|
|
||||||
return 15;
|
|
||||||
case TXT:
|
|
||||||
return 16;
|
|
||||||
case AAAA:
|
|
||||||
return 28;
|
|
||||||
case SRV:
|
|
||||||
return 33;
|
|
||||||
case CAA:
|
|
||||||
return 257;
|
|
||||||
case CMD:
|
|
||||||
return 16;
|
|
||||||
case AR:
|
|
||||||
return 1;
|
|
||||||
case AAAAR:
|
|
||||||
return 28;
|
|
||||||
default:
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void record_from_id(uint16_t i, RecordType* type) {
|
|
||||||
switch (i) {
|
|
||||||
case 1:
|
|
||||||
*type = A;
|
|
||||||
break;
|
|
||||||
case 2:
|
|
||||||
*type = NS;
|
|
||||||
break;
|
|
||||||
case 5:
|
|
||||||
*type = CNAME;
|
|
||||||
break;
|
|
||||||
case 6:
|
|
||||||
*type = SOA;
|
|
||||||
break;
|
|
||||||
case 12:
|
|
||||||
*type = PTR;
|
|
||||||
break;
|
|
||||||
case 15:
|
|
||||||
*type = MX;
|
|
||||||
break;
|
|
||||||
case 16:
|
|
||||||
*type = TXT;
|
|
||||||
break;
|
|
||||||
case 28:
|
|
||||||
*type = AAAA;
|
|
||||||
break;
|
|
||||||
case 33:
|
|
||||||
*type = SRV;
|
|
||||||
break;
|
|
||||||
case 257:
|
|
||||||
*type = CAA;
|
|
||||||
break;
|
|
||||||
case 1000:
|
|
||||||
*type = CMD;
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
*type = UNKOWN;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
static void read_a_record(PacketBuffer* buffer, Record* record) {
|
|
||||||
ARecord data;
|
|
||||||
data.addr[0] = buffer_read(buffer);
|
|
||||||
data.addr[1] = buffer_read(buffer);
|
|
||||||
data.addr[2] = buffer_read(buffer);
|
|
||||||
data.addr[3] = buffer_read(buffer);
|
|
||||||
|
|
||||||
record->data.a = data;
|
|
||||||
}
|
|
||||||
|
|
||||||
static void read_ns_record(PacketBuffer* buffer, Record* record) {
|
|
||||||
NSRecord data;
|
|
||||||
buffer_read_qname(buffer, &data.host);
|
|
||||||
|
|
||||||
record->data.ns = data;
|
|
||||||
}
|
|
||||||
|
|
||||||
static void read_cname_record(PacketBuffer* buffer, Record* record) {
|
|
||||||
CNAMERecord data;
|
|
||||||
buffer_read_qname(buffer, &data.host);
|
|
||||||
|
|
||||||
record->data.cname = data;
|
|
||||||
}
|
|
||||||
|
|
||||||
static void read_soa_record(PacketBuffer* buffer, Record* record) {
|
|
||||||
SOARecord data;
|
|
||||||
buffer_read_qname(buffer, &data.mname);
|
|
||||||
buffer_read_qname(buffer, &data.nname);
|
|
||||||
data.serial = buffer_read_int(buffer);
|
|
||||||
data.refresh = buffer_read_int(buffer);
|
|
||||||
data.retry = buffer_read_int(buffer);
|
|
||||||
data.expire = buffer_read_int(buffer);
|
|
||||||
data.minimum = buffer_read_int(buffer);
|
|
||||||
|
|
||||||
record->data.soa = data;
|
|
||||||
}
|
|
||||||
|
|
||||||
static void read_ptr_record(PacketBuffer* buffer, Record* record) {
|
|
||||||
PTRRecord data;
|
|
||||||
buffer_read_qname(buffer, &data.pointer);
|
|
||||||
|
|
||||||
record->data.ptr = data;
|
|
||||||
}
|
|
||||||
|
|
||||||
static void read_mx_record(PacketBuffer* buffer, Record* record) {
|
|
||||||
MXRecord data;
|
|
||||||
data.priority = buffer_read_short(buffer);
|
|
||||||
buffer_read_qname(buffer, &data.host);
|
|
||||||
|
|
||||||
record->data.mx = data;
|
|
||||||
}
|
|
||||||
|
|
||||||
static void read_txt_record(PacketBuffer* buffer, Record* record) {
|
|
||||||
TXTRecord data;
|
|
||||||
data.len = 0;
|
|
||||||
data.text = malloc(sizeof(uint8_t*) * 2);
|
|
||||||
uint8_t capacity = 2;
|
|
||||||
uint8_t total = record->len;
|
|
||||||
while (1) {
|
|
||||||
if (data.len >= capacity) {
|
|
||||||
if (capacity >= 128) {
|
|
||||||
capacity = 255;
|
|
||||||
} else {
|
|
||||||
capacity *= 2;
|
|
||||||
}
|
|
||||||
data.text = realloc(data.text, sizeof(uint8_t*) * capacity);
|
|
||||||
}
|
|
||||||
|
|
||||||
buffer_read_string(buffer, &data.text[data.len]);
|
|
||||||
if(data.text[data.len][0] == 0) {
|
|
||||||
free(data.text[data.len]);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
data.len++;
|
|
||||||
|
|
||||||
total -= data.text[data.len - 1][0] + 1;
|
|
||||||
if (total == 0) break;
|
|
||||||
}
|
|
||||||
|
|
||||||
record->data.txt = data;
|
|
||||||
}
|
|
||||||
|
|
||||||
static void read_aaaa_record(PacketBuffer* buffer, Record* record) {
|
|
||||||
AAAARecord data;
|
|
||||||
for (int i = 0; i < 16; i++) {
|
|
||||||
data.addr[i] = buffer_read(buffer);
|
|
||||||
}
|
|
||||||
|
|
||||||
record->data.aaaa = data;
|
|
||||||
}
|
|
||||||
|
|
||||||
static void read_srv_record(PacketBuffer* buffer, Record* record) {
|
|
||||||
SRVRecord data;
|
|
||||||
data.priority = buffer_read_short(buffer);
|
|
||||||
data.weight = buffer_read_short(buffer);
|
|
||||||
data.port = buffer_read_short(buffer);
|
|
||||||
buffer_read_qname(buffer, &data.target);
|
|
||||||
|
|
||||||
record->data.srv = data;
|
|
||||||
}
|
|
||||||
|
|
||||||
static void read_caa_record(PacketBuffer* buffer, Record* record, int header_pos) {
|
|
||||||
CAARecord data;
|
|
||||||
data.flags = buffer_read(buffer);
|
|
||||||
data.length = buffer_read(buffer);
|
|
||||||
buffer_read_n(buffer, &data.tag, data.length);
|
|
||||||
int value_len = ((int)record->len) + header_pos - buffer_get_index(buffer);
|
|
||||||
buffer_read_n(buffer, &data.value, (uint8_t)value_len);
|
|
||||||
|
|
||||||
record->data.caa = data;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool read_record(PacketBuffer* buffer, Record* record) {
|
|
||||||
buffer_read_qname(buffer, &record->domain);
|
|
||||||
|
|
||||||
uint16_t qtype_num = buffer_read_short(buffer);
|
|
||||||
record_from_id(qtype_num, &record->type);
|
|
||||||
|
|
||||||
record->cls = buffer_read_short(buffer);
|
|
||||||
record->ttl = buffer_read_int(buffer);
|
|
||||||
record->len = buffer_read_short(buffer);
|
|
||||||
|
|
||||||
int header_pos = buffer_get_index(buffer);
|
|
||||||
|
|
||||||
switch (record->type) {
|
|
||||||
case A:
|
|
||||||
read_a_record(buffer, record);
|
|
||||||
break;
|
|
||||||
case NS:
|
|
||||||
read_ns_record(buffer, record);
|
|
||||||
break;
|
|
||||||
case CNAME:
|
|
||||||
read_cname_record(buffer, record);
|
|
||||||
break;
|
|
||||||
case SOA:
|
|
||||||
read_soa_record(buffer, record);
|
|
||||||
break;
|
|
||||||
case PTR:
|
|
||||||
read_ptr_record(buffer, record);
|
|
||||||
break;
|
|
||||||
case MX:
|
|
||||||
read_mx_record(buffer, record);
|
|
||||||
break;
|
|
||||||
case TXT:
|
|
||||||
read_txt_record(buffer, record);
|
|
||||||
break;
|
|
||||||
case AAAA:
|
|
||||||
read_aaaa_record(buffer, record);
|
|
||||||
break;
|
|
||||||
case SRV:
|
|
||||||
read_srv_record(buffer, record);
|
|
||||||
break;
|
|
||||||
case CAA:
|
|
||||||
read_caa_record(buffer, record, header_pos);
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
buffer_step(buffer, record->len);
|
|
||||||
free(record->domain);
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
INIT_LOG_BUFFER(log)
|
|
||||||
LOGONLY(print_record(record, log);)
|
|
||||||
TRACE("Reading record: %s", log);
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
static void write_a_record(PacketBuffer* buffer, ARecord* data) {
|
|
||||||
buffer_write_short(buffer, 4);
|
|
||||||
buffer_write(buffer, data->addr[0]);
|
|
||||||
buffer_write(buffer, data->addr[1]);
|
|
||||||
buffer_write(buffer, data->addr[2]);
|
|
||||||
buffer_write(buffer, data->addr[3]);
|
|
||||||
}
|
|
||||||
|
|
||||||
static void write_ns_record(PacketBuffer* buffer, NSRecord* data) {
|
|
||||||
int pos = buffer_get_index(buffer);
|
|
||||||
buffer_write_short(buffer, 0);
|
|
||||||
|
|
||||||
buffer_write_qname(buffer, data->host);
|
|
||||||
|
|
||||||
int size = buffer_get_index(buffer) - pos - 2;
|
|
||||||
buffer_set_uint16_t(buffer, (uint16_t)size, pos);
|
|
||||||
}
|
|
||||||
|
|
||||||
static void write_cname_record(PacketBuffer* buffer, CNAMERecord* data) {
|
|
||||||
int pos = buffer_get_index(buffer);
|
|
||||||
buffer_write_short(buffer, 0);
|
|
||||||
|
|
||||||
buffer_write_qname(buffer, data->host);
|
|
||||||
|
|
||||||
int size = buffer_get_index(buffer) - pos - 2;
|
|
||||||
buffer_set_uint16_t(buffer, (uint16_t)size, pos);
|
|
||||||
}
|
|
||||||
|
|
||||||
static void write_soa_record(PacketBuffer* buffer, SOARecord* data) {
|
|
||||||
int pos = buffer_get_index(buffer);
|
|
||||||
buffer_write_short(buffer, 0);
|
|
||||||
|
|
||||||
buffer_write_qname(buffer, data->mname);
|
|
||||||
buffer_write_qname(buffer, data->nname);
|
|
||||||
buffer_write_int(buffer, data->serial);
|
|
||||||
buffer_write_int(buffer, data->refresh);
|
|
||||||
buffer_write_int(buffer, data->retry);
|
|
||||||
buffer_write_int(buffer, data->expire);
|
|
||||||
buffer_write_int(buffer, data->minimum);
|
|
||||||
|
|
||||||
int size = buffer_get_index(buffer) - pos - 2;
|
|
||||||
buffer_set_uint16_t(buffer, (uint16_t)size, pos);
|
|
||||||
}
|
|
||||||
|
|
||||||
static void write_ptr_record(PacketBuffer* buffer, PTRRecord* data) {
|
|
||||||
int pos = buffer_get_index(buffer);
|
|
||||||
buffer_write_short(buffer, 0);
|
|
||||||
|
|
||||||
buffer_write_qname(buffer, data->pointer);
|
|
||||||
|
|
||||||
int size = buffer_get_index(buffer) - pos - 2;
|
|
||||||
buffer_set_uint16_t(buffer, (uint16_t)size, pos);
|
|
||||||
}
|
|
||||||
|
|
||||||
static void write_mx_record(PacketBuffer* buffer, MXRecord* data) {
|
|
||||||
int pos = buffer_get_index(buffer);
|
|
||||||
buffer_write_short(buffer, 0);
|
|
||||||
|
|
||||||
buffer_write_short(buffer, data->priority);
|
|
||||||
buffer_write_qname(buffer, data->host);
|
|
||||||
|
|
||||||
int size = buffer_get_index(buffer) - pos - 2;
|
|
||||||
buffer_set_uint16_t(buffer, (uint16_t)size, pos);
|
|
||||||
}
|
|
||||||
|
|
||||||
static void write_txt_record(PacketBuffer* buffer, TXTRecord* data) {
|
|
||||||
int pos = buffer_get_index(buffer);
|
|
||||||
buffer_write_short(buffer, 0);
|
|
||||||
|
|
||||||
if(data->len == 0) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
for(uint8_t i = 0; i < data->len; i++) {
|
|
||||||
buffer_write_string(buffer, data->text[i]);
|
|
||||||
}
|
|
||||||
|
|
||||||
int size = buffer_get_index(buffer) - pos - 2;
|
|
||||||
buffer_set_uint16_t(buffer, (uint16_t)size, pos);
|
|
||||||
}
|
|
||||||
|
|
||||||
static void write_aaaa_record(PacketBuffer* buffer, AAAARecord* data) {
|
|
||||||
buffer_write_short(buffer, 16);
|
|
||||||
|
|
||||||
for (int i = 0; i < 16; i++) {
|
|
||||||
buffer_write(buffer, data->addr[i]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
static void write_srv_record(PacketBuffer* buffer, SRVRecord* data) {
|
|
||||||
int pos = buffer_get_index(buffer);
|
|
||||||
buffer_write_short(buffer, 0);
|
|
||||||
|
|
||||||
buffer_write_short(buffer, data->priority);
|
|
||||||
buffer_write_short(buffer, data->weight);
|
|
||||||
buffer_write_short(buffer, data->port);
|
|
||||||
buffer_write_qname(buffer, data->target);
|
|
||||||
|
|
||||||
int size = buffer_get_index(buffer) - pos - 2;
|
|
||||||
buffer_set_uint16_t(buffer, (uint16_t)size, pos);
|
|
||||||
}
|
|
||||||
|
|
||||||
static void write_caa_record(PacketBuffer* buffer, CAARecord* data) {
|
|
||||||
int pos = buffer_get_index(buffer);
|
|
||||||
buffer_write_short(buffer, 0);
|
|
||||||
buffer_write(buffer, data->flags);
|
|
||||||
buffer_write(buffer, data->length);
|
|
||||||
buffer_write_n(buffer, data->tag + 1, data->tag[0]);
|
|
||||||
buffer_write_n(buffer, data->value + 1, data->value[0]);
|
|
||||||
|
|
||||||
int size = buffer_get_index(buffer) - pos - 2;
|
|
||||||
buffer_set_uint16_t(buffer, (uint16_t)size, pos);
|
|
||||||
}
|
|
||||||
|
|
||||||
static void write_string(TXTRecord* record, uint8_t* capacity, const char* string, uint8_t len) {
|
|
||||||
if (len < 1) return;
|
|
||||||
if (record->len >= *capacity) {
|
|
||||||
if (*capacity >= 128) {
|
|
||||||
*capacity = 255;
|
|
||||||
} else {
|
|
||||||
*capacity *= 2;
|
|
||||||
}
|
|
||||||
record->text = realloc(record->text, sizeof(uint8_t*) * *capacity);
|
|
||||||
}
|
|
||||||
record->text[record->len] = malloc(len + 1);
|
|
||||||
record->text[record->len][0] = len;
|
|
||||||
memcpy(record->text[record->len] + 1, string, len);
|
|
||||||
record->len++;
|
|
||||||
}
|
|
||||||
|
|
||||||
static void free_text(TXTRecord* record) {
|
|
||||||
for (uint8_t i = 0; i < record->len; i++) {
|
|
||||||
free(record->text[i]);
|
|
||||||
}
|
|
||||||
free(record->text);
|
|
||||||
}
|
|
||||||
|
|
||||||
static void write_cmd_record(PacketBuffer* buffer, CMDRecord* data) {
|
|
||||||
FILE* output = popen(data->command, "r");
|
|
||||||
TXTRecord res;
|
|
||||||
|
|
||||||
uint8_t capacity = 1;
|
|
||||||
res.len = 0;
|
|
||||||
res.text = malloc(capacity * sizeof(uint8_t*));
|
|
||||||
|
|
||||||
if (output == NULL) {
|
|
||||||
write_string(&res, &capacity, "Failed to execute command", 25);
|
|
||||||
write_txt_record(buffer, &res);
|
|
||||||
free_text(&res);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
char in[255];
|
|
||||||
char c;
|
|
||||||
int i = 0;
|
|
||||||
|
|
||||||
while (1) {
|
|
||||||
if (res.len >= 255) break;
|
|
||||||
|
|
||||||
c = getc(output);
|
|
||||||
if (c == EOF || c == '\0') {
|
|
||||||
write_string(&res, &capacity, in, i);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
in[i] = c;
|
|
||||||
i++;
|
|
||||||
|
|
||||||
if (i == 255) {
|
|
||||||
write_string(&res, &capacity, in, i);
|
|
||||||
i = 0;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
write_txt_record(buffer, &res);
|
|
||||||
free_text(&res);
|
|
||||||
|
|
||||||
pclose(output);
|
|
||||||
}
|
|
||||||
|
|
||||||
static void write_ar_record(PacketBuffer* buffer) {
|
|
||||||
srand(time(NULL));
|
|
||||||
ARecord res;
|
|
||||||
|
|
||||||
for (int i = 0; i < 4; i++) {
|
|
||||||
res.addr[i] = (uint8_t) (rand() * 255);
|
|
||||||
}
|
|
||||||
|
|
||||||
write_a_record(buffer, &res);
|
|
||||||
}
|
|
||||||
|
|
||||||
static void write_aaaar_record(PacketBuffer* buffer) {
|
|
||||||
srand(time(NULL));
|
|
||||||
|
|
||||||
AAAARecord res;
|
|
||||||
for (int i = 0; i < 16; i++) {
|
|
||||||
res.addr[i] = (uint8_t) (rand() * 255);
|
|
||||||
}
|
|
||||||
|
|
||||||
write_aaaa_record(buffer, &res);
|
|
||||||
}
|
|
||||||
|
|
||||||
static void write_record_header(PacketBuffer* buffer, Record* record) {
|
|
||||||
buffer_write_qname(buffer, record->domain);
|
|
||||||
uint16_t id = record_to_id(record->type);
|
|
||||||
buffer_write_short(buffer, id);
|
|
||||||
buffer_write_short(buffer, record->cls);
|
|
||||||
buffer_write_int(buffer, record->ttl);
|
|
||||||
}
|
|
||||||
|
|
||||||
void write_record(PacketBuffer* buffer, Record* record) {
|
|
||||||
switch(record->type) {
|
|
||||||
case A:
|
|
||||||
write_record_header(buffer, record);
|
|
||||||
write_a_record(buffer, &record->data.a);
|
|
||||||
break;
|
|
||||||
case NS:
|
|
||||||
write_record_header(buffer, record);
|
|
||||||
write_ns_record(buffer, &record->data.ns);
|
|
||||||
break;
|
|
||||||
case CNAME:
|
|
||||||
write_record_header(buffer, record);
|
|
||||||
write_cname_record(buffer, &record->data.cname);
|
|
||||||
break;
|
|
||||||
case SOA:
|
|
||||||
write_record_header(buffer, record);
|
|
||||||
write_soa_record(buffer, &record->data.soa);
|
|
||||||
break;
|
|
||||||
case PTR:
|
|
||||||
write_record_header(buffer, record);
|
|
||||||
write_ptr_record(buffer, &record->data.ptr);
|
|
||||||
break;
|
|
||||||
case MX:
|
|
||||||
write_record_header(buffer, record);
|
|
||||||
write_mx_record(buffer, &record->data.mx);
|
|
||||||
break;
|
|
||||||
case TXT:
|
|
||||||
write_record_header(buffer, record);
|
|
||||||
write_txt_record(buffer, &record->data.txt);
|
|
||||||
break;
|
|
||||||
case AAAA:
|
|
||||||
write_record_header(buffer, record);
|
|
||||||
write_aaaa_record(buffer, &record->data.aaaa);
|
|
||||||
break;
|
|
||||||
case SRV:
|
|
||||||
write_record_header(buffer, record);
|
|
||||||
write_srv_record(buffer, &record->data.srv);
|
|
||||||
break;
|
|
||||||
case CAA:
|
|
||||||
write_record_header(buffer, record);
|
|
||||||
write_caa_record(buffer, &record->data.caa);
|
|
||||||
break;
|
|
||||||
case CMD:
|
|
||||||
write_record_header(buffer, record);
|
|
||||||
write_cmd_record(buffer, &record->data.cmd);
|
|
||||||
break;
|
|
||||||
case AR:
|
|
||||||
write_record_header(buffer, record);
|
|
||||||
write_ar_record(buffer);
|
|
||||||
break;
|
|
||||||
case AAAAR:
|
|
||||||
write_record_header(buffer, record);
|
|
||||||
write_aaaar_record(buffer);
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
INIT_LOG_BUFFER(log)
|
|
||||||
LOGONLY(print_record(record, log);)
|
|
||||||
TRACE("Writing record: %s", log);
|
|
||||||
}
|
|
||||||
|
|
||||||
void free_record(Record* record) {
|
|
||||||
free(record->domain);
|
|
||||||
switch (record->type) {
|
|
||||||
case NS:
|
|
||||||
free(record->data.ns.host);
|
|
||||||
break;
|
|
||||||
case CNAME:
|
|
||||||
free(record->data.cname.host);
|
|
||||||
break;
|
|
||||||
case SOA:
|
|
||||||
free(record->data.soa.mname);
|
|
||||||
free(record->data.soa.nname);
|
|
||||||
break;
|
|
||||||
case PTR:
|
|
||||||
free(record->data.ptr.pointer);
|
|
||||||
break;
|
|
||||||
case MX:
|
|
||||||
free(record->data.mx.host);
|
|
||||||
break;
|
|
||||||
case TXT:
|
|
||||||
for (uint8_t i = 0; i < record->data.txt.len; i++) {
|
|
||||||
free(record->data.txt.text[i]);
|
|
||||||
}
|
|
||||||
free(record->data.txt.text);
|
|
||||||
break;
|
|
||||||
case SRV:
|
|
||||||
free(record->data.srv.target);
|
|
||||||
break;
|
|
||||||
case CAA:
|
|
||||||
free(record->data.caa.value);
|
|
||||||
free(record->data.caa.tag);
|
|
||||||
break;
|
|
||||||
case CMD:
|
|
||||||
free(record->data.cmd.command);
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void print_record(Record* record, char* buffer) {
|
|
||||||
INIT_LOG_BOUNDS
|
|
||||||
switch(record->type) {
|
|
||||||
case UNKOWN:
|
|
||||||
APPEND(buffer, "UNKOWN");
|
|
||||||
break;
|
|
||||||
case A:
|
|
||||||
APPEND(buffer, "A (%hhu.%hhu.%hhu.%hhu)",
|
|
||||||
record->data.a.addr[0],
|
|
||||||
record->data.a.addr[1],
|
|
||||||
record->data.a.addr[2],
|
|
||||||
record->data.a.addr[3]
|
|
||||||
);
|
|
||||||
break;
|
|
||||||
case NS:
|
|
||||||
APPEND(buffer, "NS (%.*s)",
|
|
||||||
record->data.ns.host[0],
|
|
||||||
record->data.ns.host + 1
|
|
||||||
);
|
|
||||||
break;
|
|
||||||
case CNAME:
|
|
||||||
APPEND(buffer, "CNAME (%.*s)",
|
|
||||||
record->data.cname.host[0],
|
|
||||||
record->data.cname.host + 1
|
|
||||||
);
|
|
||||||
break;
|
|
||||||
case SOA:
|
|
||||||
APPEND(buffer, "SOA (%.*s %.*s %u %u %u %u %u)",
|
|
||||||
record->data.soa.mname[0],
|
|
||||||
record->data.soa.mname + 1,
|
|
||||||
record->data.soa.nname[0],
|
|
||||||
record->data.soa.nname + 1,
|
|
||||||
record->data.soa.serial,
|
|
||||||
record->data.soa.refresh,
|
|
||||||
record->data.soa.retry,
|
|
||||||
record->data.soa.expire,
|
|
||||||
record->data.soa.minimum
|
|
||||||
);
|
|
||||||
break;
|
|
||||||
case PTR:
|
|
||||||
APPEND(buffer, "PTR (%.*s)",
|
|
||||||
record->data.ptr.pointer[0],
|
|
||||||
record->data.ptr.pointer + 1
|
|
||||||
);
|
|
||||||
break;
|
|
||||||
case MX:
|
|
||||||
APPEND(buffer, "MX (%.*s %hu)",
|
|
||||||
record->data.mx.host[0],
|
|
||||||
record->data.mx.host + 1,
|
|
||||||
record->data.mx.priority
|
|
||||||
);
|
|
||||||
break;
|
|
||||||
case TXT:
|
|
||||||
APPEND(buffer, "TXT (");
|
|
||||||
for(uint8_t i = 0; i < record->data.txt.len; i++) {
|
|
||||||
APPEND(buffer, "\"%.*s\"",
|
|
||||||
record->data.txt.text[i][0],
|
|
||||||
record->data.txt.text[i] + 1
|
|
||||||
);
|
|
||||||
}
|
|
||||||
APPEND(buffer, ")");
|
|
||||||
break;
|
|
||||||
case AAAA:
|
|
||||||
APPEND(buffer, "AAAA (");
|
|
||||||
for(int i = 0; i < 8; i++) {
|
|
||||||
APPEND(buffer, "%02hhx%02hhx:",
|
|
||||||
record->data.a.addr[i*2 + 0],
|
|
||||||
record->data.a.addr[i*2 + 1]
|
|
||||||
);
|
|
||||||
}
|
|
||||||
APPEND(buffer, ":)");
|
|
||||||
break;
|
|
||||||
case SRV:
|
|
||||||
APPEND(buffer, "SRV (%hu %hu %hu %.*s)",
|
|
||||||
record->data.srv.priority,
|
|
||||||
record->data.srv.weight,
|
|
||||||
record->data.srv.port,
|
|
||||||
record->data.srv.target[0],
|
|
||||||
record->data.srv.target + 1
|
|
||||||
);
|
|
||||||
break;
|
|
||||||
case CAA:
|
|
||||||
APPEND(buffer, "CAA (%hhu %.*s %.*s)",
|
|
||||||
record->data.caa.flags,
|
|
||||||
record->data.caa.tag[0],
|
|
||||||
record->data.caa.tag + 1,
|
|
||||||
record->data.caa.value[0],
|
|
||||||
record->data.caa.value + 1
|
|
||||||
);
|
|
||||||
break;
|
|
||||||
case CMD:
|
|
||||||
APPEND(buffer, "CMD (%s)",
|
|
||||||
record->data.cmd.command
|
|
||||||
);
|
|
||||||
break;
|
|
||||||
case AR:
|
|
||||||
APPEND(buffer, "AR");
|
|
||||||
break;
|
|
||||||
case AAAAR:
|
|
||||||
APPEND(buffer, "AAAAR");
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,110 +0,0 @@
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include "buffer.h"
|
|
||||||
|
|
||||||
#include <string.h>
|
|
||||||
#include <stdbool.h>
|
|
||||||
|
|
||||||
typedef enum {
|
|
||||||
UNKOWN,
|
|
||||||
A, // 1
|
|
||||||
NS, // 2
|
|
||||||
CNAME, // 5
|
|
||||||
SOA, // 6
|
|
||||||
PTR, // 12
|
|
||||||
MX, // 15
|
|
||||||
TXT, // 16
|
|
||||||
AAAA, // 28
|
|
||||||
SRV, // 33
|
|
||||||
CAA, // 257
|
|
||||||
CMD, // 1000
|
|
||||||
AR, // 1001
|
|
||||||
AAAAR // 1002
|
|
||||||
} RecordType;
|
|
||||||
|
|
||||||
uint16_t record_to_id(RecordType type);
|
|
||||||
void record_from_id(uint16_t i, RecordType* type);
|
|
||||||
|
|
||||||
typedef struct {
|
|
||||||
uint8_t addr[4];
|
|
||||||
} ARecord;
|
|
||||||
|
|
||||||
typedef struct {
|
|
||||||
uint8_t* host;
|
|
||||||
} NSRecord;
|
|
||||||
|
|
||||||
typedef struct {
|
|
||||||
uint8_t* host;
|
|
||||||
} CNAMERecord;
|
|
||||||
|
|
||||||
typedef struct {
|
|
||||||
uint8_t* mname;
|
|
||||||
uint8_t* nname;
|
|
||||||
uint32_t serial;
|
|
||||||
uint32_t refresh;
|
|
||||||
uint32_t retry;
|
|
||||||
uint32_t expire;
|
|
||||||
uint32_t minimum;
|
|
||||||
} SOARecord;
|
|
||||||
|
|
||||||
typedef struct {
|
|
||||||
uint8_t* pointer;
|
|
||||||
} PTRRecord;
|
|
||||||
|
|
||||||
typedef struct {
|
|
||||||
uint16_t priority;
|
|
||||||
uint8_t* host;
|
|
||||||
} MXRecord;
|
|
||||||
|
|
||||||
typedef struct TXTRecord {
|
|
||||||
uint8_t** text;
|
|
||||||
uint8_t len;
|
|
||||||
} TXTRecord;
|
|
||||||
|
|
||||||
typedef struct {
|
|
||||||
uint8_t addr[16];
|
|
||||||
} AAAARecord;
|
|
||||||
|
|
||||||
typedef struct {
|
|
||||||
uint16_t priority;
|
|
||||||
uint16_t weight;
|
|
||||||
uint16_t port;
|
|
||||||
uint8_t* target;
|
|
||||||
} SRVRecord;
|
|
||||||
|
|
||||||
typedef struct {
|
|
||||||
uint8_t flags;
|
|
||||||
uint8_t length;
|
|
||||||
uint8_t* tag;
|
|
||||||
uint8_t* value;
|
|
||||||
} CAARecord;
|
|
||||||
|
|
||||||
typedef struct {
|
|
||||||
char* command;
|
|
||||||
} CMDRecord;
|
|
||||||
|
|
||||||
typedef struct {
|
|
||||||
uint32_t ttl;
|
|
||||||
uint16_t cls;
|
|
||||||
uint16_t len;
|
|
||||||
uint8_t* domain;
|
|
||||||
RecordType type;
|
|
||||||
union data {
|
|
||||||
ARecord a;
|
|
||||||
NSRecord ns;
|
|
||||||
CNAMERecord cname;
|
|
||||||
SOARecord soa;
|
|
||||||
PTRRecord ptr;
|
|
||||||
MXRecord mx;
|
|
||||||
TXTRecord txt;
|
|
||||||
AAAARecord aaaa;
|
|
||||||
SRVRecord srv;
|
|
||||||
CAARecord caa;
|
|
||||||
CMDRecord cmd;
|
|
||||||
} data;
|
|
||||||
} Record;
|
|
||||||
|
|
||||||
bool read_record(PacketBuffer* buffer, Record* record);
|
|
||||||
void write_record(PacketBuffer* buffer, Record* record);
|
|
||||||
void free_record(Record* record);
|
|
||||||
void print_record(Record* record, char* buffer);
|
|
|
@ -1,257 +0,0 @@
|
||||||
#include <errno.h>
|
|
||||||
#include <netinet/in.h>
|
|
||||||
#include <stdio.h>
|
|
||||||
#include <stdlib.h>
|
|
||||||
#include <string.h>
|
|
||||||
#include <unistd.h>
|
|
||||||
|
|
||||||
#include "addr.h"
|
|
||||||
#include "../io/log.h"
|
|
||||||
|
|
||||||
void create_ip_addr(uint8_t* domain, IpAddr* addr) {
|
|
||||||
addr->type = V4;
|
|
||||||
memcpy(&addr->data.v4, domain, 4);
|
|
||||||
}
|
|
||||||
|
|
||||||
void create_ip_addr6(uint8_t* domain, IpAddr* addr) {
|
|
||||||
addr->type = V6;
|
|
||||||
memcpy(&addr->data.v6, domain, 16);
|
|
||||||
}
|
|
||||||
|
|
||||||
void ip_addr_any(IpAddr* addr) {
|
|
||||||
addr->type = V4;
|
|
||||||
addr->data.v4.s_addr = htonl(INADDR_ANY);
|
|
||||||
}
|
|
||||||
|
|
||||||
void ip_addr_any6(IpAddr* addr) {
|
|
||||||
addr->type = V6;
|
|
||||||
addr->data.v6 = in6addr_any;
|
|
||||||
}
|
|
||||||
|
|
||||||
static struct sockaddr_in create_socket_addr_v4(IpAddr addr, uint16_t port) {
|
|
||||||
struct sockaddr_in socketaddr;
|
|
||||||
memset(&socketaddr, 0, sizeof(socketaddr));
|
|
||||||
socketaddr.sin_family = AF_INET;
|
|
||||||
socketaddr.sin_port = htons(port);
|
|
||||||
socketaddr.sin_addr = addr.data.v4;
|
|
||||||
return socketaddr;
|
|
||||||
}
|
|
||||||
|
|
||||||
static struct sockaddr_in6 create_socket_addr_v6(IpAddr addr, uint16_t port) {
|
|
||||||
struct sockaddr_in6 socketaddr;
|
|
||||||
memset(&socketaddr, 0, sizeof(socketaddr));
|
|
||||||
socketaddr.sin6_family = AF_INET6;
|
|
||||||
socketaddr.sin6_port = htons(port);
|
|
||||||
socketaddr.sin6_addr = addr.data.v6;
|
|
||||||
return socketaddr;
|
|
||||||
}
|
|
||||||
|
|
||||||
static size_t get_addr_len(AddrType type) {
|
|
||||||
if (type == V4) {
|
|
||||||
return sizeof(struct sockaddr_in);
|
|
||||||
} else if (type == V6) {
|
|
||||||
return sizeof(struct sockaddr_in6);
|
|
||||||
} else {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void create_socket_addr(uint16_t port, IpAddr addr, SocketAddr* socket) {
|
|
||||||
socket->type = addr.type;
|
|
||||||
if (addr.type == V4) {
|
|
||||||
socket->data.v4 = create_socket_addr_v4(addr, port);
|
|
||||||
} else if(addr.type == V6) {
|
|
||||||
socket->data.v6 = create_socket_addr_v6(addr, port);
|
|
||||||
} else {
|
|
||||||
ERROR("Tried to create socketaddr with invalid protocol type");
|
|
||||||
exit(EXIT_FAILURE);
|
|
||||||
}
|
|
||||||
socket->len = get_addr_len(addr.type);
|
|
||||||
}
|
|
||||||
|
|
||||||
void print_socket_addr(SocketAddr* addr, char* buffer) {
|
|
||||||
INIT_LOG_BOUNDS
|
|
||||||
if(addr->type == V4) {
|
|
||||||
APPEND(buffer, "%hhu.%hhu.%hhu.%hhu:%hu",
|
|
||||||
(uint8_t) ((uint32_t)addr->data.v4.sin_addr.s_addr >> 24),
|
|
||||||
(uint8_t) ((uint32_t)addr->data.v4.sin_addr.s_addr >> 16),
|
|
||||||
(uint8_t) ((uint32_t)addr->data.v4.sin_addr.s_addr >> 8),
|
|
||||||
(uint8_t) ((uint32_t)addr->data.v4.sin_addr.s_addr),
|
|
||||||
ntohs(addr->data.v4.sin_port)
|
|
||||||
);
|
|
||||||
} else {
|
|
||||||
uint8_t* a = (uint8_t*) &addr->data.v6.sin6_addr;
|
|
||||||
for(int i = 0; i < 8; i++) {
|
|
||||||
APPEND(buffer, "%02hhx%02hhx:",
|
|
||||||
a[i*2 + 0],
|
|
||||||
a[i*2 + 1]
|
|
||||||
);
|
|
||||||
}
|
|
||||||
APPEND(buffer, ":[%hu]", ntohs(addr->data.v6.sin6_port));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#define ADDR_DOMAIN(addr, var) \
|
|
||||||
struct sockaddr* var; \
|
|
||||||
if (addr->type == V4) { \
|
|
||||||
var = (struct sockaddr*) &addr->data.v4; \
|
|
||||||
} else if (addr->type == V6) { \
|
|
||||||
var = (struct sockaddr*) &addr->data.v6; \
|
|
||||||
} else { \
|
|
||||||
return -1; \
|
|
||||||
}
|
|
||||||
|
|
||||||
#define ADDR_AFNET(type, var) \
|
|
||||||
int var; \
|
|
||||||
if (type == V4) { \
|
|
||||||
var = AF_INET; \
|
|
||||||
} else if (type == V6) { \
|
|
||||||
var = AF_INET6; \
|
|
||||||
} else { \
|
|
||||||
return -1; \
|
|
||||||
}
|
|
||||||
|
|
||||||
int32_t create_udp_socket(AddrType type, UdpSocket* sock) {
|
|
||||||
ADDR_AFNET(type, __domain)
|
|
||||||
sock->type = type;
|
|
||||||
sock->sockfd = socket(__domain, SOCK_DGRAM, 0);
|
|
||||||
return sock->sockfd;
|
|
||||||
}
|
|
||||||
|
|
||||||
int32_t bind_udp_socket(SocketAddr* addr, UdpSocket* sock) {
|
|
||||||
if (addr->type == V6) {
|
|
||||||
int v6OnlyEnabled = 0;
|
|
||||||
int32_t res = setsockopt(
|
|
||||||
sock->sockfd,
|
|
||||||
IPPROTO_IPV6,
|
|
||||||
IPV6_V6ONLY,
|
|
||||||
&v6OnlyEnabled,
|
|
||||||
sizeof(v6OnlyEnabled)
|
|
||||||
);
|
|
||||||
if (res < 0) return res;
|
|
||||||
}
|
|
||||||
ADDR_DOMAIN(addr, __addr)
|
|
||||||
return bind(sock->sockfd, __addr, addr->len);
|
|
||||||
}
|
|
||||||
|
|
||||||
int32_t read_udp_socket(UdpSocket* socket, void* buffer, uint16_t len, SocketAddr* clientaddr) {
|
|
||||||
clientaddr->type = socket->type;
|
|
||||||
clientaddr->len = get_addr_len(socket->type);
|
|
||||||
ADDR_DOMAIN(clientaddr, __addr)
|
|
||||||
return recvfrom(
|
|
||||||
socket->sockfd,
|
|
||||||
buffer,
|
|
||||||
(size_t) len,
|
|
||||||
MSG_WAITALL,
|
|
||||||
__addr,
|
|
||||||
(uint32_t*) &clientaddr->len
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
int32_t write_udp_socket(UdpSocket* socket, void* buffer, uint16_t len, SocketAddr* clientaddr) {
|
|
||||||
ADDR_DOMAIN(clientaddr, __addr)
|
|
||||||
return sendto(
|
|
||||||
socket->sockfd,
|
|
||||||
buffer,
|
|
||||||
(size_t) len,
|
|
||||||
MSG_CONFIRM,
|
|
||||||
__addr,
|
|
||||||
(uint32_t) clientaddr->len
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
static int get_socket_error(int fd) {
|
|
||||||
int err = 1;
|
|
||||||
socklen_t len = sizeof err;
|
|
||||||
if (-1 == getsockopt(fd, SOL_SOCKET, SO_ERROR, (char *)&err, &len))
|
|
||||||
return 0;
|
|
||||||
if (err)
|
|
||||||
errno = err;
|
|
||||||
return err;
|
|
||||||
}
|
|
||||||
|
|
||||||
static int close_socket(int fd) {
|
|
||||||
if (fd >= 0) {
|
|
||||||
get_socket_error(fd); // first clear any errors, which can cause close to fail
|
|
||||||
if (shutdown(fd, SHUT_RDWR) < 0) // secondly, terminate the 'reliable' delivery
|
|
||||||
if (errno != ENOTCONN && errno != EINVAL) // SGI causes EINVAL
|
|
||||||
perror("shutdown");
|
|
||||||
if (close(fd) < 0) // finally call close()
|
|
||||||
perror("close");
|
|
||||||
}
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
int32_t close_udp_socket(UdpSocket* socket) {
|
|
||||||
return close_socket(socket->sockfd);
|
|
||||||
}
|
|
||||||
|
|
||||||
int32_t create_tcp_socket(AddrType type, TcpSocket* sock) {
|
|
||||||
ADDR_AFNET(type, __domain)
|
|
||||||
sock->type = type;
|
|
||||||
sock->sockfd = socket(__domain, SOCK_STREAM, 0);
|
|
||||||
return sock->sockfd;
|
|
||||||
}
|
|
||||||
|
|
||||||
int32_t bind_tcp_socket(SocketAddr* addr, TcpSocket* sock) {
|
|
||||||
if (addr->type == V6) {
|
|
||||||
int v6OnlyEnabled = 0;
|
|
||||||
int32_t res = setsockopt(
|
|
||||||
sock->sockfd,
|
|
||||||
IPPROTO_IPV6,
|
|
||||||
IPV6_V6ONLY,
|
|
||||||
&v6OnlyEnabled,
|
|
||||||
sizeof(v6OnlyEnabled)
|
|
||||||
);
|
|
||||||
if (res < 0) return res;
|
|
||||||
}
|
|
||||||
ADDR_DOMAIN(addr, __addr)
|
|
||||||
return bind(sock->sockfd, __addr, addr->len);
|
|
||||||
}
|
|
||||||
|
|
||||||
int32_t listen_tcp_socket(TcpSocket* socket, uint32_t max) {
|
|
||||||
return listen(socket->sockfd, max);
|
|
||||||
}
|
|
||||||
|
|
||||||
int32_t accept_tcp_socket(TcpSocket* socket, TcpStream* stream) {
|
|
||||||
stream->clientaddr.type = socket->type;
|
|
||||||
memset(&stream->clientaddr, 0, sizeof(SocketAddr));
|
|
||||||
SocketAddr* addr = &stream->clientaddr;
|
|
||||||
ADDR_DOMAIN(addr, __addr)
|
|
||||||
stream->streamfd = accept(
|
|
||||||
socket->sockfd,
|
|
||||||
__addr,
|
|
||||||
(uint32_t*) &stream->clientaddr.len
|
|
||||||
);
|
|
||||||
return stream->streamfd;
|
|
||||||
}
|
|
||||||
|
|
||||||
int32_t close_tcp_socket(TcpSocket* socket) {
|
|
||||||
return close_socket(socket->sockfd);
|
|
||||||
}
|
|
||||||
|
|
||||||
int32_t connect_tcp_stream(SocketAddr* servaddr, TcpStream* stream) {
|
|
||||||
TcpSocket socket;
|
|
||||||
int32_t res = create_tcp_socket(servaddr->type, &socket);
|
|
||||||
if (res < 0) return res;
|
|
||||||
stream->clientaddr = *servaddr;
|
|
||||||
stream->streamfd = socket.sockfd;
|
|
||||||
ADDR_DOMAIN(servaddr, __addr)
|
|
||||||
return connect(
|
|
||||||
socket.sockfd,
|
|
||||||
__addr,
|
|
||||||
servaddr->len
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
int32_t read_tcp_stream(TcpStream* stream, void* buffer, uint16_t len) {
|
|
||||||
return recv(stream->streamfd, buffer, len, MSG_WAITALL);
|
|
||||||
}
|
|
||||||
|
|
||||||
int32_t write_tcp_stream(TcpStream* stream, void* buffer, uint16_t len) {
|
|
||||||
return send(stream->streamfd, buffer, len, MSG_NOSIGNAL);
|
|
||||||
}
|
|
||||||
|
|
||||||
int32_t close_tcp_stream(TcpStream* stream) {
|
|
||||||
return close_socket(stream->streamfd);
|
|
||||||
}
|
|
|
@ -1,69 +0,0 @@
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include "../packet/record.h"
|
|
||||||
|
|
||||||
#include <string.h>
|
|
||||||
#include <netinet/in.h>
|
|
||||||
#include <sys/socket.h>
|
|
||||||
#include <sys/types.h>
|
|
||||||
|
|
||||||
typedef enum {
|
|
||||||
V4,
|
|
||||||
V6
|
|
||||||
} AddrType;
|
|
||||||
|
|
||||||
typedef struct {
|
|
||||||
AddrType type;
|
|
||||||
union {
|
|
||||||
struct in_addr v4;
|
|
||||||
struct in6_addr v6;
|
|
||||||
} data;
|
|
||||||
} IpAddr;
|
|
||||||
|
|
||||||
void create_ip_addr(uint8_t* domain, IpAddr* addr);
|
|
||||||
void create_ip_addr6(uint8_t* domain, IpAddr* addr);
|
|
||||||
void ip_addr_any(IpAddr* addr);
|
|
||||||
void ip_addr_any6(IpAddr* addr);
|
|
||||||
|
|
||||||
typedef struct {
|
|
||||||
AddrType type;
|
|
||||||
union {
|
|
||||||
struct sockaddr_in v4;
|
|
||||||
struct sockaddr_in6 v6;
|
|
||||||
} data;
|
|
||||||
size_t len;
|
|
||||||
} SocketAddr;
|
|
||||||
|
|
||||||
void create_socket_addr(uint16_t port, IpAddr addr, SocketAddr* socket);
|
|
||||||
void print_socket_addr(SocketAddr* addr, char* buffer);
|
|
||||||
|
|
||||||
typedef struct {
|
|
||||||
AddrType type;
|
|
||||||
uint32_t sockfd;
|
|
||||||
} UdpSocket;
|
|
||||||
|
|
||||||
int32_t create_udp_socket(AddrType type, UdpSocket* socket);
|
|
||||||
int32_t bind_udp_socket(SocketAddr* addr, UdpSocket* socket);
|
|
||||||
int32_t read_udp_socket(UdpSocket* socket, void* buffer, uint16_t len, SocketAddr* clientaddr);
|
|
||||||
int32_t write_udp_socket(UdpSocket* socket, void* buffer, uint16_t len, SocketAddr* clientaddr);
|
|
||||||
int32_t close_udp_socket(UdpSocket* socket);
|
|
||||||
|
|
||||||
typedef struct {
|
|
||||||
AddrType type;
|
|
||||||
uint32_t sockfd;
|
|
||||||
} TcpSocket;
|
|
||||||
|
|
||||||
typedef struct {
|
|
||||||
SocketAddr clientaddr;
|
|
||||||
uint32_t streamfd;
|
|
||||||
} TcpStream;
|
|
||||||
|
|
||||||
int32_t create_tcp_socket(AddrType type, TcpSocket* socket);
|
|
||||||
int32_t bind_tcp_socket(SocketAddr* addr, TcpSocket* socket);
|
|
||||||
int32_t listen_tcp_socket(TcpSocket* socket, uint32_t max);
|
|
||||||
int32_t accept_tcp_socket(TcpSocket* socket, TcpStream* stream);
|
|
||||||
int32_t close_tcp_socket(TcpSocket* socket);
|
|
||||||
int32_t connect_tcp_stream(SocketAddr* servaddr, TcpStream* stream);
|
|
||||||
int32_t read_tcp_stream(TcpStream* stream, void* buffer, uint16_t len);
|
|
||||||
int32_t write_tcp_stream(TcpStream* stream, void* buffer, uint16_t len);
|
|
||||||
int32_t close_tcp_stream(TcpStream* stream);
|
|
|
@ -1,246 +0,0 @@
|
||||||
#include <netinet/in.h>
|
|
||||||
#include <stdio.h>
|
|
||||||
#include <stdlib.h>
|
|
||||||
#include <string.h>
|
|
||||||
#include <unistd.h>
|
|
||||||
#include <errno.h>
|
|
||||||
|
|
||||||
#include "addr.h"
|
|
||||||
#include "binding.h"
|
|
||||||
#include "../io/log.h"
|
|
||||||
|
|
||||||
static void create_udp_binding(UdpSocket* socket, uint16_t port) {
|
|
||||||
if (create_udp_socket(V6, socket) < 0) {
|
|
||||||
ERROR("Failed to create UDP socket: %s", strerror(errno));
|
|
||||||
exit(EXIT_FAILURE);
|
|
||||||
}
|
|
||||||
|
|
||||||
IpAddr addr;
|
|
||||||
ip_addr_any6(&addr);
|
|
||||||
|
|
||||||
SocketAddr socketaddr;
|
|
||||||
create_socket_addr(port, addr, &socketaddr);
|
|
||||||
|
|
||||||
if (bind_udp_socket(&socketaddr, socket) < 0) {
|
|
||||||
ERROR("Failed to bind UDP socket on port %hu: %s", port, strerror(errno));
|
|
||||||
exit(EXIT_FAILURE);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
static void create_tcp_binding(TcpSocket* socket, uint16_t port) {
|
|
||||||
if (create_tcp_socket(V6, socket) < 0) {
|
|
||||||
ERROR("Failed to create TCP socket: %s", strerror(errno));
|
|
||||||
exit(EXIT_FAILURE);
|
|
||||||
}
|
|
||||||
|
|
||||||
IpAddr addr;
|
|
||||||
ip_addr_any6(&addr);
|
|
||||||
|
|
||||||
SocketAddr socketaddr;
|
|
||||||
create_socket_addr(port, addr, &socketaddr);
|
|
||||||
|
|
||||||
if (bind_tcp_socket(&socketaddr, socket) < 0) {
|
|
||||||
ERROR("Failed to bind TCP socket on port %hu: %s", port, strerror(errno));
|
|
||||||
exit(EXIT_FAILURE);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (listen_tcp_socket(socket, 5) < 0) {
|
|
||||||
ERROR("Failed to listen on TCP socket: %s", strerror(errno));
|
|
||||||
exit(EXIT_FAILURE);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void create_binding(BindingType type, uint16_t port, Binding* binding) {
|
|
||||||
binding->type = type;
|
|
||||||
if (type == UDP) {
|
|
||||||
create_udp_binding(&binding->sock.udp, port);
|
|
||||||
} else if(type == TCP) {
|
|
||||||
create_tcp_binding(&binding->sock.tcp, port);
|
|
||||||
} else {
|
|
||||||
exit(EXIT_FAILURE);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void free_binding(Binding* binding) {
|
|
||||||
if (binding->type == UDP) {
|
|
||||||
close_udp_socket(&binding->sock.udp);
|
|
||||||
} else if(binding->type == TCP) {
|
|
||||||
close_tcp_socket(&binding->sock.tcp);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
bool accept_connection(Binding* binding, Connection* connection) {
|
|
||||||
connection->type = binding->type;
|
|
||||||
|
|
||||||
if(binding->type == UDP) {
|
|
||||||
connection->sock.udp.udp = binding->sock.udp;
|
|
||||||
memset(&connection->sock.udp.clientaddr, 0, sizeof(SocketAddr));
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (accept_tcp_socket(&binding->sock.tcp, &connection->sock.tcp) < 0) {
|
|
||||||
ERROR("Failed to accept TCP connection: %s", strerror(errno));
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
static void read_to_packet(uint8_t* buf, uint16_t len, Packet* packet) {
|
|
||||||
PacketBuffer* pkbuffer = buffer_create(len);
|
|
||||||
for (int i = 0; i < len; i++) {
|
|
||||||
buffer_write(pkbuffer, buf[i]);
|
|
||||||
}
|
|
||||||
buffer_seek(pkbuffer, 0);
|
|
||||||
read_packet(pkbuffer, packet);
|
|
||||||
buffer_free(pkbuffer);
|
|
||||||
}
|
|
||||||
|
|
||||||
static bool read_udp(Connection* connection, Packet* packet) {
|
|
||||||
uint8_t buffer[512];
|
|
||||||
int32_t n = read_udp_socket(
|
|
||||||
&connection->sock.udp.udp,
|
|
||||||
buffer,
|
|
||||||
512,
|
|
||||||
&connection->sock.udp.clientaddr
|
|
||||||
);
|
|
||||||
if (n < 0) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
read_to_packet(buffer, n, packet);
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
static bool read_tcp(Connection* connection, Packet* packet) {
|
|
||||||
uint16_t len;
|
|
||||||
if ( read_tcp_stream(
|
|
||||||
&connection->sock.tcp,
|
|
||||||
&len,
|
|
||||||
sizeof(uint16_t)
|
|
||||||
) < 2) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
len = ntohs(len);
|
|
||||||
|
|
||||||
uint8_t buffer[len];
|
|
||||||
if (read_tcp_stream(
|
|
||||||
&connection->sock.tcp,
|
|
||||||
buffer,
|
|
||||||
len
|
|
||||||
) < len) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
read_to_packet(buffer, len, packet);
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool read_connection(Connection* connection, Packet* packet) {
|
|
||||||
if (connection->type == UDP) {
|
|
||||||
return read_udp(connection, packet);
|
|
||||||
} else if (connection->type == TCP) {
|
|
||||||
return read_tcp(connection, packet);
|
|
||||||
}
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
static bool write_udp(Connection* connection, uint8_t* buf, uint16_t len) {
|
|
||||||
if (len > 512) {
|
|
||||||
buf[2] = buf[2] | 0x03;
|
|
||||||
len = 512;
|
|
||||||
}
|
|
||||||
return write_udp_socket(
|
|
||||||
&connection->sock.udp.udp,
|
|
||||||
buf,
|
|
||||||
len,
|
|
||||||
&connection->sock.udp.clientaddr
|
|
||||||
) == len;
|
|
||||||
}
|
|
||||||
|
|
||||||
static bool write_tcp(Connection* connection, uint8_t* buf, uint16_t len) {
|
|
||||||
uint16_t net_len = htons(len);
|
|
||||||
if (write_tcp_stream(
|
|
||||||
&connection->sock.tcp,
|
|
||||||
&net_len,
|
|
||||||
sizeof(uint16_t)
|
|
||||||
) < 0) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (write_tcp_stream(
|
|
||||||
&connection->sock.tcp,
|
|
||||||
buf,
|
|
||||||
len
|
|
||||||
) < 0) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool write_connection(Connection* connection, Packet* packet) {
|
|
||||||
PacketBuffer* pkbuffer = buffer_create(64);
|
|
||||||
write_packet(pkbuffer, packet);
|
|
||||||
uint16_t len = buffer_get_size(pkbuffer);
|
|
||||||
uint8_t* buffer = buffer_get_ptr(pkbuffer);
|
|
||||||
bool success = false;
|
|
||||||
if(connection->type == UDP) {
|
|
||||||
success = write_udp(connection, buffer, len);
|
|
||||||
} else if(connection->type == TCP) {
|
|
||||||
success = write_tcp(connection, buffer, len);
|
|
||||||
};
|
|
||||||
buffer_free(pkbuffer);
|
|
||||||
return success;
|
|
||||||
}
|
|
||||||
|
|
||||||
void free_connection(Connection* connection) {
|
|
||||||
if (connection->type == TCP) {
|
|
||||||
close_tcp_stream(&connection->sock.tcp);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
static bool create_udp_request(SocketAddr* addr, Connection* request) {
|
|
||||||
if ( create_udp_socket(addr->type, &request->sock.udp.udp) < 0) {
|
|
||||||
ERROR("Failed to connect to UDP socket: %s", strerror(errno));
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
request->sock.udp.clientaddr = *addr;
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
static bool create_tcp_request(SocketAddr* addr, Connection* request) {
|
|
||||||
if( connect_tcp_stream(addr, &request->sock.tcp) < 0) {
|
|
||||||
ERROR("Failed to connect to TCP socket: %s", strerror(errno));
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool create_request(BindingType type, SocketAddr* addr, Connection* request) {
|
|
||||||
request->type = type;
|
|
||||||
if (type == UDP) {
|
|
||||||
return create_udp_request(addr, request);
|
|
||||||
} else if (type == TCP) {
|
|
||||||
return create_tcp_request(addr, request);
|
|
||||||
} else {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
bool request_packet(Connection* request, Packet* in, Packet* out) {
|
|
||||||
if (!write_connection(request, in)) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
if (!read_connection(request, out)) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
void free_request(Connection* connection) {
|
|
||||||
if (connection->type == UDP) {
|
|
||||||
close_udp_socket(&connection->sock.udp.udp);
|
|
||||||
} else if (connection->type == TCP) {
|
|
||||||
close_tcp_stream(&connection->sock.tcp);
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,42 +0,0 @@
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include "../packet/packet.h"
|
|
||||||
#include "addr.h"
|
|
||||||
|
|
||||||
#include <netinet/in.h>
|
|
||||||
|
|
||||||
typedef enum {
|
|
||||||
UDP,
|
|
||||||
TCP
|
|
||||||
} BindingType;
|
|
||||||
|
|
||||||
typedef struct {
|
|
||||||
BindingType type;
|
|
||||||
union {
|
|
||||||
UdpSocket udp;
|
|
||||||
TcpSocket tcp;
|
|
||||||
} sock;
|
|
||||||
} Binding;
|
|
||||||
|
|
||||||
void create_binding(BindingType type, uint16_t port, Binding* binding);
|
|
||||||
void free_binding(Binding* binding);
|
|
||||||
|
|
||||||
typedef struct {
|
|
||||||
BindingType type;
|
|
||||||
union {
|
|
||||||
struct {
|
|
||||||
UdpSocket udp;
|
|
||||||
SocketAddr clientaddr;
|
|
||||||
} udp;
|
|
||||||
TcpStream tcp;
|
|
||||||
} sock;
|
|
||||||
} Connection;
|
|
||||||
|
|
||||||
bool accept_connection(Binding* binding, Connection* connection);
|
|
||||||
bool read_connection(Connection* connection, Packet* packet);
|
|
||||||
bool write_connection(Connection* connection, Packet* packet);
|
|
||||||
void free_connection(Connection* connection);
|
|
||||||
|
|
||||||
bool create_request(BindingType type, SocketAddr* addr, Connection* request);
|
|
||||||
bool request_packet(Connection* request, Packet* in, Packet* out);
|
|
||||||
void free_request(Connection* connection);
|
|
|
@ -1,176 +0,0 @@
|
||||||
#include <errno.h>
|
|
||||||
#include <stdio.h>
|
|
||||||
#include <stdlib.h>
|
|
||||||
|
|
||||||
#include "resolver.h"
|
|
||||||
#include "addr.h"
|
|
||||||
#include "binding.h"
|
|
||||||
#include "../io/log.h"
|
|
||||||
|
|
||||||
static bool lookup(
|
|
||||||
Question* question,
|
|
||||||
Packet* response,
|
|
||||||
BindingType type,
|
|
||||||
SocketAddr addr
|
|
||||||
) {
|
|
||||||
INIT_LOG_BUFFER(log)
|
|
||||||
LOGONLY(print_socket_addr(&addr, log);)
|
|
||||||
TRACE("Attempting lookup on fallback dns %s", log);
|
|
||||||
|
|
||||||
Connection request;
|
|
||||||
if (!create_request(type, &addr, &request)) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
Packet req;
|
|
||||||
memset(&req, 0, sizeof(Packet));
|
|
||||||
req.header.id = response->header.id;
|
|
||||||
req.header.opcode = response->header.opcode;
|
|
||||||
req.header.questions = 1;
|
|
||||||
req.header.recursion_desired = true;
|
|
||||||
req.questions = malloc(sizeof(Question));
|
|
||||||
req.questions[0] = *question;
|
|
||||||
|
|
||||||
if (!request_packet(&request, &req, response)) {
|
|
||||||
free_request(&request);
|
|
||||||
free(req.questions);
|
|
||||||
ERROR("Failed to request fallback dns: %s", strerror(errno));
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
free_request(&request);
|
|
||||||
free(req.questions);
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
static bool search(Question* question, Packet* result, BindingType type, const RecordMap* map) {
|
|
||||||
IpAddr addr;
|
|
||||||
uint8_t ip[4] = {9,9,9,9};
|
|
||||||
create_ip_addr(ip, &addr);
|
|
||||||
|
|
||||||
uint16_t port = 53;
|
|
||||||
SocketAddr saddr;
|
|
||||||
create_socket_addr(port, addr, &saddr);
|
|
||||||
|
|
||||||
while(1) {
|
|
||||||
if (record_map_get(map, question, result)) {
|
|
||||||
TRACE("Found answer in user defined config");
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!lookup(question, result, type, saddr)) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (result->header.answers > 0 && result->header.rescode == NOERROR) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (result->header.rescode == NXDOMAIN) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (get_resolved_ns(result, question->domain, &addr)) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
Question new_question;
|
|
||||||
if (!get_unresoled_ns(result, question->domain, &new_question)) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
Packet recurse;
|
|
||||||
if (!search(&new_question, &recurse, type, map)) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
free_question(&new_question);
|
|
||||||
|
|
||||||
IpAddr random;
|
|
||||||
if (!get_random_a(&recurse, &random)) {
|
|
||||||
free_packet(&recurse);
|
|
||||||
return true;
|
|
||||||
} else {
|
|
||||||
free_packet(&recurse);
|
|
||||||
addr = random;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
static void push_records(Record* from, uint8_t from_len, Record** to, uint8_t to_len) {
|
|
||||||
if(from_len < 1) return;
|
|
||||||
*to = realloc(*to, sizeof(Record) * (from_len + to_len));
|
|
||||||
memcpy(*to + to_len, from, from_len * sizeof(Record));
|
|
||||||
}
|
|
||||||
|
|
||||||
static void push_questions(Question* from, uint8_t from_len, Question** to, uint8_t to_len) {
|
|
||||||
if(from_len < 1) return;
|
|
||||||
*to = realloc(*to, sizeof(Question) * (from_len + to_len));
|
|
||||||
memcpy(*to + to_len, from, from_len * sizeof(Question));
|
|
||||||
}
|
|
||||||
|
|
||||||
void handle_query(const Packet* request, Packet* response, BindingType type, const RecordMap* map) {
|
|
||||||
memset(response, 0, sizeof(Packet));
|
|
||||||
response->header.id = request->header.id;
|
|
||||||
response->header.opcode = request->header.opcode;
|
|
||||||
response->header.recursion_desired = true;
|
|
||||||
response->header.recursion_available = true;
|
|
||||||
response->header.response = true;
|
|
||||||
|
|
||||||
if (request->header.questions < 1) {
|
|
||||||
response->header.response = FORMERR;
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
for (uint16_t i = 0; i < request->header.questions; i++) {
|
|
||||||
Packet result;
|
|
||||||
memset(&result, 0, sizeof(Packet));
|
|
||||||
result.header.id = response->header.id;
|
|
||||||
if (!search(&request->questions[i], &result, type, map)) {
|
|
||||||
response->header.response = SERVFAIL;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
push_questions(
|
|
||||||
result.questions,
|
|
||||||
result.header.questions,
|
|
||||||
&response->questions,
|
|
||||||
response->header.questions
|
|
||||||
);
|
|
||||||
response->header.questions += result.header.questions;
|
|
||||||
|
|
||||||
push_records(
|
|
||||||
result.answers,
|
|
||||||
result.header.answers,
|
|
||||||
&response->answers,
|
|
||||||
response->header.answers
|
|
||||||
);
|
|
||||||
response->header.answers += result.header.answers;
|
|
||||||
|
|
||||||
push_records(
|
|
||||||
result.authorities,
|
|
||||||
result.header.authoritative_entries,
|
|
||||||
&response->authorities,
|
|
||||||
response->header.authoritative_entries
|
|
||||||
);
|
|
||||||
response->header.authoritative_entries += result.header.authoritative_entries;
|
|
||||||
|
|
||||||
push_records(
|
|
||||||
result.resources,
|
|
||||||
result.header.resource_entries,
|
|
||||||
&response->resources,
|
|
||||||
response->header.resource_entries
|
|
||||||
);
|
|
||||||
response->header.resource_entries += result.header.resource_entries;
|
|
||||||
|
|
||||||
if (result.header.z == false) {
|
|
||||||
// Not from cache
|
|
||||||
free(result.questions);
|
|
||||||
free(result.answers);
|
|
||||||
free(result.authorities);
|
|
||||||
free(result.resources);
|
|
||||||
} else {
|
|
||||||
response->header.z = true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,7 +0,0 @@
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include "../packet/packet.h"
|
|
||||||
#include "../io/map.h"
|
|
||||||
#include "binding.h"
|
|
||||||
|
|
||||||
void handle_query(const Packet* request, Packet* response, BindingType type, const RecordMap* map);
|
|
|
@ -1,142 +0,0 @@
|
||||||
#include <errno.h>
|
|
||||||
#include <stdio.h>
|
|
||||||
#include <stdlib.h>
|
|
||||||
#include <unistd.h>
|
|
||||||
#include <sys/wait.h>
|
|
||||||
#include <signal.h>
|
|
||||||
#include <pthread.h>
|
|
||||||
|
|
||||||
#include "addr.h"
|
|
||||||
#include "server.h"
|
|
||||||
#include "resolver.h"
|
|
||||||
#include "../io/log.h"
|
|
||||||
#include "../io/map.h"
|
|
||||||
#include "../io/config.h"
|
|
||||||
|
|
||||||
static pthread_t udp, tcp;
|
|
||||||
static RecordMap map;
|
|
||||||
|
|
||||||
void server_init(uint16_t port, Server* server) {
|
|
||||||
INFO("Server port set to %hu", port);
|
|
||||||
create_binding(UDP, port, &server->udp);
|
|
||||||
create_binding(TCP, port, &server->tcp);
|
|
||||||
|
|
||||||
if (load_config("/etc/wrapper.conf", &map)) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (load_config("config", &map)) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
WARN("No dns config files were found");
|
|
||||||
}
|
|
||||||
|
|
||||||
struct DnsRequest {
|
|
||||||
Connection connection;
|
|
||||||
Packet request;
|
|
||||||
};
|
|
||||||
|
|
||||||
static void* server_respond(void* arg) {
|
|
||||||
struct DnsRequest req = *(struct DnsRequest*) arg;
|
|
||||||
|
|
||||||
INFO("Recieved packet request ID %hu", req.request.header.id);
|
|
||||||
|
|
||||||
Packet response;
|
|
||||||
handle_query(&req.request, &response, req.connection.type, &map);
|
|
||||||
|
|
||||||
if (!write_connection(&req.connection, &response)) {
|
|
||||||
ERROR("Failed to respond to connection ID %hu: %s",
|
|
||||||
req.request.header.id,
|
|
||||||
strerror(errno)
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
free_packet(&req.request);
|
|
||||||
free_connection(&req.connection);
|
|
||||||
|
|
||||||
if (response.header.z == false) {
|
|
||||||
free_packet(&response);
|
|
||||||
} else {
|
|
||||||
// Dont free from config
|
|
||||||
free(response.questions);
|
|
||||||
free(response.answers);
|
|
||||||
free(response.authorities);
|
|
||||||
free(response.resources);
|
|
||||||
}
|
|
||||||
|
|
||||||
return NULL;
|
|
||||||
}
|
|
||||||
|
|
||||||
static void* server_listen(void* arg) {
|
|
||||||
Binding* binding = (Binding*) arg;
|
|
||||||
while(1) {
|
|
||||||
|
|
||||||
Connection connection;
|
|
||||||
if (!accept_connection(binding, &connection)) {
|
|
||||||
ERROR("Failed to accept connection");
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
Packet request;
|
|
||||||
if (!read_connection(&connection, &request)) {
|
|
||||||
ERROR("Failed to read connection");
|
|
||||||
free_connection(&connection);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
struct DnsRequest req;
|
|
||||||
req.connection = connection;
|
|
||||||
req.request = request;
|
|
||||||
|
|
||||||
pthread_t thread;
|
|
||||||
if(pthread_create(&thread, NULL, &server_respond, &req)) {
|
|
||||||
ERROR("Failed to create thread for dns request ID %hu: %s",
|
|
||||||
request.header.id,
|
|
||||||
strerror(errno)
|
|
||||||
);
|
|
||||||
free_packet(&request);
|
|
||||||
free_connection(&connection);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
pthread_detach(thread);
|
|
||||||
}
|
|
||||||
|
|
||||||
pthread_detach(pthread_self());
|
|
||||||
return NULL;
|
|
||||||
}
|
|
||||||
|
|
||||||
static void signal_handler() {
|
|
||||||
printf("\n");
|
|
||||||
pthread_kill(udp, SIGTERM);
|
|
||||||
pthread_kill(tcp, SIGTERM);
|
|
||||||
}
|
|
||||||
|
|
||||||
void server_run(Server* server) {
|
|
||||||
if (!pthread_create(&udp, NULL, &server_listen, &server->udp)) {
|
|
||||||
INFO("Listening for connections on UDP");
|
|
||||||
} else {
|
|
||||||
ERROR("Failed to start UDP thread");
|
|
||||||
exit(EXIT_FAILURE);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!pthread_create(&tcp, NULL, &server_listen, &server->tcp)) {
|
|
||||||
INFO("Listening for connections on TCP");
|
|
||||||
} else {
|
|
||||||
ERROR("Failed to start TCP thread");
|
|
||||||
exit(EXIT_FAILURE);
|
|
||||||
}
|
|
||||||
signal(SIGINT, signal_handler);
|
|
||||||
|
|
||||||
pthread_join(udp, NULL);
|
|
||||||
pthread_join(tcp, NULL);
|
|
||||||
|
|
||||||
pthread_exit(0);
|
|
||||||
}
|
|
||||||
|
|
||||||
void server_free(Server* server) {
|
|
||||||
free_binding(&server->udp);
|
|
||||||
free_binding(&server->tcp);
|
|
||||||
record_map_free(&map);
|
|
||||||
}
|
|
||||||
|
|
|
@ -1,12 +0,0 @@
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include "binding.h"
|
|
||||||
|
|
||||||
typedef struct {
|
|
||||||
Binding udp;
|
|
||||||
Binding tcp;
|
|
||||||
} Server;
|
|
||||||
|
|
||||||
void server_init(uint16_t port, Server* server);
|
|
||||||
void server_run(Server* server);
|
|
||||||
void server_free(Server* server);
|
|
156
src/web/api.rs
Normal file
156
src/web/api.rs
Normal file
|
@ -0,0 +1,156 @@
|
||||||
|
use std::net::IpAddr;
|
||||||
|
|
||||||
|
use axum::{
|
||||||
|
extract::Query,
|
||||||
|
response::Response,
|
||||||
|
routing::{get, post, put, delete},
|
||||||
|
Extension, Router,
|
||||||
|
};
|
||||||
|
use moka::future::Cache;
|
||||||
|
use rand::distributions::{Alphanumeric, DistString};
|
||||||
|
use serde::Deserialize;
|
||||||
|
use tower_cookies::{Cookie, Cookies};
|
||||||
|
|
||||||
|
use crate::{config::Config, database::Database, dns::packet::record::DnsRecord};
|
||||||
|
|
||||||
|
use super::{
|
||||||
|
extract::{Authorized, Body, RequestIp},
|
||||||
|
http::{json, text},
|
||||||
|
};
|
||||||
|
|
||||||
|
pub fn router() -> Router {
|
||||||
|
Router::new()
|
||||||
|
.route("/login", post(login))
|
||||||
|
.route("/domains", get(list_domains))
|
||||||
|
.route("/domains", delete(delete_domain))
|
||||||
|
.route("/records", get(get_domain))
|
||||||
|
.route("/records", put(add_record))
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn list_domains(_: Authorized, Extension(database): Extension<Database>) -> Response {
|
||||||
|
let domains = match database.get_domains().await {
|
||||||
|
Ok(domains) => domains,
|
||||||
|
Err(err) => return text(500, &format!("{err}")),
|
||||||
|
};
|
||||||
|
|
||||||
|
let Ok(domains) = serde_json::to_string(&domains) else {
|
||||||
|
return text(500, "Failed to fetch domains")
|
||||||
|
};
|
||||||
|
|
||||||
|
json(200, &domains)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
struct DomainRequest {
|
||||||
|
domain: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn get_domain(
|
||||||
|
_: Authorized,
|
||||||
|
Extension(database): Extension<Database>,
|
||||||
|
Query(query): Query<DomainRequest>,
|
||||||
|
) -> Response {
|
||||||
|
let records = match database.get_domain(&query.domain).await {
|
||||||
|
Ok(records) => records,
|
||||||
|
Err(err) => return text(500, &format!("{err}")),
|
||||||
|
};
|
||||||
|
|
||||||
|
let Ok(records) = serde_json::to_string(&records) else {
|
||||||
|
return text(500, "Failed to fetch records")
|
||||||
|
};
|
||||||
|
|
||||||
|
json(200, &records)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn delete_domain(
|
||||||
|
_: Authorized,
|
||||||
|
Extension(database): Extension<Database>,
|
||||||
|
Body(body): Body,
|
||||||
|
) -> Response {
|
||||||
|
|
||||||
|
let Ok(request) = serde_json::from_str::<DomainRequest>(&body) else {
|
||||||
|
return text(400, "Missing request parameters")
|
||||||
|
};
|
||||||
|
|
||||||
|
let Ok(domains) = database.get_domains().await else {
|
||||||
|
return text(500, "Failed to delete domain")
|
||||||
|
};
|
||||||
|
|
||||||
|
if !domains.contains(&request.domain) {
|
||||||
|
return text(400, "Domain does not exist")
|
||||||
|
}
|
||||||
|
|
||||||
|
if database.delete_domain(request.domain).await.is_err() {
|
||||||
|
return text(500, "Failed to delete domain")
|
||||||
|
};
|
||||||
|
|
||||||
|
return text(204, "Successfully deleted domain")
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn add_record(
|
||||||
|
_: Authorized,
|
||||||
|
Extension(database): Extension<Database>,
|
||||||
|
Body(body): Body,
|
||||||
|
) -> Response {
|
||||||
|
let Ok(record) = serde_json::from_str::<DnsRecord>(&body) else {
|
||||||
|
return text(400, "Invalid DNS record")
|
||||||
|
};
|
||||||
|
|
||||||
|
let allowed = record.get_qtype().allowed_actions();
|
||||||
|
if !allowed.1 {
|
||||||
|
return text(400, "Not allowed to create record")
|
||||||
|
}
|
||||||
|
|
||||||
|
let Ok(records) = database.get_records(&record.get_domain(), record.get_qtype()).await else {
|
||||||
|
return text(500, "Failed to complete record check");
|
||||||
|
};
|
||||||
|
|
||||||
|
if !records.is_empty() && !allowed.0 {
|
||||||
|
return text(400, "Not allowed to create duplicate record")
|
||||||
|
};
|
||||||
|
|
||||||
|
if records.contains(&record) {
|
||||||
|
return text(400, "Not allowed to create duplicate record")
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Err(err) = database.add_record(record).await {
|
||||||
|
return text(500, &format!("{err}"));
|
||||||
|
}
|
||||||
|
|
||||||
|
return text(201, "Added record to database successfully");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
struct LoginRequest {
|
||||||
|
user: String,
|
||||||
|
pass: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn login(
|
||||||
|
Extension(config): Extension<Config>,
|
||||||
|
Extension(cache): Extension<Cache<String, IpAddr>>,
|
||||||
|
RequestIp(ip): RequestIp,
|
||||||
|
cookies: Cookies,
|
||||||
|
Body(body): Body,
|
||||||
|
) -> Response {
|
||||||
|
let Ok(request) = serde_json::from_str::<LoginRequest>(&body) else {
|
||||||
|
return text(400, "Missing request parameters")
|
||||||
|
};
|
||||||
|
|
||||||
|
if request.user != config.web_user || request.pass != config.web_pass {
|
||||||
|
return text(400, "Invalid credentials");
|
||||||
|
};
|
||||||
|
|
||||||
|
let token = Alphanumeric.sample_string(&mut rand::thread_rng(), 128);
|
||||||
|
|
||||||
|
cache.insert(token.clone(), ip).await;
|
||||||
|
|
||||||
|
let mut cookie = Cookie::new("auth", token);
|
||||||
|
cookie.set_secure(true);
|
||||||
|
cookie.set_http_only(true);
|
||||||
|
cookie.set_path("/");
|
||||||
|
|
||||||
|
cookies.add(cookie);
|
||||||
|
|
||||||
|
text(200, "Successfully logged in")
|
||||||
|
}
|
139
src/web/extract.rs
Normal file
139
src/web/extract.rs
Normal file
|
@ -0,0 +1,139 @@
|
||||||
|
use std::{
|
||||||
|
io::Read,
|
||||||
|
net::{IpAddr, SocketAddr},
|
||||||
|
};
|
||||||
|
|
||||||
|
use axum::{
|
||||||
|
async_trait,
|
||||||
|
body::HttpBody,
|
||||||
|
extract::{ConnectInfo, FromRequest, FromRequestParts},
|
||||||
|
http::{request::Parts, Request},
|
||||||
|
response::Response,
|
||||||
|
BoxError,
|
||||||
|
};
|
||||||
|
use bytes::Bytes;
|
||||||
|
use moka::future::Cache;
|
||||||
|
use tower_cookies::Cookies;
|
||||||
|
|
||||||
|
use super::http::text;
|
||||||
|
|
||||||
|
pub struct Authorized;
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl<S> FromRequestParts<S> for Authorized
|
||||||
|
where
|
||||||
|
S: Send + Sync,
|
||||||
|
{
|
||||||
|
type Rejection = Response;
|
||||||
|
|
||||||
|
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
|
||||||
|
let Ok(Some(cookies)) = Option::<Cookies>::from_request_parts(parts, state).await else {
|
||||||
|
return Err(text(403, "No cookies provided"))
|
||||||
|
};
|
||||||
|
|
||||||
|
let Some(token) = cookies.get("auth") else {
|
||||||
|
return Err(text(403, "No auth token provided"))
|
||||||
|
};
|
||||||
|
|
||||||
|
let auth_ip: IpAddr;
|
||||||
|
{
|
||||||
|
let Some(cache) = parts.extensions.get::<Cache<String, IpAddr>>() else {
|
||||||
|
return Err(text(500, "Failed to load auth store"))
|
||||||
|
};
|
||||||
|
|
||||||
|
let Some(ip) = cache.get(token.value()) else {
|
||||||
|
return Err(text(401, "Unauthorized"))
|
||||||
|
};
|
||||||
|
|
||||||
|
auth_ip = ip
|
||||||
|
}
|
||||||
|
|
||||||
|
let Ok(Some(RequestIp(ip))) = Option::<RequestIp>::from_request_parts(parts, state).await else {
|
||||||
|
return Err(text(403, "You have no ip"))
|
||||||
|
};
|
||||||
|
|
||||||
|
if auth_ip != ip {
|
||||||
|
return Err(text(403, "Auth token does not match current ip"));
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(Self)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct RequestIp(pub IpAddr);
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl<S> FromRequestParts<S> for RequestIp
|
||||||
|
where
|
||||||
|
S: Send + Sync,
|
||||||
|
{
|
||||||
|
type Rejection = Response;
|
||||||
|
|
||||||
|
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
|
||||||
|
let headers = &parts.headers;
|
||||||
|
|
||||||
|
let forwardedfor = headers
|
||||||
|
.get("x-forwarded-for")
|
||||||
|
.and_then(|h| h.to_str().ok())
|
||||||
|
.and_then(|h| {
|
||||||
|
h.split(',')
|
||||||
|
.rev()
|
||||||
|
.find_map(|s| s.trim().parse::<IpAddr>().ok())
|
||||||
|
});
|
||||||
|
|
||||||
|
if let Some(forwardedfor) = forwardedfor {
|
||||||
|
return Ok(Self(forwardedfor));
|
||||||
|
}
|
||||||
|
|
||||||
|
let realip = headers
|
||||||
|
.get("x-real-ip")
|
||||||
|
.and_then(|hv| hv.to_str().ok())
|
||||||
|
.and_then(|s| s.parse::<IpAddr>().ok());
|
||||||
|
|
||||||
|
if let Some(realip) = realip {
|
||||||
|
return Ok(Self(realip));
|
||||||
|
}
|
||||||
|
|
||||||
|
let realip = headers
|
||||||
|
.get("x-real-ip")
|
||||||
|
.and_then(|hv| hv.to_str().ok())
|
||||||
|
.and_then(|s| s.parse::<IpAddr>().ok());
|
||||||
|
|
||||||
|
if let Some(realip) = realip {
|
||||||
|
return Ok(Self(realip));
|
||||||
|
}
|
||||||
|
|
||||||
|
let info = parts.extensions.get::<ConnectInfo<SocketAddr>>();
|
||||||
|
|
||||||
|
if let Some(info) = info {
|
||||||
|
return Ok(Self(info.0.ip()));
|
||||||
|
}
|
||||||
|
|
||||||
|
Err(text(403, "You have no ip"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct Body(pub String);
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl<S, B> FromRequest<S, B> for Body
|
||||||
|
where
|
||||||
|
B: HttpBody + Sync + Send + 'static,
|
||||||
|
B::Data: Send,
|
||||||
|
B::Error: Into<BoxError>,
|
||||||
|
S: Send + Sync,
|
||||||
|
{
|
||||||
|
type Rejection = Response;
|
||||||
|
|
||||||
|
async fn from_request(req: Request<B>, state: &S) -> Result<Self, Self::Rejection> {
|
||||||
|
let Ok(bytes) = Bytes::from_request(req, state).await else {
|
||||||
|
return Err(text(413, "Payload too large"));
|
||||||
|
};
|
||||||
|
|
||||||
|
let Ok(body) = String::from_utf8(bytes.bytes().flatten().collect()) else {
|
||||||
|
return Err(text(400, "Invalid utf8 body"))
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(Self(body))
|
||||||
|
}
|
||||||
|
}
|
31
src/web/file.rs
Normal file
31
src/web/file.rs
Normal file
|
@ -0,0 +1,31 @@
|
||||||
|
use axum::{extract::Path, response::Response};
|
||||||
|
|
||||||
|
use super::http::serve;
|
||||||
|
|
||||||
|
pub async fn js(Path(path): Path<String>) -> Response {
|
||||||
|
let path = format!("/js/{path}");
|
||||||
|
serve(&path).await
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn css(Path(path): Path<String>) -> Response {
|
||||||
|
let path = format!("/css/{path}");
|
||||||
|
serve(&path).await
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn fonts(Path(path): Path<String>) -> Response {
|
||||||
|
let path = format!("/fonts/{path}");
|
||||||
|
serve(&path).await
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn image(Path(path): Path<String>) -> Response {
|
||||||
|
let path = format!("/image/{path}");
|
||||||
|
serve(&path).await
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn favicon() -> Response {
|
||||||
|
serve("/favicon.ico").await
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn robots() -> Response {
|
||||||
|
serve("/robots.txt").await
|
||||||
|
}
|
50
src/web/http.rs
Normal file
50
src/web/http.rs
Normal file
|
@ -0,0 +1,50 @@
|
||||||
|
use axum::{
|
||||||
|
body::Body,
|
||||||
|
http::{header::HeaderName, HeaderValue, Request, StatusCode},
|
||||||
|
response::{IntoResponse, Response},
|
||||||
|
};
|
||||||
|
use std::str;
|
||||||
|
use tower::ServiceExt;
|
||||||
|
use tower_http::services::ServeFile;
|
||||||
|
|
||||||
|
pub fn text(code: u16, msg: &str) -> Response {
|
||||||
|
(status_code(code), msg.to_owned()).into_response()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn json(code: u16, json: &str) -> Response {
|
||||||
|
let mut res = (status_code(code), json.to_owned()).into_response();
|
||||||
|
res.headers_mut().insert(
|
||||||
|
HeaderName::from_static("content-type"),
|
||||||
|
HeaderValue::from_static("application/json"),
|
||||||
|
);
|
||||||
|
res
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn serve(path: &str) -> Response {
|
||||||
|
if !path.chars().any(|c| c == '.') {
|
||||||
|
return text(403, "Invalid file path");
|
||||||
|
}
|
||||||
|
|
||||||
|
let path = format!("public{path}");
|
||||||
|
let file = ServeFile::new(path);
|
||||||
|
|
||||||
|
let Ok(mut res) = file.oneshot(Request::new(Body::empty())).await else {
|
||||||
|
tracing::error!("Error while fetching file");
|
||||||
|
return text(500, "Error when fetching file")
|
||||||
|
};
|
||||||
|
|
||||||
|
if res.status() != StatusCode::OK {
|
||||||
|
return text(404, "File not found");
|
||||||
|
}
|
||||||
|
|
||||||
|
res.headers_mut().insert(
|
||||||
|
HeaderName::from_static("cache-control"),
|
||||||
|
HeaderValue::from_static("max-age=300"),
|
||||||
|
);
|
||||||
|
|
||||||
|
res.into_response()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn status_code(code: u16) -> StatusCode {
|
||||||
|
StatusCode::from_u16(code).map_or(StatusCode::OK, |code| code)
|
||||||
|
}
|
82
src/web/mod.rs
Normal file
82
src/web/mod.rs
Normal file
|
@ -0,0 +1,82 @@
|
||||||
|
use std::net::{IpAddr, SocketAddr, TcpListener};
|
||||||
|
use std::time::Duration;
|
||||||
|
|
||||||
|
use axum::routing::get;
|
||||||
|
use axum::{Extension, Router};
|
||||||
|
use moka::future::Cache;
|
||||||
|
use tokio::task::JoinHandle;
|
||||||
|
use tower_cookies::CookieManagerLayer;
|
||||||
|
use tracing::{error, info};
|
||||||
|
|
||||||
|
use crate::config::Config;
|
||||||
|
use crate::database::Database;
|
||||||
|
use crate::Result;
|
||||||
|
|
||||||
|
mod api;
|
||||||
|
mod extract;
|
||||||
|
mod file;
|
||||||
|
mod http;
|
||||||
|
mod pages;
|
||||||
|
|
||||||
|
pub struct WebServer {
|
||||||
|
config: Config,
|
||||||
|
database: Database,
|
||||||
|
addr: SocketAddr,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl WebServer {
|
||||||
|
pub async fn new(config: Config, database: Database) -> Result<Self> {
|
||||||
|
let addr = format!("[::]:{}", config.web_port).parse::<SocketAddr>()?;
|
||||||
|
Ok(Self {
|
||||||
|
config,
|
||||||
|
database,
|
||||||
|
addr,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn run(&self) -> Result<JoinHandle<()>> {
|
||||||
|
let config = self.config.clone();
|
||||||
|
let database = self.database.clone();
|
||||||
|
let listener = TcpListener::bind(self.addr)?;
|
||||||
|
|
||||||
|
info!(
|
||||||
|
"Listening for HTTP traffic on [::]:{}",
|
||||||
|
self.config.web_port
|
||||||
|
);
|
||||||
|
|
||||||
|
let app = Self::router(config, database);
|
||||||
|
let server = axum::Server::from_tcp(listener)?;
|
||||||
|
|
||||||
|
let web_handle = tokio::spawn(async move {
|
||||||
|
if let Err(err) = server
|
||||||
|
.serve(app.into_make_service_with_connect_info::<SocketAddr>())
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
error!("{err}");
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
Ok(web_handle)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn router(config: Config, database: Database) -> Router {
|
||||||
|
let cache: Cache<String, IpAddr> = Cache::builder()
|
||||||
|
.time_to_live(Duration::from_secs(60 * 15))
|
||||||
|
.max_capacity(config.dns_cache_size)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
Router::new()
|
||||||
|
.nest("/", pages::router())
|
||||||
|
.nest("/api", api::router())
|
||||||
|
.layer(Extension(config))
|
||||||
|
.layer(Extension(cache))
|
||||||
|
.layer(Extension(database))
|
||||||
|
.layer(CookieManagerLayer::new())
|
||||||
|
.route("/js/*path", get(file::js))
|
||||||
|
.route("/css/*path", get(file::css))
|
||||||
|
.route("/fonts/*path", get(file::fonts))
|
||||||
|
.route("/image/*path", get(file::image))
|
||||||
|
.route("/favicon.ico", get(file::favicon))
|
||||||
|
.route("/robots.txt", get(file::robots))
|
||||||
|
}
|
||||||
|
}
|
31
src/web/pages.rs
Normal file
31
src/web/pages.rs
Normal file
|
@ -0,0 +1,31 @@
|
||||||
|
use axum::{response::Response, routing::get, Router};
|
||||||
|
|
||||||
|
use super::{extract::Authorized, http::serve};
|
||||||
|
|
||||||
|
pub fn router() -> Router {
|
||||||
|
Router::new()
|
||||||
|
.route("/", get(root))
|
||||||
|
.route("/login", get(login))
|
||||||
|
.route("/home", get(home))
|
||||||
|
.route("/domain", get(domain))
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn root(user: Option<Authorized>) -> Response {
|
||||||
|
if user.is_some() {
|
||||||
|
home().await
|
||||||
|
} else {
|
||||||
|
login().await
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn login() -> Response {
|
||||||
|
serve("/login.html").await
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn home() -> Response {
|
||||||
|
serve("/home.html").await
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn domain() -> Response {
|
||||||
|
serve("/domain.html").await
|
||||||
|
}
|
Loading…
Reference in a new issue