Compare commits

..

1 commit

Author SHA1 Message Date
f46d5307fc move old rust ver to own archive branch 2023-04-05 23:06:29 -04:00
70 changed files with 5583 additions and 3678 deletions

5
.gitignore vendored
View file

@ -1,3 +1,2 @@
bin **/target
config .env
bee

2447
Cargo.lock generated Normal file

File diff suppressed because it is too large Load diff

37
Cargo.toml Normal file
View file

@ -0,0 +1,37 @@
[package]
name = "wrapper"
version = "0.1.0"
edition = "2021"
[dependencies]
# Blazingly fast runtime
tokio = { version = "1", features = ["full"] }
tracing-subscriber = "0.3.16"
tracing = "0.1.37"
# Allow recursion inside tokio async
async-recursion = "1"
# DNS Caching Layer
moka = { version = "0.10.0", features = ["future"] }
# Mongodb
mongodb = { version = "2.4", features = ["tokio-sync"] }
futures = "0.3.26"
# Convert values to json for Mongodb
serde = { version = "1.0", features = ["derive"] }
# Reading env vars from .env
dotenv = "0.15.0"
# For the meme records
rand = "0.8.5"
# For the http web frontend
axum = "0.6.4"
tower-http = { version = "0.4.0", features = ["fs"] }
tower-cookies = "0.9.0"
tower = "0.4.13"
bytes = "1.4.0"
serde_json = "1"

View file

@ -1,51 +0,0 @@
CC = gcc
INCFLAGS = -Isrc
CCFLAGS = -std=gnu99 -Wall -Wextra -pedantic -O2
CCFLAGS += $(INCFLAGS)
LDFLAGS += $(INCFLAGS)
LDFLAGS += -lpthread
BIN = bin
APP = $(BIN)/app
SRC = $(shell find src -name "*.c")
OBJ = $(SRC:%.c=$(BIN)/%.o)
.PHONY: dirs run clean build install uninstall install-openrc uninstall-openrc
EOF: clean build
dirs:
mkdir -p ./$(BIN)
mkdir -p ./$(BIN)/src
mkdir -p ./$(BIN)/src/io
mkdir -p ./$(BIN)/src/packet
mkdir -p ./$(BIN)/src/server
run: build
$(APP)
build: dirs ${OBJ}
${CC} -o $(APP) $(filter %.o,$^) $(LDFLAGS)
$(BIN)/%.o: %.c
$(CC) -o $@ -c $< $(CCFLAGS)
clean:
rm -rf $(APP)
rm -rf $(BIN)
install:
cp $(APP) /usr/local/bin/wrapper
uninstall:
rm /usr/local/bin/wrapper
install-openrc: install
cp service/openrc /etc/init.d/wrapper
chmod +x /etc/init.d/wrapper
uninstall-openrc: uninstall
rm /etc/init.d/wrapper

240
buffer.c
View file

@ -1,240 +0,0 @@
#include "buffer.h"
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
struct PacketBuffer {
uint8_t* arr;
int capacity;
int index;
int size;
};
PacketBuffer* buffer_create(int capacity) {
PacketBuffer* buffer = malloc(sizeof(PacketBuffer));
buffer->arr = malloc(capacity);
buffer->capacity = capacity;
buffer->size = 0;
buffer->index = 0;
return buffer;
}
void buffer_free(PacketBuffer* buffer) {
free(buffer->arr);
free(buffer);
}
void buffer_seek(PacketBuffer* buffer, int index) {
buffer->index = index;
}
uint8_t buffer_read(PacketBuffer* buffer) {
if (buffer->index > buffer->size) {
return 0;
}
uint8_t data = buffer->arr[buffer->index];
buffer->index++;
return data;
}
uint16_t buffer_read_short(PacketBuffer* buffer) {
return
(uint16_t) buffer_read(buffer) << 8 |
(uint16_t) buffer_read(buffer);
}
uint32_t buffer_read_int(PacketBuffer* buffer) {
return
(uint32_t) buffer_read(buffer) << 24 |
(uint32_t) buffer_read(buffer) << 16 |
(uint32_t) buffer_read(buffer) << 8 |
(uint32_t) buffer_read(buffer);
}
uint8_t buffer_get(PacketBuffer* buffer, int index) {
if (index > buffer->size) {
return 0;
}
uint8_t data = buffer->arr[index];
return data;
}
uint8_t* buffer_get_range(PacketBuffer* buffer, int start, int len) {
uint8_t* arr = malloc(len);
for (int i = 0; i < len; i++) {
arr[i] = buffer_get(buffer, start + i);
}
return arr;
}
uint16_t buffer_get_size(PacketBuffer* buffer) {
return (uint16_t) buffer->size + 1;
}
static void write(uint8_t** buffer, uint8_t* size, uint8_t* capacity, uint8_t data) {
if (*size == *capacity) {
*capacity *= 2;
*buffer = realloc(*buffer, *capacity);
}
(*buffer)[*size] = data;
(*size)++;
}
void buffer_read_qname(PacketBuffer* buffer, uint8_t** out) {
int index = buffer->index;
int jumped = 0;
int max_jumps = 5;
int jumps_performed = 0;
uint8_t length = 0;
uint8_t capacity = 8;
*out = malloc(capacity);
write(out, &length, &capacity, 0);
while(1) {
if (jumps_performed > max_jumps) {
break;
}
uint8_t len = buffer_get(buffer, index);
if ((len & 0xC0) == 0xC0) {
if (jumped == 0) {
buffer_seek(buffer, index + 2);
}
uint16_t b2 = (uint16_t) buffer_get(buffer, index + 1);
uint16_t offset = ((((uint16_t) len) ^ 0xC0) << 8) | b2;
index = (int) offset;
jumped = 1;
jumps_performed++;
continue;
}
index++;
if (len == 0) {
break;
}
if (length > 1) {
write(out, &length, &capacity, '.');
}
uint8_t* range = buffer_get_range(buffer, index, len);
for (uint8_t i = 0; i < len; i++) {
write(out, &length, &capacity, range[i]);
}
free(range);
index += (int) len;
}
if (jumped == 0) {
buffer_seek(buffer, index);
}
(*out)[0] = length - 1;
}
void buffer_read_string(PacketBuffer* buffer, uint8_t** out) {
uint8_t len = buffer_read(buffer);
buffer_read_n(buffer, out, len);
}
void buffer_read_n(PacketBuffer* buffer, uint8_t** out, uint8_t len) {
*out = malloc(len + 1);
*out[0] = len;
memcpy(*out + 1, buffer->arr + buffer->index, len);
buffer->index += len;
}
void buffer_write(PacketBuffer* buffer, uint8_t data) {
if(buffer->index == buffer->capacity) {
buffer->capacity *= 2;
buffer->arr = realloc(buffer->arr, buffer->capacity);
}
if (buffer->size < buffer->index) {
buffer->size = buffer->index;
}
buffer->arr[buffer->index] = data;
buffer->index++;
}
void buffer_write_short(PacketBuffer* buffer, uint16_t data) {
buffer_write(buffer, (uint8_t)(data >> 8));
buffer_write(buffer, (uint8_t)(data & 0xFF));
}
void buffer_write_int(PacketBuffer* buffer, uint32_t data) {
buffer_write(buffer, (uint8_t)(data >> 24));
buffer_write(buffer, (uint8_t)(data >> 16));
buffer_write(buffer, (uint8_t)(data >> 8));
buffer_write(buffer, (uint8_t)(data & 0xFF));
}
void buffer_write_qname(PacketBuffer* buffer, uint8_t* in) {
uint8_t part = 0;
uint8_t len = in[0];
buffer_write(buffer, 0);
if (len == 0) {
return;
}
for(uint8_t i = 0; i < len; i ++) {
if (in[i+1] == '.') {
buffer_set(buffer, part, buffer->index - (int)part - 1);
buffer_write(buffer, 0);
part = 0;
} else {
buffer_write(buffer, in[i+1]);
part++;
}
}
buffer_set(buffer, part, buffer->index - (int)part - 1);
buffer_write(buffer, 0);
}
void buffer_write_string(PacketBuffer* buffer, uint8_t* in) {
buffer_write(buffer, in[0]);
buffer_write_n(buffer, in + 1, in[0]);
}
void buffer_write_n(PacketBuffer* buffer, uint8_t* in, int len) {
if (buffer->size + len >= buffer->capacity) {
buffer->capacity *= 2;
buffer->capacity += len;
buffer->arr = realloc(buffer->arr, buffer->capacity);
}
memcpy(buffer->arr + buffer->index, in, len);
buffer->size += len;
buffer->index += len;
}
void buffer_set(PacketBuffer* buffer, uint8_t data, int index) {
if (index > buffer->size) {
return;
}
buffer->arr[index] = data;
}
void buffer_set_uint16_t(PacketBuffer* buffer, uint16_t data, int index) {
buffer_set(buffer, (uint8_t)(data >> 8), index);
buffer_set(buffer, (uint8_t)(data & 0xFF), index + 1);
}
int buffer_get_index(PacketBuffer* buffer) {
return buffer->index;
}
void buffer_step(PacketBuffer* buffer, int len) {
buffer->index += len;
}
uint8_t* buffer_get_ptr(PacketBuffer* buffer) {
return buffer->arr;
}

40
public/css/home.css Normal file
View file

@ -0,0 +1,40 @@
span {
margin-top: 5rem;
margin-bottom: 1rem;
width: 45rem;
font-size: 2em;
}
#new {
display: flex;
justify-content: center;
width: 100%;
padding-top: 2rem;
padding-bottom: 1rem;
border-bottom: solid 1px var(--gray);
}
#new input, .block {
border-radius: 1rem 0 0 1rem;
width: 40rem;
}
.block {
width: 33em;
}
#new button {
border-radius: 0 1rem 1rem 0;
}
.domain {
margin-top: 2rem;
}
.domain .delete {
border-radius: 0 1rem 1rem 0;
}
.domain .edit {
border-radius: 0;
}

18
public/css/login.css Normal file
View file

@ -0,0 +1,18 @@
#login {
margin-top: 20em;
}
#logo {
font-size: 6em;
font-weight: 750;
font-family: bold;
margin-bottom: 2rem;
}
form {
width: 30rem;
}
form input {
width: 100%;
}

119
public/css/main.css Normal file
View file

@ -0,0 +1,119 @@
:root {
--dark: #222428;
--dark-alternate: #2b2e36;
--header: #1e1e22;
--accent: #8849f5;
--accent-alternate: #6829d5;
--gray: #2f2f3f;
--main: #ffffff;
--main-alternate: #cccccc;
}
* {
padding: 0;
margin: 0;
}
@font-face {
font-family: main;
src: url("../fonts/helvetica.ttf") format("truetype");
font-display: swap;
}
@font-face {
font-family: bold;
src: url("../fonts/overpass-bold.otf") format("opentype");
font-display: swap;
}
@font-face {
font-family: bold-italic;
src: url("../fonts/overpass-bold-italic.otf") format("opentype");
font-display: swap;
}
html {
background-color: var(--dark);
font-family: main;
color: var(--main);
width: 100%;
height: 100%;
}
body {
width: 100%;
height: 100%;
display: flex;
flex-direction: column;
align-items: center;
}
.accent {
color: var(--accent);
}
.fill {
width: 100%;
height: 100%;
display: flex;
flex-direction: column;
align-items: center;
}
input, button, .block {
all: unset;
display: inline-block;
font: main;
background-color: var(--dark-alternate);
font-size: 1rem;
padding: 1rem;
border-radius: 1rem;
margin-bottom: 20px;
}
button {
background-color: var(--accent);
width: 5em;
text-align: center;
}
button:hover {
cursor: pointer;
background-color: var(--accent-alternate);
}
.delete {
background-color: #f54842;
}
.delete:hover {
cursor: pointer;
background-color: #d52822;
}
form {
display: flex;
flex-direction: column;
}
#header {
width: calc(100% - 4rem);
background-color: var(--header);
border-bottom: solid 1px var(--gray);
padding: 1rem;
padding-left: 3rem;
}
#logo {
font-size: 2em;
font-weight: 500;
font-family: bold;
}
#title {
font-size: 2em;
font-weight: 300;
font-family: sans-serif;
padding-left: 1em;
}

67
public/css/record.css Normal file
View file

@ -0,0 +1,67 @@
#buttons {
margin-top: 2rem;
width: 50rem;
}
#buttons button {
margin: 0;
margin-right: 2rem;
border-radius: 10px;
width: auto;
padding: .75rem 1rem;
}
.record {
width: 50rem;
background-color: var(--header);
padding: 1rem;
margin-top: 2rem;
}
.header {
display: flex;
align-items: center;
margin-bottom: 1rem;
}
.header span {
font-family: bold;
}
.header button {
margin: 0;
margin-left: 2rem;
padding: .5rem 1rem;
width: auto;
border-radius: 5px;
}
.type {
margin-right: 1rem;
background-color: var(--accent);
padding: .25rem .5rem;
border-radius: 5px;
}
.domain {
color: var(--main-alternate);
flex-grow: 1;
}
.properties {
display: flex;
flex-direction: column;
}
.poperty {
display: flex;
flex-direction: row;
border-bottom: solid 1px var(--gray);
margin-top: 1rem;
}
.key {
font-family: bold;
width: 5rem;
}

21
public/domain.html Normal file
View file

@ -0,0 +1,21 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Wrapper - Records</title>
<meta name="author" content="Tyler Murphy">
<meta name="description" content="wrapper records">
<meta property="og:title" content="wrapper">
<meta property="og:description" content="wrapper records">
<link rel="stylesheet" href="/css/main.css">
<link rel="stylesheet" href="/css/record.css">
<script type="module" src="/js/domain.js"></script>
</head>
<body>
</body>
</html>

Binary file not shown.

BIN
public/fonts/helvetica.ttf Normal file

Binary file not shown.

Binary file not shown.

Binary file not shown.

21
public/home.html Normal file
View file

@ -0,0 +1,21 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Wrapper - Domains</title>
<meta name="author" content="Tyler Murphy">
<meta name="description" content="wrapper domains">
<meta property="og:title" content="wrapper">
<meta property="og:description" content="wrapper domains">
<link rel="stylesheet" href="/css/main.css">
<link rel="stylesheet" href="/css/home.css">
<script type="module" src="/js/home.js"></script>
</head>
<body>
</body>
</html>

51
public/js/api.js Normal file
View file

@ -0,0 +1,51 @@
const endpoint = '/api'
const request = async (url, method, body) => {
let response;
if (method == 'GET') {
response = await fetch(endpoint + url, {
method,
headers: {
'Content-Type': 'application/json'
}
});
} else {
response = await fetch(endpoint + url, {
method,
body: JSON.stringify(body),
headers: {
'Content-Type': 'application/json'
}
});
}
if (response.status == 401) {
location.href = '/login'
}
const contentType = response.headers.get("content-type");
if (contentType && contentType.indexOf("application/json") !== -1) {
const json = await response.json()
return { status: response.status, msg: json.msg, json }
} else {
const msg = await response.text();
return { status: response.status, msg }
}
}
export const login = async (user, pass) => {
return await request('/login', 'POST', {user, pass})
}
export const domains = async () => {
return await request('/domains', 'GET')
}
export const del_domain = async (domain) => {
return await request('/domains', 'DELETE', {domain})
}
export const records = async (domain) => {
return await request(`/records?domain=${domain}`, 'GET')
}

12
public/js/components.js Normal file
View file

@ -0,0 +1,12 @@
import { div, parse, span } from './main.js';
export function header(title) {
return div({id: 'header'},
span({id: 'logo', class: 'accent'},
parse("Wrapper")
),
span({id: 'title'},
parse(title)
),
)
}

95
public/js/domain.js Normal file
View file

@ -0,0 +1,95 @@
import { del_domain, domains, records } from './api.js'
import { header } from './components.js'
import { body, parse, div, input, button, span, is_domain } from './main.js';
function render(domain, records) {
let divs = []
for (const record of records) {
divs.push(gen_record(record))
}
document.body.replaceWith(
body({},
header(domain),
div({id: 'buttons'},
button({onclick: (event) => {
location.href = '/home'
}}, parse("Home")),
button({}, parse("New Record")),
),
...divs
)
)
}
function gen_record(record) {
let domain = record.domain
let prefix = record.prefix
if (prefix.length > 0) {
prefix = prefix + '.'
}
let type = Object.keys(record.record)[0]
let data = record.record[type]
let divs = []
for (const key in data) {
let disp_key;
if (key == 'ttl') {
disp_key = 'TTL'
} else {
disp_key = upper(key)
}
divs.push(
div({class: 'poperty'},
div({class: 'key'}, parse(disp_key)),
div({class: 'value'}, parse(data[key])),
)
)
}
return div({class: 'record'},
div({class: 'header'},
span({class: 'type'}, parse(type)),
span({class: 'prefix'}, parse(prefix)),
span({class: 'domain'}, parse(domain)),
button({}, parse("Edit")),
button({class: 'delete'}, parse("Delete"))
),
div({class: 'properties'},
...divs
)
)
}
function upper(string) {
return string.charAt(0).toUpperCase() + string.slice(1);
}
async function init() {
const params = new Proxy(new URLSearchParams(window.location.search), {
get: (searchParams, prop) => searchParams.get(prop),
});
let domain = params.domain;
if (!is_domain(domain)) {
location.href = '/home'
return
}
let res = await records(domain);
if (res.status !== 200) {
alert(res.msg)
return
}
render(domain, res.json)
}
init()

91
public/js/home.js Normal file
View file

@ -0,0 +1,91 @@
import { del_domain, domains } from './api.js'
import { header } from './components.js'
import { body, parse, div, input, button, span, is_domain } from './main.js';
function render(domains) {
document.body.replaceWith(
body({},
header('domains'),
div({id: 'new'},
input({
type: 'text',
name: 'domain',
id: 'domain',
placeholder: 'Type domain name to create new records',
autocomplete: "off",
}),
button({onclick: () => {
let domain = document.getElementById('domain').value
if (!is_domain(domain)) {
alert("Invalid domain")
return
}
location.href = '/domain?domain='+domain
}},
parse("Create")
)
),
...domain(domains)
)
)
}
function domain(domains) {
let divs = []
for (const domain of domains) {
divs.push(
div({class: 'domain'},
div({class: 'block'},
parse(domain)
),
button({class: 'edit', onclick: (event) => {
console.log(event.target.parentElement)
let domain = event
.target
.parentElement
.getElementsByClassName('block')[0]
.innerText
if (!is_domain(domain)) {
alert("Invalid domain")
return
}
location.href = '/domain?domain='+domain
}},
parse("Edit")
),
button({class: 'delete', onclick: async () => {
let res = await del_domain(domain)
if (res.status != 204) {
alert(res.msg)
return
}
location.reload()
}},
parse("Delete")
)
)
)
}
return divs
}
async function init() {
let res = await domains();
if (res.status !== 200) {
alert(res.msg)
return
}
render(res.json)
}
init()

44
public/js/login.js Normal file
View file

@ -0,0 +1,44 @@
import { body, div, form, input, p, parse, span} from './main.js'
import { login } from './api.js'
function render() {
document.body.replaceWith(
body({},
div({id: 'login', class: 'fill'},
span({id: 'logo'},
span({class: 'accent'}, parse('Wrapper'))
),
form({autocomplete: "off"},
input({
type: 'text',
name: 'user',
id: 'user',
placeholder: 'Username',
autofocus: 1
}),
input({
type: 'password',
name: 'pass',
id: 'pass',
placeholder: 'Password',
onkeydown: async (event) => {
if (event.key == 'Enter') {
event.preventDefault()
let user = document.getElementById('user').value
let pass = document.getElementById('pass').value
let res = await login(user, pass)
if (res.status === 200) {
location.href = '/home'
}
}
}
})
)
)
)
)
}
render()

136
public/js/main.js Normal file
View file

@ -0,0 +1,136 @@
function createElement(name, attrs, ...children) {
const el = document.createElement(name);
for (const attr in attrs) {
if(attr.startsWith("on")) {
el[attr] = attrs[attr];
} else {
el.setAttribute(attr, attrs[attr])
}
}
for (const child of children) {
if (child == null) {
continue
}
el.appendChild(child)
}
return el
}
export function createElementNS(name, attrs, ...children) {
var svgns = "http://www.w3.org/2000/svg";
var el = document.createElementNS(svgns, name);
for (const attr in attrs) {
if(attr.startsWith("on")) {
el[attr] = attrs[attr];
} else {
el.setAttribute(attr, attrs[attr])
}
}
for (const child of children) {
if (child == null) {
continue
}
el.appendChild(child)
}
return el
}
export function p(attrs, ...children) {
return createElement("p", attrs, ...children)
}
export function span(attrs, ...children) {
return createElement("span", attrs, ...children)
}
export function div(attrs, ...children) {
return createElement("div", attrs, ...children)
}
export function a(attrs, ...children) {
return createElement("a", attrs, ...children)
}
export function i(attrs, ...children) {
return createElement("i", attrs, ...children)
}
export function form(attrs, ...children) {
return createElement("form", attrs, ...children)
}
export function img(alt, attrs, ...children) {
attrs['onerror'] = (event) => event.target.remove()
attrs['alt'] = alt
return createElement("img", attrs, ...children)
}
export function input(attrs, ...children) {
return createElement("input", attrs, ...children)
}
export function button(attrs, ...children) {
return createElement("button", attrs, ...children)
}
export function path(attrs, ...children) {
return createElementNS("path", attrs, ...children)
}
export function svg(attrs, ...children) {
return createElementNS("svg", attrs, ...children)
}
export function body(attrs, ...children) {
return createElement("body", attrs, ...children)
}
export function textarea(attrs, ...children) {
return createElement("textarea", attrs, ...children)
}
export function parse(input) {
const pattern = /^[ a-zA-Z0-9!@#$%^&*()_+\-=\[\]{};':"\\|,.<>\/?]*$/;
input = input + '';
if (!pattern.test(input)) {
return null;
}
const sanitized = input.replace(/</g, '&lt;').replace(/>/g, '&gt;');
return document.createRange().createContextualFragment(sanitized);
}
export function is_domain(domain) {
domain = domain.toLowerCase()
const pattern = /^[a-z0-9_\-.]*$/;
if (!pattern.test(domain)) {
return false
}
let parts = domain.split('.').reverse()
for (const part of parts) {
if (part.length < 1) {
return false
}
}
if (parts.length < 2 || parts[0].length < 2) {
return false
}
const tld_pattern = /^[a-z]*$/;
if (!tld_pattern.test(parts[0])) {
return false
}
return true
}

21
public/login.html Normal file
View file

@ -0,0 +1,21 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Wrapper - Login</title>
<meta name="author" content="Tyler Murphy">
<meta name="description" content="wrapper dns login">
<meta property="og:title" content="wrapper">
<meta property="og:description" content="wrapper dns login">
<link rel="stylesheet" href="/css/main.css">
<link rel="stylesheet" href="/css/login.css">
<script type="module" src="/js/login.js"></script>
</head>
<body>
</body>
</html>

9
public/robots.txt Normal file
View file

@ -0,0 +1,9 @@
User-agent: Googlebot
Disallow: /api
User-agent: Googlebot
User-agent: AdsBot-Google
Disallow: /api
User-agent: *
Disallow: /api

View file

@ -1,89 +0,0 @@
# Wrapper
A simple and lightweight dns server written in C
## How to
Wrapper by default runs on port 53 udp and tcp, which is the default port for DNS. If you wish to change this variable, set the `PORT` environment variable is set to a different port number.
To set custom records, wrapper reads configuration from `/etc/wrapper.conf`, and if that doesn't exist, `./config`.
The config file format is a question on its own line, then followed by records on their own line. To separate questions/records from each other, place a extra empty new line between them. For example...
```
IN A google.com
IN A 300 12.34.56.78
IN TXT joe
IN TXT 60 biden
```
### Question
Now to break this down piece by piece, the question is made up of a class, record type, and domain. Domain is self explanitory, but for the other two:
#### Class
The valid classes are `IN` (Internet), `CH` (Chaosnet), and `HS` (Hesiod). For most purposes your going to be using IN.
#### Record type
The current supported record types are `A`, `NS`, `CNAME`, `SOA`, `PTR`, `MX`, `TXT`, `AAAA`, `SRV`, and `CAA`.
### Answer
Answers are very similar to questions, they have a class, record type, but also have a Time to Live (TTL), and its record data. TTL is just the amount of seconds other DNS servers should cache the value, but for record data, it's formatted as such:
#### Record Data
- `A` ipv4 (0.0.0.0)
- `NS`, `CNAME`, `PTR` domain (google.com)
- `SOA` mname, nname serial refresh retry expire minimum (ns1.google.com. dns-admin.google.com. 523655285 900 900 1800 60)
- `MX`: priority domain (10 smtp.google.com)
- `TXT`: text (Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua)
- `AAAA` ipv6 (2607:f8b0:4006:080c:0000:0000:0000:200e::)
- `SRV` priority weight port domain (10 10 10 example.com)
- `CAA` flags tag value (0 issue "pki.goog")
Wrapper also has a few joke/meme records for fun. Note these should never acutaly be used for general use, but they are funny.
- `AR`, `AAAAR`
- `CMD` command (neofetch)
`AR` and `AAAAR` have no record data since they generate ipv4's and ipv6's respectively, and turn into `A` and `AAAA` upon response to sender
`CMD` runs the command supplied on the host system and returns the std output as a `TXT` record
## License
This project is Licensed under the [GPLv3](https://www.gnu.org/licenses/gpl-3.0.en.html)
## Compilation
Wrapper only runs on Linux systems that are Posix 1995 compliant
Make sure to have `gcc` and `make` installed, and then run
```shell
$ make # compiles the program
$ sudo make install # installs the binary
```
If you are running openrc, there is a premade service file so you can run
```shell
$ sudo make install-openrc # installs the binary and service file
```
If you wish to remove the program, you can run
```shell
$ sudo make uninstall # removes the binary
```
Or again if your running openrc
```shell
$ sudo make uninstall-openrc # removes the binary and service file
```

View file

@ -1,13 +0,0 @@
#!/sbin/openrc-run
name="WrapperDNS"
description="Wrapper dns server"
command=/usr/local/bin/wrapper
supervisor=supervise-daemon
depend() {
need net
}
#start_pre() {
# export PORT=53
#}

57
src/config.rs Normal file
View file

@ -0,0 +1,57 @@
use std::{env, net::IpAddr, str::FromStr, fmt::Display};
#[derive(Clone)]
pub struct Config {
pub dns_fallback: IpAddr,
pub dns_port: u16,
pub dns_cache_size: u64,
pub db_host: String,
pub db_port: u16,
pub db_user: String,
pub db_pass: String,
pub web_user: String,
pub web_pass: String,
pub web_port: u16,
}
impl Config {
pub fn new() -> Self {
let dns_port = Self::get_var::<u16>("WRAPPER_DNS_PORT", 53);
let dns_fallback = Self::get_var::<IpAddr>("WRAPPER_FALLBACK_DNS", [9, 9, 9, 9].into());
let dns_cache_size = Self::get_var::<u64>("WRAPPER_CACHE_SIZE", 1000);
let db_host = Self::get_var::<String>("WRAPPER_DB_HOST", String::from("localhost"));
let db_port = Self::get_var::<u16>("WRAPPER_DB_PORT", 27017);
let db_user = Self::get_var::<String>("WRAPPER_DB_USER", String::from("root"));
let db_pass = Self::get_var::<String>("WRAPPER_DB_PASS", String::from(""));
let web_user = Self::get_var::<String>("WRAPPER_WEB_USER", String::from("admin"));
let web_pass = Self::get_var::<String>("WRAPPER_WEB_PASS", String::from("wrapper"));
let web_port = Self::get_var::<u16>("WRAPPER_WEB_PORT", 80);
Self {
dns_fallback,
dns_port,
dns_cache_size,
db_host,
db_port,
db_user,
db_pass,
web_user,
web_pass,
web_port,
}
}
fn get_var<T>(name: &str, default: T) -> T
where
T: FromStr + Display,
{
let env = env::var(name).unwrap_or(format!("{default}"));
env.parse::<T>().unwrap_or(default)
}
}

146
src/database/mod.rs Normal file
View file

@ -0,0 +1,146 @@
use futures::TryStreamExt;
use mongodb::{
bson::doc,
options::{ClientOptions, Credential, ServerAddress},
Client,
};
use serde::{Deserialize, Serialize};
use tracing::info;
use crate::{
config::Config,
dns::packet::{query::QueryType, record::DnsRecord},
};
use crate::Result;
#[derive(Clone)]
pub struct Database {
client: Client,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct StoredRecord {
record: DnsRecord,
domain: String,
prefix: String,
}
impl StoredRecord {
fn get_domain_parts(domain: &str) -> (String, String) {
let parts: Vec<&str> = domain.split(".").collect();
let len = parts.len();
if len == 1 {
(String::new(), String::from(parts[0]))
} else if len == 2 {
(String::new(), String::from(parts.join(".")))
} else {
(
String::from(parts[0..len - 2].join(".")),
String::from(parts[len - 2..len].join(".")),
)
}
}
}
impl From<DnsRecord> for StoredRecord {
fn from(record: DnsRecord) -> Self {
let (prefix, domain) = Self::get_domain_parts(&record.get_domain());
Self {
record,
domain,
prefix,
}
}
}
impl Into<DnsRecord> for StoredRecord {
fn into(self) -> DnsRecord {
self.record
}
}
impl Database {
pub async fn new(config: Config) -> Result<Self> {
let options = ClientOptions::builder()
.hosts(vec![ServerAddress::Tcp {
host: config.db_host,
port: Some(config.db_port),
}])
.credential(
Credential::builder()
.username(config.db_user)
.password(config.db_pass)
.build(),
)
.max_pool_size(100)
.app_name(String::from("wrapper"))
.build();
let client = Client::with_options(options)?;
client
.database("wrapper")
.run_command(doc! {"ping": 1}, None)
.await?;
info!("Connection to mongodb successfully");
Ok(Database { client })
}
pub async fn get_records(&self, domain: &str, qtype: QueryType) -> Result<Vec<DnsRecord>> {
let (prefix, domain) = StoredRecord::get_domain_parts(domain);
Ok(self
.get_domain(&domain)
.await?
.into_iter()
.filter(|r| r.prefix == prefix)
.filter(|r| {
let rqtype = r.record.get_qtype();
if qtype == QueryType::A {
return rqtype == QueryType::A || rqtype == QueryType::AR;
} else if qtype == QueryType::AAAA {
return rqtype == QueryType::AAAA || rqtype == QueryType::AAAAR;
} else {
r.record.get_qtype() == qtype
}
})
.map(|r| r.into())
.collect())
}
pub async fn get_domain(&self, domain: &str) -> Result<Vec<StoredRecord>> {
let db = self.client.database("wrapper");
let col = db.collection::<StoredRecord>(domain);
let filter = doc! { "domain": domain };
let mut cursor = col.find(filter, None).await?;
let mut records = Vec::new();
while let Some(record) = cursor.try_next().await? {
records.push(record);
}
Ok(records)
}
pub async fn add_record(&self, record: DnsRecord) -> Result<()> {
let record = StoredRecord::from(record);
let db = self.client.database("wrapper");
let col = db.collection::<StoredRecord>(&record.domain);
col.insert_one(record, None).await?;
Ok(())
}
pub async fn get_domains(&self) -> Result<Vec<String>> {
let db = self.client.database("wrapper");
Ok(db.list_collection_names(None).await?)
}
pub async fn delete_domain(&self, domain: String) -> Result<()> {
let db = self.client.database("wrapper");
let col = db.collection::<StoredRecord>(&domain);
Ok(col.drop(None).await?)
}
}

144
src/dns/binding.rs Normal file
View file

@ -0,0 +1,144 @@
use std::{
net::{IpAddr, SocketAddr},
sync::Arc,
};
use super::packet::{buffer::PacketBuffer, Packet};
use crate::Result;
use tokio::{
io::{AsyncReadExt, AsyncWriteExt},
net::{TcpListener, TcpStream, UdpSocket},
};
use tracing::trace;
pub enum Binding {
UDP(Arc<UdpSocket>),
TCP(TcpListener),
}
impl Binding {
pub async fn udp(addr: SocketAddr) -> Result<Self> {
let socket = UdpSocket::bind(addr).await?;
Ok(Self::UDP(Arc::new(socket)))
}
pub async fn tcp(addr: SocketAddr) -> Result<Self> {
let socket = TcpListener::bind(addr).await?;
Ok(Self::TCP(socket))
}
pub fn name(&self) -> &str {
match self {
Binding::UDP(_) => "UDP",
Binding::TCP(_) => "TCP",
}
}
pub async fn connect(&mut self) -> Result<Connection> {
match self {
Self::UDP(socket) => {
let mut buf = [0; 512];
let (_, addr) = socket.recv_from(&mut buf).await?;
Ok(Connection::UDP(socket.clone(), addr, buf))
}
Self::TCP(socket) => {
let (stream, _) = socket.accept().await?;
Ok(Connection::TCP(stream))
}
}
}
}
pub enum Connection {
UDP(Arc<UdpSocket>, SocketAddr, [u8; 512]),
TCP(TcpStream),
}
impl Connection {
pub async fn read_packet(&mut self) -> Result<Packet> {
let data = self.read().await?;
let mut packet_buffer = PacketBuffer::new(data);
let packet = Packet::from_buffer(&mut packet_buffer)?;
Ok(packet)
}
pub async fn write_packet(self, mut packet: Packet) -> Result<()> {
let mut packet_buffer = PacketBuffer::new(Vec::new());
packet.write(&mut packet_buffer)?;
self.write(packet_buffer.buf).await?;
Ok(())
}
pub async fn request_packet(&self, mut packet: Packet, dest: (IpAddr, u16)) -> Result<Packet> {
let mut packet_buffer = PacketBuffer::new(Vec::new());
packet.write(&mut packet_buffer)?;
let data = self.request(packet_buffer.buf, dest).await?;
let mut packet_buffer = PacketBuffer::new(data);
let packet = Packet::from_buffer(&mut packet_buffer)?;
Ok(packet)
}
async fn read(&mut self) -> Result<Vec<u8>> {
trace!("Reading DNS packet");
match self {
Self::UDP(_, _, src) => Ok(Vec::from(*src)),
Self::TCP(stream) => {
let size = stream.read_u16().await?;
let mut buf = Vec::with_capacity(size as usize);
stream.read_buf(&mut buf).await?;
Ok(buf)
}
}
}
async fn write(self, mut buf: Vec<u8>) -> Result<()> {
trace!("Returning DNS packet");
match self {
Self::UDP(socket, addr, _) => {
if buf.len() > 512 {
buf[2] = buf[2] | 0x03;
socket.send_to(&buf[0..512], addr).await?;
} else {
socket.send_to(&buf, addr).await?;
}
Ok(())
}
Self::TCP(mut stream) => {
stream.write_u16(buf.len() as u16).await?;
stream.write(&buf[0..buf.len()]).await?;
Ok(())
}
}
}
async fn request(&self, buf: Vec<u8>, dest: (IpAddr, u16)) -> Result<Vec<u8>> {
match self {
Self::UDP(_socket, _addr, _src) => {
let local_addr = "[::]:0".parse::<SocketAddr>()?;
let socket = UdpSocket::bind(local_addr).await?;
socket.send_to(&buf, dest).await?;
let mut buf = [0; 512];
socket.recv_from(&mut buf).await?;
Ok(Vec::from(buf))
}
Self::TCP(_stream) => {
let mut stream = TcpStream::connect(dest).await?;
stream.write_u16((buf.len()) as u16).await?;
stream.write_all(&buf[0..buf.len()]).await?;
stream.readable().await?;
let size = stream.read_u16().await?;
let mut buf = Vec::with_capacity(size as usize);
stream.read_buf(&mut buf).await?;
Ok(buf)
}
}
}
}

4
src/dns/mod.rs Normal file
View file

@ -0,0 +1,4 @@
mod binding;
pub mod packet;
mod resolver;
pub mod server;

228
src/dns/packet/buffer.rs Normal file
View file

@ -0,0 +1,228 @@
use crate::Result;
pub struct PacketBuffer {
pub buf: Vec<u8>,
pub pos: usize,
pub size: usize,
}
impl PacketBuffer {
pub fn new(buf: Vec<u8>) -> Self {
Self {
size: buf.len(),
buf,
pos: 0,
}
}
pub fn pos(&self) -> usize {
self.pos
}
pub fn step(&mut self, steps: usize) -> Result<()> {
self.pos += steps;
Ok(())
}
pub fn seek(&mut self, pos: usize) -> Result<()> {
self.pos = pos;
Ok(())
}
pub fn read(&mut self) -> Result<u8> {
if self.pos >= self.size {
return Err("Tried to read past end of buffer".into());
}
let res = self.buf[self.pos];
self.pos += 1;
Ok(res)
}
pub fn get(&mut self, pos: usize) -> Result<u8> {
if pos >= self.size {
return Err("Tried to read past end of buffer".into());
}
Ok(self.buf[pos])
}
pub fn get_range(&mut self, start: usize, len: usize) -> Result<&[u8]> {
if start + len >= self.size {
return Err("Tried to read past end of buffer".into());
}
Ok(&self.buf[start..start + len])
}
pub fn read_u16(&mut self) -> Result<u16> {
let res = ((self.read()? as u16) << 8) | (self.read()? as u16);
Ok(res)
}
pub fn read_u32(&mut self) -> Result<u32> {
let res = ((self.read()? as u32) << 24)
| ((self.read()? as u32) << 16)
| ((self.read()? as u32) << 8)
| (self.read()? as u32);
Ok(res)
}
pub fn read_qname(&mut self, outstr: &mut String) -> Result<()> {
let mut pos = self.pos();
let mut jumped = false;
let mut delim = "";
let max_jumps = 5;
let mut jumps_performed = 0;
loop {
// Dns Packets are untrusted data, so we need to be paranoid. Someone
// can craft a packet with a cycle in the jump instructions. This guards
// against such packets.
if jumps_performed > max_jumps {
return Err(format!("Limit of {max_jumps} jumps exceeded").into());
}
let len = self.get(pos)?;
if (len & 0xC0) == 0xC0 {
if !jumped {
self.seek(pos + 2)?;
}
let b2 = self.get(pos + 1)? as u16;
let offset = (((len as u16) ^ 0xC0) << 8) | b2;
pos = offset as usize;
jumped = true;
jumps_performed += 1;
continue;
}
pos += 1;
if len == 0 {
break;
}
outstr.push_str(delim);
let str_buffer = self.get_range(pos, len as usize)?;
outstr.push_str(&String::from_utf8_lossy(str_buffer).to_lowercase());
delim = ".";
pos += len as usize;
}
if !jumped {
self.seek(pos)?;
}
Ok(())
}
pub fn read_string(&mut self, outstr: &mut String) -> Result<()> {
let len = self.read()?;
self.read_string_n(outstr, len)?;
Ok(())
}
pub fn read_string_n(&mut self, outstr: &mut String, len: u8) -> Result<()> {
let mut pos = self.pos;
let str_buffer = self.get_range(pos, len as usize)?;
let mut i = 0;
for b in str_buffer {
let c = *b as char;
if c == '\0' {
break;
}
outstr.push(c);
i += 1;
}
pos += i;
self.seek(pos)?;
Ok(())
}
pub fn write(&mut self, val: u8) -> Result<()> {
if self.size < self.pos {
self.size = self.pos;
}
if self.buf.len() <= self.size {
self.buf.resize(self.size + 1, 0x00);
}
self.buf[self.pos] = val;
self.pos += 1;
Ok(())
}
pub fn write_u8(&mut self, val: u8) -> Result<()> {
self.write(val)?;
Ok(())
}
pub fn write_u16(&mut self, val: u16) -> Result<()> {
self.write((val >> 8) as u8)?;
self.write((val & 0xFF) as u8)?;
Ok(())
}
pub fn write_u32(&mut self, val: u32) -> Result<()> {
self.write(((val >> 24) & 0xFF) as u8)?;
self.write(((val >> 16) & 0xFF) as u8)?;
self.write(((val >> 8) & 0xFF) as u8)?;
self.write((val & 0xFF) as u8)?;
Ok(())
}
pub fn write_qname(&mut self, qname: &str) -> Result<()> {
for label in qname.split('.') {
let len = label.len();
self.write_u8(len as u8)?;
for b in label.as_bytes() {
self.write_u8(*b)?;
}
}
if !qname.is_empty() {
self.write_u8(0)?;
}
Ok(())
}
pub fn write_string(&mut self, text: &str) -> Result<()> {
for b in text.as_bytes() {
self.write_u8(*b)?;
}
Ok(())
}
pub fn set(&mut self, pos: usize, val: u8) -> Result<()> {
self.buf[pos] = val;
Ok(())
}
pub fn set_u16(&mut self, pos: usize, val: u16) -> Result<()> {
self.set(pos, (val >> 8) as u8)?;
self.set(pos + 1, (val & 0xFF) as u8)?;
Ok(())
}
}

102
src/dns/packet/header.rs Normal file
View file

@ -0,0 +1,102 @@
use super::{buffer::PacketBuffer, result::ResultCode};
use crate::Result;
#[derive(Clone, Debug)]
pub struct DnsHeader {
pub id: u16, // 16 bits
pub recursion_desired: bool, // 1 bit
pub truncated_message: bool, // 1 bit
pub authoritative_answer: bool, // 1 bit
pub opcode: u8, // 4 bits
pub response: bool, // 1 bit
pub rescode: ResultCode, // 4 bits
pub checking_disabled: bool, // 1 bit
pub authed_data: bool, // 1 bit
pub z: bool, // 1 bit
pub recursion_available: bool, // 1 bit
pub questions: u16, // 16 bits
pub answers: u16, // 16 bits
pub authoritative_entries: u16, // 16 bits
pub resource_entries: u16, // 16 bits
}
impl DnsHeader {
pub fn new() -> Self {
Self {
id: 0,
recursion_desired: false,
truncated_message: false,
authoritative_answer: false,
opcode: 0,
response: false,
rescode: ResultCode::NOERROR,
checking_disabled: false,
authed_data: false,
z: false,
recursion_available: false,
questions: 0,
answers: 0,
authoritative_entries: 0,
resource_entries: 0,
}
}
pub fn read(&mut self, buffer: &mut PacketBuffer) -> Result<()> {
self.id = buffer.read_u16()?;
let flags = buffer.read_u16()?;
let a = (flags >> 8) as u8;
let b = (flags & 0xFF) as u8;
self.recursion_desired = (a & (1 << 0)) > 0;
self.truncated_message = (a & (1 << 1)) > 0;
self.authoritative_answer = (a & (1 << 2)) > 0;
self.opcode = (a >> 3) & 0x0F;
self.response = (a & (1 << 7)) > 0;
self.rescode = ResultCode::from_num(b & 0x0F);
self.checking_disabled = (b & (1 << 4)) > 0;
self.authed_data = (b & (1 << 5)) > 0;
self.z = (b & (1 << 6)) > 0;
self.recursion_available = (b & (1 << 7)) > 0;
self.questions = buffer.read_u16()?;
self.answers = buffer.read_u16()?;
self.authoritative_entries = buffer.read_u16()?;
self.resource_entries = buffer.read_u16()?;
// Return the constant header size
Ok(())
}
pub fn write(&self, buffer: &mut PacketBuffer) -> Result<()> {
buffer.write_u16(self.id)?;
buffer.write_u8(
(self.recursion_desired as u8)
| ((self.truncated_message as u8) << 1)
| ((self.authoritative_answer as u8) << 2)
| (self.opcode << 3)
| ((self.response as u8) << 7),
)?;
buffer.write_u8(
(self.rescode as u8)
| ((self.checking_disabled as u8) << 4)
| ((self.authed_data as u8) << 5)
| ((self.z as u8) << 6)
| ((self.recursion_available as u8) << 7),
)?;
buffer.write_u16(self.questions)?;
buffer.write_u16(self.answers)?;
buffer.write_u16(self.authoritative_entries)?;
buffer.write_u16(self.resource_entries)?;
Ok(())
}
}

128
src/dns/packet/mod.rs Normal file
View file

@ -0,0 +1,128 @@
use std::net::IpAddr;
use self::{
buffer::PacketBuffer, header::DnsHeader, query::QueryType, question::DnsQuestion,
record::DnsRecord,
};
use crate::Result;
pub mod buffer;
pub mod header;
pub mod query;
pub mod question;
pub mod record;
pub mod result;
#[derive(Clone, Debug)]
pub struct Packet {
pub header: DnsHeader,
pub questions: Vec<DnsQuestion>,
pub answers: Vec<DnsRecord>,
pub authorities: Vec<DnsRecord>,
pub resources: Vec<DnsRecord>,
}
impl Packet {
pub fn new() -> Self {
Self {
header: DnsHeader::new(),
questions: Vec::new(),
answers: Vec::new(),
authorities: Vec::new(),
resources: Vec::new(),
}
}
pub fn from_buffer(buffer: &mut PacketBuffer) -> Result<Self> {
let mut result = Self::new();
result.header.read(buffer)?;
for _ in 0..result.header.questions {
let mut question = DnsQuestion::new("".to_string(), QueryType::UNKNOWN(0));
question.read(buffer)?;
result.questions.push(question);
}
for _ in 0..result.header.answers {
let rec = DnsRecord::read(buffer)?;
result.answers.push(rec);
}
for _ in 0..result.header.authoritative_entries {
let rec = DnsRecord::read(buffer)?;
result.authorities.push(rec);
}
for _ in 0..result.header.resource_entries {
let rec = DnsRecord::read(buffer)?;
result.resources.push(rec);
}
Ok(result)
}
pub fn write(&mut self, buffer: &mut PacketBuffer) -> Result<()> {
self.header.questions = self.questions.len() as u16;
self.header.answers = self.answers.len() as u16;
self.header.authoritative_entries = self.authorities.len() as u16;
self.header.resource_entries = self.resources.len() as u16;
self.header.write(buffer)?;
for question in &self.questions {
question.write(buffer)?;
}
for rec in &self.answers {
rec.write(buffer)?;
}
for rec in &self.authorities {
rec.write(buffer)?;
}
for rec in &self.resources {
rec.write(buffer)?;
}
Ok(())
}
pub fn get_random_a(&self) -> Option<IpAddr> {
self.answers
.iter()
.filter_map(|record| match record {
DnsRecord::A { addr, .. } => Some(IpAddr::V4(*addr)),
DnsRecord::AAAA { addr, .. } => Some(IpAddr::V6(*addr)),
_ => None,
})
.next()
}
fn get_ns<'a>(&'a self, qname: &'a str) -> impl Iterator<Item = (&'a str, &'a str)> {
self.authorities
.iter()
.filter_map(|record| match record {
DnsRecord::NS { domain, host, .. } => Some((domain.as_str(), host.as_str())),
_ => None,
})
.filter(move |(domain, _)| qname.ends_with(*domain))
}
pub fn get_resolved_ns(&self, qname: &str) -> Option<IpAddr> {
self.get_ns(qname)
.flat_map(|(_, host)| {
self.resources
.iter()
.filter_map(move |record| match record {
DnsRecord::A { domain, addr, .. } if domain == host => {
Some(IpAddr::V4(*addr))
}
DnsRecord::AAAA { domain, addr, .. } if domain == host => {
Some(IpAddr::V6(*addr))
}
_ => None,
})
})
.next()
}
pub fn get_unresolved_ns<'a>(&'a self, qname: &'a str) -> Option<&'a str> {
self.get_ns(qname).map(|(_, host)| host).next()
}
}

78
src/dns/packet/query.rs Normal file
View file

@ -0,0 +1,78 @@
#[derive(PartialEq, Eq, Debug, Clone, Hash, Copy)]
pub enum QueryType {
UNKNOWN(u16),
A, // 1
NS, // 2
CNAME, // 5
SOA, // 6
PTR, // 12
MX, // 15
TXT, // 16
AAAA, // 28
SRV, // 33
OPT, // 41
CAA, // 257
AR, // 1000
AAAAR, // 1001
}
impl QueryType {
pub fn to_num(&self) -> u16 {
match *self {
Self::UNKNOWN(x) => x,
Self::A => 1,
Self::NS => 2,
Self::CNAME => 5,
Self::SOA => 6,
Self::PTR => 12,
Self::MX => 15,
Self::TXT => 16,
Self::AAAA => 28,
Self::SRV => 33,
Self::OPT => 41,
Self::CAA => 257,
Self::AR => 1000,
Self::AAAAR => 1001,
}
}
pub fn from_num(num: u16) -> Self {
match num {
1 => Self::A,
2 => Self::NS,
5 => Self::CNAME,
6 => Self::SOA,
12 => Self::PTR,
15 => Self::MX,
16 => Self::TXT,
28 => Self::AAAA,
33 => Self::SRV,
41 => Self::OPT,
257 => Self::CAA,
1000 => Self::AR,
1001 => Self::AAAAR,
_ => Self::UNKNOWN(num),
}
}
pub fn allowed_actions(&self) -> (bool, bool) {
// 0. duplicates allowed
// 1. allowed to be created by database
match self {
QueryType::UNKNOWN(_) => (false, false),
QueryType::A => (true, true),
QueryType::NS => (false, true),
QueryType::CNAME => (false, true),
QueryType::SOA => (false, false),
QueryType::PTR => (false, true),
QueryType::MX => (false, true),
QueryType::TXT => (true, true),
QueryType::AAAA => (true, true),
QueryType::SRV => (false, true),
QueryType::OPT => (false, false),
QueryType::CAA => (false, true),
QueryType::AR => (false, true),
QueryType::AAAAR => (false, true),
}
}
}

View file

@ -0,0 +1,31 @@
use super::{buffer::PacketBuffer, query::QueryType, Result};
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct DnsQuestion {
pub name: String,
pub qtype: QueryType,
}
impl DnsQuestion {
pub fn new(name: String, qtype: QueryType) -> Self {
Self { name, qtype }
}
pub fn read(&mut self, buffer: &mut PacketBuffer) -> Result<()> {
buffer.read_qname(&mut self.name)?;
self.qtype = QueryType::from_num(buffer.read_u16()?); // qtype
let _ = buffer.read_u16()?; // class
Ok(())
}
pub fn write(&self, buffer: &mut PacketBuffer) -> Result<()> {
buffer.write_qname(&self.name)?;
let typenum = self.qtype.to_num();
buffer.write_u16(typenum)?;
buffer.write_u16(1)?;
Ok(())
}
}

544
src/dns/packet/record.rs Normal file
View file

@ -0,0 +1,544 @@
use std::net::{Ipv4Addr, Ipv6Addr};
use rand::RngCore;
use serde::{Deserialize, Serialize};
use tracing::{trace, warn};
use super::{buffer::PacketBuffer, query::QueryType, Result};
#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)]
pub enum DnsRecord {
UNKNOWN {
domain: String,
qtype: u16,
data_len: u16,
ttl: u32,
}, // 0
A {
domain: String,
addr: Ipv4Addr,
ttl: u32,
}, // 1
NS {
domain: String,
host: String,
ttl: u32,
}, // 2
CNAME {
domain: String,
host: String,
ttl: u32,
}, // 5
SOA {
domain: String,
mname: String,
nname: String,
serial: u32,
refresh: u32,
retry: u32,
expire: u32,
minimum: u32,
ttl: u32,
}, // 6
PTR {
domain: String,
pointer: String,
ttl: u32,
}, // 12
MX {
domain: String,
priority: u16,
host: String,
ttl: u32,
}, // 15
TXT {
domain: String,
text: Vec<String>,
ttl: u32,
}, //16
AAAA {
domain: String,
addr: Ipv6Addr,
ttl: u32,
}, // 28
SRV {
domain: String,
priority: u16,
weight: u16,
port: u16,
target: String,
ttl: u32,
}, // 33
CAA {
domain: String,
flags: u8,
length: u8,
tag: String,
value: String,
ttl: u32,
}, // 257
AR {
domain: String,
ttl: u32,
},
AAAAR {
domain: String,
ttl: u32,
},
}
impl DnsRecord {
pub fn read(buffer: &mut PacketBuffer) -> Result<Self> {
let mut domain = String::new();
buffer.read_qname(&mut domain)?;
let qtype_num = buffer.read_u16()?;
let qtype = QueryType::from_num(qtype_num);
let _ = buffer.read_u16()?;
let ttl = buffer.read_u32()?;
let data_len = buffer.read_u16()?;
trace!("Reading DNS Record TYPE: {:?}", qtype);
let header_pos = buffer.pos();
match qtype {
QueryType::A => {
let raw_addr = buffer.read_u32()?;
let addr = Ipv4Addr::new(
((raw_addr >> 24) & 0xFF) as u8,
((raw_addr >> 16) & 0xFF) as u8,
((raw_addr >> 8) & 0xFF) as u8,
(raw_addr & 0xFF) as u8,
);
Ok(Self::A { domain, addr, ttl })
}
QueryType::AAAA => {
let raw_addr1 = buffer.read_u32()?;
let raw_addr2 = buffer.read_u32()?;
let raw_addr3 = buffer.read_u32()?;
let raw_addr4 = buffer.read_u32()?;
let addr = Ipv6Addr::new(
((raw_addr1 >> 16) & 0xFFFF) as u16,
(raw_addr1 & 0xFFFF) as u16,
((raw_addr2 >> 16) & 0xFFFF) as u16,
(raw_addr2 & 0xFFFF) as u16,
((raw_addr3 >> 16) & 0xFFFF) as u16,
(raw_addr3 & 0xFFFF) as u16,
((raw_addr4 >> 16) & 0xFFFF) as u16,
(raw_addr4 & 0xFFFF) as u16,
);
Ok(Self::AAAA { domain, addr, ttl })
}
QueryType::NS => {
let mut ns = String::new();
buffer.read_qname(&mut ns)?;
Ok(Self::NS {
domain,
host: ns,
ttl,
})
}
QueryType::CNAME => {
let mut cname = String::new();
buffer.read_qname(&mut cname)?;
Ok(Self::CNAME {
domain,
host: cname,
ttl,
})
}
QueryType::SOA => {
let mut mname = String::new();
buffer.read_qname(&mut mname)?;
let mut nname = String::new();
buffer.read_qname(&mut nname)?;
let serial = buffer.read_u32()?;
let refresh = buffer.read_u32()?;
let retry = buffer.read_u32()?;
let expire = buffer.read_u32()?;
let minimum = buffer.read_u32()?;
Ok(Self::SOA {
domain,
mname,
nname,
serial,
refresh,
retry,
expire,
minimum,
ttl,
})
}
QueryType::PTR => {
let mut pointer = String::new();
buffer.read_qname(&mut pointer)?;
Ok(Self::PTR {
domain,
pointer,
ttl,
})
}
QueryType::MX => {
let priority = buffer.read_u16()?;
let mut mx = String::new();
buffer.read_qname(&mut mx)?;
Ok(Self::MX {
domain,
priority,
host: mx,
ttl,
})
}
QueryType::TXT => {
let mut text = Vec::new();
loop {
let mut s = String::new();
buffer.read_string(&mut s)?;
if s.len() == 0 {
break;
} else {
text.push(s);
}
}
Ok(Self::TXT { domain, text, ttl })
}
QueryType::SRV => {
let priority = buffer.read_u16()?;
let weight = buffer.read_u16()?;
let port = buffer.read_u16()?;
let mut target = String::new();
buffer.read_qname(&mut target)?;
Ok(Self::SRV {
domain,
priority,
weight,
port,
target,
ttl,
})
}
QueryType::CAA => {
let flags = buffer.read()?;
let length = buffer.read()?;
let mut tag = String::new();
buffer.read_string_n(&mut tag, length)?;
let value_len = (data_len as usize) + header_pos - buffer.pos;
let mut value = String::new();
buffer.read_string_n(&mut value, value_len as u8)?;
Ok(Self::CAA {
domain,
flags,
length,
tag,
value,
ttl,
})
}
QueryType::UNKNOWN(_) | _ => {
buffer.step(data_len as usize)?;
Ok(Self::UNKNOWN {
domain,
qtype: qtype_num,
data_len,
ttl,
})
}
}
}
pub fn write(&self, buffer: &mut PacketBuffer) -> Result<usize> {
let start_pos = buffer.pos();
trace!("Writing DNS Record {:?}", self);
match *self {
Self::A {
ref domain,
ref addr,
ttl,
} => {
buffer.write_qname(domain)?;
buffer.write_u16(QueryType::A.to_num())?;
buffer.write_u16(1)?;
buffer.write_u32(ttl)?;
buffer.write_u16(4)?;
let octets = addr.octets();
buffer.write_u8(octets[0])?;
buffer.write_u8(octets[1])?;
buffer.write_u8(octets[2])?;
buffer.write_u8(octets[3])?;
}
Self::NS {
ref domain,
ref host,
ttl,
} => {
buffer.write_qname(domain)?;
buffer.write_u16(QueryType::NS.to_num())?;
buffer.write_u16(1)?;
buffer.write_u32(ttl)?;
let pos = buffer.pos();
buffer.write_u16(0)?;
buffer.write_qname(host)?;
let size = buffer.pos() - (pos + 2);
buffer.set_u16(pos, size as u16)?;
}
Self::CNAME {
ref domain,
ref host,
ttl,
} => {
buffer.write_qname(domain)?;
buffer.write_u16(QueryType::CNAME.to_num())?;
buffer.write_u16(1)?;
buffer.write_u32(ttl)?;
let pos = buffer.pos();
buffer.write_u16(0)?;
buffer.write_qname(host)?;
let size = buffer.pos() - (pos + 2);
buffer.set_u16(pos, size as u16)?;
}
Self::SOA {
ref domain,
ref mname,
ref nname,
serial,
refresh,
retry,
expire,
minimum,
ttl,
} => {
buffer.write_qname(domain)?;
buffer.write_u16(QueryType::SOA.to_num())?;
buffer.write_u16(1)?;
buffer.write_u32(ttl)?;
let pos = buffer.pos();
buffer.write_u16(0)?;
buffer.write_qname(mname)?;
buffer.write_qname(nname)?;
buffer.write_u32(serial)?;
buffer.write_u32(refresh)?;
buffer.write_u32(retry)?;
buffer.write_u32(expire)?;
buffer.write_u32(minimum)?;
let size = buffer.pos() - (pos + 2);
buffer.set_u16(pos, size as u16)?;
}
Self::PTR {
ref domain,
ref pointer,
ttl,
} => {
buffer.write_qname(domain)?;
buffer.write_u16(QueryType::NS.to_num())?;
buffer.write_u16(1)?;
buffer.write_u32(ttl)?;
let pos = buffer.pos();
buffer.write_u16(0)?;
buffer.write_qname(&pointer)?;
let size = buffer.pos() - (pos + 2);
buffer.set_u16(pos, size as u16)?;
}
Self::MX {
ref domain,
priority,
ref host,
ttl,
} => {
buffer.write_qname(domain)?;
buffer.write_u16(QueryType::MX.to_num())?;
buffer.write_u16(1)?;
buffer.write_u32(ttl)?;
let pos = buffer.pos();
buffer.write_u16(0)?;
buffer.write_u16(priority)?;
buffer.write_qname(host)?;
let size = buffer.pos() - (pos + 2);
buffer.set_u16(pos, size as u16)?;
}
Self::TXT {
ref domain,
ref text,
ttl,
} => {
buffer.write_qname(&domain)?;
buffer.write_u16(QueryType::TXT.to_num())?;
buffer.write_u16(1)?;
buffer.write_u32(ttl)?;
let pos = buffer.pos();
buffer.write_u16(0)?;
if text.is_empty() {
return Ok(buffer.pos() - start_pos);
}
for s in text {
buffer.write_u8(s.len() as u8)?;
buffer.write_string(&s)?;
}
let size = buffer.pos() - (pos + 2);
buffer.set_u16(pos, size as u16)?;
}
Self::AAAA {
ref domain,
ref addr,
ttl,
} => {
buffer.write_qname(domain)?;
buffer.write_u16(QueryType::AAAA.to_num())?;
buffer.write_u16(1)?;
buffer.write_u32(ttl)?;
buffer.write_u16(16)?;
for octet in &addr.segments() {
buffer.write_u16(*octet)?;
}
}
Self::SRV {
ref domain,
priority,
weight,
port,
ref target,
ttl,
} => {
buffer.write_qname(domain)?;
buffer.write_u16(QueryType::SRV.to_num())?;
buffer.write_u16(1)?;
buffer.write_u32(ttl)?;
let pos = buffer.pos();
buffer.write_u16(0)?;
buffer.write_u16(priority)?;
buffer.write_u16(weight)?;
buffer.write_u16(port)?;
buffer.write_qname(target)?;
let size = buffer.pos() - (pos + 2);
buffer.set_u16(pos, size as u16)?;
}
Self::CAA {
ref domain,
flags,
length,
ref tag,
ref value,
ttl,
} => {
buffer.write_qname(domain)?;
buffer.write_u16(QueryType::CAA.to_num())?;
buffer.write_u16(1)?;
buffer.write_u32(ttl)?;
let pos = buffer.pos();
buffer.write_u16(0)?;
buffer.write_u8(flags)?;
buffer.write_u8(length)?;
buffer.write_string(tag)?;
buffer.write_string(value)?;
let size = buffer.pos() - (pos + 2);
buffer.set_u16(pos, size as u16)?;
}
Self::AR { ref domain, ttl } => {
buffer.write_qname(domain)?;
buffer.write_u16(QueryType::A.to_num())?;
buffer.write_u16(1)?;
buffer.write_u32(ttl)?;
buffer.write_u16(4)?;
let mut rand = rand::thread_rng();
buffer.write_u32(rand.next_u32())?;
}
Self::AAAAR { ref domain, ttl } => {
buffer.write_qname(domain)?;
buffer.write_u16(QueryType::A.to_num())?;
buffer.write_u16(1)?;
buffer.write_u32(ttl)?;
buffer.write_u16(4)?;
let mut rand = rand::thread_rng();
buffer.write_u32(rand.next_u32())?;
buffer.write_u32(rand.next_u32())?;
buffer.write_u32(rand.next_u32())?;
buffer.write_u32(rand.next_u32())?;
}
Self::UNKNOWN { .. } => {
warn!("Skipping record: {self:?}");
}
}
Ok(buffer.pos() - start_pos)
}
pub fn get_domain(&self) -> String {
self.get_shared_domain().0
}
pub fn get_qtype(&self) -> QueryType {
self.get_shared_domain().1
}
pub fn get_ttl(&self) -> u32 {
self.get_shared_domain().2
}
fn get_shared_domain(&self) -> (String, QueryType, u32) {
match self {
DnsRecord::UNKNOWN {
domain, ttl, qtype, ..
} => (domain.clone(), QueryType::UNKNOWN(*qtype), *ttl),
DnsRecord::AAAA { domain, ttl, .. } => (domain.clone(), QueryType::AAAA, *ttl),
DnsRecord::A { domain, ttl, .. } => (domain.clone(), QueryType::A, *ttl),
DnsRecord::NS { domain, ttl, .. } => (domain.clone(), QueryType::NS, *ttl),
DnsRecord::CNAME { domain, ttl, .. } => (domain.clone(), QueryType::CNAME, *ttl),
DnsRecord::SOA { domain, ttl, .. } => (domain.clone(), QueryType::SOA, *ttl),
DnsRecord::PTR { domain, ttl, .. } => (domain.clone(), QueryType::PTR, *ttl),
DnsRecord::MX { domain, ttl, .. } => (domain.clone(), QueryType::MX, *ttl),
DnsRecord::TXT { domain, ttl, .. } => (domain.clone(), QueryType::TXT, *ttl),
DnsRecord::SRV { domain, ttl, .. } => (domain.clone(), QueryType::SRV, *ttl),
DnsRecord::CAA { domain, ttl, .. } => (domain.clone(), QueryType::CAA, *ttl),
DnsRecord::AR { domain, ttl, .. } => (domain.clone(), QueryType::AR, *ttl),
DnsRecord::AAAAR { domain, ttl, .. } => (domain.clone(), QueryType::AAAAR, *ttl),
}
}
}

22
src/dns/packet/result.rs Normal file
View file

@ -0,0 +1,22 @@
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub enum ResultCode {
NOERROR = 0,
FORMERR = 1,
SERVFAIL = 2,
NXDOMAIN = 3,
NOTIMP = 4,
REFUSED = 5,
}
impl ResultCode {
pub fn from_num(num: u8) -> Self {
match num {
1 => Self::FORMERR,
2 => Self::SERVFAIL,
3 => Self::NXDOMAIN,
4 => Self::NOTIMP,
5 => Self::REFUSED,
0 | _ => Self::NOERROR,
}
}
}

230
src/dns/resolver.rs Normal file
View file

@ -0,0 +1,230 @@
use super::binding::Connection;
use super::packet::{query::QueryType, question::DnsQuestion, result::ResultCode, Packet};
use crate::Result;
use crate::{config::Config, database::Database, get_time};
use async_recursion::async_recursion;
use moka::future::Cache;
use std::{net::IpAddr, sync::Arc, time::Duration};
use tracing::{error, trace};
pub struct Resolver {
request_id: u16,
connection: Connection,
config: Arc<Config>,
database: Arc<Database>,
cache: Cache<DnsQuestion, (Packet, u64)>,
}
impl Resolver {
pub fn new(
request_id: u16,
connection: Connection,
config: Arc<Config>,
database: Arc<Database>,
cache: Cache<DnsQuestion, (Packet, u64)>,
) -> Self {
Self {
request_id,
connection,
config,
database,
cache,
}
}
async fn lookup_database(&self, question: &DnsQuestion) -> Option<Packet> {
let records = match self
.database
.get_records(&question.name, question.qtype)
.await
{
Ok(record) => record,
Err(err) => {
error!("{err}");
return None;
}
};
if records.is_empty() {
return None;
}
let mut packet = Packet::new();
packet.header.id = self.request_id;
packet.header.questions = 1;
packet.header.answers = records.len() as u16;
packet.header.recursion_desired = true;
packet
.questions
.push(DnsQuestion::new(question.name.to_string(), question.qtype));
for record in records {
packet.answers.push(record);
}
trace!(
"Found stored value for {:?} {}",
question.qtype,
question.name
);
Some(packet)
}
async fn lookup_cache(&self, question: &DnsQuestion) -> Option<Packet> {
let Some((packet, date)) = self.cache.get(&question) else {
return None
};
let now = get_time();
let diff = Duration::from_millis(now - date).as_secs() as u32;
for answer in &packet.answers {
let ttl = answer.get_ttl();
if diff > ttl {
self.cache.invalidate(&question).await;
return None;
}
}
trace!(
"Found cached value for {:?} {}",
question.qtype,
question.name
);
Some(packet)
}
async fn lookup_fallback(&self, question: &DnsQuestion, server: (IpAddr, u16)) -> Packet {
let mut packet = Packet::new();
packet.header.id = self.request_id;
packet.header.questions = 1;
packet.header.recursion_desired = true;
packet
.questions
.push(DnsQuestion::new(question.name.to_string(), question.qtype));
let packet = match self.connection.request_packet(packet, server).await {
Ok(packet) => packet,
Err(e) => {
error!("Failed to complete nameserver request: {e}");
let mut packet = Packet::new();
packet.header.rescode = ResultCode::SERVFAIL;
packet
}
};
packet
}
async fn lookup(&self, question: &DnsQuestion, server: (IpAddr, u16)) -> Packet {
if let Some(packet) = self.lookup_cache(question).await {
return packet;
};
if let Some(packet) = self.lookup_database(question).await {
return packet;
};
trace!(
"Attempting lookup of {:?} {} with ns {}",
question.qtype,
question.name,
server.0
);
self.lookup_fallback(question, server).await
}
#[async_recursion]
async fn recursive_lookup(&mut self, qname: &str, qtype: QueryType) -> Packet {
let question = DnsQuestion::new(qname.to_string(), qtype);
let mut ns = self.config.dns_fallback.clone();
loop {
let ns_copy = ns;
let server = (ns_copy, 53);
let response = self.lookup(&question, server).await;
if !response.answers.is_empty() && response.header.rescode == ResultCode::NOERROR {
self.cache
.insert(question, (response.clone(), get_time()))
.await;
return response;
}
if response.header.rescode == ResultCode::NXDOMAIN {
self.cache
.insert(question, (response.clone(), get_time()))
.await;
return response;
}
if let Some(new_ns) = response.get_resolved_ns(qname) {
ns = new_ns;
continue;
}
let new_ns_name = match response.get_unresolved_ns(qname) {
Some(x) => x,
None => {
self.cache
.insert(question, (response.clone(), get_time()))
.await;
return response;
}
};
let recursive_response = self.recursive_lookup(new_ns_name, QueryType::A).await;
if let Some(new_ns) = recursive_response.get_random_a() {
ns = new_ns;
} else {
self.cache
.insert(question, (response.clone(), get_time()))
.await;
return response;
}
}
}
pub async fn handle_query(mut self) -> Result<()> {
let mut request = self.connection.read_packet().await?;
let mut packet = Packet::new();
packet.header.id = request.header.id;
packet.header.recursion_desired = true;
packet.header.recursion_available = true;
packet.header.response = true;
if let Some(question) = request.questions.pop() {
trace!("Received query: {question:?}");
let result = self.recursive_lookup(&question.name, question.qtype).await;
packet.questions.push(question.clone());
packet.header.rescode = result.header.rescode;
for rec in result.answers {
trace!("Answer: {rec:?}");
packet.answers.push(rec);
}
for rec in result.authorities {
trace!("Authority: {rec:?}");
packet.authorities.push(rec);
}
for rec in result.resources {
trace!("Resource: {rec:?}");
packet.resources.push(rec);
}
} else {
packet.header.rescode = ResultCode::FORMERR;
}
self.connection.write_packet(packet).await?;
Ok(())
}
}

85
src/dns/server.rs Normal file
View file

@ -0,0 +1,85 @@
use super::{
binding::Binding,
packet::{question::DnsQuestion, Packet},
resolver::Resolver,
};
use crate::{config::Config, database::Database, Result};
use moka::future::Cache;
use std::{net::SocketAddr, sync::Arc, time::Duration};
use tokio::task::JoinHandle;
use tracing::{error, info};
pub struct DnsServer {
addr: SocketAddr,
config: Arc<Config>,
database: Arc<Database>,
cache: Cache<DnsQuestion, (Packet, u64)>,
}
impl DnsServer {
pub async fn new(config: Config, database: Database) -> Result<Self> {
let addr = format!("[::]:{}", config.dns_port).parse::<SocketAddr>()?;
let cache = Cache::builder()
.time_to_live(Duration::from_secs(60 * 60))
.max_capacity(config.dns_cache_size)
.build();
info!("Created DNS cache with size of {}", config.dns_cache_size);
Ok(Self {
addr,
config: Arc::new(config),
database: Arc::new(database),
cache,
})
}
pub async fn run(&self) -> Result<(JoinHandle<()>, JoinHandle<()>)> {
let tcp = Binding::tcp(self.addr).await?;
let tcp_handle = self.listen(tcp);
let udp = Binding::udp(self.addr).await?;
let udp_handle = self.listen(udp);
info!(
"Fallback DNS Server is set to: {:?}",
self.config.dns_fallback
);
info!(
"Listening for TCP and UDP traffic on [::]:{}",
self.config.dns_port
);
Ok((udp_handle, tcp_handle))
}
fn listen(&self, mut binding: Binding) -> JoinHandle<()> {
let config = self.config.clone();
let database = self.database.clone();
let cache = self.cache.clone();
tokio::spawn(async move {
let mut id = 0;
loop {
let Ok(connection) = binding.connect().await else { continue };
info!("Received request on {}", binding.name());
let resolver = Resolver::new(
id,
connection,
config.clone(),
database.clone(),
cache.clone(),
);
let name = binding.name().to_string();
tokio::spawn(async move {
if let Err(err) = resolver.handle_query().await {
error!("{} request {} failed: {:?}", name, id, err);
};
});
id += 1;
}
})
}
}

View file

@ -1,525 +0,0 @@
#include <errno.h>
#include <stdio.h>
#include <string.h>
#include <stdlib.h>
#include "config.h"
#include "log.h"
#include "map.h"
#define MAX_LEN 1024
#define BUF(name) char name[MAX_LEN]
static int line = 0;
static bool get_line(FILE* file, const BUF(buf)) {
line++;
return fgets((char*) buf, MAX_LEN, file) != NULL;
}
static bool is_whitespace(const char* buf) {
int i = 0;
char c;
while(c = buf[i], 1) {
if (c == '\n' || c == '\0') return true;
if (c != ' ' && c != '\n') return false;
i++;
}
}
static bool get_words(char* buf, char** words, int count) {
int last = 0;
int offset = 0;
int i = 0;
while(1) {
char c;
while(1) {
if (offset == MAX_LEN) return false;
c = buf[offset];
if (c == '\0' || c == '\n') {
break;
}
if (c == ' ' && i + 1 != count) {
break;
}
offset++;
}
if (offset - last < 1) {
return false;
}
words[i] = buf + last;
buf[offset] = '\0';
offset++;
last = offset;
if (c == '\0' || c == '\n') {
break;
} else if (i + 1 == count) {
break;
}
i++;
}
return i + 1 == count;
}
static bool get_int(const char* word, uint32_t* i) {
char* end;
uint32_t res = (uint32_t) strtol(word, &end, 10);
if (*end == '\0') {
*i = res;
return true;
} else {
return false;
}
}
static bool config_read_qtype(const char* qstr, RecordType* qtype) {
if (strcmp(qstr, "A") == 0) {
*qtype = A;
return true;
} else if (strcmp(qstr, "NS") == 0) {
*qtype = NS;
return true;
} else if (strcmp(qstr, "CNAME") == 0) {
*qtype = CNAME;
return true;
} else if (strcmp(qstr, "SOA") == 0) {
*qtype = SOA;
return true;
} else if (strcmp(qstr, "PTR") == 0) {
*qtype = PTR;
return true;
} else if (strcmp(qstr, "MX") == 0) {
*qtype = MX;
return true;
} else if (strcmp(qstr, "TXT") == 0) {
*qtype = TXT;
return true;
} else if (strcmp(qstr, "AAAA") == 0) {
*qtype = AAAA;
return true;
} else if (strcmp(qstr, "SRV") == 0) {
*qtype = SRV;
return true;
} else if (strcmp(qstr, "CAA") == 0) {
*qtype = CAA;
return true;
} else if (strcmp(qstr, "CMD") == 0) {
*qtype = CMD;
return true;
} else if(strcmp(qstr, "AR") == 0) {
*qtype = AR;
return true;
} else if(strcmp(qstr, "AAAAR") == 0) {
*qtype = AAAAR;
return true;
} else {
return false;
}
}
static bool config_read_class(const char* cstr, uint16_t* class) {
if (strcmp(cstr, "IN") == 0) {
*class = 1;
return true;
} else if (strcmp(cstr, "CH") == 0) {
*class = 3;
return true;
} else if (strcmp(cstr, "HS") == 0) {
*class = 4;
return true;
} else {
return false;
}
}
// Format QTYPE CLASS DOMAIN: A IN google.com
static bool config_read_question(FILE* file, Question* question) {
BUF(buf);
if (!get_line(file, buf)) {
return false;
}
if (is_whitespace(buf)) {
return false;
}
char* words[3];
if (!get_words(&buf[0], &words[0], 3)) {
WARN("Invalid question at line %d", line);
return false;
};
uint16_t class;
if (!config_read_class(words[0], &class)) {
WARN("Invalid question class at line %d", line);
return false;
}
RecordType qtype;
if (!config_read_qtype(words[1], &qtype)) {
WARN("Invalid question qtype at line %d", line);
return false;
}
size_t domain_len = strlen(words[2]);
question->cls = class;
question->qtype = qtype;
question->domain = malloc(domain_len + 1);
question->domain[0] = domain_len;
memcpy(question->domain + 1, words[2], domain_len);
return true;
}
static void copy_str(char* from, uint8_t** to) {
size_t len = strlen(from);
if (len > 255) {
len = 255;
}
uint8_t* new = malloc(len + 1);
new[0] = len;
memcpy(new + 1, from, len);
*to = new;
}
static bool config_read_a_record(char* data, ARecord* record) {
sscanf(data, "%hhu.%hhu.%hhu.%hhu",
&record->addr[0],
&record->addr[1],
&record->addr[2],
&record->addr[3]
);
return true;
}
static bool config_read_ns_record(char* data, NSRecord* record) {
copy_str(data, &record->host);
return true;
}
static bool config_read_cname_record(char* data, CNAMERecord* record) {
copy_str(data, &record->host);
return true;
}
static bool config_read_soa_record(char* data, SOARecord* record) {
char* words[7];
if (!get_words(&data[0], &words[0], 7)) {
WARN("Invalid SOA record data at line %d", line);
record->mname = NULL;
record->nname = NULL;
return false;
}
copy_str(words[0], &record->mname);
copy_str(words[1], &record->nname);
if (!get_int(words[2], &record->serial)) {
WARN("Invalid SOA record data at line %d", line);
return false;
}
if (!get_int(words[3], &record->refresh)) {
WARN("Invalid SOA record data at line %d", line);
return false;
}
if (!get_int(words[4], &record->retry)) {
WARN("Invalid SOA record data at line %d", line);
return false;
}
if (!get_int(words[5], &record->expire)) {
WARN("Invalid SOA record data at line %d", line);
return false;
}
if (!get_int(words[6], &record->minimum)) {
WARN("Invalid SOA record data at line %d", line);
return false;
}
return true;
}
static bool config_read_ptr_record(char* data, PTRRecord* record) {
copy_str(data, &record->pointer);
return true;
}
static bool config_read_mx_record(char* data, MXRecord* record) {
char* words[2];
if (!get_words(&data[0], &words[0], 2)) {
WARN("Invalid MX record data at line %d", line);
record->host = NULL;
return false;
}
copy_str(words[1], &record->host);
uint32_t priority;
if (!get_int(words[0], &priority)) {
WARN("Invalid MX record data at line %d", line);
return false;
}
record->priority = (uint16_t) priority;
return true;
}
static bool config_read_txt_record(char* data, TXTRecord* record) {
int len = strlen(data);
uint8_t count = ((uint8_t)len + 254) / 255;
record->len = count;
record->text = malloc(sizeof(uint8_t*) * count);
for (uint8_t i = 0; i < count; i++) {
uint32_t offset = count * 255;
uint32_t part_len = len - offset;
if (part_len > 255) part_len = 255;
uint8_t* part = malloc(part_len + 1);
part[0] = part_len;
memcpy(part + 1, data + offset, part_len);
record->text[i] = part;
}
return true;
}
static bool config_read_aaaa_record(char* data, AAAARecord* record) {
for(int i = 0; i < 8; i++) {
if (sscanf(data, "%02hhx%02hhx:",
&record->addr[i*2 + 0],
&record->addr[i*2 + 1]
) == EOF) {
return false;
}
}
return true;
}
static bool config_read_srv_record(char* data, SRVRecord* record) {
char* words[4];
if (!get_words(&data[0], &words[0], 4)) {
WARN("Invalid SRV record data at line %d", line);
record->target = NULL;
return false;
}
copy_str(words[3], &record->target);
uint32_t priority;
if (!get_int(words[0], &priority)) {
WARN("Invalid SRV record data at line %d", line);
return false;
}
record->priority = (uint16_t) priority;
uint32_t weight;
if (!get_int(words[1], &weight)) {
WARN("Invalid SRV record data at line %d", line);
return false;
}
record->weight = (uint16_t) weight;
uint32_t port;
if (!get_int(words[2], &port)) {
WARN("Invalid SRV record data at line %d", line);
return false;
}
record->port = (uint16_t) port;
return true;
}
static bool config_read_caa_record(char* data, CAARecord* record) {
char* words[4];
if (!get_words(&data[0], &words[0], 4)) {
WARN("Invalid SRV record data at line %d", line);
record->tag = NULL;
record->value = NULL;
return false;
}
copy_str(words[2], &record->tag);
copy_str(words[3], &record->value);
uint32_t flags;
if (!get_int(words[0], &flags)) {
WARN("Invalid SRV record data at line %d", line);
return false;
}
record->flags = (uint8_t) flags;
uint32_t length;
if (!get_int(words[1], &length)) {
WARN("Invalid SRV record data at line %d", line);
return false;
}
record->length = (uint8_t) length;
return true;
}
static bool config_read_cmd_record(char* data, CMDRecord* record) {
int len = strlen(data);
record->command = malloc(len);
memcpy(record->command, data, len);
return true;
}
static bool config_read_record_data(char* data, Record* record) {
switch (record->type) {
case UNKOWN:
// This can never happend in here so uh do nothing i guess
return false;
case A:
return config_read_a_record(data, &record->data.a);
case NS:
return config_read_ns_record(data, &record->data.ns);
case CNAME:
return config_read_cname_record(data, &record->data.cname);
case SOA:
return config_read_soa_record(data, &record->data.soa);
case PTR:
return config_read_ptr_record(data, &record->data.ptr);
case MX:
return config_read_mx_record(data, &record->data.mx);
case TXT:
return config_read_txt_record(data, &record->data.txt);
case AAAA:
return config_read_aaaa_record(data, &record->data.aaaa);
case SRV:
return config_read_srv_record(data, &record->data.srv);
case CAA:
return config_read_caa_record(data, &record->data.caa);
case CMD:
return config_read_cmd_record(data, &record->data.cmd);
case AR:
case AAAAR:
memset(&record->data, 0, sizeof(record->data));
return true;
}
return false;
}
static bool config_read_record(FILE* file, Record* record, Question* question) {
BUF(buf);
if (!get_line(file, buf)) {
return false;
}
if (is_whitespace(buf)) {
return false;
}
char* words[4];
if (!get_words(&buf[0], &words[0], 4)) {
WARN("Invalid record at line %d", line);
return false;
}
uint16_t class;
if (!config_read_class(words[0], &class)) {
WARN("Invalid question class at line %d", line);
return false;
}
RecordType qtype;
if (!config_read_qtype(words[1], &qtype)) {
WARN("Invalid question qtype at line %d", line);
return false;
}
uint32_t ttl;
if (!get_int(words[2], &ttl)) {
WARN("Invalid record ttl at line %d", line);
return false;
}
record->cls = class;
record->type = qtype;
record->len = 0;
record->ttl = ttl;
record->domain = malloc(question->domain[0] + 1);
memcpy(record->domain, question->domain, question->domain[0] + 1);
if(!config_read_record_data(words[3], record)) {
free_record(record);
return false;
}
return true;
}
static void config_push_record(Record** buf, Record record, uint16_t* capacity, uint16_t* size) {
if (size == capacity) {
*capacity *= 2;
*buf = realloc(*buf, sizeof(Record) * *capacity);
}
(*buf)[*size] = record;
(*size)++;
}
bool load_config(const char* path, RecordMap* map) {
FILE* file = fopen(path, "r");
if (file == NULL) {
return false;
}
INFO("Using config file at path: %s", path);
line = 0;
record_map_init(map);
while (1) {
Question* question = malloc(sizeof(Question));
if (!config_read_question(file, question)) {
free(question);
break;
}
INIT_LOG_BUFFER(log);
LOGONLY(print_question(question, log));
TRACE("Found config question: %s", log);
Packet* packet = malloc(sizeof(Packet));
memset(packet, 0, sizeof(Packet));
packet->authorities = NULL;
packet->resources = NULL;
packet->questions = malloc(sizeof(Question));
packet->questions[0] = *question;
uint16_t capacity = 1;
packet->answers = malloc(sizeof(Record));
while(1) {
Record record;
if (!config_read_record(file, &record, question)) {
break;
}
LOGONLY(print_record(&record, log));
TRACE("Found config record: %s", log);
config_push_record(&packet->answers, record, &capacity, &packet->header.answers);
}
record_map_add(map, question, packet);
}
fclose(file);
return true;
}

View file

@ -1,5 +0,0 @@
#pragma once
#include "map.h"
bool load_config(const char* path, RecordMap* map);

View file

@ -1,49 +0,0 @@
#include <stdarg.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <time.h>
#include "log.h"
#ifdef LOG
void logmsg(LogLevel level, const char* msg, ...) {
INIT_LOG_BOUNDS
INIT_LOG_BUFFER(buffer)
time_t now = time(NULL);
struct tm *tm = localtime(&now);
APPEND(buffer, "\x1b[97m%02d:%02d:%02d ", tm->tm_hour, tm->tm_min, tm->tm_sec);
switch (level) {
case DEBUG:
APPEND(buffer, "\x1b[%dm%s\x1b[97m ", 95, "DEBUG");
break;
case TRACE:
APPEND(buffer, "\x1b[%dm%s\x1b[97m ", 96, "TRACE");
break;
case INFO:
APPEND(buffer, "\x1b[%dm%s\x1b[97m ", 92, "INFO");
break;
case WARN:
APPEND(buffer, "\x1b[%dm%s\x1b[97m ", 93, "WARN");
break;
case ERROR:
APPEND(buffer, "\x1b[%dm%s\x1b[97m ", 91, "ERROR");
break;
break;
}
va_list valist;
va_start(valist, msg);
t += vsnprintf(buffer + t, BUF_LENGTH - t, msg, valist);
va_end(valist);
APPEND(buffer, "\n");
fwrite(&buffer, t, 1, stdout);
}
#endif

View file

@ -1,45 +0,0 @@
#pragma once
#include <stdint.h>
#define LOG
#ifdef LOG
typedef enum {
DEBUG,
TRACE,
INFO,
WARN,
ERROR,
} LogLevel;
#define BUF_LENGTH 256
#define INIT_LOG_BUFFER(name) char name[BUF_LENGTH];
#define INIT_LOG_BOUNDS int t = 0;
#define APPEND(buffer, msg, ...) t += snprintf(buffer + t, BUF_LENGTH - t, msg, ##__VA_ARGS__);
#define LOGONLY(code) code
void logmsg(LogLevel level, const char* msg, ...)
__attribute__ ((__format__(printf, 2, 3)));
#define DEBUG(msg, ...) logmsg(DEBUG, msg, ##__VA_ARGS__)
#define TRACE(msg, ...) logmsg(TRACE, msg, ##__VA_ARGS__)
#define INFO(msg, ...) logmsg(INFO, msg, ##__VA_ARGS__)
#define WARN(msg, ...) logmsg(WARN, msg, ##__VA_ARGS__)
#define ERROR(msg, ...) logmsg(ERROR, msg, ##__VA_ARGS__)
#else
#define BUF_LENGTH
#define INIT_LOG_BUFFER(name)
#define INIT_LOG_BOUNDS
#define APPEND(buffer, msg, ...)
#define LOGONLY(code)
#define DEBUG(msg, ...)
#define TRACE(msg, ...)
#define INFO(msg, ...)
#define WARN(msg, ...)
#define ERROR(msg, ...)
#endif

View file

@ -1,104 +0,0 @@
#include <string.h>
#include <stdlib.h>
#include "map.h"
void record_map_init(RecordMap* map) {
map->capacity = 0;
map->len = 0;
map->entries = NULL;
}
void record_map_free(RecordMap* map) {
for(uint32_t i = 0; i < map->capacity; i++) {
Entry* e = &map->entries[i];
if (e->key != NULL) {
free_question(e->key);
free(e->key);
free_packet(e->value);
free(e->value);
}
}
free(map->entries);
}
static size_t hash_question(const Question* question) {
size_t hash = 5381;
for(int i = 0; i < question->domain[0]; i++) {
uint8_t c = question->domain[i+1];
hash = ((hash << 5) + hash) + c;
}
hash = ((hash << 5) + hash) + (uint8_t)question->cls;
hash = ((hash << 5) + hash) + (uint8_t)question->qtype;
return hash;
}
static bool question_equals(const Question* a, const Question* b) {
if (a->qtype != b->qtype) return false;
if (a->cls != b->cls) return false;
if (a->domain[0] != b->domain[0]) return false;
return memcmp(a->domain+1, b->domain+1, a->domain[0]) == 0;
}
static Entry* record_map_find(Entry* entries, uint32_t capacity, const Question* key) {
uint32_t index = hash_question(key) % capacity;
while(true) {
Entry* entry = &entries[index];
if(entry->key == NULL) {
return entry;
} else if(question_equals(entry->key, key)) {
return entry;
}
index += 1;
index %= capacity;
}
}
static void record_map_grow(RecordMap* map, uint32_t capacity) {
Entry* entries = malloc(capacity * sizeof(Entry));
for(uint32_t i = 0; i < capacity; i++) {
entries[i].key = NULL;
entries[i].value = NULL;
}
map->len = 0;
for(uint32_t i = 0; i < map->capacity; i++) {
Entry* entry = &map->entries[i];
if(entry->key == NULL) continue;
Entry* dest = record_map_find(entries, capacity, entry->key);
dest->key = entry->key;
dest->value = entry->value;
map->len++;
}
free(map->entries);
map->entries = entries;
map->capacity = capacity;
}
bool record_map_get(const RecordMap* map, const Question* key, Packet* value) {
if(map->len == 0) return false;
Entry* e = record_map_find(map->entries, map->capacity, key);
if (e->key == NULL) return false;
*value = *(e->value);
return true;
}
void record_map_add(RecordMap* map, Question* key, Packet* value) {
if(map->len + 1 > map->capacity * 0.75) {
int capacity = (map->capacity == 0 ? 8 : (2 * map->capacity));
record_map_grow(map, capacity);
}
Entry* e = record_map_find(map->entries, map->capacity, key);
bool new_key = e->key == NULL;
if(new_key) {
map->len++;
e->key = key;
}
value->header.z = true;
e->value = value;
}

View file

@ -1,20 +0,0 @@
#pragma once
#include "../packet/packet.h"
typedef struct {
Question* key;
Packet* value;
} Entry;
typedef struct {
uint32_t capacity;
uint32_t len;
Entry* entries;
} RecordMap;
void record_map_init(RecordMap* map);
void record_map_free(RecordMap* map);
bool record_map_get(const RecordMap* map, const Question* key, Packet* value);
void record_map_add(RecordMap* map, Question* key, Packet* value);

View file

@ -1,31 +0,0 @@
#include "server/server.h"
#include <stdlib.h>
#define DEFAULT_PORT 53
static uint16_t get_port(const char* port_str) {
if (port_str == NULL) {
return DEFAULT_PORT;
}
uint16_t port;
if ((port = strtoul(port_str, NULL, 10)) == 0) {
return DEFAULT_PORT;
}
return port;
}
int main(void) {
const char* port_str = getenv("PORT");
uint16_t port = get_port(port_str);
Server server;
server_init(port, &server);
server_run(&server);
server_free(&server);
return EXIT_SUCCESS;
}

64
src/main.rs Normal file
View file

@ -0,0 +1,64 @@
use std::time::{SystemTime, UNIX_EPOCH};
use config::Config;
use database::Database;
use dotenv::dotenv;
use dns::server::DnsServer;
use tracing::{error, metadata::LevelFilter};
use tracing_subscriber::{
filter::filter_fn, prelude::__tracing_subscriber_SubscriberExt, util::SubscriberInitExt, Layer,
};
use web::WebServer;
mod config;
mod database;
mod dns;
mod web;
type Error = Box<dyn std::error::Error>;
pub type Result<T> = std::result::Result<T, Error>;
#[tokio::main]
async fn main() {
if let Err(err) = run().await {
error!("{err}")
};
}
async fn run() -> Result<()> {
dotenv().ok();
tracing_subscriber::registry()
.with(
tracing_subscriber::fmt::layer()
.with_filter(LevelFilter::TRACE)
.with_filter(filter_fn(|metadata| {
metadata.target().starts_with("wrapper")
})),
)
.init();
let config = Config::new();
let database = Database::new(config.clone()).await?;
let dns_server = DnsServer::new(config.clone(), database.clone()).await?;
let (udp, tcp) = dns_server.run().await?;
let web_server = WebServer::new(config, database).await?;
let web = web_server.run().await?;
tokio::join!(udp).0?;
tokio::join!(tcp).0?;
tokio::join!(web).0?;
Ok(())
}
pub fn get_time() -> u64 {
let start = SystemTime::now();
let since_the_epoch = start
.duration_since(UNIX_EPOCH)
.expect("Time went backwards");
since_the_epoch.as_millis() as u64
}

View file

@ -1,250 +0,0 @@
#include "buffer.h"
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
struct PacketBuffer {
uint8_t* arr;
int capacity;
int index;
int size;
};
PacketBuffer* buffer_create(int capacity) {
PacketBuffer* buffer = malloc(sizeof(PacketBuffer));
buffer->arr = malloc(capacity);
memset(buffer->arr, 0, capacity);
buffer->capacity = capacity;
buffer->size = 0;
buffer->index = 0;
return buffer;
}
void buffer_free(PacketBuffer* buffer) {
free(buffer->arr);
free(buffer);
}
void buffer_seek(PacketBuffer* buffer, int index) {
buffer->index = index;
}
uint8_t buffer_read(PacketBuffer* buffer) {
if (buffer->index >= buffer->size) {
return 0;
}
uint8_t data = buffer->arr[buffer->index];
buffer->index++;
return data;
}
uint16_t buffer_read_short(PacketBuffer* buffer) {
return
(uint16_t) buffer_read(buffer) << 8 |
(uint16_t) buffer_read(buffer);
}
uint32_t buffer_read_int(PacketBuffer* buffer) {
return
(uint32_t) buffer_read(buffer) << 24 |
(uint32_t) buffer_read(buffer) << 16 |
(uint32_t) buffer_read(buffer) << 8 |
(uint32_t) buffer_read(buffer);
}
uint8_t buffer_get(PacketBuffer* buffer, int index) {
if (index >= buffer->size) {
return 0;
}
uint8_t data = buffer->arr[index];
return data;
}
uint8_t* buffer_get_range(PacketBuffer* buffer, int start, int len) {
uint8_t* arr = malloc(len);
for (int i = 0; i < len; i++) {
arr[i] = buffer_get(buffer, start + i);
}
return arr;
}
uint16_t buffer_get_size(PacketBuffer* buffer) {
return (uint16_t) buffer->size + 1;
}
static void write(uint8_t** buffer, uint8_t* size, uint8_t* capacity, uint8_t data) {
if (*size >= *capacity) {
if (*capacity >= 128) {
*capacity = 255;
} else {
*capacity *= 2;
}
*buffer = realloc(*buffer, *capacity);
}
(*buffer)[*size] = data;
(*size)++;
}
void buffer_read_qname(PacketBuffer* buffer, uint8_t** out) {
int index = buffer->index;
int jumped = 0;
int max_jumps = 5;
int jumps_performed = 0;
uint8_t length = 0;
uint8_t capacity = 8;
*out = malloc(capacity);
write(out, &length, &capacity, 0);
while(1) {
if (jumps_performed > max_jumps) {
break;
}
uint8_t len = buffer_get(buffer, index);
if ((len & 0xC0) == 0xC0) {
if (jumped == 0) {
buffer_seek(buffer, index + 2);
}
uint16_t b2 = (uint16_t) buffer_get(buffer, index + 1);
uint16_t offset = ((((uint16_t) len) ^ 0xC0) << 8) | b2;
index = (int) offset;
jumped = 1;
jumps_performed++;
continue;
}
index++;
if (len == 0) {
break;
}
if (length > 1) {
write(out, &length, &capacity, '.');
}
uint8_t* range = buffer_get_range(buffer, index, len);
for (uint8_t i = 0; i < len; i++) {
write(out, &length, &capacity, range[i]);
}
free(range);
index += (int) len;
}
if (jumped == 0) {
buffer_seek(buffer, index);
}
(*out)[0] = length - 1;
}
void buffer_read_string(PacketBuffer* buffer, uint8_t** out) {
uint8_t len = buffer_read(buffer);
buffer_read_n(buffer, out, len);
}
static void buffer_expand(PacketBuffer* buffer, int capacity) {
if (buffer->capacity >= capacity) return;
buffer->arr = realloc(buffer->arr, capacity);
memset(buffer->arr + buffer->capacity, 0, capacity - buffer->capacity);
buffer->capacity = capacity;
}
void buffer_read_n(PacketBuffer* buffer, uint8_t** out, uint8_t len) {
buffer_expand(buffer, buffer->index + len + 1);
*out = malloc(len + 1);
*out[0] = len;
memcpy(*out + 1, buffer->arr + buffer->index, len);
buffer->index += len;
}
void buffer_write(PacketBuffer* buffer, uint8_t data) {
if(buffer->index >= buffer->capacity) {
buffer->capacity *= 2;
buffer->arr = realloc(buffer->arr, buffer->capacity);
}
if (buffer->size < buffer->index) {
buffer->size = buffer->index;
}
buffer->arr[buffer->index] = data;
buffer->index++;
}
void buffer_write_short(PacketBuffer* buffer, uint16_t data) {
buffer_write(buffer, (uint8_t)(data >> 8));
buffer_write(buffer, (uint8_t)(data & 0xFF));
}
void buffer_write_int(PacketBuffer* buffer, uint32_t data) {
buffer_write(buffer, (uint8_t)(data >> 24));
buffer_write(buffer, (uint8_t)(data >> 16));
buffer_write(buffer, (uint8_t)(data >> 8));
buffer_write(buffer, (uint8_t)(data & 0xFF));
}
void buffer_write_qname(PacketBuffer* buffer, uint8_t* in) {
uint8_t part = 0;
uint8_t len = in[0];
buffer_write(buffer, 0);
if (len == 0) {
return;
}
for(uint8_t i = 0; i < len; i ++) {
if (in[i+1] == '.') {
buffer_set(buffer, part, buffer->index - (int)part - 1);
buffer_write(buffer, 0);
part = 0;
} else {
buffer_write(buffer, in[i+1]);
part++;
}
}
buffer_set(buffer, part, buffer->index - (int)part - 1);
buffer_write(buffer, 0);
}
void buffer_write_string(PacketBuffer* buffer, uint8_t* in) {
buffer_write(buffer, in[0]);
buffer_write_n(buffer, in + 1, in[0]);
}
void buffer_write_n(PacketBuffer* buffer, uint8_t* in, int len) {
buffer_expand(buffer, buffer->index + len + 1);
memcpy(buffer->arr + buffer->index, in, len);
buffer->size += len;
buffer->index += len;
}
void buffer_set(PacketBuffer* buffer, uint8_t data, int index) {
if (index > buffer->size) {
return;
}
buffer->arr[index] = data;
}
void buffer_set_uint16_t(PacketBuffer* buffer, uint16_t data, int index) {
buffer_set(buffer, (uint8_t)(data >> 8), index);
buffer_set(buffer, (uint8_t)(data & 0xFF), index + 1);
}
int buffer_get_index(PacketBuffer* buffer) {
return buffer->index;
}
void buffer_step(PacketBuffer* buffer, int len) {
buffer->index += len;
}
uint8_t* buffer_get_ptr(PacketBuffer* buffer) {
return buffer->arr;
}

View file

@ -1,51 +0,0 @@
#pragma once
#include <stdint.h>
typedef struct PacketBuffer PacketBuffer;
PacketBuffer* buffer_create(int capacity);
void buffer_free(PacketBuffer* buffer);
void buffer_seek(PacketBuffer* buffer, int index);
uint8_t buffer_read(PacketBuffer* buffer);
uint16_t buffer_read_short(PacketBuffer* buffer);
uint32_t buffer_read_int(PacketBuffer* buffer);
uint8_t buffer_get(PacketBuffer* buffer, int index);
uint8_t* buffer_get_range(PacketBuffer* buffer, int start, int len);
uint16_t buffer_get_size(PacketBuffer* buffer);
void buffer_read_qname(PacketBuffer* buffer, uint8_t** out);
void buffer_read_string(PacketBuffer* buffer, uint8_t** out);
void buffer_read_n(PacketBuffer* buffer, uint8_t** out, uint8_t len);
void buffer_write(PacketBuffer* buffer, uint8_t data);
void buffer_write_short(PacketBuffer* buffer, uint16_t data);
void buffer_write_int(PacketBuffer* buffer, uint32_t data);
void buffer_write_qname(PacketBuffer* buffer, uint8_t* in);
void buffer_write_string(PacketBuffer* buffer, uint8_t* in);
void buffer_write_n(PacketBuffer* buffer, uint8_t* in, int len);
void buffer_set(PacketBuffer* buffer, uint8_t data, int index);
void buffer_set_uint16_t(PacketBuffer* buffer, uint16_t data, int index);
int buffer_get_index(PacketBuffer* buffer);
void buffer_step(PacketBuffer* buffer, int len);
uint8_t* buffer_get_ptr(PacketBuffer* buffer);

View file

@ -1,95 +0,0 @@
#include "header.h"
#include "buffer.h"
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
uint8_t rescode_to_id(ResultCode code) {
switch(code) {
case NOERROR:
return 0;
case FORMERR:
return 1;
case SERVFAIL:
return 2;
case NXDOMAIN:
return 3;
case NOTIMP:
return 4;
case REFUSED:
return 5;
default:
return 2;
}
}
ResultCode rescode_from_id(uint8_t id) {
switch(id) {
case 0:
return NOERROR;
case 1:
return FORMERR;
case 2:
return SERVFAIL;
case 3:
return NXDOMAIN;
case 4:
return NOTIMP;
case 5:
return REFUSED;
default:
return FORMERR;
}
}
#define MAX(var, max) var = var > max ? max : var;
void read_header(PacketBuffer* buffer, Header* header) {
// memset(header, 0, sizeof(Header));
header->id = buffer_read_short(buffer);
uint8_t a = buffer_read(buffer);
header->recursion_desired = (a & (1 << 0)) > 0;
header->truncated_message = (a & (1 << 1)) > 0;
header->authorative_answer = (a & (1 << 2)) > 0;
header->opcode = (a >> 3) & 0x0F;
header->response = (a & (1 << 7)) > 0;
uint8_t b = buffer_read(buffer);
header->rescode = rescode_from_id(b & 0x0F);
header->checking_disabled = (b & (1 << 4)) > 0;
header->authed_data = (b& (1 << 4)) > 0;
header->z = (b & (1 << 6)) > 0;
header->recursion_available = (b & (1 << 7)) > 0;
header->questions = buffer_read_short(buffer);
header->answers = buffer_read_short(buffer);
header->authoritative_entries = buffer_read_short(buffer);
header->resource_entries = buffer_read_short(buffer);
}
void write_header(PacketBuffer* buffer, Header* header) {
buffer_write_short(buffer, header->id);
buffer_write(buffer,
((uint8_t) header->recursion_desired) |
((uint8_t) header->truncated_message << 1) |
((uint8_t) header->authorative_answer << 2) |
(header->opcode << 3) |
((uint8_t) header->response << 7)
);
buffer_write(buffer,
(rescode_to_id(header->rescode)) |
((uint8_t) header->checking_disabled << 4) |
((uint8_t) header->authed_data << 5) |
((uint8_t) header->z << 6) |
((uint8_t) header->recursion_available << 7)
);
buffer_write_short(buffer, header->questions);
buffer_write_short(buffer, header->answers);
buffer_write_short(buffer, header->authoritative_entries);
buffer_write_short(buffer, header->resource_entries);
}

View file

@ -1,41 +0,0 @@
#pragma once
#include "buffer.h"
#include <stdbool.h>
typedef enum {
NOERROR, // 0
FORMERR, // 1
SERVFAIL, // 2
NXDOMAIN, // 3,
NOTIMP, // 4
REFUSED, // 5
} ResultCode;
uint8_t rescode_to_id(ResultCode code);
ResultCode rescode_from_id(uint8_t id);
typedef struct {
uint16_t id;
bool recursion_desired; // 1 bit
bool truncated_message; // 1 bit
bool authorative_answer; // 1 bit
uint8_t opcode; // 4 bits
bool response; // 1 bit
ResultCode rescode; // 4 bits
bool checking_disabled; // 1 bit
bool authed_data; // 1 bit
bool z; // 1 bit
bool recursion_available; // 1 bit
uint16_t questions; // 16 bits
uint16_t answers; // 16 bits
uint16_t authoritative_entries; // 16 bits
uint16_t resource_entries; // 16 bits
} Header;
void read_header(PacketBuffer* buffer, Header* header);
void write_header(PacketBuffer* buffer, Header* header);

View file

@ -1,185 +0,0 @@
#include <stdbool.h>
#include <stdio.h>
#include <string.h>
#include <stdlib.h>
#include "packet.h"
#include "buffer.h"
#include "header.h"
#include "question.h"
#include "record.h"
#include "../io/log.h"
void read_packet(PacketBuffer* buffer, Packet* packet) {
read_header(buffer, &packet->header);
packet->questions = malloc(sizeof(Question) * packet->header.questions);
for(uint16_t i = 0; i < packet->header.questions; i++) {
if (!read_question(buffer, &packet->questions[i])) {
i--;
packet->header.questions--;
}
}
packet->answers = malloc(sizeof(Record) * packet->header.answers);
for(uint16_t i = 0; i < packet->header.answers; i++) {
if (!read_record(buffer, &packet->answers[i])) {
i--;
packet->header.answers--;
}
}
packet->authorities = malloc(sizeof(Record) * packet->header.authoritative_entries);
for(uint16_t i = 0; i < packet->header.authoritative_entries; i++) {
if (!read_record(buffer, &packet->authorities[i])) {
i--;
packet->header.authoritative_entries--;
}
}
packet->resources = malloc(sizeof(Record) * packet->header.resource_entries);
for(uint16_t i = 0; i < packet->header.resource_entries; i++) {
if (!read_record(buffer, &packet->resources[i])) {
i--;
packet->header.resource_entries--;
}
}
}
void write_packet(PacketBuffer* buffer, Packet* packet) {
write_header(buffer, &packet->header);
for(uint16_t i = 0; i < packet->header.questions; i++) {
write_question(buffer, &packet->questions[i]);
}
for(uint16_t i = 0; i < packet->header.answers; i++) {
write_record(buffer, &packet->answers[i]);
}
for(uint16_t i = 0; i < packet->header.authoritative_entries; i++) {
write_record(buffer, &packet->authorities[i]);
}
for(uint16_t i = 0; i < packet->header.resource_entries; i++) {
write_record(buffer, &packet->resources[i]);
}
}
void free_packet(Packet* packet) {
for(uint16_t i = 0; i < packet->header.questions; i++) {
free_question(&packet->questions[i]);
}
free(packet->questions);
for(uint16_t i = 0; i < packet->header.answers; i++) {
free_record(&packet->answers[i]);
}
free(packet->answers);
for(uint16_t i = 0; i < packet->header.authoritative_entries; i++) {
free_record(&packet->authorities[i]);
}
free(packet->authorities);
for(uint16_t i = 0; i < packet->header.resource_entries; i++) {
free_record(&packet->resources[i]);
}
free(packet->resources);
}
bool get_random_a(Packet* packet, IpAddr* addr) {
for (uint16_t i = 0; i < packet->header.answers; i++) {
Record record = packet->answers[i];
if (record.type == A) {
create_ip_addr(record.data.a.addr, addr);
return true;
} else if (record.type == AAAA) {
create_ip_addr6(record.data.aaaa.addr, addr);
return true;
}
}
return false;
}
static bool ends_with(uint8_t* full, uint8_t* end) {
uint8_t check = end[0];
uint8_t len = full[0];
if (check > len) {
return false;
}
for(uint8_t i = 0; i < check; i++) {
if (end[check - 1 - i] != full[len - 1 - i]) {
return false;
}
}
return true;
}
static bool equals(uint8_t* a, uint8_t* b) {
if (a[0] != b[0]) {
return false;
}
for(uint8_t i = 1; i < a[0] + 1; i++) {
if(a[i] != b[i]) {
return false;
}
}
return true;
}
bool get_resolved_ns(Packet* packet, uint8_t* qname, IpAddr* addr) {
for (uint16_t i = 0; i < packet->header.authoritative_entries; i++) {
Record record = packet->authorities[i];
if (record.type != NS) {
continue;
}
if(!ends_with(qname, record.domain)) {
continue;
}
for (uint16_t i = 0; i < packet->header.resource_entries; i++) {
Record resource = packet->resources[i];
if (!equals(record.data.ns.host, resource.domain)) {
continue;
}
if (resource.type == A) {
create_ip_addr(record.data.a.addr, addr);
return true;
} else if (resource.type == AAAA) {
create_ip_addr6(record.data.aaaa.addr, addr);
return true;
}
}
}
return false;
}
bool get_unresoled_ns(Packet* packet, uint8_t* qname, Question* question) {
for (uint16_t i = 0; i < packet->header.authoritative_entries; i++) {
Record record = packet->authorities[i];
if (record.type != NS) {
continue;
}
if(!ends_with(qname, record.domain)) {
continue;
}
uint8_t* host = record.data.ns.host;
question->qtype = NS;
question->domain = malloc(host[0] + 1);
memcpy(question->domain, host, host[0] + 1);
return true;
}
return false;
}

View file

@ -1,25 +0,0 @@
#pragma once
#include "buffer.h"
#include "question.h"
#include "header.h"
#include "record.h"
#include "../server/addr.h"
#include <stdbool.h>
typedef struct {
Header header;
Question* questions;
Record* answers;
Record* authorities;
Record* resources;
} Packet;
void read_packet(PacketBuffer* buffer, Packet* packet);
void write_packet(PacketBuffer* buffer, Packet* packet);
void free_packet(Packet* packet);
bool get_random_a(Packet* packet, IpAddr* addr);
bool get_resolved_ns(Packet* packet, uint8_t* qname, IpAddr* addr);
bool get_unresoled_ns(Packet* packet, uint8_t* qname, Question* question);

View file

@ -1,110 +0,0 @@
#include <stdio.h>
#include <stdlib.h>
#include <stdbool.h>
#include "question.h"
#include "buffer.h"
#include "record.h"
#include "../io/log.h"
bool read_question(PacketBuffer* buffer, Question* question) {
buffer_read_qname(buffer, &question->domain);
uint16_t qtype_num = buffer_read_short(buffer);
record_from_id(qtype_num, &question->qtype);
question->cls = buffer_read_short(buffer);
if (question->qtype == UNKOWN) {
free(question->domain);
return false;
}
INIT_LOG_BUFFER(log)
LOGONLY(print_question(question, log);)
TRACE("Reading question: %s", log);
return true;
}
void write_question(PacketBuffer* buffer, Question* question) {
buffer_write_qname(buffer, question->domain);
uint16_t id = record_to_id(question->qtype);
buffer_write_short(buffer, id);
buffer_write_short(buffer, question->cls);
INIT_LOG_BUFFER(log)
LOGONLY(print_question(question, log);)
TRACE("Writing question: %s", log);
}
void free_question(Question* question) {
free(question->domain);
}
void print_question(Question* question, char* buffer) {
INIT_LOG_BOUNDS
switch (question->cls) {
case 1:
APPEND(buffer, "IN ");;
break;
case 3:
APPEND(buffer, "CH ");
break;
case 4:
APPEND(buffer, "HS ");
break;
default:
APPEND(buffer, "?? ");
break;
}
switch(question->qtype) {
case UNKOWN:
APPEND(buffer, "UNKOWN ");
break;
case A:
APPEND(buffer, "A ");
break;
case NS:
APPEND(buffer, "NS ");
break;
case CNAME:
APPEND(buffer, "CNAME ");
break;
case SOA:
APPEND(buffer, "SOA ");
break;
case PTR:
APPEND(buffer, "PTR ");
break;
case MX:
APPEND(buffer, "MX ");
break;
case TXT:
APPEND(buffer, "TXT ");
break;
case AAAA:
APPEND(buffer, "AAAA ");
break;
case SRV:
APPEND(buffer, "SRV ");
break;
case CAA:
APPEND(buffer, "CAA ");
break;
case CMD:
APPEND(buffer, "CMD ");
break;
case AR:
APPEND(buffer, "A ");
break;
case AAAAR:
APPEND(buffer, "AAAAR ");
break;
}
APPEND(buffer, "%.*s",
question->domain[0],
question->domain + 1
);
}

View file

@ -1,15 +0,0 @@
#pragma once
#include "buffer.h"
#include "record.h"
typedef struct {
uint8_t* domain;
RecordType qtype;
uint16_t cls;
} Question;
bool read_question(PacketBuffer* buffer, Question* question);
void write_question(PacketBuffer* buffer, Question* question);
void free_question(Question* question);
void print_question(Question* question, char* buffer);

View file

@ -1,670 +0,0 @@
#include <stdbool.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <time.h>
#undef _POSIX_C_SOURCE
#include <stdio.h>
#include "record.h"
#include "buffer.h"
#include "../io/log.h"
uint16_t record_to_id(RecordType type) {
switch (type) {
case A:
return 1;
case NS:
return 2;
case CNAME:
return 5;
case SOA:
return 6;
case PTR:
return 12;
case MX:
return 15;
case TXT:
return 16;
case AAAA:
return 28;
case SRV:
return 33;
case CAA:
return 257;
case CMD:
return 16;
case AR:
return 1;
case AAAAR:
return 28;
default:
return 0;
}
}
void record_from_id(uint16_t i, RecordType* type) {
switch (i) {
case 1:
*type = A;
break;
case 2:
*type = NS;
break;
case 5:
*type = CNAME;
break;
case 6:
*type = SOA;
break;
case 12:
*type = PTR;
break;
case 15:
*type = MX;
break;
case 16:
*type = TXT;
break;
case 28:
*type = AAAA;
break;
case 33:
*type = SRV;
break;
case 257:
*type = CAA;
break;
case 1000:
*type = CMD;
break;
default:
*type = UNKOWN;
}
}
static void read_a_record(PacketBuffer* buffer, Record* record) {
ARecord data;
data.addr[0] = buffer_read(buffer);
data.addr[1] = buffer_read(buffer);
data.addr[2] = buffer_read(buffer);
data.addr[3] = buffer_read(buffer);
record->data.a = data;
}
static void read_ns_record(PacketBuffer* buffer, Record* record) {
NSRecord data;
buffer_read_qname(buffer, &data.host);
record->data.ns = data;
}
static void read_cname_record(PacketBuffer* buffer, Record* record) {
CNAMERecord data;
buffer_read_qname(buffer, &data.host);
record->data.cname = data;
}
static void read_soa_record(PacketBuffer* buffer, Record* record) {
SOARecord data;
buffer_read_qname(buffer, &data.mname);
buffer_read_qname(buffer, &data.nname);
data.serial = buffer_read_int(buffer);
data.refresh = buffer_read_int(buffer);
data.retry = buffer_read_int(buffer);
data.expire = buffer_read_int(buffer);
data.minimum = buffer_read_int(buffer);
record->data.soa = data;
}
static void read_ptr_record(PacketBuffer* buffer, Record* record) {
PTRRecord data;
buffer_read_qname(buffer, &data.pointer);
record->data.ptr = data;
}
static void read_mx_record(PacketBuffer* buffer, Record* record) {
MXRecord data;
data.priority = buffer_read_short(buffer);
buffer_read_qname(buffer, &data.host);
record->data.mx = data;
}
static void read_txt_record(PacketBuffer* buffer, Record* record) {
TXTRecord data;
data.len = 0;
data.text = malloc(sizeof(uint8_t*) * 2);
uint8_t capacity = 2;
uint8_t total = record->len;
while (1) {
if (data.len >= capacity) {
if (capacity >= 128) {
capacity = 255;
} else {
capacity *= 2;
}
data.text = realloc(data.text, sizeof(uint8_t*) * capacity);
}
buffer_read_string(buffer, &data.text[data.len]);
if(data.text[data.len][0] == 0) {
free(data.text[data.len]);
break;
}
data.len++;
total -= data.text[data.len - 1][0] + 1;
if (total == 0) break;
}
record->data.txt = data;
}
static void read_aaaa_record(PacketBuffer* buffer, Record* record) {
AAAARecord data;
for (int i = 0; i < 16; i++) {
data.addr[i] = buffer_read(buffer);
}
record->data.aaaa = data;
}
static void read_srv_record(PacketBuffer* buffer, Record* record) {
SRVRecord data;
data.priority = buffer_read_short(buffer);
data.weight = buffer_read_short(buffer);
data.port = buffer_read_short(buffer);
buffer_read_qname(buffer, &data.target);
record->data.srv = data;
}
static void read_caa_record(PacketBuffer* buffer, Record* record, int header_pos) {
CAARecord data;
data.flags = buffer_read(buffer);
data.length = buffer_read(buffer);
buffer_read_n(buffer, &data.tag, data.length);
int value_len = ((int)record->len) + header_pos - buffer_get_index(buffer);
buffer_read_n(buffer, &data.value, (uint8_t)value_len);
record->data.caa = data;
}
bool read_record(PacketBuffer* buffer, Record* record) {
buffer_read_qname(buffer, &record->domain);
uint16_t qtype_num = buffer_read_short(buffer);
record_from_id(qtype_num, &record->type);
record->cls = buffer_read_short(buffer);
record->ttl = buffer_read_int(buffer);
record->len = buffer_read_short(buffer);
int header_pos = buffer_get_index(buffer);
switch (record->type) {
case A:
read_a_record(buffer, record);
break;
case NS:
read_ns_record(buffer, record);
break;
case CNAME:
read_cname_record(buffer, record);
break;
case SOA:
read_soa_record(buffer, record);
break;
case PTR:
read_ptr_record(buffer, record);
break;
case MX:
read_mx_record(buffer, record);
break;
case TXT:
read_txt_record(buffer, record);
break;
case AAAA:
read_aaaa_record(buffer, record);
break;
case SRV:
read_srv_record(buffer, record);
break;
case CAA:
read_caa_record(buffer, record, header_pos);
break;
default:
buffer_step(buffer, record->len);
free(record->domain);
return false;
}
INIT_LOG_BUFFER(log)
LOGONLY(print_record(record, log);)
TRACE("Reading record: %s", log);
return true;
}
static void write_a_record(PacketBuffer* buffer, ARecord* data) {
buffer_write_short(buffer, 4);
buffer_write(buffer, data->addr[0]);
buffer_write(buffer, data->addr[1]);
buffer_write(buffer, data->addr[2]);
buffer_write(buffer, data->addr[3]);
}
static void write_ns_record(PacketBuffer* buffer, NSRecord* data) {
int pos = buffer_get_index(buffer);
buffer_write_short(buffer, 0);
buffer_write_qname(buffer, data->host);
int size = buffer_get_index(buffer) - pos - 2;
buffer_set_uint16_t(buffer, (uint16_t)size, pos);
}
static void write_cname_record(PacketBuffer* buffer, CNAMERecord* data) {
int pos = buffer_get_index(buffer);
buffer_write_short(buffer, 0);
buffer_write_qname(buffer, data->host);
int size = buffer_get_index(buffer) - pos - 2;
buffer_set_uint16_t(buffer, (uint16_t)size, pos);
}
static void write_soa_record(PacketBuffer* buffer, SOARecord* data) {
int pos = buffer_get_index(buffer);
buffer_write_short(buffer, 0);
buffer_write_qname(buffer, data->mname);
buffer_write_qname(buffer, data->nname);
buffer_write_int(buffer, data->serial);
buffer_write_int(buffer, data->refresh);
buffer_write_int(buffer, data->retry);
buffer_write_int(buffer, data->expire);
buffer_write_int(buffer, data->minimum);
int size = buffer_get_index(buffer) - pos - 2;
buffer_set_uint16_t(buffer, (uint16_t)size, pos);
}
static void write_ptr_record(PacketBuffer* buffer, PTRRecord* data) {
int pos = buffer_get_index(buffer);
buffer_write_short(buffer, 0);
buffer_write_qname(buffer, data->pointer);
int size = buffer_get_index(buffer) - pos - 2;
buffer_set_uint16_t(buffer, (uint16_t)size, pos);
}
static void write_mx_record(PacketBuffer* buffer, MXRecord* data) {
int pos = buffer_get_index(buffer);
buffer_write_short(buffer, 0);
buffer_write_short(buffer, data->priority);
buffer_write_qname(buffer, data->host);
int size = buffer_get_index(buffer) - pos - 2;
buffer_set_uint16_t(buffer, (uint16_t)size, pos);
}
static void write_txt_record(PacketBuffer* buffer, TXTRecord* data) {
int pos = buffer_get_index(buffer);
buffer_write_short(buffer, 0);
if(data->len == 0) {
return;
}
for(uint8_t i = 0; i < data->len; i++) {
buffer_write_string(buffer, data->text[i]);
}
int size = buffer_get_index(buffer) - pos - 2;
buffer_set_uint16_t(buffer, (uint16_t)size, pos);
}
static void write_aaaa_record(PacketBuffer* buffer, AAAARecord* data) {
buffer_write_short(buffer, 16);
for (int i = 0; i < 16; i++) {
buffer_write(buffer, data->addr[i]);
}
}
static void write_srv_record(PacketBuffer* buffer, SRVRecord* data) {
int pos = buffer_get_index(buffer);
buffer_write_short(buffer, 0);
buffer_write_short(buffer, data->priority);
buffer_write_short(buffer, data->weight);
buffer_write_short(buffer, data->port);
buffer_write_qname(buffer, data->target);
int size = buffer_get_index(buffer) - pos - 2;
buffer_set_uint16_t(buffer, (uint16_t)size, pos);
}
static void write_caa_record(PacketBuffer* buffer, CAARecord* data) {
int pos = buffer_get_index(buffer);
buffer_write_short(buffer, 0);
buffer_write(buffer, data->flags);
buffer_write(buffer, data->length);
buffer_write_n(buffer, data->tag + 1, data->tag[0]);
buffer_write_n(buffer, data->value + 1, data->value[0]);
int size = buffer_get_index(buffer) - pos - 2;
buffer_set_uint16_t(buffer, (uint16_t)size, pos);
}
static void write_string(TXTRecord* record, uint8_t* capacity, const char* string, uint8_t len) {
if (len < 1) return;
if (record->len >= *capacity) {
if (*capacity >= 128) {
*capacity = 255;
} else {
*capacity *= 2;
}
record->text = realloc(record->text, sizeof(uint8_t*) * *capacity);
}
record->text[record->len] = malloc(len + 1);
record->text[record->len][0] = len;
memcpy(record->text[record->len] + 1, string, len);
record->len++;
}
static void free_text(TXTRecord* record) {
for (uint8_t i = 0; i < record->len; i++) {
free(record->text[i]);
}
free(record->text);
}
static void write_cmd_record(PacketBuffer* buffer, CMDRecord* data) {
FILE* output = popen(data->command, "r");
TXTRecord res;
uint8_t capacity = 1;
res.len = 0;
res.text = malloc(capacity * sizeof(uint8_t*));
if (output == NULL) {
write_string(&res, &capacity, "Failed to execute command", 25);
write_txt_record(buffer, &res);
free_text(&res);
return;
}
char in[255];
char c;
int i = 0;
while (1) {
if (res.len >= 255) break;
c = getc(output);
if (c == EOF || c == '\0') {
write_string(&res, &capacity, in, i);
break;
}
in[i] = c;
i++;
if (i == 255) {
write_string(&res, &capacity, in, i);
i = 0;
}
}
write_txt_record(buffer, &res);
free_text(&res);
pclose(output);
}
static void write_ar_record(PacketBuffer* buffer) {
srand(time(NULL));
ARecord res;
for (int i = 0; i < 4; i++) {
res.addr[i] = (uint8_t) (rand() * 255);
}
write_a_record(buffer, &res);
}
static void write_aaaar_record(PacketBuffer* buffer) {
srand(time(NULL));
AAAARecord res;
for (int i = 0; i < 16; i++) {
res.addr[i] = (uint8_t) (rand() * 255);
}
write_aaaa_record(buffer, &res);
}
static void write_record_header(PacketBuffer* buffer, Record* record) {
buffer_write_qname(buffer, record->domain);
uint16_t id = record_to_id(record->type);
buffer_write_short(buffer, id);
buffer_write_short(buffer, record->cls);
buffer_write_int(buffer, record->ttl);
}
void write_record(PacketBuffer* buffer, Record* record) {
switch(record->type) {
case A:
write_record_header(buffer, record);
write_a_record(buffer, &record->data.a);
break;
case NS:
write_record_header(buffer, record);
write_ns_record(buffer, &record->data.ns);
break;
case CNAME:
write_record_header(buffer, record);
write_cname_record(buffer, &record->data.cname);
break;
case SOA:
write_record_header(buffer, record);
write_soa_record(buffer, &record->data.soa);
break;
case PTR:
write_record_header(buffer, record);
write_ptr_record(buffer, &record->data.ptr);
break;
case MX:
write_record_header(buffer, record);
write_mx_record(buffer, &record->data.mx);
break;
case TXT:
write_record_header(buffer, record);
write_txt_record(buffer, &record->data.txt);
break;
case AAAA:
write_record_header(buffer, record);
write_aaaa_record(buffer, &record->data.aaaa);
break;
case SRV:
write_record_header(buffer, record);
write_srv_record(buffer, &record->data.srv);
break;
case CAA:
write_record_header(buffer, record);
write_caa_record(buffer, &record->data.caa);
break;
case CMD:
write_record_header(buffer, record);
write_cmd_record(buffer, &record->data.cmd);
break;
case AR:
write_record_header(buffer, record);
write_ar_record(buffer);
break;
case AAAAR:
write_record_header(buffer, record);
write_aaaar_record(buffer);
break;
default:
break;
}
INIT_LOG_BUFFER(log)
LOGONLY(print_record(record, log);)
TRACE("Writing record: %s", log);
}
void free_record(Record* record) {
free(record->domain);
switch (record->type) {
case NS:
free(record->data.ns.host);
break;
case CNAME:
free(record->data.cname.host);
break;
case SOA:
free(record->data.soa.mname);
free(record->data.soa.nname);
break;
case PTR:
free(record->data.ptr.pointer);
break;
case MX:
free(record->data.mx.host);
break;
case TXT:
for (uint8_t i = 0; i < record->data.txt.len; i++) {
free(record->data.txt.text[i]);
}
free(record->data.txt.text);
break;
case SRV:
free(record->data.srv.target);
break;
case CAA:
free(record->data.caa.value);
free(record->data.caa.tag);
break;
case CMD:
free(record->data.cmd.command);
break;
default:
break;
}
}
void print_record(Record* record, char* buffer) {
INIT_LOG_BOUNDS
switch(record->type) {
case UNKOWN:
APPEND(buffer, "UNKOWN");
break;
case A:
APPEND(buffer, "A (%hhu.%hhu.%hhu.%hhu)",
record->data.a.addr[0],
record->data.a.addr[1],
record->data.a.addr[2],
record->data.a.addr[3]
);
break;
case NS:
APPEND(buffer, "NS (%.*s)",
record->data.ns.host[0],
record->data.ns.host + 1
);
break;
case CNAME:
APPEND(buffer, "CNAME (%.*s)",
record->data.cname.host[0],
record->data.cname.host + 1
);
break;
case SOA:
APPEND(buffer, "SOA (%.*s %.*s %u %u %u %u %u)",
record->data.soa.mname[0],
record->data.soa.mname + 1,
record->data.soa.nname[0],
record->data.soa.nname + 1,
record->data.soa.serial,
record->data.soa.refresh,
record->data.soa.retry,
record->data.soa.expire,
record->data.soa.minimum
);
break;
case PTR:
APPEND(buffer, "PTR (%.*s)",
record->data.ptr.pointer[0],
record->data.ptr.pointer + 1
);
break;
case MX:
APPEND(buffer, "MX (%.*s %hu)",
record->data.mx.host[0],
record->data.mx.host + 1,
record->data.mx.priority
);
break;
case TXT:
APPEND(buffer, "TXT (");
for(uint8_t i = 0; i < record->data.txt.len; i++) {
APPEND(buffer, "\"%.*s\"",
record->data.txt.text[i][0],
record->data.txt.text[i] + 1
);
}
APPEND(buffer, ")");
break;
case AAAA:
APPEND(buffer, "AAAA (");
for(int i = 0; i < 8; i++) {
APPEND(buffer, "%02hhx%02hhx:",
record->data.a.addr[i*2 + 0],
record->data.a.addr[i*2 + 1]
);
}
APPEND(buffer, ":)");
break;
case SRV:
APPEND(buffer, "SRV (%hu %hu %hu %.*s)",
record->data.srv.priority,
record->data.srv.weight,
record->data.srv.port,
record->data.srv.target[0],
record->data.srv.target + 1
);
break;
case CAA:
APPEND(buffer, "CAA (%hhu %.*s %.*s)",
record->data.caa.flags,
record->data.caa.tag[0],
record->data.caa.tag + 1,
record->data.caa.value[0],
record->data.caa.value + 1
);
break;
case CMD:
APPEND(buffer, "CMD (%s)",
record->data.cmd.command
);
break;
case AR:
APPEND(buffer, "AR");
break;
case AAAAR:
APPEND(buffer, "AAAAR");
break;
}
}

View file

@ -1,110 +0,0 @@
#pragma once
#include "buffer.h"
#include <string.h>
#include <stdbool.h>
typedef enum {
UNKOWN,
A, // 1
NS, // 2
CNAME, // 5
SOA, // 6
PTR, // 12
MX, // 15
TXT, // 16
AAAA, // 28
SRV, // 33
CAA, // 257
CMD, // 1000
AR, // 1001
AAAAR // 1002
} RecordType;
uint16_t record_to_id(RecordType type);
void record_from_id(uint16_t i, RecordType* type);
typedef struct {
uint8_t addr[4];
} ARecord;
typedef struct {
uint8_t* host;
} NSRecord;
typedef struct {
uint8_t* host;
} CNAMERecord;
typedef struct {
uint8_t* mname;
uint8_t* nname;
uint32_t serial;
uint32_t refresh;
uint32_t retry;
uint32_t expire;
uint32_t minimum;
} SOARecord;
typedef struct {
uint8_t* pointer;
} PTRRecord;
typedef struct {
uint16_t priority;
uint8_t* host;
} MXRecord;
typedef struct TXTRecord {
uint8_t** text;
uint8_t len;
} TXTRecord;
typedef struct {
uint8_t addr[16];
} AAAARecord;
typedef struct {
uint16_t priority;
uint16_t weight;
uint16_t port;
uint8_t* target;
} SRVRecord;
typedef struct {
uint8_t flags;
uint8_t length;
uint8_t* tag;
uint8_t* value;
} CAARecord;
typedef struct {
char* command;
} CMDRecord;
typedef struct {
uint32_t ttl;
uint16_t cls;
uint16_t len;
uint8_t* domain;
RecordType type;
union data {
ARecord a;
NSRecord ns;
CNAMERecord cname;
SOARecord soa;
PTRRecord ptr;
MXRecord mx;
TXTRecord txt;
AAAARecord aaaa;
SRVRecord srv;
CAARecord caa;
CMDRecord cmd;
} data;
} Record;
bool read_record(PacketBuffer* buffer, Record* record);
void write_record(PacketBuffer* buffer, Record* record);
void free_record(Record* record);
void print_record(Record* record, char* buffer);

View file

@ -1,257 +0,0 @@
#include <errno.h>
#include <netinet/in.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <unistd.h>
#include "addr.h"
#include "../io/log.h"
void create_ip_addr(uint8_t* domain, IpAddr* addr) {
addr->type = V4;
memcpy(&addr->data.v4, domain, 4);
}
void create_ip_addr6(uint8_t* domain, IpAddr* addr) {
addr->type = V6;
memcpy(&addr->data.v6, domain, 16);
}
void ip_addr_any(IpAddr* addr) {
addr->type = V4;
addr->data.v4.s_addr = htonl(INADDR_ANY);
}
void ip_addr_any6(IpAddr* addr) {
addr->type = V6;
addr->data.v6 = in6addr_any;
}
static struct sockaddr_in create_socket_addr_v4(IpAddr addr, uint16_t port) {
struct sockaddr_in socketaddr;
memset(&socketaddr, 0, sizeof(socketaddr));
socketaddr.sin_family = AF_INET;
socketaddr.sin_port = htons(port);
socketaddr.sin_addr = addr.data.v4;
return socketaddr;
}
static struct sockaddr_in6 create_socket_addr_v6(IpAddr addr, uint16_t port) {
struct sockaddr_in6 socketaddr;
memset(&socketaddr, 0, sizeof(socketaddr));
socketaddr.sin6_family = AF_INET6;
socketaddr.sin6_port = htons(port);
socketaddr.sin6_addr = addr.data.v6;
return socketaddr;
}
static size_t get_addr_len(AddrType type) {
if (type == V4) {
return sizeof(struct sockaddr_in);
} else if (type == V6) {
return sizeof(struct sockaddr_in6);
} else {
return 0;
}
}
void create_socket_addr(uint16_t port, IpAddr addr, SocketAddr* socket) {
socket->type = addr.type;
if (addr.type == V4) {
socket->data.v4 = create_socket_addr_v4(addr, port);
} else if(addr.type == V6) {
socket->data.v6 = create_socket_addr_v6(addr, port);
} else {
ERROR("Tried to create socketaddr with invalid protocol type");
exit(EXIT_FAILURE);
}
socket->len = get_addr_len(addr.type);
}
void print_socket_addr(SocketAddr* addr, char* buffer) {
INIT_LOG_BOUNDS
if(addr->type == V4) {
APPEND(buffer, "%hhu.%hhu.%hhu.%hhu:%hu",
(uint8_t) ((uint32_t)addr->data.v4.sin_addr.s_addr >> 24),
(uint8_t) ((uint32_t)addr->data.v4.sin_addr.s_addr >> 16),
(uint8_t) ((uint32_t)addr->data.v4.sin_addr.s_addr >> 8),
(uint8_t) ((uint32_t)addr->data.v4.sin_addr.s_addr),
ntohs(addr->data.v4.sin_port)
);
} else {
uint8_t* a = (uint8_t*) &addr->data.v6.sin6_addr;
for(int i = 0; i < 8; i++) {
APPEND(buffer, "%02hhx%02hhx:",
a[i*2 + 0],
a[i*2 + 1]
);
}
APPEND(buffer, ":[%hu]", ntohs(addr->data.v6.sin6_port));
}
}
#define ADDR_DOMAIN(addr, var) \
struct sockaddr* var; \
if (addr->type == V4) { \
var = (struct sockaddr*) &addr->data.v4; \
} else if (addr->type == V6) { \
var = (struct sockaddr*) &addr->data.v6; \
} else { \
return -1; \
}
#define ADDR_AFNET(type, var) \
int var; \
if (type == V4) { \
var = AF_INET; \
} else if (type == V6) { \
var = AF_INET6; \
} else { \
return -1; \
}
int32_t create_udp_socket(AddrType type, UdpSocket* sock) {
ADDR_AFNET(type, __domain)
sock->type = type;
sock->sockfd = socket(__domain, SOCK_DGRAM, 0);
return sock->sockfd;
}
int32_t bind_udp_socket(SocketAddr* addr, UdpSocket* sock) {
if (addr->type == V6) {
int v6OnlyEnabled = 0;
int32_t res = setsockopt(
sock->sockfd,
IPPROTO_IPV6,
IPV6_V6ONLY,
&v6OnlyEnabled,
sizeof(v6OnlyEnabled)
);
if (res < 0) return res;
}
ADDR_DOMAIN(addr, __addr)
return bind(sock->sockfd, __addr, addr->len);
}
int32_t read_udp_socket(UdpSocket* socket, void* buffer, uint16_t len, SocketAddr* clientaddr) {
clientaddr->type = socket->type;
clientaddr->len = get_addr_len(socket->type);
ADDR_DOMAIN(clientaddr, __addr)
return recvfrom(
socket->sockfd,
buffer,
(size_t) len,
MSG_WAITALL,
__addr,
(uint32_t*) &clientaddr->len
);
}
int32_t write_udp_socket(UdpSocket* socket, void* buffer, uint16_t len, SocketAddr* clientaddr) {
ADDR_DOMAIN(clientaddr, __addr)
return sendto(
socket->sockfd,
buffer,
(size_t) len,
MSG_CONFIRM,
__addr,
(uint32_t) clientaddr->len
);
}
static int get_socket_error(int fd) {
int err = 1;
socklen_t len = sizeof err;
if (-1 == getsockopt(fd, SOL_SOCKET, SO_ERROR, (char *)&err, &len))
return 0;
if (err)
errno = err;
return err;
}
static int close_socket(int fd) {
if (fd >= 0) {
get_socket_error(fd); // first clear any errors, which can cause close to fail
if (shutdown(fd, SHUT_RDWR) < 0) // secondly, terminate the 'reliable' delivery
if (errno != ENOTCONN && errno != EINVAL) // SGI causes EINVAL
perror("shutdown");
if (close(fd) < 0) // finally call close()
perror("close");
}
return 0;
}
int32_t close_udp_socket(UdpSocket* socket) {
return close_socket(socket->sockfd);
}
int32_t create_tcp_socket(AddrType type, TcpSocket* sock) {
ADDR_AFNET(type, __domain)
sock->type = type;
sock->sockfd = socket(__domain, SOCK_STREAM, 0);
return sock->sockfd;
}
int32_t bind_tcp_socket(SocketAddr* addr, TcpSocket* sock) {
if (addr->type == V6) {
int v6OnlyEnabled = 0;
int32_t res = setsockopt(
sock->sockfd,
IPPROTO_IPV6,
IPV6_V6ONLY,
&v6OnlyEnabled,
sizeof(v6OnlyEnabled)
);
if (res < 0) return res;
}
ADDR_DOMAIN(addr, __addr)
return bind(sock->sockfd, __addr, addr->len);
}
int32_t listen_tcp_socket(TcpSocket* socket, uint32_t max) {
return listen(socket->sockfd, max);
}
int32_t accept_tcp_socket(TcpSocket* socket, TcpStream* stream) {
stream->clientaddr.type = socket->type;
memset(&stream->clientaddr, 0, sizeof(SocketAddr));
SocketAddr* addr = &stream->clientaddr;
ADDR_DOMAIN(addr, __addr)
stream->streamfd = accept(
socket->sockfd,
__addr,
(uint32_t*) &stream->clientaddr.len
);
return stream->streamfd;
}
int32_t close_tcp_socket(TcpSocket* socket) {
return close_socket(socket->sockfd);
}
int32_t connect_tcp_stream(SocketAddr* servaddr, TcpStream* stream) {
TcpSocket socket;
int32_t res = create_tcp_socket(servaddr->type, &socket);
if (res < 0) return res;
stream->clientaddr = *servaddr;
stream->streamfd = socket.sockfd;
ADDR_DOMAIN(servaddr, __addr)
return connect(
socket.sockfd,
__addr,
servaddr->len
);
}
int32_t read_tcp_stream(TcpStream* stream, void* buffer, uint16_t len) {
return recv(stream->streamfd, buffer, len, MSG_WAITALL);
}
int32_t write_tcp_stream(TcpStream* stream, void* buffer, uint16_t len) {
return send(stream->streamfd, buffer, len, MSG_NOSIGNAL);
}
int32_t close_tcp_stream(TcpStream* stream) {
return close_socket(stream->streamfd);
}

View file

@ -1,69 +0,0 @@
#pragma once
#include "../packet/record.h"
#include <string.h>
#include <netinet/in.h>
#include <sys/socket.h>
#include <sys/types.h>
typedef enum {
V4,
V6
} AddrType;
typedef struct {
AddrType type;
union {
struct in_addr v4;
struct in6_addr v6;
} data;
} IpAddr;
void create_ip_addr(uint8_t* domain, IpAddr* addr);
void create_ip_addr6(uint8_t* domain, IpAddr* addr);
void ip_addr_any(IpAddr* addr);
void ip_addr_any6(IpAddr* addr);
typedef struct {
AddrType type;
union {
struct sockaddr_in v4;
struct sockaddr_in6 v6;
} data;
size_t len;
} SocketAddr;
void create_socket_addr(uint16_t port, IpAddr addr, SocketAddr* socket);
void print_socket_addr(SocketAddr* addr, char* buffer);
typedef struct {
AddrType type;
uint32_t sockfd;
} UdpSocket;
int32_t create_udp_socket(AddrType type, UdpSocket* socket);
int32_t bind_udp_socket(SocketAddr* addr, UdpSocket* socket);
int32_t read_udp_socket(UdpSocket* socket, void* buffer, uint16_t len, SocketAddr* clientaddr);
int32_t write_udp_socket(UdpSocket* socket, void* buffer, uint16_t len, SocketAddr* clientaddr);
int32_t close_udp_socket(UdpSocket* socket);
typedef struct {
AddrType type;
uint32_t sockfd;
} TcpSocket;
typedef struct {
SocketAddr clientaddr;
uint32_t streamfd;
} TcpStream;
int32_t create_tcp_socket(AddrType type, TcpSocket* socket);
int32_t bind_tcp_socket(SocketAddr* addr, TcpSocket* socket);
int32_t listen_tcp_socket(TcpSocket* socket, uint32_t max);
int32_t accept_tcp_socket(TcpSocket* socket, TcpStream* stream);
int32_t close_tcp_socket(TcpSocket* socket);
int32_t connect_tcp_stream(SocketAddr* servaddr, TcpStream* stream);
int32_t read_tcp_stream(TcpStream* stream, void* buffer, uint16_t len);
int32_t write_tcp_stream(TcpStream* stream, void* buffer, uint16_t len);
int32_t close_tcp_stream(TcpStream* stream);

View file

@ -1,246 +0,0 @@
#include <netinet/in.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <unistd.h>
#include <errno.h>
#include "addr.h"
#include "binding.h"
#include "../io/log.h"
static void create_udp_binding(UdpSocket* socket, uint16_t port) {
if (create_udp_socket(V6, socket) < 0) {
ERROR("Failed to create UDP socket: %s", strerror(errno));
exit(EXIT_FAILURE);
}
IpAddr addr;
ip_addr_any6(&addr);
SocketAddr socketaddr;
create_socket_addr(port, addr, &socketaddr);
if (bind_udp_socket(&socketaddr, socket) < 0) {
ERROR("Failed to bind UDP socket on port %hu: %s", port, strerror(errno));
exit(EXIT_FAILURE);
}
}
static void create_tcp_binding(TcpSocket* socket, uint16_t port) {
if (create_tcp_socket(V6, socket) < 0) {
ERROR("Failed to create TCP socket: %s", strerror(errno));
exit(EXIT_FAILURE);
}
IpAddr addr;
ip_addr_any6(&addr);
SocketAddr socketaddr;
create_socket_addr(port, addr, &socketaddr);
if (bind_tcp_socket(&socketaddr, socket) < 0) {
ERROR("Failed to bind TCP socket on port %hu: %s", port, strerror(errno));
exit(EXIT_FAILURE);
}
if (listen_tcp_socket(socket, 5) < 0) {
ERROR("Failed to listen on TCP socket: %s", strerror(errno));
exit(EXIT_FAILURE);
}
}
void create_binding(BindingType type, uint16_t port, Binding* binding) {
binding->type = type;
if (type == UDP) {
create_udp_binding(&binding->sock.udp, port);
} else if(type == TCP) {
create_tcp_binding(&binding->sock.tcp, port);
} else {
exit(EXIT_FAILURE);
}
}
void free_binding(Binding* binding) {
if (binding->type == UDP) {
close_udp_socket(&binding->sock.udp);
} else if(binding->type == TCP) {
close_tcp_socket(&binding->sock.tcp);
}
}
bool accept_connection(Binding* binding, Connection* connection) {
connection->type = binding->type;
if(binding->type == UDP) {
connection->sock.udp.udp = binding->sock.udp;
memset(&connection->sock.udp.clientaddr, 0, sizeof(SocketAddr));
return true;
}
if (accept_tcp_socket(&binding->sock.tcp, &connection->sock.tcp) < 0) {
ERROR("Failed to accept TCP connection: %s", strerror(errno));
return false;
}
return true;
}
static void read_to_packet(uint8_t* buf, uint16_t len, Packet* packet) {
PacketBuffer* pkbuffer = buffer_create(len);
for (int i = 0; i < len; i++) {
buffer_write(pkbuffer, buf[i]);
}
buffer_seek(pkbuffer, 0);
read_packet(pkbuffer, packet);
buffer_free(pkbuffer);
}
static bool read_udp(Connection* connection, Packet* packet) {
uint8_t buffer[512];
int32_t n = read_udp_socket(
&connection->sock.udp.udp,
buffer,
512,
&connection->sock.udp.clientaddr
);
if (n < 0) {
return false;
}
read_to_packet(buffer, n, packet);
return true;
}
static bool read_tcp(Connection* connection, Packet* packet) {
uint16_t len;
if ( read_tcp_stream(
&connection->sock.tcp,
&len,
sizeof(uint16_t)
) < 2) {
return false;
}
len = ntohs(len);
uint8_t buffer[len];
if (read_tcp_stream(
&connection->sock.tcp,
buffer,
len
) < len) {
return false;
}
read_to_packet(buffer, len, packet);
return true;
}
bool read_connection(Connection* connection, Packet* packet) {
if (connection->type == UDP) {
return read_udp(connection, packet);
} else if (connection->type == TCP) {
return read_tcp(connection, packet);
}
return false;
}
static bool write_udp(Connection* connection, uint8_t* buf, uint16_t len) {
if (len > 512) {
buf[2] = buf[2] | 0x03;
len = 512;
}
return write_udp_socket(
&connection->sock.udp.udp,
buf,
len,
&connection->sock.udp.clientaddr
) == len;
}
static bool write_tcp(Connection* connection, uint8_t* buf, uint16_t len) {
uint16_t net_len = htons(len);
if (write_tcp_stream(
&connection->sock.tcp,
&net_len,
sizeof(uint16_t)
) < 0) {
return false;
}
if (write_tcp_stream(
&connection->sock.tcp,
buf,
len
) < 0) {
return false;
}
return true;
}
bool write_connection(Connection* connection, Packet* packet) {
PacketBuffer* pkbuffer = buffer_create(64);
write_packet(pkbuffer, packet);
uint16_t len = buffer_get_size(pkbuffer);
uint8_t* buffer = buffer_get_ptr(pkbuffer);
bool success = false;
if(connection->type == UDP) {
success = write_udp(connection, buffer, len);
} else if(connection->type == TCP) {
success = write_tcp(connection, buffer, len);
};
buffer_free(pkbuffer);
return success;
}
void free_connection(Connection* connection) {
if (connection->type == TCP) {
close_tcp_stream(&connection->sock.tcp);
}
}
static bool create_udp_request(SocketAddr* addr, Connection* request) {
if ( create_udp_socket(addr->type, &request->sock.udp.udp) < 0) {
ERROR("Failed to connect to UDP socket: %s", strerror(errno));
return false;
}
request->sock.udp.clientaddr = *addr;
return true;
}
static bool create_tcp_request(SocketAddr* addr, Connection* request) {
if( connect_tcp_stream(addr, &request->sock.tcp) < 0) {
ERROR("Failed to connect to TCP socket: %s", strerror(errno));
return false;
}
return true;
}
bool create_request(BindingType type, SocketAddr* addr, Connection* request) {
request->type = type;
if (type == UDP) {
return create_udp_request(addr, request);
} else if (type == TCP) {
return create_tcp_request(addr, request);
} else {
return true;
}
}
bool request_packet(Connection* request, Packet* in, Packet* out) {
if (!write_connection(request, in)) {
return false;
}
if (!read_connection(request, out)) {
return false;
}
return true;
}
void free_request(Connection* connection) {
if (connection->type == UDP) {
close_udp_socket(&connection->sock.udp.udp);
} else if (connection->type == TCP) {
close_tcp_stream(&connection->sock.tcp);
}
}

View file

@ -1,42 +0,0 @@
#pragma once
#include "../packet/packet.h"
#include "addr.h"
#include <netinet/in.h>
typedef enum {
UDP,
TCP
} BindingType;
typedef struct {
BindingType type;
union {
UdpSocket udp;
TcpSocket tcp;
} sock;
} Binding;
void create_binding(BindingType type, uint16_t port, Binding* binding);
void free_binding(Binding* binding);
typedef struct {
BindingType type;
union {
struct {
UdpSocket udp;
SocketAddr clientaddr;
} udp;
TcpStream tcp;
} sock;
} Connection;
bool accept_connection(Binding* binding, Connection* connection);
bool read_connection(Connection* connection, Packet* packet);
bool write_connection(Connection* connection, Packet* packet);
void free_connection(Connection* connection);
bool create_request(BindingType type, SocketAddr* addr, Connection* request);
bool request_packet(Connection* request, Packet* in, Packet* out);
void free_request(Connection* connection);

View file

@ -1,176 +0,0 @@
#include <errno.h>
#include <stdio.h>
#include <stdlib.h>
#include "resolver.h"
#include "addr.h"
#include "binding.h"
#include "../io/log.h"
static bool lookup(
Question* question,
Packet* response,
BindingType type,
SocketAddr addr
) {
INIT_LOG_BUFFER(log)
LOGONLY(print_socket_addr(&addr, log);)
TRACE("Attempting lookup on fallback dns %s", log);
Connection request;
if (!create_request(type, &addr, &request)) {
return false;
}
Packet req;
memset(&req, 0, sizeof(Packet));
req.header.id = response->header.id;
req.header.opcode = response->header.opcode;
req.header.questions = 1;
req.header.recursion_desired = true;
req.questions = malloc(sizeof(Question));
req.questions[0] = *question;
if (!request_packet(&request, &req, response)) {
free_request(&request);
free(req.questions);
ERROR("Failed to request fallback dns: %s", strerror(errno));
return false;
}
free_request(&request);
free(req.questions);
return true;
}
static bool search(Question* question, Packet* result, BindingType type, const RecordMap* map) {
IpAddr addr;
uint8_t ip[4] = {9,9,9,9};
create_ip_addr(ip, &addr);
uint16_t port = 53;
SocketAddr saddr;
create_socket_addr(port, addr, &saddr);
while(1) {
if (record_map_get(map, question, result)) {
TRACE("Found answer in user defined config");
return true;
}
if (!lookup(question, result, type, saddr)) {
return false;
}
if (result->header.answers > 0 && result->header.rescode == NOERROR) {
return true;
}
if (result->header.rescode == NXDOMAIN) {
return true;
}
if (get_resolved_ns(result, question->domain, &addr)) {
continue;
}
Question new_question;
if (!get_unresoled_ns(result, question->domain, &new_question)) {
return true;
}
Packet recurse;
if (!search(&new_question, &recurse, type, map)) {
return false;
}
free_question(&new_question);
IpAddr random;
if (!get_random_a(&recurse, &random)) {
free_packet(&recurse);
return true;
} else {
free_packet(&recurse);
addr = random;
}
}
}
static void push_records(Record* from, uint8_t from_len, Record** to, uint8_t to_len) {
if(from_len < 1) return;
*to = realloc(*to, sizeof(Record) * (from_len + to_len));
memcpy(*to + to_len, from, from_len * sizeof(Record));
}
static void push_questions(Question* from, uint8_t from_len, Question** to, uint8_t to_len) {
if(from_len < 1) return;
*to = realloc(*to, sizeof(Question) * (from_len + to_len));
memcpy(*to + to_len, from, from_len * sizeof(Question));
}
void handle_query(const Packet* request, Packet* response, BindingType type, const RecordMap* map) {
memset(response, 0, sizeof(Packet));
response->header.id = request->header.id;
response->header.opcode = request->header.opcode;
response->header.recursion_desired = true;
response->header.recursion_available = true;
response->header.response = true;
if (request->header.questions < 1) {
response->header.response = FORMERR;
return;
}
for (uint16_t i = 0; i < request->header.questions; i++) {
Packet result;
memset(&result, 0, sizeof(Packet));
result.header.id = response->header.id;
if (!search(&request->questions[i], &result, type, map)) {
response->header.response = SERVFAIL;
break;
}
push_questions(
result.questions,
result.header.questions,
&response->questions,
response->header.questions
);
response->header.questions += result.header.questions;
push_records(
result.answers,
result.header.answers,
&response->answers,
response->header.answers
);
response->header.answers += result.header.answers;
push_records(
result.authorities,
result.header.authoritative_entries,
&response->authorities,
response->header.authoritative_entries
);
response->header.authoritative_entries += result.header.authoritative_entries;
push_records(
result.resources,
result.header.resource_entries,
&response->resources,
response->header.resource_entries
);
response->header.resource_entries += result.header.resource_entries;
if (result.header.z == false) {
// Not from cache
free(result.questions);
free(result.answers);
free(result.authorities);
free(result.resources);
} else {
response->header.z = true;
}
}
}

View file

@ -1,7 +0,0 @@
#pragma once
#include "../packet/packet.h"
#include "../io/map.h"
#include "binding.h"
void handle_query(const Packet* request, Packet* response, BindingType type, const RecordMap* map);

View file

@ -1,142 +0,0 @@
#include <errno.h>
#include <stdio.h>
#include <stdlib.h>
#include <unistd.h>
#include <sys/wait.h>
#include <signal.h>
#include <pthread.h>
#include "addr.h"
#include "server.h"
#include "resolver.h"
#include "../io/log.h"
#include "../io/map.h"
#include "../io/config.h"
static pthread_t udp, tcp;
static RecordMap map;
void server_init(uint16_t port, Server* server) {
INFO("Server port set to %hu", port);
create_binding(UDP, port, &server->udp);
create_binding(TCP, port, &server->tcp);
if (load_config("/etc/wrapper.conf", &map)) {
return;
}
if (load_config("config", &map)) {
return;
}
WARN("No dns config files were found");
}
struct DnsRequest {
Connection connection;
Packet request;
};
static void* server_respond(void* arg) {
struct DnsRequest req = *(struct DnsRequest*) arg;
INFO("Recieved packet request ID %hu", req.request.header.id);
Packet response;
handle_query(&req.request, &response, req.connection.type, &map);
if (!write_connection(&req.connection, &response)) {
ERROR("Failed to respond to connection ID %hu: %s",
req.request.header.id,
strerror(errno)
);
}
free_packet(&req.request);
free_connection(&req.connection);
if (response.header.z == false) {
free_packet(&response);
} else {
// Dont free from config
free(response.questions);
free(response.answers);
free(response.authorities);
free(response.resources);
}
return NULL;
}
static void* server_listen(void* arg) {
Binding* binding = (Binding*) arg;
while(1) {
Connection connection;
if (!accept_connection(binding, &connection)) {
ERROR("Failed to accept connection");
continue;
}
Packet request;
if (!read_connection(&connection, &request)) {
ERROR("Failed to read connection");
free_connection(&connection);
continue;
}
struct DnsRequest req;
req.connection = connection;
req.request = request;
pthread_t thread;
if(pthread_create(&thread, NULL, &server_respond, &req)) {
ERROR("Failed to create thread for dns request ID %hu: %s",
request.header.id,
strerror(errno)
);
free_packet(&request);
free_connection(&connection);
continue;
}
pthread_detach(thread);
}
pthread_detach(pthread_self());
return NULL;
}
static void signal_handler() {
printf("\n");
pthread_kill(udp, SIGTERM);
pthread_kill(tcp, SIGTERM);
}
void server_run(Server* server) {
if (!pthread_create(&udp, NULL, &server_listen, &server->udp)) {
INFO("Listening for connections on UDP");
} else {
ERROR("Failed to start UDP thread");
exit(EXIT_FAILURE);
}
if (!pthread_create(&tcp, NULL, &server_listen, &server->tcp)) {
INFO("Listening for connections on TCP");
} else {
ERROR("Failed to start TCP thread");
exit(EXIT_FAILURE);
}
signal(SIGINT, signal_handler);
pthread_join(udp, NULL);
pthread_join(tcp, NULL);
pthread_exit(0);
}
void server_free(Server* server) {
free_binding(&server->udp);
free_binding(&server->tcp);
record_map_free(&map);
}

View file

@ -1,12 +0,0 @@
#pragma once
#include "binding.h"
typedef struct {
Binding udp;
Binding tcp;
} Server;
void server_init(uint16_t port, Server* server);
void server_run(Server* server);
void server_free(Server* server);

156
src/web/api.rs Normal file
View file

@ -0,0 +1,156 @@
use std::net::IpAddr;
use axum::{
extract::Query,
response::Response,
routing::{get, post, put, delete},
Extension, Router,
};
use moka::future::Cache;
use rand::distributions::{Alphanumeric, DistString};
use serde::Deserialize;
use tower_cookies::{Cookie, Cookies};
use crate::{config::Config, database::Database, dns::packet::record::DnsRecord};
use super::{
extract::{Authorized, Body, RequestIp},
http::{json, text},
};
pub fn router() -> Router {
Router::new()
.route("/login", post(login))
.route("/domains", get(list_domains))
.route("/domains", delete(delete_domain))
.route("/records", get(get_domain))
.route("/records", put(add_record))
}
async fn list_domains(_: Authorized, Extension(database): Extension<Database>) -> Response {
let domains = match database.get_domains().await {
Ok(domains) => domains,
Err(err) => return text(500, &format!("{err}")),
};
let Ok(domains) = serde_json::to_string(&domains) else {
return text(500, "Failed to fetch domains")
};
json(200, &domains)
}
#[derive(Deserialize)]
struct DomainRequest {
domain: String,
}
async fn get_domain(
_: Authorized,
Extension(database): Extension<Database>,
Query(query): Query<DomainRequest>,
) -> Response {
let records = match database.get_domain(&query.domain).await {
Ok(records) => records,
Err(err) => return text(500, &format!("{err}")),
};
let Ok(records) = serde_json::to_string(&records) else {
return text(500, "Failed to fetch records")
};
json(200, &records)
}
async fn delete_domain(
_: Authorized,
Extension(database): Extension<Database>,
Body(body): Body,
) -> Response {
let Ok(request) = serde_json::from_str::<DomainRequest>(&body) else {
return text(400, "Missing request parameters")
};
let Ok(domains) = database.get_domains().await else {
return text(500, "Failed to delete domain")
};
if !domains.contains(&request.domain) {
return text(400, "Domain does not exist")
}
if database.delete_domain(request.domain).await.is_err() {
return text(500, "Failed to delete domain")
};
return text(204, "Successfully deleted domain")
}
async fn add_record(
_: Authorized,
Extension(database): Extension<Database>,
Body(body): Body,
) -> Response {
let Ok(record) = serde_json::from_str::<DnsRecord>(&body) else {
return text(400, "Invalid DNS record")
};
let allowed = record.get_qtype().allowed_actions();
if !allowed.1 {
return text(400, "Not allowed to create record")
}
let Ok(records) = database.get_records(&record.get_domain(), record.get_qtype()).await else {
return text(500, "Failed to complete record check");
};
if !records.is_empty() && !allowed.0 {
return text(400, "Not allowed to create duplicate record")
};
if records.contains(&record) {
return text(400, "Not allowed to create duplicate record")
}
if let Err(err) = database.add_record(record).await {
return text(500, &format!("{err}"));
}
return text(201, "Added record to database successfully");
}
#[derive(Deserialize)]
struct LoginRequest {
user: String,
pass: String,
}
async fn login(
Extension(config): Extension<Config>,
Extension(cache): Extension<Cache<String, IpAddr>>,
RequestIp(ip): RequestIp,
cookies: Cookies,
Body(body): Body,
) -> Response {
let Ok(request) = serde_json::from_str::<LoginRequest>(&body) else {
return text(400, "Missing request parameters")
};
if request.user != config.web_user || request.pass != config.web_pass {
return text(400, "Invalid credentials");
};
let token = Alphanumeric.sample_string(&mut rand::thread_rng(), 128);
cache.insert(token.clone(), ip).await;
let mut cookie = Cookie::new("auth", token);
cookie.set_secure(true);
cookie.set_http_only(true);
cookie.set_path("/");
cookies.add(cookie);
text(200, "Successfully logged in")
}

139
src/web/extract.rs Normal file
View file

@ -0,0 +1,139 @@
use std::{
io::Read,
net::{IpAddr, SocketAddr},
};
use axum::{
async_trait,
body::HttpBody,
extract::{ConnectInfo, FromRequest, FromRequestParts},
http::{request::Parts, Request},
response::Response,
BoxError,
};
use bytes::Bytes;
use moka::future::Cache;
use tower_cookies::Cookies;
use super::http::text;
pub struct Authorized;
#[async_trait]
impl<S> FromRequestParts<S> for Authorized
where
S: Send + Sync,
{
type Rejection = Response;
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
let Ok(Some(cookies)) = Option::<Cookies>::from_request_parts(parts, state).await else {
return Err(text(403, "No cookies provided"))
};
let Some(token) = cookies.get("auth") else {
return Err(text(403, "No auth token provided"))
};
let auth_ip: IpAddr;
{
let Some(cache) = parts.extensions.get::<Cache<String, IpAddr>>() else {
return Err(text(500, "Failed to load auth store"))
};
let Some(ip) = cache.get(token.value()) else {
return Err(text(401, "Unauthorized"))
};
auth_ip = ip
}
let Ok(Some(RequestIp(ip))) = Option::<RequestIp>::from_request_parts(parts, state).await else {
return Err(text(403, "You have no ip"))
};
if auth_ip != ip {
return Err(text(403, "Auth token does not match current ip"));
}
Ok(Self)
}
}
pub struct RequestIp(pub IpAddr);
#[async_trait]
impl<S> FromRequestParts<S> for RequestIp
where
S: Send + Sync,
{
type Rejection = Response;
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
let headers = &parts.headers;
let forwardedfor = headers
.get("x-forwarded-for")
.and_then(|h| h.to_str().ok())
.and_then(|h| {
h.split(',')
.rev()
.find_map(|s| s.trim().parse::<IpAddr>().ok())
});
if let Some(forwardedfor) = forwardedfor {
return Ok(Self(forwardedfor));
}
let realip = headers
.get("x-real-ip")
.and_then(|hv| hv.to_str().ok())
.and_then(|s| s.parse::<IpAddr>().ok());
if let Some(realip) = realip {
return Ok(Self(realip));
}
let realip = headers
.get("x-real-ip")
.and_then(|hv| hv.to_str().ok())
.and_then(|s| s.parse::<IpAddr>().ok());
if let Some(realip) = realip {
return Ok(Self(realip));
}
let info = parts.extensions.get::<ConnectInfo<SocketAddr>>();
if let Some(info) = info {
return Ok(Self(info.0.ip()));
}
Err(text(403, "You have no ip"))
}
}
pub struct Body(pub String);
#[async_trait]
impl<S, B> FromRequest<S, B> for Body
where
B: HttpBody + Sync + Send + 'static,
B::Data: Send,
B::Error: Into<BoxError>,
S: Send + Sync,
{
type Rejection = Response;
async fn from_request(req: Request<B>, state: &S) -> Result<Self, Self::Rejection> {
let Ok(bytes) = Bytes::from_request(req, state).await else {
return Err(text(413, "Payload too large"));
};
let Ok(body) = String::from_utf8(bytes.bytes().flatten().collect()) else {
return Err(text(400, "Invalid utf8 body"))
};
Ok(Self(body))
}
}

31
src/web/file.rs Normal file
View file

@ -0,0 +1,31 @@
use axum::{extract::Path, response::Response};
use super::http::serve;
pub async fn js(Path(path): Path<String>) -> Response {
let path = format!("/js/{path}");
serve(&path).await
}
pub async fn css(Path(path): Path<String>) -> Response {
let path = format!("/css/{path}");
serve(&path).await
}
pub async fn fonts(Path(path): Path<String>) -> Response {
let path = format!("/fonts/{path}");
serve(&path).await
}
pub async fn image(Path(path): Path<String>) -> Response {
let path = format!("/image/{path}");
serve(&path).await
}
pub async fn favicon() -> Response {
serve("/favicon.ico").await
}
pub async fn robots() -> Response {
serve("/robots.txt").await
}

50
src/web/http.rs Normal file
View file

@ -0,0 +1,50 @@
use axum::{
body::Body,
http::{header::HeaderName, HeaderValue, Request, StatusCode},
response::{IntoResponse, Response},
};
use std::str;
use tower::ServiceExt;
use tower_http::services::ServeFile;
pub fn text(code: u16, msg: &str) -> Response {
(status_code(code), msg.to_owned()).into_response()
}
pub fn json(code: u16, json: &str) -> Response {
let mut res = (status_code(code), json.to_owned()).into_response();
res.headers_mut().insert(
HeaderName::from_static("content-type"),
HeaderValue::from_static("application/json"),
);
res
}
pub async fn serve(path: &str) -> Response {
if !path.chars().any(|c| c == '.') {
return text(403, "Invalid file path");
}
let path = format!("public{path}");
let file = ServeFile::new(path);
let Ok(mut res) = file.oneshot(Request::new(Body::empty())).await else {
tracing::error!("Error while fetching file");
return text(500, "Error when fetching file")
};
if res.status() != StatusCode::OK {
return text(404, "File not found");
}
res.headers_mut().insert(
HeaderName::from_static("cache-control"),
HeaderValue::from_static("max-age=300"),
);
res.into_response()
}
fn status_code(code: u16) -> StatusCode {
StatusCode::from_u16(code).map_or(StatusCode::OK, |code| code)
}

82
src/web/mod.rs Normal file
View file

@ -0,0 +1,82 @@
use std::net::{IpAddr, SocketAddr, TcpListener};
use std::time::Duration;
use axum::routing::get;
use axum::{Extension, Router};
use moka::future::Cache;
use tokio::task::JoinHandle;
use tower_cookies::CookieManagerLayer;
use tracing::{error, info};
use crate::config::Config;
use crate::database::Database;
use crate::Result;
mod api;
mod extract;
mod file;
mod http;
mod pages;
pub struct WebServer {
config: Config,
database: Database,
addr: SocketAddr,
}
impl WebServer {
pub async fn new(config: Config, database: Database) -> Result<Self> {
let addr = format!("[::]:{}", config.web_port).parse::<SocketAddr>()?;
Ok(Self {
config,
database,
addr,
})
}
pub async fn run(&self) -> Result<JoinHandle<()>> {
let config = self.config.clone();
let database = self.database.clone();
let listener = TcpListener::bind(self.addr)?;
info!(
"Listening for HTTP traffic on [::]:{}",
self.config.web_port
);
let app = Self::router(config, database);
let server = axum::Server::from_tcp(listener)?;
let web_handle = tokio::spawn(async move {
if let Err(err) = server
.serve(app.into_make_service_with_connect_info::<SocketAddr>())
.await
{
error!("{err}");
}
});
Ok(web_handle)
}
fn router(config: Config, database: Database) -> Router {
let cache: Cache<String, IpAddr> = Cache::builder()
.time_to_live(Duration::from_secs(60 * 15))
.max_capacity(config.dns_cache_size)
.build();
Router::new()
.nest("/", pages::router())
.nest("/api", api::router())
.layer(Extension(config))
.layer(Extension(cache))
.layer(Extension(database))
.layer(CookieManagerLayer::new())
.route("/js/*path", get(file::js))
.route("/css/*path", get(file::css))
.route("/fonts/*path", get(file::fonts))
.route("/image/*path", get(file::image))
.route("/favicon.ico", get(file::favicon))
.route("/robots.txt", get(file::robots))
}
}

31
src/web/pages.rs Normal file
View file

@ -0,0 +1,31 @@
use axum::{response::Response, routing::get, Router};
use super::{extract::Authorized, http::serve};
pub fn router() -> Router {
Router::new()
.route("/", get(root))
.route("/login", get(login))
.route("/home", get(home))
.route("/domain", get(domain))
}
async fn root(user: Option<Authorized>) -> Response {
if user.is_some() {
home().await
} else {
login().await
}
}
async fn login() -> Response {
serve("/login.html").await
}
async fn home() -> Response {
serve("/home.html").await
}
async fn domain() -> Response {
serve("/domain.html").await
}