new c version
This commit is contained in:
parent
b1fb410aff
commit
bb85374b79
75 changed files with 2403 additions and 5582 deletions
2
.gitignore
vendored
2
.gitignore
vendored
|
@ -1,2 +0,0 @@
|
|||
**/target
|
||||
.env
|
2447
Cargo.lock
generated
2447
Cargo.lock
generated
File diff suppressed because it is too large
Load diff
37
Cargo.toml
37
Cargo.toml
|
@ -1,37 +0,0 @@
|
|||
[package]
|
||||
name = "wrapper"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
[dependencies]
|
||||
# Blazingly fast runtime
|
||||
tokio = { version = "1", features = ["full"] }
|
||||
tracing-subscriber = "0.3.16"
|
||||
tracing = "0.1.37"
|
||||
|
||||
# Allow recursion inside tokio async
|
||||
async-recursion = "1"
|
||||
|
||||
# DNS Caching Layer
|
||||
moka = { version = "0.10.0", features = ["future"] }
|
||||
|
||||
# Mongodb
|
||||
mongodb = { version = "2.4", features = ["tokio-sync"] }
|
||||
futures = "0.3.26"
|
||||
|
||||
# Convert values to json for Mongodb
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
|
||||
# Reading env vars from .env
|
||||
dotenv = "0.15.0"
|
||||
|
||||
# For the meme records
|
||||
rand = "0.8.5"
|
||||
|
||||
# For the http web frontend
|
||||
axum = "0.6.4"
|
||||
tower-http = { version = "0.4.0", features = ["fs"] }
|
||||
tower-cookies = "0.9.0"
|
||||
tower = "0.4.13"
|
||||
bytes = "1.4.0"
|
||||
serde_json = "1"
|
35
Makefile
Normal file
35
Makefile
Normal file
|
@ -0,0 +1,35 @@
|
|||
CC = gcc
|
||||
|
||||
INCFLAGS = -Isrc
|
||||
|
||||
CCFLAGS = -std=c17 -Wall -Wextra -O2
|
||||
CCFLAGS += $(INCFLAGS)
|
||||
|
||||
LDFLAGS += $(INCFLAGS)
|
||||
|
||||
BIN = bin
|
||||
APP = $(BIN)/app
|
||||
SRC = $(shell find src -name "*.c")
|
||||
OBJ = $(SRC:%.c=$(BIN)/%.o)
|
||||
|
||||
.PHONY: clean build
|
||||
|
||||
dirs:
|
||||
mkdir -p ./$(BIN)
|
||||
mkdir -p ./$(BIN)/src
|
||||
mkdir -p ./$(BIN)/src/io
|
||||
mkdir -p ./$(BIN)/src/packet
|
||||
mkdir -p ./$(BIN)/src/server
|
||||
|
||||
run: build
|
||||
$(APP)
|
||||
|
||||
build: dirs ${OBJ}
|
||||
${CC} -o $(APP) $(filter %.o,$^) $(LDFLAGS)
|
||||
|
||||
$(BIN)/%.o: %.c
|
||||
$(CC) -o $@ -c $< $(CCFLAGS)
|
||||
|
||||
clean:
|
||||
rm -rf $(APP)
|
||||
rm -rf $(BIN)
|
BIN
bin/app
Executable file
BIN
bin/app
Executable file
Binary file not shown.
BIN
bin/src/io/log.o
Normal file
BIN
bin/src/io/log.o
Normal file
Binary file not shown.
BIN
bin/src/main.o
Normal file
BIN
bin/src/main.o
Normal file
Binary file not shown.
BIN
bin/src/packet/buffer.o
Normal file
BIN
bin/src/packet/buffer.o
Normal file
Binary file not shown.
BIN
bin/src/packet/header.o
Normal file
BIN
bin/src/packet/header.o
Normal file
Binary file not shown.
BIN
bin/src/packet/packet.o
Normal file
BIN
bin/src/packet/packet.o
Normal file
Binary file not shown.
BIN
bin/src/packet/question.o
Normal file
BIN
bin/src/packet/question.o
Normal file
Binary file not shown.
BIN
bin/src/packet/record.o
Normal file
BIN
bin/src/packet/record.o
Normal file
Binary file not shown.
BIN
bin/src/server/addr.o
Normal file
BIN
bin/src/server/addr.o
Normal file
Binary file not shown.
BIN
bin/src/server/binding.o
Normal file
BIN
bin/src/server/binding.o
Normal file
Binary file not shown.
BIN
bin/src/server/resolver.o
Normal file
BIN
bin/src/server/resolver.o
Normal file
Binary file not shown.
BIN
bin/src/server/server.o
Normal file
BIN
bin/src/server/server.o
Normal file
Binary file not shown.
|
@ -1,40 +0,0 @@
|
|||
span {
|
||||
margin-top: 5rem;
|
||||
margin-bottom: 1rem;
|
||||
width: 45rem;
|
||||
font-size: 2em;
|
||||
}
|
||||
|
||||
#new {
|
||||
display: flex;
|
||||
justify-content: center;
|
||||
width: 100%;
|
||||
padding-top: 2rem;
|
||||
padding-bottom: 1rem;
|
||||
border-bottom: solid 1px var(--gray);
|
||||
}
|
||||
|
||||
#new input, .block {
|
||||
border-radius: 1rem 0 0 1rem;
|
||||
width: 40rem;
|
||||
}
|
||||
|
||||
.block {
|
||||
width: 33em;
|
||||
}
|
||||
|
||||
#new button {
|
||||
border-radius: 0 1rem 1rem 0;
|
||||
}
|
||||
|
||||
.domain {
|
||||
margin-top: 2rem;
|
||||
}
|
||||
|
||||
.domain .delete {
|
||||
border-radius: 0 1rem 1rem 0;
|
||||
}
|
||||
|
||||
.domain .edit {
|
||||
border-radius: 0;
|
||||
}
|
|
@ -1,18 +0,0 @@
|
|||
#login {
|
||||
margin-top: 20em;
|
||||
}
|
||||
|
||||
#logo {
|
||||
font-size: 6em;
|
||||
font-weight: 750;
|
||||
font-family: bold;
|
||||
margin-bottom: 2rem;
|
||||
}
|
||||
|
||||
form {
|
||||
width: 30rem;
|
||||
}
|
||||
|
||||
form input {
|
||||
width: 100%;
|
||||
}
|
|
@ -1,119 +0,0 @@
|
|||
:root {
|
||||
--dark: #222428;
|
||||
--dark-alternate: #2b2e36;
|
||||
--header: #1e1e22;
|
||||
|
||||
--accent: #8849f5;
|
||||
--accent-alternate: #6829d5;
|
||||
--gray: #2f2f3f;
|
||||
--main: #ffffff;
|
||||
--main-alternate: #cccccc;
|
||||
}
|
||||
|
||||
* {
|
||||
padding: 0;
|
||||
margin: 0;
|
||||
}
|
||||
|
||||
@font-face {
|
||||
font-family: main;
|
||||
src: url("../fonts/helvetica.ttf") format("truetype");
|
||||
font-display: swap;
|
||||
}
|
||||
|
||||
@font-face {
|
||||
font-family: bold;
|
||||
src: url("../fonts/overpass-bold.otf") format("opentype");
|
||||
font-display: swap;
|
||||
}
|
||||
|
||||
@font-face {
|
||||
font-family: bold-italic;
|
||||
src: url("../fonts/overpass-bold-italic.otf") format("opentype");
|
||||
font-display: swap;
|
||||
}
|
||||
|
||||
html {
|
||||
background-color: var(--dark);
|
||||
font-family: main;
|
||||
color: var(--main);
|
||||
width: 100%;
|
||||
height: 100%;
|
||||
}
|
||||
|
||||
body {
|
||||
width: 100%;
|
||||
height: 100%;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
align-items: center;
|
||||
}
|
||||
|
||||
.accent {
|
||||
color: var(--accent);
|
||||
}
|
||||
|
||||
.fill {
|
||||
width: 100%;
|
||||
height: 100%;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
align-items: center;
|
||||
}
|
||||
|
||||
input, button, .block {
|
||||
all: unset;
|
||||
display: inline-block;
|
||||
font: main;
|
||||
background-color: var(--dark-alternate);
|
||||
font-size: 1rem;
|
||||
padding: 1rem;
|
||||
border-radius: 1rem;
|
||||
margin-bottom: 20px;
|
||||
}
|
||||
|
||||
button {
|
||||
background-color: var(--accent);
|
||||
width: 5em;
|
||||
text-align: center;
|
||||
}
|
||||
|
||||
button:hover {
|
||||
cursor: pointer;
|
||||
background-color: var(--accent-alternate);
|
||||
}
|
||||
|
||||
.delete {
|
||||
background-color: #f54842;
|
||||
}
|
||||
|
||||
.delete:hover {
|
||||
cursor: pointer;
|
||||
background-color: #d52822;
|
||||
}
|
||||
|
||||
form {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
}
|
||||
|
||||
#header {
|
||||
width: calc(100% - 4rem);
|
||||
background-color: var(--header);
|
||||
border-bottom: solid 1px var(--gray);
|
||||
padding: 1rem;
|
||||
padding-left: 3rem;
|
||||
}
|
||||
|
||||
#logo {
|
||||
font-size: 2em;
|
||||
font-weight: 500;
|
||||
font-family: bold;
|
||||
}
|
||||
|
||||
#title {
|
||||
font-size: 2em;
|
||||
font-weight: 300;
|
||||
font-family: sans-serif;
|
||||
padding-left: 1em;
|
||||
}
|
|
@ -1,67 +0,0 @@
|
|||
#buttons {
|
||||
margin-top: 2rem;
|
||||
width: 50rem;
|
||||
}
|
||||
|
||||
#buttons button {
|
||||
margin: 0;
|
||||
margin-right: 2rem;
|
||||
border-radius: 10px;
|
||||
width: auto;
|
||||
padding: .75rem 1rem;
|
||||
}
|
||||
|
||||
.record {
|
||||
width: 50rem;
|
||||
background-color: var(--header);
|
||||
padding: 1rem;
|
||||
margin-top: 2rem;
|
||||
}
|
||||
|
||||
.header {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
margin-bottom: 1rem;
|
||||
}
|
||||
|
||||
.header span {
|
||||
font-family: bold;
|
||||
}
|
||||
|
||||
.header button {
|
||||
margin: 0;
|
||||
margin-left: 2rem;
|
||||
padding: .5rem 1rem;
|
||||
width: auto;
|
||||
border-radius: 5px;
|
||||
}
|
||||
|
||||
.type {
|
||||
margin-right: 1rem;
|
||||
background-color: var(--accent);
|
||||
padding: .25rem .5rem;
|
||||
border-radius: 5px;
|
||||
}
|
||||
|
||||
.domain {
|
||||
color: var(--main-alternate);
|
||||
flex-grow: 1;
|
||||
}
|
||||
|
||||
.properties {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
}
|
||||
|
||||
.poperty {
|
||||
display: flex;
|
||||
flex-direction: row;
|
||||
border-bottom: solid 1px var(--gray);
|
||||
margin-top: 1rem;
|
||||
}
|
||||
|
||||
.key {
|
||||
font-family: bold;
|
||||
width: 5rem;
|
||||
}
|
||||
|
|
@ -1,21 +0,0 @@
|
|||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>Wrapper - Records</title>
|
||||
|
||||
<meta name="author" content="Tyler Murphy">
|
||||
<meta name="description" content="wrapper records">
|
||||
|
||||
<meta property="og:title" content="wrapper">
|
||||
<meta property="og:description" content="wrapper records">
|
||||
|
||||
<link rel="stylesheet" href="/css/main.css">
|
||||
<link rel="stylesheet" href="/css/record.css">
|
||||
|
||||
<script type="module" src="/js/domain.js"></script>
|
||||
</head>
|
||||
<body>
|
||||
</body>
|
||||
</html>
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
@ -1,21 +0,0 @@
|
|||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>Wrapper - Domains</title>
|
||||
|
||||
<meta name="author" content="Tyler Murphy">
|
||||
<meta name="description" content="wrapper domains">
|
||||
|
||||
<meta property="og:title" content="wrapper">
|
||||
<meta property="og:description" content="wrapper domains">
|
||||
|
||||
<link rel="stylesheet" href="/css/main.css">
|
||||
<link rel="stylesheet" href="/css/home.css">
|
||||
|
||||
<script type="module" src="/js/home.js"></script>
|
||||
</head>
|
||||
<body>
|
||||
</body>
|
||||
</html>
|
|
@ -1,51 +0,0 @@
|
|||
const endpoint = '/api'
|
||||
|
||||
const request = async (url, method, body) => {
|
||||
|
||||
let response;
|
||||
|
||||
if (method == 'GET') {
|
||||
response = await fetch(endpoint + url, {
|
||||
method,
|
||||
headers: {
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
});
|
||||
} else {
|
||||
response = await fetch(endpoint + url, {
|
||||
method,
|
||||
body: JSON.stringify(body),
|
||||
headers: {
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
if (response.status == 401) {
|
||||
location.href = '/login'
|
||||
}
|
||||
const contentType = response.headers.get("content-type");
|
||||
if (contentType && contentType.indexOf("application/json") !== -1) {
|
||||
const json = await response.json()
|
||||
return { status: response.status, msg: json.msg, json }
|
||||
} else {
|
||||
const msg = await response.text();
|
||||
return { status: response.status, msg }
|
||||
}
|
||||
}
|
||||
|
||||
export const login = async (user, pass) => {
|
||||
return await request('/login', 'POST', {user, pass})
|
||||
}
|
||||
|
||||
export const domains = async () => {
|
||||
return await request('/domains', 'GET')
|
||||
}
|
||||
|
||||
export const del_domain = async (domain) => {
|
||||
return await request('/domains', 'DELETE', {domain})
|
||||
}
|
||||
|
||||
export const records = async (domain) => {
|
||||
return await request(`/records?domain=${domain}`, 'GET')
|
||||
}
|
|
@ -1,12 +0,0 @@
|
|||
import { div, parse, span } from './main.js';
|
||||
|
||||
export function header(title) {
|
||||
return div({id: 'header'},
|
||||
span({id: 'logo', class: 'accent'},
|
||||
parse("Wrapper")
|
||||
),
|
||||
span({id: 'title'},
|
||||
parse(title)
|
||||
),
|
||||
)
|
||||
}
|
|
@ -1,95 +0,0 @@
|
|||
import { del_domain, domains, records } from './api.js'
|
||||
import { header } from './components.js'
|
||||
import { body, parse, div, input, button, span, is_domain } from './main.js';
|
||||
|
||||
function render(domain, records) {
|
||||
|
||||
let divs = []
|
||||
for (const record of records) {
|
||||
divs.push(gen_record(record))
|
||||
}
|
||||
|
||||
document.body.replaceWith(
|
||||
body({},
|
||||
header(domain),
|
||||
div({id: 'buttons'},
|
||||
button({onclick: (event) => {
|
||||
location.href = '/home'
|
||||
}}, parse("Home")),
|
||||
button({}, parse("New Record")),
|
||||
),
|
||||
...divs
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
function gen_record(record) {
|
||||
let domain = record.domain
|
||||
let prefix = record.prefix
|
||||
|
||||
if (prefix.length > 0) {
|
||||
prefix = prefix + '.'
|
||||
}
|
||||
|
||||
let type = Object.keys(record.record)[0]
|
||||
let data = record.record[type]
|
||||
|
||||
let divs = []
|
||||
for (const key in data) {
|
||||
let disp_key;
|
||||
if (key == 'ttl') {
|
||||
disp_key = 'TTL'
|
||||
} else {
|
||||
disp_key = upper(key)
|
||||
}
|
||||
divs.push(
|
||||
div({class: 'poperty'},
|
||||
div({class: 'key'}, parse(disp_key)),
|
||||
div({class: 'value'}, parse(data[key])),
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
return div({class: 'record'},
|
||||
div({class: 'header'},
|
||||
span({class: 'type'}, parse(type)),
|
||||
span({class: 'prefix'}, parse(prefix)),
|
||||
span({class: 'domain'}, parse(domain)),
|
||||
button({}, parse("Edit")),
|
||||
button({class: 'delete'}, parse("Delete"))
|
||||
),
|
||||
div({class: 'properties'},
|
||||
...divs
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
function upper(string) {
|
||||
return string.charAt(0).toUpperCase() + string.slice(1);
|
||||
}
|
||||
|
||||
async function init() {
|
||||
|
||||
const params = new Proxy(new URLSearchParams(window.location.search), {
|
||||
get: (searchParams, prop) => searchParams.get(prop),
|
||||
});
|
||||
|
||||
let domain = params.domain;
|
||||
|
||||
if (!is_domain(domain)) {
|
||||
location.href = '/home'
|
||||
return
|
||||
}
|
||||
|
||||
let res = await records(domain);
|
||||
|
||||
if (res.status !== 200) {
|
||||
alert(res.msg)
|
||||
return
|
||||
}
|
||||
|
||||
render(domain, res.json)
|
||||
|
||||
}
|
||||
|
||||
init()
|
|
@ -1,91 +0,0 @@
|
|||
import { del_domain, domains } from './api.js'
|
||||
import { header } from './components.js'
|
||||
import { body, parse, div, input, button, span, is_domain } from './main.js';
|
||||
|
||||
function render(domains) {
|
||||
document.body.replaceWith(
|
||||
body({},
|
||||
header('domains'),
|
||||
div({id: 'new'},
|
||||
input({
|
||||
type: 'text',
|
||||
name: 'domain',
|
||||
id: 'domain',
|
||||
placeholder: 'Type domain name to create new records',
|
||||
autocomplete: "off",
|
||||
}),
|
||||
button({onclick: () => {
|
||||
let domain = document.getElementById('domain').value
|
||||
|
||||
if (!is_domain(domain)) {
|
||||
alert("Invalid domain")
|
||||
return
|
||||
}
|
||||
|
||||
location.href = '/domain?domain='+domain
|
||||
}},
|
||||
parse("Create")
|
||||
)
|
||||
),
|
||||
...domain(domains)
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
function domain(domains) {
|
||||
let divs = []
|
||||
for (const domain of domains) {
|
||||
divs.push(
|
||||
div({class: 'domain'},
|
||||
div({class: 'block'},
|
||||
parse(domain)
|
||||
),
|
||||
button({class: 'edit', onclick: (event) => {
|
||||
console.log(event.target.parentElement)
|
||||
let domain = event
|
||||
.target
|
||||
.parentElement
|
||||
.getElementsByClassName('block')[0]
|
||||
.innerText
|
||||
|
||||
if (!is_domain(domain)) {
|
||||
alert("Invalid domain")
|
||||
return
|
||||
}
|
||||
|
||||
location.href = '/domain?domain='+domain
|
||||
}},
|
||||
parse("Edit")
|
||||
),
|
||||
button({class: 'delete', onclick: async () => {
|
||||
let res = await del_domain(domain)
|
||||
|
||||
if (res.status != 204) {
|
||||
alert(res.msg)
|
||||
return
|
||||
}
|
||||
|
||||
location.reload()
|
||||
}},
|
||||
parse("Delete")
|
||||
)
|
||||
)
|
||||
)
|
||||
}
|
||||
return divs
|
||||
}
|
||||
|
||||
async function init() {
|
||||
|
||||
let res = await domains();
|
||||
|
||||
if (res.status !== 200) {
|
||||
alert(res.msg)
|
||||
return
|
||||
}
|
||||
|
||||
render(res.json)
|
||||
|
||||
}
|
||||
|
||||
init()
|
|
@ -1,44 +0,0 @@
|
|||
import { body, div, form, input, p, parse, span} from './main.js'
|
||||
import { login } from './api.js'
|
||||
|
||||
function render() {
|
||||
document.body.replaceWith(
|
||||
body({},
|
||||
div({id: 'login', class: 'fill'},
|
||||
span({id: 'logo'},
|
||||
span({class: 'accent'}, parse('Wrapper'))
|
||||
),
|
||||
form({autocomplete: "off"},
|
||||
input({
|
||||
type: 'text',
|
||||
name: 'user',
|
||||
id: 'user',
|
||||
placeholder: 'Username',
|
||||
autofocus: 1
|
||||
}),
|
||||
input({
|
||||
type: 'password',
|
||||
name: 'pass',
|
||||
id: 'pass',
|
||||
placeholder: 'Password',
|
||||
onkeydown: async (event) => {
|
||||
if (event.key == 'Enter') {
|
||||
event.preventDefault()
|
||||
let user = document.getElementById('user').value
|
||||
let pass = document.getElementById('pass').value
|
||||
|
||||
let res = await login(user, pass)
|
||||
|
||||
if (res.status === 200) {
|
||||
location.href = '/home'
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
render()
|
|
@ -1,136 +0,0 @@
|
|||
function createElement(name, attrs, ...children) {
|
||||
const el = document.createElement(name);
|
||||
|
||||
for (const attr in attrs) {
|
||||
if(attr.startsWith("on")) {
|
||||
el[attr] = attrs[attr];
|
||||
} else {
|
||||
el.setAttribute(attr, attrs[attr])
|
||||
}
|
||||
}
|
||||
|
||||
for (const child of children) {
|
||||
if (child == null) {
|
||||
continue
|
||||
}
|
||||
el.appendChild(child)
|
||||
}
|
||||
|
||||
return el
|
||||
}
|
||||
|
||||
export function createElementNS(name, attrs, ...children) {
|
||||
var svgns = "http://www.w3.org/2000/svg";
|
||||
var el = document.createElementNS(svgns, name);
|
||||
|
||||
for (const attr in attrs) {
|
||||
if(attr.startsWith("on")) {
|
||||
el[attr] = attrs[attr];
|
||||
} else {
|
||||
el.setAttribute(attr, attrs[attr])
|
||||
}
|
||||
}
|
||||
|
||||
for (const child of children) {
|
||||
if (child == null) {
|
||||
continue
|
||||
}
|
||||
el.appendChild(child)
|
||||
}
|
||||
|
||||
return el
|
||||
}
|
||||
|
||||
export function p(attrs, ...children) {
|
||||
return createElement("p", attrs, ...children)
|
||||
}
|
||||
|
||||
export function span(attrs, ...children) {
|
||||
return createElement("span", attrs, ...children)
|
||||
}
|
||||
|
||||
export function div(attrs, ...children) {
|
||||
return createElement("div", attrs, ...children)
|
||||
}
|
||||
|
||||
export function a(attrs, ...children) {
|
||||
return createElement("a", attrs, ...children)
|
||||
}
|
||||
|
||||
export function i(attrs, ...children) {
|
||||
return createElement("i", attrs, ...children)
|
||||
}
|
||||
|
||||
export function form(attrs, ...children) {
|
||||
return createElement("form", attrs, ...children)
|
||||
}
|
||||
|
||||
export function img(alt, attrs, ...children) {
|
||||
attrs['onerror'] = (event) => event.target.remove()
|
||||
attrs['alt'] = alt
|
||||
return createElement("img", attrs, ...children)
|
||||
}
|
||||
|
||||
export function input(attrs, ...children) {
|
||||
return createElement("input", attrs, ...children)
|
||||
}
|
||||
|
||||
export function button(attrs, ...children) {
|
||||
return createElement("button", attrs, ...children)
|
||||
}
|
||||
|
||||
export function path(attrs, ...children) {
|
||||
return createElementNS("path", attrs, ...children)
|
||||
}
|
||||
|
||||
export function svg(attrs, ...children) {
|
||||
return createElementNS("svg", attrs, ...children)
|
||||
}
|
||||
|
||||
export function body(attrs, ...children) {
|
||||
return createElement("body", attrs, ...children)
|
||||
}
|
||||
|
||||
export function textarea(attrs, ...children) {
|
||||
return createElement("textarea", attrs, ...children)
|
||||
}
|
||||
|
||||
export function parse(input) {
|
||||
const pattern = /^[ a-zA-Z0-9!@#$%^&*()_+\-=\[\]{};':"\\|,.<>\/?]*$/;
|
||||
|
||||
input = input + '';
|
||||
|
||||
if (!pattern.test(input)) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const sanitized = input.replace(/</g, '<').replace(/>/g, '>');
|
||||
return document.createRange().createContextualFragment(sanitized);
|
||||
}
|
||||
|
||||
export function is_domain(domain) {
|
||||
domain = domain.toLowerCase()
|
||||
|
||||
const pattern = /^[a-z0-9_\-.]*$/;
|
||||
if (!pattern.test(domain)) {
|
||||
return false
|
||||
}
|
||||
|
||||
let parts = domain.split('.').reverse()
|
||||
for (const part of parts) {
|
||||
if (part.length < 1) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
if (parts.length < 2 || parts[0].length < 2) {
|
||||
return false
|
||||
}
|
||||
|
||||
const tld_pattern = /^[a-z]*$/;
|
||||
if (!tld_pattern.test(parts[0])) {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
|
@ -1,21 +0,0 @@
|
|||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>Wrapper - Login</title>
|
||||
|
||||
<meta name="author" content="Tyler Murphy">
|
||||
<meta name="description" content="wrapper dns login">
|
||||
|
||||
<meta property="og:title" content="wrapper">
|
||||
<meta property="og:description" content="wrapper dns login">
|
||||
|
||||
<link rel="stylesheet" href="/css/main.css">
|
||||
<link rel="stylesheet" href="/css/login.css">
|
||||
|
||||
<script type="module" src="/js/login.js"></script>
|
||||
</head>
|
||||
<body>
|
||||
</body>
|
||||
</html>
|
|
@ -1,9 +0,0 @@
|
|||
User-agent: Googlebot
|
||||
Disallow: /api
|
||||
|
||||
User-agent: Googlebot
|
||||
User-agent: AdsBot-Google
|
||||
Disallow: /api
|
||||
|
||||
User-agent: *
|
||||
Disallow: /api
|
|
@ -1,57 +0,0 @@
|
|||
use std::{env, net::IpAddr, str::FromStr, fmt::Display};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Config {
|
||||
pub dns_fallback: IpAddr,
|
||||
pub dns_port: u16,
|
||||
pub dns_cache_size: u64,
|
||||
|
||||
pub db_host: String,
|
||||
pub db_port: u16,
|
||||
pub db_user: String,
|
||||
pub db_pass: String,
|
||||
|
||||
pub web_user: String,
|
||||
pub web_pass: String,
|
||||
pub web_port: u16,
|
||||
}
|
||||
|
||||
impl Config {
|
||||
pub fn new() -> Self {
|
||||
let dns_port = Self::get_var::<u16>("WRAPPER_DNS_PORT", 53);
|
||||
let dns_fallback = Self::get_var::<IpAddr>("WRAPPER_FALLBACK_DNS", [9, 9, 9, 9].into());
|
||||
let dns_cache_size = Self::get_var::<u64>("WRAPPER_CACHE_SIZE", 1000);
|
||||
|
||||
let db_host = Self::get_var::<String>("WRAPPER_DB_HOST", String::from("localhost"));
|
||||
let db_port = Self::get_var::<u16>("WRAPPER_DB_PORT", 27017);
|
||||
let db_user = Self::get_var::<String>("WRAPPER_DB_USER", String::from("root"));
|
||||
let db_pass = Self::get_var::<String>("WRAPPER_DB_PASS", String::from(""));
|
||||
|
||||
let web_user = Self::get_var::<String>("WRAPPER_WEB_USER", String::from("admin"));
|
||||
let web_pass = Self::get_var::<String>("WRAPPER_WEB_PASS", String::from("wrapper"));
|
||||
let web_port = Self::get_var::<u16>("WRAPPER_WEB_PORT", 80);
|
||||
|
||||
Self {
|
||||
dns_fallback,
|
||||
dns_port,
|
||||
dns_cache_size,
|
||||
|
||||
db_host,
|
||||
db_port,
|
||||
db_user,
|
||||
db_pass,
|
||||
|
||||
web_user,
|
||||
web_pass,
|
||||
web_port,
|
||||
}
|
||||
}
|
||||
|
||||
fn get_var<T>(name: &str, default: T) -> T
|
||||
where
|
||||
T: FromStr + Display,
|
||||
{
|
||||
let env = env::var(name).unwrap_or(format!("{default}"));
|
||||
env.parse::<T>().unwrap_or(default)
|
||||
}
|
||||
}
|
|
@ -1,146 +0,0 @@
|
|||
use futures::TryStreamExt;
|
||||
use mongodb::{
|
||||
bson::doc,
|
||||
options::{ClientOptions, Credential, ServerAddress},
|
||||
Client,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tracing::info;
|
||||
|
||||
use crate::{
|
||||
config::Config,
|
||||
dns::packet::{query::QueryType, record::DnsRecord},
|
||||
};
|
||||
|
||||
use crate::Result;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Database {
|
||||
client: Client,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct StoredRecord {
|
||||
record: DnsRecord,
|
||||
domain: String,
|
||||
prefix: String,
|
||||
}
|
||||
|
||||
impl StoredRecord {
|
||||
fn get_domain_parts(domain: &str) -> (String, String) {
|
||||
let parts: Vec<&str> = domain.split(".").collect();
|
||||
let len = parts.len();
|
||||
if len == 1 {
|
||||
(String::new(), String::from(parts[0]))
|
||||
} else if len == 2 {
|
||||
(String::new(), String::from(parts.join(".")))
|
||||
} else {
|
||||
(
|
||||
String::from(parts[0..len - 2].join(".")),
|
||||
String::from(parts[len - 2..len].join(".")),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<DnsRecord> for StoredRecord {
|
||||
fn from(record: DnsRecord) -> Self {
|
||||
let (prefix, domain) = Self::get_domain_parts(&record.get_domain());
|
||||
Self {
|
||||
record,
|
||||
domain,
|
||||
prefix,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Into<DnsRecord> for StoredRecord {
|
||||
fn into(self) -> DnsRecord {
|
||||
self.record
|
||||
}
|
||||
}
|
||||
|
||||
impl Database {
|
||||
pub async fn new(config: Config) -> Result<Self> {
|
||||
let options = ClientOptions::builder()
|
||||
.hosts(vec![ServerAddress::Tcp {
|
||||
host: config.db_host,
|
||||
port: Some(config.db_port),
|
||||
}])
|
||||
.credential(
|
||||
Credential::builder()
|
||||
.username(config.db_user)
|
||||
.password(config.db_pass)
|
||||
.build(),
|
||||
)
|
||||
.max_pool_size(100)
|
||||
.app_name(String::from("wrapper"))
|
||||
.build();
|
||||
|
||||
let client = Client::with_options(options)?;
|
||||
|
||||
client
|
||||
.database("wrapper")
|
||||
.run_command(doc! {"ping": 1}, None)
|
||||
.await?;
|
||||
|
||||
info!("Connection to mongodb successfully");
|
||||
|
||||
Ok(Database { client })
|
||||
}
|
||||
|
||||
pub async fn get_records(&self, domain: &str, qtype: QueryType) -> Result<Vec<DnsRecord>> {
|
||||
let (prefix, domain) = StoredRecord::get_domain_parts(domain);
|
||||
Ok(self
|
||||
.get_domain(&domain)
|
||||
.await?
|
||||
.into_iter()
|
||||
.filter(|r| r.prefix == prefix)
|
||||
.filter(|r| {
|
||||
let rqtype = r.record.get_qtype();
|
||||
if qtype == QueryType::A {
|
||||
return rqtype == QueryType::A || rqtype == QueryType::AR;
|
||||
} else if qtype == QueryType::AAAA {
|
||||
return rqtype == QueryType::AAAA || rqtype == QueryType::AAAAR;
|
||||
} else {
|
||||
r.record.get_qtype() == qtype
|
||||
}
|
||||
})
|
||||
.map(|r| r.into())
|
||||
.collect())
|
||||
}
|
||||
|
||||
pub async fn get_domain(&self, domain: &str) -> Result<Vec<StoredRecord>> {
|
||||
let db = self.client.database("wrapper");
|
||||
let col = db.collection::<StoredRecord>(domain);
|
||||
|
||||
let filter = doc! { "domain": domain };
|
||||
let mut cursor = col.find(filter, None).await?;
|
||||
|
||||
let mut records = Vec::new();
|
||||
while let Some(record) = cursor.try_next().await? {
|
||||
records.push(record);
|
||||
}
|
||||
|
||||
Ok(records)
|
||||
}
|
||||
|
||||
pub async fn add_record(&self, record: DnsRecord) -> Result<()> {
|
||||
let record = StoredRecord::from(record);
|
||||
let db = self.client.database("wrapper");
|
||||
let col = db.collection::<StoredRecord>(&record.domain);
|
||||
col.insert_one(record, None).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn get_domains(&self) -> Result<Vec<String>> {
|
||||
let db = self.client.database("wrapper");
|
||||
Ok(db.list_collection_names(None).await?)
|
||||
}
|
||||
|
||||
pub async fn delete_domain(&self, domain: String) -> Result<()> {
|
||||
let db = self.client.database("wrapper");
|
||||
let col = db.collection::<StoredRecord>(&domain);
|
||||
Ok(col.drop(None).await?)
|
||||
}
|
||||
}
|
|
@ -1,144 +0,0 @@
|
|||
use std::{
|
||||
net::{IpAddr, SocketAddr},
|
||||
sync::Arc,
|
||||
};
|
||||
|
||||
use super::packet::{buffer::PacketBuffer, Packet};
|
||||
use crate::Result;
|
||||
use tokio::{
|
||||
io::{AsyncReadExt, AsyncWriteExt},
|
||||
net::{TcpListener, TcpStream, UdpSocket},
|
||||
};
|
||||
use tracing::trace;
|
||||
|
||||
pub enum Binding {
|
||||
UDP(Arc<UdpSocket>),
|
||||
TCP(TcpListener),
|
||||
}
|
||||
|
||||
impl Binding {
|
||||
pub async fn udp(addr: SocketAddr) -> Result<Self> {
|
||||
let socket = UdpSocket::bind(addr).await?;
|
||||
Ok(Self::UDP(Arc::new(socket)))
|
||||
}
|
||||
|
||||
pub async fn tcp(addr: SocketAddr) -> Result<Self> {
|
||||
let socket = TcpListener::bind(addr).await?;
|
||||
Ok(Self::TCP(socket))
|
||||
}
|
||||
|
||||
pub fn name(&self) -> &str {
|
||||
match self {
|
||||
Binding::UDP(_) => "UDP",
|
||||
Binding::TCP(_) => "TCP",
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn connect(&mut self) -> Result<Connection> {
|
||||
match self {
|
||||
Self::UDP(socket) => {
|
||||
let mut buf = [0; 512];
|
||||
let (_, addr) = socket.recv_from(&mut buf).await?;
|
||||
Ok(Connection::UDP(socket.clone(), addr, buf))
|
||||
}
|
||||
Self::TCP(socket) => {
|
||||
let (stream, _) = socket.accept().await?;
|
||||
Ok(Connection::TCP(stream))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub enum Connection {
|
||||
UDP(Arc<UdpSocket>, SocketAddr, [u8; 512]),
|
||||
TCP(TcpStream),
|
||||
}
|
||||
|
||||
impl Connection {
|
||||
pub async fn read_packet(&mut self) -> Result<Packet> {
|
||||
let data = self.read().await?;
|
||||
let mut packet_buffer = PacketBuffer::new(data);
|
||||
|
||||
let packet = Packet::from_buffer(&mut packet_buffer)?;
|
||||
Ok(packet)
|
||||
}
|
||||
|
||||
pub async fn write_packet(self, mut packet: Packet) -> Result<()> {
|
||||
let mut packet_buffer = PacketBuffer::new(Vec::new());
|
||||
packet.write(&mut packet_buffer)?;
|
||||
|
||||
self.write(packet_buffer.buf).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn request_packet(&self, mut packet: Packet, dest: (IpAddr, u16)) -> Result<Packet> {
|
||||
let mut packet_buffer = PacketBuffer::new(Vec::new());
|
||||
packet.write(&mut packet_buffer)?;
|
||||
|
||||
let data = self.request(packet_buffer.buf, dest).await?;
|
||||
let mut packet_buffer = PacketBuffer::new(data);
|
||||
|
||||
let packet = Packet::from_buffer(&mut packet_buffer)?;
|
||||
Ok(packet)
|
||||
}
|
||||
|
||||
async fn read(&mut self) -> Result<Vec<u8>> {
|
||||
trace!("Reading DNS packet");
|
||||
match self {
|
||||
Self::UDP(_, _, src) => Ok(Vec::from(*src)),
|
||||
Self::TCP(stream) => {
|
||||
let size = stream.read_u16().await?;
|
||||
let mut buf = Vec::with_capacity(size as usize);
|
||||
stream.read_buf(&mut buf).await?;
|
||||
Ok(buf)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn write(self, mut buf: Vec<u8>) -> Result<()> {
|
||||
trace!("Returning DNS packet");
|
||||
match self {
|
||||
Self::UDP(socket, addr, _) => {
|
||||
if buf.len() > 512 {
|
||||
buf[2] = buf[2] | 0x03;
|
||||
socket.send_to(&buf[0..512], addr).await?;
|
||||
} else {
|
||||
socket.send_to(&buf, addr).await?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
Self::TCP(mut stream) => {
|
||||
stream.write_u16(buf.len() as u16).await?;
|
||||
stream.write(&buf[0..buf.len()]).await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn request(&self, buf: Vec<u8>, dest: (IpAddr, u16)) -> Result<Vec<u8>> {
|
||||
match self {
|
||||
Self::UDP(_socket, _addr, _src) => {
|
||||
let local_addr = "[::]:0".parse::<SocketAddr>()?;
|
||||
let socket = UdpSocket::bind(local_addr).await?;
|
||||
socket.send_to(&buf, dest).await?;
|
||||
|
||||
let mut buf = [0; 512];
|
||||
socket.recv_from(&mut buf).await?;
|
||||
|
||||
Ok(Vec::from(buf))
|
||||
}
|
||||
Self::TCP(_stream) => {
|
||||
let mut stream = TcpStream::connect(dest).await?;
|
||||
stream.write_u16((buf.len()) as u16).await?;
|
||||
stream.write_all(&buf[0..buf.len()]).await?;
|
||||
|
||||
stream.readable().await?;
|
||||
let size = stream.read_u16().await?;
|
||||
let mut buf = Vec::with_capacity(size as usize);
|
||||
stream.read_buf(&mut buf).await?;
|
||||
|
||||
Ok(buf)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,4 +0,0 @@
|
|||
mod binding;
|
||||
pub mod packet;
|
||||
mod resolver;
|
||||
pub mod server;
|
|
@ -1,227 +0,0 @@
|
|||
use crate::Result;
|
||||
|
||||
pub struct PacketBuffer {
|
||||
pub buf: Vec<u8>,
|
||||
pub pos: usize,
|
||||
pub size: usize,
|
||||
}
|
||||
|
||||
impl PacketBuffer {
|
||||
pub fn new(buf: Vec<u8>) -> Self {
|
||||
Self {
|
||||
size: buf.len(),
|
||||
buf,
|
||||
pos: 0,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn pos(&self) -> usize {
|
||||
self.pos
|
||||
}
|
||||
|
||||
pub fn step(&mut self, steps: usize) -> Result<()> {
|
||||
self.pos += steps;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn seek(&mut self, pos: usize) -> Result<()> {
|
||||
self.pos = pos;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn read(&mut self) -> Result<u8> {
|
||||
if self.pos >= self.size {
|
||||
return Err("Tried to read past end of buffer".into());
|
||||
}
|
||||
let res = self.buf[self.pos];
|
||||
self.pos += 1;
|
||||
Ok(res)
|
||||
}
|
||||
|
||||
pub fn get(&mut self, pos: usize) -> Result<u8> {
|
||||
if pos >= self.size {
|
||||
return Err("Tried to read past end of buffer".into());
|
||||
}
|
||||
Ok(self.buf[pos])
|
||||
}
|
||||
|
||||
pub fn get_range(&mut self, start: usize, len: usize) -> Result<&[u8]> {
|
||||
if start + len >= self.size {
|
||||
return Err("Tried to read past end of buffer".into());
|
||||
}
|
||||
Ok(&self.buf[start..start + len])
|
||||
}
|
||||
|
||||
pub fn read_u16(&mut self) -> Result<u16> {
|
||||
let res = ((self.read()? as u16) << 8) | (self.read()? as u16);
|
||||
|
||||
Ok(res)
|
||||
}
|
||||
|
||||
pub fn read_u32(&mut self) -> Result<u32> {
|
||||
let res = ((self.read()? as u32) << 24)
|
||||
| ((self.read()? as u32) << 16)
|
||||
| ((self.read()? as u32) << 8)
|
||||
| (self.read()? as u32);
|
||||
|
||||
Ok(res)
|
||||
}
|
||||
|
||||
pub fn read_qname(&mut self, outstr: &mut String) -> Result<()> {
|
||||
let mut pos = self.pos();
|
||||
let mut jumped = false;
|
||||
|
||||
let mut delim = "";
|
||||
let max_jumps = 5;
|
||||
let mut jumps_performed = 0;
|
||||
loop {
|
||||
// Dns Packets are untrusted data, so we need to be paranoid. Someone
|
||||
// can craft a packet with a cycle in the jump instructions. This guards
|
||||
// against such packets.
|
||||
if jumps_performed > max_jumps {
|
||||
return Err(format!("Limit of {max_jumps} jumps exceeded").into());
|
||||
}
|
||||
|
||||
let len = self.get(pos)?;
|
||||
|
||||
if (len & 0xC0) == 0xC0 {
|
||||
if !jumped {
|
||||
self.seek(pos + 2)?;
|
||||
}
|
||||
|
||||
let b2 = self.get(pos + 1)? as u16;
|
||||
let offset = (((len as u16) ^ 0xC0) << 8) | b2;
|
||||
pos = offset as usize;
|
||||
jumped = true;
|
||||
jumps_performed += 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
pos += 1;
|
||||
|
||||
if len == 0 {
|
||||
break;
|
||||
}
|
||||
|
||||
outstr.push_str(delim);
|
||||
|
||||
let str_buffer = self.get_range(pos, len as usize)?;
|
||||
outstr.push_str(&String::from_utf8_lossy(str_buffer).to_lowercase());
|
||||
|
||||
delim = ".";
|
||||
|
||||
pos += len as usize;
|
||||
}
|
||||
|
||||
if !jumped {
|
||||
self.seek(pos)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn read_string(&mut self, outstr: &mut String) -> Result<()> {
|
||||
let len = self.read()?;
|
||||
|
||||
self.read_string_n(outstr, len)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn read_string_n(&mut self, outstr: &mut String, len: u8) -> Result<()> {
|
||||
let mut pos = self.pos;
|
||||
|
||||
let str_buffer = self.get_range(pos, len as usize)?;
|
||||
|
||||
let mut i = 0;
|
||||
for b in str_buffer {
|
||||
let c = *b as char;
|
||||
if c == '\0' {
|
||||
break;
|
||||
}
|
||||
outstr.push(c);
|
||||
i += 1;
|
||||
}
|
||||
|
||||
pos += i;
|
||||
self.seek(pos)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn write(&mut self, val: u8) -> Result<()> {
|
||||
if self.size < self.pos {
|
||||
self.size = self.pos;
|
||||
}
|
||||
|
||||
if self.buf.len() <= self.size {
|
||||
self.buf.resize(self.size + 1, 0x00);
|
||||
}
|
||||
|
||||
self.buf[self.pos] = val;
|
||||
self.pos += 1;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn write_u8(&mut self, val: u8) -> Result<()> {
|
||||
self.write(val)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn write_u16(&mut self, val: u16) -> Result<()> {
|
||||
self.write((val >> 8) as u8)?;
|
||||
self.write((val & 0xFF) as u8)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn write_u32(&mut self, val: u32) -> Result<()> {
|
||||
self.write(((val >> 24) & 0xFF) as u8)?;
|
||||
self.write(((val >> 16) & 0xFF) as u8)?;
|
||||
self.write(((val >> 8) & 0xFF) as u8)?;
|
||||
self.write((val & 0xFF) as u8)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn write_qname(&mut self, qname: &str) -> Result<()> {
|
||||
for label in qname.split('.') {
|
||||
let len = label.len();
|
||||
|
||||
self.write_u8(len as u8)?;
|
||||
for b in label.as_bytes() {
|
||||
self.write_u8(*b)?;
|
||||
}
|
||||
}
|
||||
|
||||
if !qname.is_empty() {
|
||||
self.write_u8(0)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn write_string(&mut self, text: &str) -> Result<()> {
|
||||
for b in text.as_bytes() {
|
||||
self.write_u8(*b)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn set(&mut self, pos: usize, val: u8) -> Result<()> {
|
||||
self.buf[pos] = val;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn set_u16(&mut self, pos: usize, val: u16) -> Result<()> {
|
||||
self.set(pos, (val >> 8) as u8)?;
|
||||
self.set(pos + 1, (val & 0xFF) as u8)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
|
@ -1,102 +0,0 @@
|
|||
use super::{buffer::PacketBuffer, result::ResultCode};
|
||||
use crate::Result;
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct DnsHeader {
|
||||
pub id: u16, // 16 bits
|
||||
|
||||
pub recursion_desired: bool, // 1 bit
|
||||
pub truncated_message: bool, // 1 bit
|
||||
pub authoritative_answer: bool, // 1 bit
|
||||
pub opcode: u8, // 4 bits
|
||||
pub response: bool, // 1 bit
|
||||
|
||||
pub rescode: ResultCode, // 4 bits
|
||||
pub checking_disabled: bool, // 1 bit
|
||||
pub authed_data: bool, // 1 bit
|
||||
pub z: bool, // 1 bit
|
||||
pub recursion_available: bool, // 1 bit
|
||||
|
||||
pub questions: u16, // 16 bits
|
||||
pub answers: u16, // 16 bits
|
||||
pub authoritative_entries: u16, // 16 bits
|
||||
pub resource_entries: u16, // 16 bits
|
||||
}
|
||||
|
||||
impl DnsHeader {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
id: 0,
|
||||
|
||||
recursion_desired: false,
|
||||
truncated_message: false,
|
||||
authoritative_answer: false,
|
||||
opcode: 0,
|
||||
response: false,
|
||||
|
||||
rescode: ResultCode::NOERROR,
|
||||
checking_disabled: false,
|
||||
authed_data: false,
|
||||
z: false,
|
||||
recursion_available: false,
|
||||
|
||||
questions: 0,
|
||||
answers: 0,
|
||||
authoritative_entries: 0,
|
||||
resource_entries: 0,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn read(&mut self, buffer: &mut PacketBuffer) -> Result<()> {
|
||||
self.id = buffer.read_u16()?;
|
||||
let flags = buffer.read_u16()?;
|
||||
let a = (flags >> 8) as u8;
|
||||
let b = (flags & 0xFF) as u8;
|
||||
self.recursion_desired = (a & (1 << 0)) > 0;
|
||||
self.truncated_message = (a & (1 << 1)) > 0;
|
||||
self.authoritative_answer = (a & (1 << 2)) > 0;
|
||||
self.opcode = (a >> 3) & 0x0F;
|
||||
self.response = (a & (1 << 7)) > 0;
|
||||
|
||||
self.rescode = ResultCode::from_num(b & 0x0F);
|
||||
self.checking_disabled = (b & (1 << 4)) > 0;
|
||||
self.authed_data = (b & (1 << 5)) > 0;
|
||||
self.z = (b & (1 << 6)) > 0;
|
||||
self.recursion_available = (b & (1 << 7)) > 0;
|
||||
|
||||
self.questions = buffer.read_u16()?;
|
||||
self.answers = buffer.read_u16()?;
|
||||
self.authoritative_entries = buffer.read_u16()?;
|
||||
self.resource_entries = buffer.read_u16()?;
|
||||
|
||||
// Return the constant header size
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn write(&self, buffer: &mut PacketBuffer) -> Result<()> {
|
||||
buffer.write_u16(self.id)?;
|
||||
|
||||
buffer.write_u8(
|
||||
(self.recursion_desired as u8)
|
||||
| ((self.truncated_message as u8) << 1)
|
||||
| ((self.authoritative_answer as u8) << 2)
|
||||
| (self.opcode << 3)
|
||||
| ((self.response as u8) << 7),
|
||||
)?;
|
||||
|
||||
buffer.write_u8(
|
||||
(self.rescode as u8)
|
||||
| ((self.checking_disabled as u8) << 4)
|
||||
| ((self.authed_data as u8) << 5)
|
||||
| ((self.z as u8) << 6)
|
||||
| ((self.recursion_available as u8) << 7),
|
||||
)?;
|
||||
|
||||
buffer.write_u16(self.questions)?;
|
||||
buffer.write_u16(self.answers)?;
|
||||
buffer.write_u16(self.authoritative_entries)?;
|
||||
buffer.write_u16(self.resource_entries)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
|
@ -1,128 +0,0 @@
|
|||
use std::net::IpAddr;
|
||||
|
||||
use self::{
|
||||
buffer::PacketBuffer, header::DnsHeader, query::QueryType, question::DnsQuestion,
|
||||
record::DnsRecord,
|
||||
};
|
||||
use crate::Result;
|
||||
|
||||
pub mod buffer;
|
||||
pub mod header;
|
||||
pub mod query;
|
||||
pub mod question;
|
||||
pub mod record;
|
||||
pub mod result;
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct Packet {
|
||||
pub header: DnsHeader,
|
||||
pub questions: Vec<DnsQuestion>,
|
||||
pub answers: Vec<DnsRecord>,
|
||||
pub authorities: Vec<DnsRecord>,
|
||||
pub resources: Vec<DnsRecord>,
|
||||
}
|
||||
|
||||
impl Packet {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
header: DnsHeader::new(),
|
||||
questions: Vec::new(),
|
||||
answers: Vec::new(),
|
||||
authorities: Vec::new(),
|
||||
resources: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_buffer(buffer: &mut PacketBuffer) -> Result<Self> {
|
||||
let mut result = Self::new();
|
||||
result.header.read(buffer)?;
|
||||
|
||||
for _ in 0..result.header.questions {
|
||||
let mut question = DnsQuestion::new("".to_string(), QueryType::UNKNOWN(0));
|
||||
question.read(buffer)?;
|
||||
result.questions.push(question);
|
||||
}
|
||||
|
||||
for _ in 0..result.header.answers {
|
||||
let rec = DnsRecord::read(buffer)?;
|
||||
result.answers.push(rec);
|
||||
}
|
||||
for _ in 0..result.header.authoritative_entries {
|
||||
let rec = DnsRecord::read(buffer)?;
|
||||
result.authorities.push(rec);
|
||||
}
|
||||
for _ in 0..result.header.resource_entries {
|
||||
let rec = DnsRecord::read(buffer)?;
|
||||
result.resources.push(rec);
|
||||
}
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
pub fn write(&mut self, buffer: &mut PacketBuffer) -> Result<()> {
|
||||
self.header.questions = self.questions.len() as u16;
|
||||
self.header.answers = self.answers.len() as u16;
|
||||
self.header.authoritative_entries = self.authorities.len() as u16;
|
||||
self.header.resource_entries = self.resources.len() as u16;
|
||||
|
||||
self.header.write(buffer)?;
|
||||
|
||||
for question in &self.questions {
|
||||
question.write(buffer)?;
|
||||
}
|
||||
for rec in &self.answers {
|
||||
rec.write(buffer)?;
|
||||
}
|
||||
for rec in &self.authorities {
|
||||
rec.write(buffer)?;
|
||||
}
|
||||
for rec in &self.resources {
|
||||
rec.write(buffer)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn get_random_a(&self) -> Option<IpAddr> {
|
||||
self.answers
|
||||
.iter()
|
||||
.filter_map(|record| match record {
|
||||
DnsRecord::A { addr, .. } => Some(IpAddr::V4(*addr)),
|
||||
DnsRecord::AAAA { addr, .. } => Some(IpAddr::V6(*addr)),
|
||||
_ => None,
|
||||
})
|
||||
.next()
|
||||
}
|
||||
|
||||
fn get_ns<'a>(&'a self, qname: &'a str) -> impl Iterator<Item = (&'a str, &'a str)> {
|
||||
self.authorities
|
||||
.iter()
|
||||
.filter_map(|record| match record {
|
||||
DnsRecord::NS { domain, host, .. } => Some((domain.as_str(), host.as_str())),
|
||||
_ => None,
|
||||
})
|
||||
.filter(move |(domain, _)| qname.ends_with(*domain))
|
||||
}
|
||||
|
||||
pub fn get_resolved_ns(&self, qname: &str) -> Option<IpAddr> {
|
||||
self.get_ns(qname)
|
||||
.flat_map(|(_, host)| {
|
||||
self.resources
|
||||
.iter()
|
||||
.filter_map(move |record| match record {
|
||||
DnsRecord::A { domain, addr, .. } if domain == host => {
|
||||
Some(IpAddr::V4(*addr))
|
||||
}
|
||||
DnsRecord::AAAA { domain, addr, .. } if domain == host => {
|
||||
Some(IpAddr::V6(*addr))
|
||||
}
|
||||
_ => None,
|
||||
})
|
||||
})
|
||||
.next()
|
||||
}
|
||||
|
||||
pub fn get_unresolved_ns<'a>(&'a self, qname: &'a str) -> Option<&'a str> {
|
||||
self.get_ns(qname).map(|(_, host)| host).next()
|
||||
}
|
||||
}
|
|
@ -1,78 +0,0 @@
|
|||
#[derive(PartialEq, Eq, Debug, Clone, Hash, Copy)]
|
||||
pub enum QueryType {
|
||||
UNKNOWN(u16),
|
||||
A, // 1
|
||||
NS, // 2
|
||||
CNAME, // 5
|
||||
SOA, // 6
|
||||
PTR, // 12
|
||||
MX, // 15
|
||||
TXT, // 16
|
||||
AAAA, // 28
|
||||
SRV, // 33
|
||||
OPT, // 41
|
||||
CAA, // 257
|
||||
AR, // 1000
|
||||
AAAAR, // 1001
|
||||
}
|
||||
|
||||
impl QueryType {
|
||||
pub fn to_num(&self) -> u16 {
|
||||
match *self {
|
||||
Self::UNKNOWN(x) => x,
|
||||
Self::A => 1,
|
||||
Self::NS => 2,
|
||||
Self::CNAME => 5,
|
||||
Self::SOA => 6,
|
||||
Self::PTR => 12,
|
||||
Self::MX => 15,
|
||||
Self::TXT => 16,
|
||||
Self::AAAA => 28,
|
||||
Self::SRV => 33,
|
||||
Self::OPT => 41,
|
||||
Self::CAA => 257,
|
||||
Self::AR => 1000,
|
||||
Self::AAAAR => 1001,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_num(num: u16) -> Self {
|
||||
match num {
|
||||
1 => Self::A,
|
||||
2 => Self::NS,
|
||||
5 => Self::CNAME,
|
||||
6 => Self::SOA,
|
||||
12 => Self::PTR,
|
||||
15 => Self::MX,
|
||||
16 => Self::TXT,
|
||||
28 => Self::AAAA,
|
||||
33 => Self::SRV,
|
||||
41 => Self::OPT,
|
||||
257 => Self::CAA,
|
||||
1000 => Self::AR,
|
||||
1001 => Self::AAAAR,
|
||||
_ => Self::UNKNOWN(num),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn allowed_actions(&self) -> (bool, bool) {
|
||||
// 0. duplicates allowed
|
||||
// 1. allowed to be created by database
|
||||
match self {
|
||||
QueryType::UNKNOWN(_) => (false, false),
|
||||
QueryType::A => (true, true),
|
||||
QueryType::NS => (false, true),
|
||||
QueryType::CNAME => (false, true),
|
||||
QueryType::SOA => (false, false),
|
||||
QueryType::PTR => (false, true),
|
||||
QueryType::MX => (false, true),
|
||||
QueryType::TXT => (true, true),
|
||||
QueryType::AAAA => (true, true),
|
||||
QueryType::SRV => (false, true),
|
||||
QueryType::OPT => (false, false),
|
||||
QueryType::CAA => (false, true),
|
||||
QueryType::AR => (false, true),
|
||||
QueryType::AAAAR => (false, true),
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,31 +0,0 @@
|
|||
use super::{buffer::PacketBuffer, query::QueryType, Result};
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
||||
pub struct DnsQuestion {
|
||||
pub name: String,
|
||||
pub qtype: QueryType,
|
||||
}
|
||||
|
||||
impl DnsQuestion {
|
||||
pub fn new(name: String, qtype: QueryType) -> Self {
|
||||
Self { name, qtype }
|
||||
}
|
||||
|
||||
pub fn read(&mut self, buffer: &mut PacketBuffer) -> Result<()> {
|
||||
buffer.read_qname(&mut self.name)?;
|
||||
self.qtype = QueryType::from_num(buffer.read_u16()?); // qtype
|
||||
let _ = buffer.read_u16()?; // class
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn write(&self, buffer: &mut PacketBuffer) -> Result<()> {
|
||||
buffer.write_qname(&self.name)?;
|
||||
|
||||
let typenum = self.qtype.to_num();
|
||||
buffer.write_u16(typenum)?;
|
||||
buffer.write_u16(1)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
|
@ -1,544 +0,0 @@
|
|||
use std::net::{Ipv4Addr, Ipv6Addr};
|
||||
|
||||
use rand::RngCore;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tracing::{trace, warn};
|
||||
|
||||
use super::{buffer::PacketBuffer, query::QueryType, Result};
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)]
|
||||
pub enum DnsRecord {
|
||||
UNKNOWN {
|
||||
domain: String,
|
||||
qtype: u16,
|
||||
data_len: u16,
|
||||
ttl: u32,
|
||||
}, // 0
|
||||
A {
|
||||
domain: String,
|
||||
addr: Ipv4Addr,
|
||||
ttl: u32,
|
||||
}, // 1
|
||||
NS {
|
||||
domain: String,
|
||||
host: String,
|
||||
ttl: u32,
|
||||
}, // 2
|
||||
CNAME {
|
||||
domain: String,
|
||||
host: String,
|
||||
ttl: u32,
|
||||
}, // 5
|
||||
SOA {
|
||||
domain: String,
|
||||
mname: String,
|
||||
nname: String,
|
||||
serial: u32,
|
||||
refresh: u32,
|
||||
retry: u32,
|
||||
expire: u32,
|
||||
minimum: u32,
|
||||
ttl: u32,
|
||||
}, // 6
|
||||
PTR {
|
||||
domain: String,
|
||||
pointer: String,
|
||||
ttl: u32,
|
||||
}, // 12
|
||||
MX {
|
||||
domain: String,
|
||||
priority: u16,
|
||||
host: String,
|
||||
ttl: u32,
|
||||
}, // 15
|
||||
TXT {
|
||||
domain: String,
|
||||
text: Vec<String>,
|
||||
ttl: u32,
|
||||
}, //16
|
||||
AAAA {
|
||||
domain: String,
|
||||
addr: Ipv6Addr,
|
||||
ttl: u32,
|
||||
}, // 28
|
||||
SRV {
|
||||
domain: String,
|
||||
priority: u16,
|
||||
weight: u16,
|
||||
port: u16,
|
||||
target: String,
|
||||
ttl: u32,
|
||||
}, // 33
|
||||
CAA {
|
||||
domain: String,
|
||||
flags: u8,
|
||||
length: u8,
|
||||
tag: String,
|
||||
value: String,
|
||||
ttl: u32,
|
||||
}, // 257
|
||||
AR {
|
||||
domain: String,
|
||||
ttl: u32,
|
||||
},
|
||||
AAAAR {
|
||||
domain: String,
|
||||
ttl: u32,
|
||||
},
|
||||
}
|
||||
|
||||
impl DnsRecord {
|
||||
pub fn read(buffer: &mut PacketBuffer) -> Result<Self> {
|
||||
let mut domain = String::new();
|
||||
buffer.read_qname(&mut domain)?;
|
||||
|
||||
let qtype_num = buffer.read_u16()?;
|
||||
let qtype = QueryType::from_num(qtype_num);
|
||||
let _ = buffer.read_u16()?;
|
||||
let ttl = buffer.read_u32()?;
|
||||
let data_len = buffer.read_u16()?;
|
||||
|
||||
trace!("Reading DNS Record TYPE: {:?}", qtype);
|
||||
|
||||
let header_pos = buffer.pos();
|
||||
|
||||
match qtype {
|
||||
QueryType::A => {
|
||||
let raw_addr = buffer.read_u32()?;
|
||||
let addr = Ipv4Addr::new(
|
||||
((raw_addr >> 24) & 0xFF) as u8,
|
||||
((raw_addr >> 16) & 0xFF) as u8,
|
||||
((raw_addr >> 8) & 0xFF) as u8,
|
||||
(raw_addr & 0xFF) as u8,
|
||||
);
|
||||
|
||||
Ok(Self::A { domain, addr, ttl })
|
||||
}
|
||||
QueryType::AAAA => {
|
||||
let raw_addr1 = buffer.read_u32()?;
|
||||
let raw_addr2 = buffer.read_u32()?;
|
||||
let raw_addr3 = buffer.read_u32()?;
|
||||
let raw_addr4 = buffer.read_u32()?;
|
||||
let addr = Ipv6Addr::new(
|
||||
((raw_addr1 >> 16) & 0xFFFF) as u16,
|
||||
(raw_addr1 & 0xFFFF) as u16,
|
||||
((raw_addr2 >> 16) & 0xFFFF) as u16,
|
||||
(raw_addr2 & 0xFFFF) as u16,
|
||||
((raw_addr3 >> 16) & 0xFFFF) as u16,
|
||||
(raw_addr3 & 0xFFFF) as u16,
|
||||
((raw_addr4 >> 16) & 0xFFFF) as u16,
|
||||
(raw_addr4 & 0xFFFF) as u16,
|
||||
);
|
||||
|
||||
Ok(Self::AAAA { domain, addr, ttl })
|
||||
}
|
||||
QueryType::NS => {
|
||||
let mut ns = String::new();
|
||||
buffer.read_qname(&mut ns)?;
|
||||
|
||||
Ok(Self::NS {
|
||||
domain,
|
||||
host: ns,
|
||||
ttl,
|
||||
})
|
||||
}
|
||||
QueryType::CNAME => {
|
||||
let mut cname = String::new();
|
||||
buffer.read_qname(&mut cname)?;
|
||||
|
||||
Ok(Self::CNAME {
|
||||
domain,
|
||||
host: cname,
|
||||
ttl,
|
||||
})
|
||||
}
|
||||
QueryType::SOA => {
|
||||
let mut mname = String::new();
|
||||
buffer.read_qname(&mut mname)?;
|
||||
|
||||
let mut nname = String::new();
|
||||
buffer.read_qname(&mut nname)?;
|
||||
|
||||
let serial = buffer.read_u32()?;
|
||||
let refresh = buffer.read_u32()?;
|
||||
let retry = buffer.read_u32()?;
|
||||
let expire = buffer.read_u32()?;
|
||||
let minimum = buffer.read_u32()?;
|
||||
|
||||
Ok(Self::SOA {
|
||||
domain,
|
||||
mname,
|
||||
nname,
|
||||
serial,
|
||||
refresh,
|
||||
retry,
|
||||
expire,
|
||||
minimum,
|
||||
ttl,
|
||||
})
|
||||
}
|
||||
QueryType::PTR => {
|
||||
let mut pointer = String::new();
|
||||
buffer.read_qname(&mut pointer)?;
|
||||
|
||||
Ok(Self::PTR {
|
||||
domain,
|
||||
pointer,
|
||||
ttl,
|
||||
})
|
||||
}
|
||||
QueryType::MX => {
|
||||
let priority = buffer.read_u16()?;
|
||||
let mut mx = String::new();
|
||||
buffer.read_qname(&mut mx)?;
|
||||
|
||||
Ok(Self::MX {
|
||||
domain,
|
||||
priority,
|
||||
host: mx,
|
||||
ttl,
|
||||
})
|
||||
}
|
||||
QueryType::TXT => {
|
||||
let mut text = Vec::new();
|
||||
|
||||
loop {
|
||||
let mut s = String::new();
|
||||
buffer.read_string(&mut s)?;
|
||||
|
||||
if s.len() == 0 {
|
||||
break;
|
||||
} else {
|
||||
text.push(s);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(Self::TXT { domain, text, ttl })
|
||||
}
|
||||
QueryType::SRV => {
|
||||
let priority = buffer.read_u16()?;
|
||||
let weight = buffer.read_u16()?;
|
||||
let port = buffer.read_u16()?;
|
||||
|
||||
let mut target = String::new();
|
||||
buffer.read_qname(&mut target)?;
|
||||
|
||||
Ok(Self::SRV {
|
||||
domain,
|
||||
priority,
|
||||
weight,
|
||||
port,
|
||||
target,
|
||||
ttl,
|
||||
})
|
||||
}
|
||||
QueryType::CAA => {
|
||||
let flags = buffer.read()?;
|
||||
let length = buffer.read()?;
|
||||
|
||||
let mut tag = String::new();
|
||||
buffer.read_string_n(&mut tag, length)?;
|
||||
|
||||
let value_len = (data_len as usize) + header_pos - buffer.pos;
|
||||
let mut value = String::new();
|
||||
buffer.read_string_n(&mut value, value_len as u8)?;
|
||||
|
||||
Ok(Self::CAA {
|
||||
domain,
|
||||
flags,
|
||||
length,
|
||||
tag,
|
||||
value,
|
||||
ttl,
|
||||
})
|
||||
}
|
||||
QueryType::UNKNOWN(_) | _ => {
|
||||
buffer.step(data_len as usize)?;
|
||||
|
||||
Ok(Self::UNKNOWN {
|
||||
domain,
|
||||
qtype: qtype_num,
|
||||
data_len,
|
||||
ttl,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn write(&self, buffer: &mut PacketBuffer) -> Result<usize> {
|
||||
let start_pos = buffer.pos();
|
||||
|
||||
trace!("Writing DNS Record {:?}", self);
|
||||
|
||||
match *self {
|
||||
Self::A {
|
||||
ref domain,
|
||||
ref addr,
|
||||
ttl,
|
||||
} => {
|
||||
buffer.write_qname(domain)?;
|
||||
buffer.write_u16(QueryType::A.to_num())?;
|
||||
buffer.write_u16(1)?;
|
||||
buffer.write_u32(ttl)?;
|
||||
buffer.write_u16(4)?;
|
||||
|
||||
let octets = addr.octets();
|
||||
buffer.write_u8(octets[0])?;
|
||||
buffer.write_u8(octets[1])?;
|
||||
buffer.write_u8(octets[2])?;
|
||||
buffer.write_u8(octets[3])?;
|
||||
}
|
||||
Self::NS {
|
||||
ref domain,
|
||||
ref host,
|
||||
ttl,
|
||||
} => {
|
||||
buffer.write_qname(domain)?;
|
||||
buffer.write_u16(QueryType::NS.to_num())?;
|
||||
buffer.write_u16(1)?;
|
||||
buffer.write_u32(ttl)?;
|
||||
|
||||
let pos = buffer.pos();
|
||||
buffer.write_u16(0)?;
|
||||
|
||||
buffer.write_qname(host)?;
|
||||
|
||||
let size = buffer.pos() - (pos + 2);
|
||||
buffer.set_u16(pos, size as u16)?;
|
||||
}
|
||||
Self::CNAME {
|
||||
ref domain,
|
||||
ref host,
|
||||
ttl,
|
||||
} => {
|
||||
buffer.write_qname(domain)?;
|
||||
buffer.write_u16(QueryType::CNAME.to_num())?;
|
||||
buffer.write_u16(1)?;
|
||||
buffer.write_u32(ttl)?;
|
||||
|
||||
let pos = buffer.pos();
|
||||
buffer.write_u16(0)?;
|
||||
|
||||
buffer.write_qname(host)?;
|
||||
|
||||
let size = buffer.pos() - (pos + 2);
|
||||
buffer.set_u16(pos, size as u16)?;
|
||||
}
|
||||
Self::SOA {
|
||||
ref domain,
|
||||
ref mname,
|
||||
ref nname,
|
||||
serial,
|
||||
refresh,
|
||||
retry,
|
||||
expire,
|
||||
minimum,
|
||||
ttl,
|
||||
} => {
|
||||
buffer.write_qname(domain)?;
|
||||
buffer.write_u16(QueryType::SOA.to_num())?;
|
||||
buffer.write_u16(1)?;
|
||||
buffer.write_u32(ttl)?;
|
||||
|
||||
let pos = buffer.pos();
|
||||
buffer.write_u16(0)?;
|
||||
|
||||
buffer.write_qname(mname)?;
|
||||
buffer.write_qname(nname)?;
|
||||
buffer.write_u32(serial)?;
|
||||
buffer.write_u32(refresh)?;
|
||||
buffer.write_u32(retry)?;
|
||||
buffer.write_u32(expire)?;
|
||||
buffer.write_u32(minimum)?;
|
||||
|
||||
let size = buffer.pos() - (pos + 2);
|
||||
buffer.set_u16(pos, size as u16)?;
|
||||
}
|
||||
Self::PTR {
|
||||
ref domain,
|
||||
ref pointer,
|
||||
ttl,
|
||||
} => {
|
||||
buffer.write_qname(domain)?;
|
||||
buffer.write_u16(QueryType::NS.to_num())?;
|
||||
buffer.write_u16(1)?;
|
||||
buffer.write_u32(ttl)?;
|
||||
|
||||
let pos = buffer.pos();
|
||||
buffer.write_u16(0)?;
|
||||
|
||||
buffer.write_qname(&pointer)?;
|
||||
|
||||
let size = buffer.pos() - (pos + 2);
|
||||
buffer.set_u16(pos, size as u16)?;
|
||||
}
|
||||
Self::MX {
|
||||
ref domain,
|
||||
priority,
|
||||
ref host,
|
||||
ttl,
|
||||
} => {
|
||||
buffer.write_qname(domain)?;
|
||||
buffer.write_u16(QueryType::MX.to_num())?;
|
||||
buffer.write_u16(1)?;
|
||||
buffer.write_u32(ttl)?;
|
||||
|
||||
let pos = buffer.pos();
|
||||
buffer.write_u16(0)?;
|
||||
|
||||
buffer.write_u16(priority)?;
|
||||
buffer.write_qname(host)?;
|
||||
|
||||
let size = buffer.pos() - (pos + 2);
|
||||
buffer.set_u16(pos, size as u16)?;
|
||||
}
|
||||
Self::TXT {
|
||||
ref domain,
|
||||
ref text,
|
||||
ttl,
|
||||
} => {
|
||||
buffer.write_qname(&domain)?;
|
||||
buffer.write_u16(QueryType::TXT.to_num())?;
|
||||
buffer.write_u16(1)?;
|
||||
buffer.write_u32(ttl)?;
|
||||
|
||||
let pos = buffer.pos();
|
||||
buffer.write_u16(0)?;
|
||||
|
||||
if text.is_empty() {
|
||||
return Ok(buffer.pos() - start_pos);
|
||||
}
|
||||
|
||||
for s in text {
|
||||
buffer.write_u8(s.len() as u8)?;
|
||||
buffer.write_string(&s)?;
|
||||
}
|
||||
let size = buffer.pos() - (pos + 2);
|
||||
buffer.set_u16(pos, size as u16)?;
|
||||
}
|
||||
Self::AAAA {
|
||||
ref domain,
|
||||
ref addr,
|
||||
ttl,
|
||||
} => {
|
||||
buffer.write_qname(domain)?;
|
||||
buffer.write_u16(QueryType::AAAA.to_num())?;
|
||||
buffer.write_u16(1)?;
|
||||
buffer.write_u32(ttl)?;
|
||||
buffer.write_u16(16)?;
|
||||
|
||||
for octet in &addr.segments() {
|
||||
buffer.write_u16(*octet)?;
|
||||
}
|
||||
}
|
||||
Self::SRV {
|
||||
ref domain,
|
||||
priority,
|
||||
weight,
|
||||
port,
|
||||
ref target,
|
||||
ttl,
|
||||
} => {
|
||||
buffer.write_qname(domain)?;
|
||||
buffer.write_u16(QueryType::SRV.to_num())?;
|
||||
buffer.write_u16(1)?;
|
||||
buffer.write_u32(ttl)?;
|
||||
|
||||
let pos = buffer.pos();
|
||||
buffer.write_u16(0)?;
|
||||
|
||||
buffer.write_u16(priority)?;
|
||||
buffer.write_u16(weight)?;
|
||||
buffer.write_u16(port)?;
|
||||
buffer.write_qname(target)?;
|
||||
|
||||
let size = buffer.pos() - (pos + 2);
|
||||
buffer.set_u16(pos, size as u16)?;
|
||||
}
|
||||
Self::CAA {
|
||||
ref domain,
|
||||
flags,
|
||||
length,
|
||||
ref tag,
|
||||
ref value,
|
||||
ttl,
|
||||
} => {
|
||||
buffer.write_qname(domain)?;
|
||||
buffer.write_u16(QueryType::CAA.to_num())?;
|
||||
buffer.write_u16(1)?;
|
||||
buffer.write_u32(ttl)?;
|
||||
|
||||
let pos = buffer.pos();
|
||||
buffer.write_u16(0)?;
|
||||
|
||||
buffer.write_u8(flags)?;
|
||||
buffer.write_u8(length)?;
|
||||
buffer.write_string(tag)?;
|
||||
buffer.write_string(value)?;
|
||||
|
||||
let size = buffer.pos() - (pos + 2);
|
||||
buffer.set_u16(pos, size as u16)?;
|
||||
}
|
||||
Self::AR { ref domain, ttl } => {
|
||||
buffer.write_qname(domain)?;
|
||||
buffer.write_u16(QueryType::A.to_num())?;
|
||||
buffer.write_u16(1)?;
|
||||
buffer.write_u32(ttl)?;
|
||||
buffer.write_u16(4)?;
|
||||
|
||||
let mut rand = rand::thread_rng();
|
||||
buffer.write_u32(rand.next_u32())?;
|
||||
}
|
||||
Self::AAAAR { ref domain, ttl } => {
|
||||
buffer.write_qname(domain)?;
|
||||
buffer.write_u16(QueryType::A.to_num())?;
|
||||
buffer.write_u16(1)?;
|
||||
buffer.write_u32(ttl)?;
|
||||
buffer.write_u16(4)?;
|
||||
|
||||
let mut rand = rand::thread_rng();
|
||||
buffer.write_u32(rand.next_u32())?;
|
||||
buffer.write_u32(rand.next_u32())?;
|
||||
buffer.write_u32(rand.next_u32())?;
|
||||
buffer.write_u32(rand.next_u32())?;
|
||||
}
|
||||
Self::UNKNOWN { .. } => {
|
||||
warn!("Skipping record: {self:?}");
|
||||
}
|
||||
}
|
||||
|
||||
Ok(buffer.pos() - start_pos)
|
||||
}
|
||||
|
||||
pub fn get_domain(&self) -> String {
|
||||
self.get_shared_domain().0
|
||||
}
|
||||
|
||||
pub fn get_qtype(&self) -> QueryType {
|
||||
self.get_shared_domain().1
|
||||
}
|
||||
|
||||
pub fn get_ttl(&self) -> u32 {
|
||||
self.get_shared_domain().2
|
||||
}
|
||||
|
||||
fn get_shared_domain(&self) -> (String, QueryType, u32) {
|
||||
match self {
|
||||
DnsRecord::UNKNOWN {
|
||||
domain, ttl, qtype, ..
|
||||
} => (domain.clone(), QueryType::UNKNOWN(*qtype), *ttl),
|
||||
DnsRecord::AAAA { domain, ttl, .. } => (domain.clone(), QueryType::AAAA, *ttl),
|
||||
DnsRecord::A { domain, ttl, .. } => (domain.clone(), QueryType::A, *ttl),
|
||||
DnsRecord::NS { domain, ttl, .. } => (domain.clone(), QueryType::NS, *ttl),
|
||||
DnsRecord::CNAME { domain, ttl, .. } => (domain.clone(), QueryType::CNAME, *ttl),
|
||||
DnsRecord::SOA { domain, ttl, .. } => (domain.clone(), QueryType::SOA, *ttl),
|
||||
DnsRecord::PTR { domain, ttl, .. } => (domain.clone(), QueryType::PTR, *ttl),
|
||||
DnsRecord::MX { domain, ttl, .. } => (domain.clone(), QueryType::MX, *ttl),
|
||||
DnsRecord::TXT { domain, ttl, .. } => (domain.clone(), QueryType::TXT, *ttl),
|
||||
DnsRecord::SRV { domain, ttl, .. } => (domain.clone(), QueryType::SRV, *ttl),
|
||||
DnsRecord::CAA { domain, ttl, .. } => (domain.clone(), QueryType::CAA, *ttl),
|
||||
DnsRecord::AR { domain, ttl, .. } => (domain.clone(), QueryType::AR, *ttl),
|
||||
DnsRecord::AAAAR { domain, ttl, .. } => (domain.clone(), QueryType::AAAAR, *ttl),
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,22 +0,0 @@
|
|||
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
|
||||
pub enum ResultCode {
|
||||
NOERROR = 0,
|
||||
FORMERR = 1,
|
||||
SERVFAIL = 2,
|
||||
NXDOMAIN = 3,
|
||||
NOTIMP = 4,
|
||||
REFUSED = 5,
|
||||
}
|
||||
|
||||
impl ResultCode {
|
||||
pub fn from_num(num: u8) -> Self {
|
||||
match num {
|
||||
1 => Self::FORMERR,
|
||||
2 => Self::SERVFAIL,
|
||||
3 => Self::NXDOMAIN,
|
||||
4 => Self::NOTIMP,
|
||||
5 => Self::REFUSED,
|
||||
0 | _ => Self::NOERROR,
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,230 +0,0 @@
|
|||
use super::binding::Connection;
|
||||
use super::packet::{query::QueryType, question::DnsQuestion, result::ResultCode, Packet};
|
||||
use crate::Result;
|
||||
use crate::{config::Config, database::Database, get_time};
|
||||
use async_recursion::async_recursion;
|
||||
use moka::future::Cache;
|
||||
use std::{net::IpAddr, sync::Arc, time::Duration};
|
||||
use tracing::{error, trace};
|
||||
|
||||
pub struct Resolver {
|
||||
request_id: u16,
|
||||
connection: Connection,
|
||||
config: Arc<Config>,
|
||||
database: Arc<Database>,
|
||||
cache: Cache<DnsQuestion, (Packet, u64)>,
|
||||
}
|
||||
|
||||
impl Resolver {
|
||||
pub fn new(
|
||||
request_id: u16,
|
||||
connection: Connection,
|
||||
config: Arc<Config>,
|
||||
database: Arc<Database>,
|
||||
cache: Cache<DnsQuestion, (Packet, u64)>,
|
||||
) -> Self {
|
||||
Self {
|
||||
request_id,
|
||||
connection,
|
||||
config,
|
||||
database,
|
||||
cache,
|
||||
}
|
||||
}
|
||||
|
||||
async fn lookup_database(&self, question: &DnsQuestion) -> Option<Packet> {
|
||||
let records = match self
|
||||
.database
|
||||
.get_records(&question.name, question.qtype)
|
||||
.await
|
||||
{
|
||||
Ok(record) => record,
|
||||
Err(err) => {
|
||||
error!("{err}");
|
||||
return None;
|
||||
}
|
||||
};
|
||||
|
||||
if records.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let mut packet = Packet::new();
|
||||
|
||||
packet.header.id = self.request_id;
|
||||
packet.header.questions = 1;
|
||||
packet.header.answers = records.len() as u16;
|
||||
packet.header.recursion_desired = true;
|
||||
packet
|
||||
.questions
|
||||
.push(DnsQuestion::new(question.name.to_string(), question.qtype));
|
||||
|
||||
for record in records {
|
||||
packet.answers.push(record);
|
||||
}
|
||||
|
||||
trace!(
|
||||
"Found stored value for {:?} {}",
|
||||
question.qtype,
|
||||
question.name
|
||||
);
|
||||
|
||||
Some(packet)
|
||||
}
|
||||
|
||||
async fn lookup_cache(&self, question: &DnsQuestion) -> Option<Packet> {
|
||||
let Some((packet, date)) = self.cache.get(&question) else {
|
||||
return None
|
||||
};
|
||||
|
||||
let now = get_time();
|
||||
let diff = Duration::from_millis(now - date).as_secs() as u32;
|
||||
|
||||
for answer in &packet.answers {
|
||||
let ttl = answer.get_ttl();
|
||||
if diff > ttl {
|
||||
self.cache.invalidate(&question).await;
|
||||
return None;
|
||||
}
|
||||
}
|
||||
|
||||
trace!(
|
||||
"Found cached value for {:?} {}",
|
||||
question.qtype,
|
||||
question.name
|
||||
);
|
||||
|
||||
Some(packet)
|
||||
}
|
||||
|
||||
async fn lookup_fallback(&self, question: &DnsQuestion, server: (IpAddr, u16)) -> Packet {
|
||||
let mut packet = Packet::new();
|
||||
|
||||
packet.header.id = self.request_id;
|
||||
packet.header.questions = 1;
|
||||
packet.header.recursion_desired = true;
|
||||
packet
|
||||
.questions
|
||||
.push(DnsQuestion::new(question.name.to_string(), question.qtype));
|
||||
|
||||
let packet = match self.connection.request_packet(packet, server).await {
|
||||
Ok(packet) => packet,
|
||||
Err(e) => {
|
||||
error!("Failed to complete nameserver request: {e}");
|
||||
let mut packet = Packet::new();
|
||||
packet.header.rescode = ResultCode::SERVFAIL;
|
||||
packet
|
||||
}
|
||||
};
|
||||
|
||||
packet
|
||||
}
|
||||
|
||||
async fn lookup(&self, question: &DnsQuestion, server: (IpAddr, u16)) -> Packet {
|
||||
if let Some(packet) = self.lookup_cache(question).await {
|
||||
return packet;
|
||||
};
|
||||
|
||||
if let Some(packet) = self.lookup_database(question).await {
|
||||
return packet;
|
||||
};
|
||||
|
||||
trace!(
|
||||
"Attempting lookup of {:?} {} with ns {}",
|
||||
question.qtype,
|
||||
question.name,
|
||||
server.0
|
||||
);
|
||||
|
||||
self.lookup_fallback(question, server).await
|
||||
}
|
||||
|
||||
#[async_recursion]
|
||||
async fn recursive_lookup(&mut self, qname: &str, qtype: QueryType) -> Packet {
|
||||
let question = DnsQuestion::new(qname.to_string(), qtype);
|
||||
let mut ns = self.config.dns_fallback.clone();
|
||||
|
||||
loop {
|
||||
let ns_copy = ns;
|
||||
|
||||
let server = (ns_copy, 53);
|
||||
let response = self.lookup(&question, server).await;
|
||||
|
||||
if !response.answers.is_empty() && response.header.rescode == ResultCode::NOERROR {
|
||||
self.cache
|
||||
.insert(question, (response.clone(), get_time()))
|
||||
.await;
|
||||
return response;
|
||||
}
|
||||
|
||||
if response.header.rescode == ResultCode::NXDOMAIN {
|
||||
self.cache
|
||||
.insert(question, (response.clone(), get_time()))
|
||||
.await;
|
||||
return response;
|
||||
}
|
||||
|
||||
if let Some(new_ns) = response.get_resolved_ns(qname) {
|
||||
ns = new_ns;
|
||||
continue;
|
||||
}
|
||||
|
||||
let new_ns_name = match response.get_unresolved_ns(qname) {
|
||||
Some(x) => x,
|
||||
None => {
|
||||
self.cache
|
||||
.insert(question, (response.clone(), get_time()))
|
||||
.await;
|
||||
return response;
|
||||
}
|
||||
};
|
||||
|
||||
let recursive_response = self.recursive_lookup(new_ns_name, QueryType::A).await;
|
||||
|
||||
if let Some(new_ns) = recursive_response.get_random_a() {
|
||||
ns = new_ns;
|
||||
} else {
|
||||
self.cache
|
||||
.insert(question, (response.clone(), get_time()))
|
||||
.await;
|
||||
return response;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn handle_query(mut self) -> Result<()> {
|
||||
let mut request = self.connection.read_packet().await?;
|
||||
|
||||
let mut packet = Packet::new();
|
||||
packet.header.id = request.header.id;
|
||||
packet.header.recursion_desired = true;
|
||||
packet.header.recursion_available = true;
|
||||
packet.header.response = true;
|
||||
|
||||
if let Some(question) = request.questions.pop() {
|
||||
trace!("Received query: {question:?}");
|
||||
|
||||
let result = self.recursive_lookup(&question.name, question.qtype).await;
|
||||
packet.questions.push(question.clone());
|
||||
packet.header.rescode = result.header.rescode;
|
||||
|
||||
for rec in result.answers {
|
||||
trace!("Answer: {rec:?}");
|
||||
packet.answers.push(rec);
|
||||
}
|
||||
for rec in result.authorities {
|
||||
trace!("Authority: {rec:?}");
|
||||
packet.authorities.push(rec);
|
||||
}
|
||||
for rec in result.resources {
|
||||
trace!("Resource: {rec:?}");
|
||||
packet.resources.push(rec);
|
||||
}
|
||||
} else {
|
||||
packet.header.rescode = ResultCode::FORMERR;
|
||||
}
|
||||
|
||||
self.connection.write_packet(packet).await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
|
@ -1,85 +0,0 @@
|
|||
use super::{
|
||||
binding::Binding,
|
||||
packet::{question::DnsQuestion, Packet},
|
||||
resolver::Resolver,
|
||||
};
|
||||
use crate::{config::Config, database::Database, Result};
|
||||
use moka::future::Cache;
|
||||
use std::{net::SocketAddr, sync::Arc, time::Duration};
|
||||
use tokio::task::JoinHandle;
|
||||
use tracing::{error, info};
|
||||
|
||||
pub struct DnsServer {
|
||||
addr: SocketAddr,
|
||||
config: Arc<Config>,
|
||||
database: Arc<Database>,
|
||||
cache: Cache<DnsQuestion, (Packet, u64)>,
|
||||
}
|
||||
|
||||
impl DnsServer {
|
||||
pub async fn new(config: Config, database: Database) -> Result<Self> {
|
||||
let addr = format!("[::]:{}", config.dns_port).parse::<SocketAddr>()?;
|
||||
let cache = Cache::builder()
|
||||
.time_to_live(Duration::from_secs(60 * 60))
|
||||
.max_capacity(config.dns_cache_size)
|
||||
.build();
|
||||
|
||||
info!("Created DNS cache with size of {}", config.dns_cache_size);
|
||||
|
||||
Ok(Self {
|
||||
addr,
|
||||
config: Arc::new(config),
|
||||
database: Arc::new(database),
|
||||
cache,
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn run(&self) -> Result<(JoinHandle<()>, JoinHandle<()>)> {
|
||||
let tcp = Binding::tcp(self.addr).await?;
|
||||
let tcp_handle = self.listen(tcp);
|
||||
|
||||
let udp = Binding::udp(self.addr).await?;
|
||||
let udp_handle = self.listen(udp);
|
||||
|
||||
info!(
|
||||
"Fallback DNS Server is set to: {:?}",
|
||||
self.config.dns_fallback
|
||||
);
|
||||
info!(
|
||||
"Listening for TCP and UDP traffic on [::]:{}",
|
||||
self.config.dns_port
|
||||
);
|
||||
|
||||
Ok((udp_handle, tcp_handle))
|
||||
}
|
||||
|
||||
fn listen(&self, mut binding: Binding) -> JoinHandle<()> {
|
||||
let config = self.config.clone();
|
||||
let database = self.database.clone();
|
||||
let cache = self.cache.clone();
|
||||
tokio::spawn(async move {
|
||||
let mut id = 0;
|
||||
loop {
|
||||
let Ok(connection) = binding.connect().await else { continue };
|
||||
info!("Received request on {}", binding.name());
|
||||
|
||||
let resolver = Resolver::new(
|
||||
id,
|
||||
connection,
|
||||
config.clone(),
|
||||
database.clone(),
|
||||
cache.clone(),
|
||||
);
|
||||
|
||||
let name = binding.name().to_string();
|
||||
tokio::spawn(async move {
|
||||
if let Err(err) = resolver.handle_query().await {
|
||||
error!("{} request {} failed: {:?}", name, id, err);
|
||||
};
|
||||
});
|
||||
|
||||
id += 1;
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
49
src/io/log.c
Normal file
49
src/io/log.c
Normal file
|
@ -0,0 +1,49 @@
|
|||
#include <stdarg.h>
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
#include <time.h>
|
||||
|
||||
#include "log.h"
|
||||
|
||||
#ifdef LOG
|
||||
|
||||
void logmsg(LogLevel level, const char* msg, ...) {
|
||||
|
||||
INIT_LOG_BOUNDS
|
||||
INIT_LOG_BUFFER(buffer)
|
||||
|
||||
time_t now = time(NULL);
|
||||
struct tm *tm = localtime(&now);
|
||||
APPEND(buffer, "\x1b[97m%02d:%02d:%02d ", tm->tm_hour, tm->tm_min, tm->tm_sec);
|
||||
|
||||
switch (level) {
|
||||
case DEBUG:
|
||||
APPEND(buffer, "\x1b[%dm%s\x1b[97m ", 95, "DEBUG");
|
||||
break;
|
||||
case TRACE:
|
||||
APPEND(buffer, "\x1b[%dm%s\x1b[97m ", 96, "TRACE");
|
||||
break;
|
||||
case INFO:
|
||||
APPEND(buffer, "\x1b[%dm%s\x1b[97m ", 92, "INFO");
|
||||
break;
|
||||
case WARN:
|
||||
APPEND(buffer, "\x1b[%dm%s\x1b[97m ", 93, "WARN");
|
||||
break;
|
||||
case ERROR:
|
||||
APPEND(buffer, "\x1b[%dm%s\x1b[97m ", 91, "ERROR");
|
||||
break;
|
||||
break;
|
||||
}
|
||||
|
||||
va_list valist;
|
||||
va_start(valist, msg);
|
||||
t += vsnprintf(buffer + t, BUF_LENGTH - t, msg, valist);
|
||||
va_end(valist);
|
||||
|
||||
APPEND(buffer, "\n");
|
||||
|
||||
fwrite(&buffer, t, 1, stdout);
|
||||
}
|
||||
|
||||
#endif
|
45
src/io/log.h
Normal file
45
src/io/log.h
Normal file
|
@ -0,0 +1,45 @@
|
|||
#pragma once
|
||||
|
||||
#include <stdint.h>
|
||||
|
||||
#define LOG
|
||||
#ifdef LOG
|
||||
|
||||
typedef enum {
|
||||
DEBUG,
|
||||
TRACE,
|
||||
INFO,
|
||||
WARN,
|
||||
ERROR,
|
||||
} LogLevel;
|
||||
|
||||
#define BUF_LENGTH 256
|
||||
#define INIT_LOG_BUFFER(name) char name[BUF_LENGTH];
|
||||
#define INIT_LOG_BOUNDS int t = 0;
|
||||
#define APPEND(buffer, msg, ...) t += snprintf(buffer + t, BUF_LENGTH - t, msg, ##__VA_ARGS__);
|
||||
#define LOGONLY(code) code
|
||||
|
||||
void logmsg(LogLevel level, const char* msg, ...)
|
||||
__attribute__ ((__format__(printf, 2, 3)));
|
||||
|
||||
#define DEBUG(msg, ...) logmsg(DEBUG, msg, ##__VA_ARGS__)
|
||||
#define TRACE(msg, ...) logmsg(TRACE, msg, ##__VA_ARGS__)
|
||||
#define INFO(msg, ...) logmsg(INFO, msg, ##__VA_ARGS__)
|
||||
#define WARN(msg, ...) logmsg(WARN, msg, ##__VA_ARGS__)
|
||||
#define ERROR(msg, ...) logmsg(ERROR, msg, ##__VA_ARGS__)
|
||||
|
||||
#else
|
||||
|
||||
#define BUF_LENGTH
|
||||
#define INIT_LOG_BUFFER(name)
|
||||
#define INIT_LOG_BOUNDS
|
||||
#define APPEND(buffer, msg, ...)
|
||||
#define LOGONLY(code)
|
||||
|
||||
#define DEBUG(msg, ...)
|
||||
#define TRACE(msg, ...)
|
||||
#define INFO(msg, ...)
|
||||
#define WARN(msg, ...)
|
||||
#define ERROR(msg, ...)
|
||||
|
||||
#endif
|
32
src/main.c
Normal file
32
src/main.c
Normal file
|
@ -0,0 +1,32 @@
|
|||
#include "server/server.h"
|
||||
|
||||
#include <stdlib.h>
|
||||
#include <sys/select.h>
|
||||
|
||||
#define DEFAULT_PORT 53
|
||||
|
||||
static uint16_t get_port(const char* port_str) {
|
||||
if (port_str == NULL) {
|
||||
return DEFAULT_PORT;
|
||||
}
|
||||
|
||||
uint16_t port;
|
||||
if ((port = strtoul(port_str, NULL, 10)) == 0) {
|
||||
return DEFAULT_PORT;
|
||||
}
|
||||
|
||||
return port;
|
||||
}
|
||||
|
||||
int main(void) {
|
||||
|
||||
const char* port_str = getenv("PORT");
|
||||
uint16_t port = get_port(port_str);
|
||||
|
||||
Server server;
|
||||
server_init(port, &server);
|
||||
server_run(&server);
|
||||
server_free(&server);
|
||||
|
||||
return EXIT_SUCCESS;
|
||||
}
|
64
src/main.rs
64
src/main.rs
|
@ -1,64 +0,0 @@
|
|||
use std::time::{SystemTime, UNIX_EPOCH};
|
||||
|
||||
use config::Config;
|
||||
|
||||
use database::Database;
|
||||
use dotenv::dotenv;
|
||||
use dns::server::DnsServer;
|
||||
use tracing::{error, metadata::LevelFilter};
|
||||
use tracing_subscriber::{
|
||||
filter::filter_fn, prelude::__tracing_subscriber_SubscriberExt, util::SubscriberInitExt, Layer,
|
||||
};
|
||||
use web::WebServer;
|
||||
|
||||
mod config;
|
||||
mod database;
|
||||
mod dns;
|
||||
mod web;
|
||||
|
||||
type Error = Box<dyn std::error::Error>;
|
||||
pub type Result<T> = std::result::Result<T, Error>;
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() {
|
||||
if let Err(err) = run().await {
|
||||
error!("{err}")
|
||||
};
|
||||
}
|
||||
|
||||
async fn run() -> Result<()> {
|
||||
dotenv().ok();
|
||||
|
||||
tracing_subscriber::registry()
|
||||
.with(
|
||||
tracing_subscriber::fmt::layer()
|
||||
.with_filter(LevelFilter::TRACE)
|
||||
.with_filter(filter_fn(|metadata| {
|
||||
metadata.target().starts_with("wrapper")
|
||||
})),
|
||||
)
|
||||
.init();
|
||||
|
||||
let config = Config::new();
|
||||
let database = Database::new(config.clone()).await?;
|
||||
|
||||
let dns_server = DnsServer::new(config.clone(), database.clone()).await?;
|
||||
let (udp, tcp) = dns_server.run().await?;
|
||||
|
||||
let web_server = WebServer::new(config, database).await?;
|
||||
let web = web_server.run().await?;
|
||||
|
||||
tokio::join!(udp).0?;
|
||||
tokio::join!(tcp).0?;
|
||||
tokio::join!(web).0?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn get_time() -> u64 {
|
||||
let start = SystemTime::now();
|
||||
let since_the_epoch = start
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.expect("Time went backwards");
|
||||
since_the_epoch.as_millis() as u64
|
||||
}
|
240
src/packet/buffer.c
Normal file
240
src/packet/buffer.c
Normal file
|
@ -0,0 +1,240 @@
|
|||
#include "buffer.h"
|
||||
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
|
||||
struct PacketBuffer {
|
||||
uint8_t* arr;
|
||||
int capacity;
|
||||
int index;
|
||||
int size;
|
||||
};
|
||||
|
||||
PacketBuffer* buffer_create(int capacity) {
|
||||
PacketBuffer* buffer = malloc(sizeof(PacketBuffer));
|
||||
buffer->arr = malloc(capacity);
|
||||
buffer->capacity = capacity;
|
||||
buffer->size = 0;
|
||||
buffer->index = 0;
|
||||
return buffer;
|
||||
}
|
||||
|
||||
void buffer_free(PacketBuffer* buffer) {
|
||||
free(buffer->arr);
|
||||
free(buffer);
|
||||
}
|
||||
|
||||
void buffer_seek(PacketBuffer* buffer, int index) {
|
||||
buffer->index = index;
|
||||
}
|
||||
|
||||
uint8_t buffer_read(PacketBuffer* buffer) {
|
||||
if (buffer->index > buffer->size) {
|
||||
return 0;
|
||||
}
|
||||
uint8_t data = buffer->arr[buffer->index];
|
||||
buffer->index++;
|
||||
return data;
|
||||
}
|
||||
|
||||
uint16_t buffer_read_short(PacketBuffer* buffer) {
|
||||
return
|
||||
(uint16_t) buffer_read(buffer) << 8 |
|
||||
(uint16_t) buffer_read(buffer);
|
||||
}
|
||||
|
||||
uint32_t buffer_read_int(PacketBuffer* buffer) {
|
||||
return
|
||||
(uint32_t) buffer_read(buffer) << 24 |
|
||||
(uint32_t) buffer_read(buffer) << 16 |
|
||||
(uint32_t) buffer_read(buffer) << 8 |
|
||||
(uint32_t) buffer_read(buffer);
|
||||
}
|
||||
|
||||
uint8_t buffer_get(PacketBuffer* buffer, int index) {
|
||||
if (index > buffer->size) {
|
||||
return 0;
|
||||
}
|
||||
uint8_t data = buffer->arr[index];
|
||||
return data;
|
||||
}
|
||||
|
||||
uint8_t* buffer_get_range(PacketBuffer* buffer, int start, int len) {
|
||||
uint8_t* arr = malloc(len);
|
||||
for (int i = 0; i < len; i++) {
|
||||
arr[i] = buffer_get(buffer, start + i);
|
||||
}
|
||||
return arr;
|
||||
}
|
||||
|
||||
uint16_t buffer_get_size(PacketBuffer* buffer) {
|
||||
return (uint16_t) buffer->size + 1;
|
||||
}
|
||||
|
||||
static void write(uint8_t** buffer, uint8_t* size, uint8_t* capacity, uint8_t data) {
|
||||
if (*size == *capacity) {
|
||||
*capacity *= 2;
|
||||
*buffer = realloc(*buffer, *capacity);
|
||||
}
|
||||
(*buffer)[*size] = data;
|
||||
(*size)++;
|
||||
}
|
||||
|
||||
void buffer_read_qname(PacketBuffer* buffer, uint8_t** out) {
|
||||
int index = buffer->index;
|
||||
int jumped = 0;
|
||||
|
||||
int max_jumps = 5;
|
||||
int jumps_performed = 0;
|
||||
|
||||
uint8_t length = 0;
|
||||
uint8_t capacity = 8;
|
||||
*out = malloc(capacity);
|
||||
write(out, &length, &capacity, 0);
|
||||
|
||||
while(1) {
|
||||
if (jumps_performed > max_jumps) {
|
||||
break;
|
||||
}
|
||||
|
||||
uint8_t len = buffer_get(buffer, index);
|
||||
|
||||
if ((len & 0xC0) == 0xC0) {
|
||||
if (jumped == 0) {
|
||||
buffer_seek(buffer, index + 2);
|
||||
}
|
||||
|
||||
uint16_t b2 = (uint16_t) buffer_get(buffer, index + 1);
|
||||
uint16_t offset = ((((uint16_t) len) ^ 0xC0) << 8) | b2;
|
||||
index = (int) offset;
|
||||
jumped = 1;
|
||||
jumps_performed++;
|
||||
continue;
|
||||
}
|
||||
|
||||
index++;
|
||||
|
||||
if (len == 0) {
|
||||
break;
|
||||
}
|
||||
|
||||
if (length > 1) {
|
||||
write(out, &length, &capacity, '.');
|
||||
}
|
||||
|
||||
uint8_t* range = buffer_get_range(buffer, index, len);
|
||||
for (uint8_t i = 0; i < len; i++) {
|
||||
write(out, &length, &capacity, range[i]);
|
||||
}
|
||||
free(range);
|
||||
|
||||
index += (int) len;
|
||||
}
|
||||
|
||||
if (jumped == 0) {
|
||||
buffer_seek(buffer, index);
|
||||
}
|
||||
|
||||
(*out)[0] = length - 1;
|
||||
}
|
||||
|
||||
void buffer_read_string(PacketBuffer* buffer, uint8_t** out) {
|
||||
uint8_t len = buffer_read(buffer);
|
||||
buffer_read_n(buffer, out, len);
|
||||
}
|
||||
|
||||
void buffer_read_n(PacketBuffer* buffer, uint8_t** out, uint8_t len) {
|
||||
*out = malloc(len + 1);
|
||||
*out[0] = len;
|
||||
memcpy(*out + 1, buffer->arr + buffer->index, len);
|
||||
buffer->index += len;
|
||||
}
|
||||
|
||||
void buffer_write(PacketBuffer* buffer, uint8_t data) {
|
||||
if(buffer->index == buffer->capacity) {
|
||||
buffer->capacity *= 2;
|
||||
buffer->arr = realloc(buffer->arr, buffer->capacity);
|
||||
}
|
||||
if (buffer->size < buffer->index) {
|
||||
buffer->size = buffer->index;
|
||||
}
|
||||
buffer->arr[buffer->index] = data;
|
||||
buffer->index++;
|
||||
}
|
||||
|
||||
void buffer_write_short(PacketBuffer* buffer, uint16_t data) {
|
||||
buffer_write(buffer, (uint8_t)(data >> 8));
|
||||
buffer_write(buffer, (uint8_t)(data & 0xFF));
|
||||
}
|
||||
|
||||
void buffer_write_int(PacketBuffer* buffer, uint32_t data) {
|
||||
buffer_write(buffer, (uint8_t)(data >> 24));
|
||||
buffer_write(buffer, (uint8_t)(data >> 16));
|
||||
buffer_write(buffer, (uint8_t)(data >> 8));
|
||||
buffer_write(buffer, (uint8_t)(data & 0xFF));
|
||||
}
|
||||
|
||||
void buffer_write_qname(PacketBuffer* buffer, uint8_t* in) {
|
||||
uint8_t part = 0;
|
||||
uint8_t len = in[0];
|
||||
|
||||
buffer_write(buffer, 0);
|
||||
|
||||
if (len == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
for(uint8_t i = 0; i < len; i ++) {
|
||||
if (in[i+1] == '.') {
|
||||
buffer_set(buffer, part, buffer->index - (int)part - 1);
|
||||
buffer_write(buffer, 0);
|
||||
part = 0;
|
||||
} else {
|
||||
buffer_write(buffer, in[i+1]);
|
||||
part++;
|
||||
}
|
||||
}
|
||||
buffer_set(buffer, part, buffer->index - (int)part - 1);
|
||||
buffer_write(buffer, 0);
|
||||
}
|
||||
|
||||
void buffer_write_string(PacketBuffer* buffer, uint8_t* in) {
|
||||
buffer_write(buffer, in[0]);
|
||||
buffer_write_n(buffer, in + 1, in[0]);
|
||||
}
|
||||
|
||||
void buffer_write_n(PacketBuffer* buffer, uint8_t* in, int len) {
|
||||
if (buffer->size + len >= buffer->capacity) {
|
||||
buffer->capacity *= 2;
|
||||
buffer->capacity += len;
|
||||
buffer->arr = realloc(buffer->arr, buffer->capacity);
|
||||
}
|
||||
memcpy(buffer->arr + buffer->index, in, len);
|
||||
buffer->size += len;
|
||||
buffer->index += len;
|
||||
}
|
||||
|
||||
void buffer_set(PacketBuffer* buffer, uint8_t data, int index) {
|
||||
if (index > buffer->size) {
|
||||
return;
|
||||
}
|
||||
buffer->arr[index] = data;
|
||||
}
|
||||
|
||||
void buffer_set_uint16_t(PacketBuffer* buffer, uint16_t data, int index) {
|
||||
buffer_set(buffer, (uint8_t)(data >> 8), index);
|
||||
buffer_set(buffer, (uint8_t)(data & 0xFF), index + 1);
|
||||
}
|
||||
|
||||
int buffer_get_index(PacketBuffer* buffer) {
|
||||
return buffer->index;
|
||||
}
|
||||
|
||||
void buffer_step(PacketBuffer* buffer, int len) {
|
||||
buffer->index += len;
|
||||
}
|
||||
|
||||
uint8_t* buffer_get_ptr(PacketBuffer* buffer) {
|
||||
return buffer->arr;
|
||||
}
|
51
src/packet/buffer.h
Normal file
51
src/packet/buffer.h
Normal file
|
@ -0,0 +1,51 @@
|
|||
#pragma once
|
||||
|
||||
#include <stdint.h>
|
||||
|
||||
typedef struct PacketBuffer PacketBuffer;
|
||||
|
||||
PacketBuffer* buffer_create(int capacity);
|
||||
|
||||
void buffer_free(PacketBuffer* buffer);
|
||||
|
||||
void buffer_seek(PacketBuffer* buffer, int index);
|
||||
|
||||
uint8_t buffer_read(PacketBuffer* buffer);
|
||||
|
||||
uint16_t buffer_read_short(PacketBuffer* buffer);
|
||||
|
||||
uint32_t buffer_read_int(PacketBuffer* buffer);
|
||||
|
||||
uint8_t buffer_get(PacketBuffer* buffer, int index);
|
||||
|
||||
uint8_t* buffer_get_range(PacketBuffer* buffer, int start, int len);
|
||||
|
||||
uint16_t buffer_get_size(PacketBuffer* buffer);
|
||||
|
||||
void buffer_read_qname(PacketBuffer* buffer, uint8_t** out);
|
||||
|
||||
void buffer_read_string(PacketBuffer* buffer, uint8_t** out);
|
||||
|
||||
void buffer_read_n(PacketBuffer* buffer, uint8_t** out, uint8_t len);
|
||||
|
||||
void buffer_write(PacketBuffer* buffer, uint8_t data);
|
||||
|
||||
void buffer_write_short(PacketBuffer* buffer, uint16_t data);
|
||||
|
||||
void buffer_write_int(PacketBuffer* buffer, uint32_t data);
|
||||
|
||||
void buffer_write_qname(PacketBuffer* buffer, uint8_t* in);
|
||||
|
||||
void buffer_write_string(PacketBuffer* buffer, uint8_t* in);
|
||||
|
||||
void buffer_write_n(PacketBuffer* buffer, uint8_t* in, int len);
|
||||
|
||||
void buffer_set(PacketBuffer* buffer, uint8_t data, int index);
|
||||
|
||||
void buffer_set_uint16_t(PacketBuffer* buffer, uint16_t data, int index);
|
||||
|
||||
int buffer_get_index(PacketBuffer* buffer);
|
||||
|
||||
void buffer_step(PacketBuffer* buffer, int len);
|
||||
|
||||
uint8_t* buffer_get_ptr(PacketBuffer* buffer);
|
93
src/packet/header.c
Normal file
93
src/packet/header.c
Normal file
|
@ -0,0 +1,93 @@
|
|||
#include "header.h"
|
||||
#include "buffer.h"
|
||||
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
|
||||
uint8_t rescode_to_id(ResultCode code) {
|
||||
switch(code) {
|
||||
case NOERROR:
|
||||
return 0;
|
||||
case FORMERR:
|
||||
return 1;
|
||||
case SERVFAIL:
|
||||
return 2;
|
||||
case NXDOMAIN:
|
||||
return 3;
|
||||
case NOTIMP:
|
||||
return 4;
|
||||
case REFUSED:
|
||||
return 5;
|
||||
default:
|
||||
return 2;
|
||||
}
|
||||
}
|
||||
|
||||
ResultCode rescode_from_id(uint8_t id) {
|
||||
switch(id) {
|
||||
case 0:
|
||||
return NOERROR;
|
||||
case 1:
|
||||
return FORMERR;
|
||||
case 2:
|
||||
return SERVFAIL;
|
||||
case 3:
|
||||
return NXDOMAIN;
|
||||
case 4:
|
||||
return NOTIMP;
|
||||
case 5:
|
||||
return REFUSED;
|
||||
default:
|
||||
return FORMERR;
|
||||
}
|
||||
}
|
||||
|
||||
void read_header(PacketBuffer* buffer, Header* header) {
|
||||
// memset(header, 0, sizeof(Header));
|
||||
header->id = buffer_read_short(buffer);
|
||||
|
||||
uint8_t a = buffer_read(buffer);
|
||||
header->recursion_desired = (a & (1 << 0)) > 0;
|
||||
header->truncated_message = (a & (1 << 1)) > 0;
|
||||
header->authorative_answer = (a & (1 << 2)) > 0;
|
||||
header->opcode = (a >> 3) & 0x0F;
|
||||
header->response = (a & (1 << 7)) > 0;
|
||||
|
||||
uint8_t b = buffer_read(buffer);
|
||||
header->rescode = rescode_from_id(b & 0x0F);
|
||||
header->checking_disabled = (b & (1 << 4)) > 0;
|
||||
header->authed_data = (b& (1 << 4)) > 0;
|
||||
header->z = (b & (1 << 6)) > 0;
|
||||
header->recursion_available = (b & (1 << 7)) > 0;
|
||||
|
||||
header->questions = buffer_read_short(buffer);
|
||||
header->answers = buffer_read_short(buffer);
|
||||
header->authoritative_entries = buffer_read_short(buffer);
|
||||
header->resource_entries = buffer_read_short(buffer);
|
||||
}
|
||||
|
||||
void write_header(PacketBuffer* buffer, Header* header) {
|
||||
buffer_write_short(buffer, header->id);
|
||||
|
||||
buffer_write(buffer,
|
||||
((uint8_t) header->recursion_desired) |
|
||||
((uint8_t) header->truncated_message << 1) |
|
||||
((uint8_t) header->authorative_answer << 2) |
|
||||
(header->opcode << 3) |
|
||||
((uint8_t) header->response << 7)
|
||||
);
|
||||
|
||||
buffer_write(buffer,
|
||||
(rescode_to_id(header->rescode)) |
|
||||
((uint8_t) header->checking_disabled << 4) |
|
||||
((uint8_t) header->authed_data << 5) |
|
||||
((uint8_t) header->z << 6) |
|
||||
((uint8_t) header->recursion_available << 7)
|
||||
);
|
||||
|
||||
buffer_write_short(buffer, header->questions);
|
||||
buffer_write_short(buffer, header->answers);
|
||||
buffer_write_short(buffer, header->authoritative_entries);
|
||||
buffer_write_short(buffer, header->resource_entries);
|
||||
}
|
41
src/packet/header.h
Normal file
41
src/packet/header.h
Normal file
|
@ -0,0 +1,41 @@
|
|||
#pragma once
|
||||
|
||||
#include "buffer.h"
|
||||
|
||||
#include <stdbool.h>
|
||||
|
||||
typedef enum {
|
||||
NOERROR, // 0
|
||||
FORMERR, // 1
|
||||
SERVFAIL, // 2
|
||||
NXDOMAIN, // 3,
|
||||
NOTIMP, // 4
|
||||
REFUSED, // 5
|
||||
} ResultCode;
|
||||
|
||||
uint8_t rescode_to_id(ResultCode code);
|
||||
ResultCode rescode_from_id(uint8_t id);
|
||||
|
||||
typedef struct {
|
||||
uint16_t id;
|
||||
|
||||
bool recursion_desired; // 1 bit
|
||||
bool truncated_message; // 1 bit
|
||||
bool authorative_answer; // 1 bit
|
||||
uint8_t opcode; // 4 bits
|
||||
bool response; // 1 bit
|
||||
|
||||
ResultCode rescode; // 4 bits
|
||||
bool checking_disabled; // 1 bit
|
||||
bool authed_data; // 1 bit
|
||||
bool z; // 1 bit
|
||||
bool recursion_available; // 1 bit
|
||||
|
||||
uint16_t questions; // 16 bits
|
||||
uint16_t answers; // 16 bits
|
||||
uint16_t authoritative_entries; // 16 bits
|
||||
uint16_t resource_entries; // 16 bits
|
||||
} Header;
|
||||
|
||||
void read_header(PacketBuffer* buffer, Header* header);
|
||||
void write_header(PacketBuffer* buffer, Header* header);
|
171
src/packet/packet.c
Normal file
171
src/packet/packet.c
Normal file
|
@ -0,0 +1,171 @@
|
|||
#include "packet.h"
|
||||
#include "buffer.h"
|
||||
#include "header.h"
|
||||
#include "question.h"
|
||||
#include "record.h"
|
||||
|
||||
#include <stdio.h>
|
||||
#include <string.h>
|
||||
#include <stdlib.h>
|
||||
|
||||
void read_packet(PacketBuffer* buffer, Packet* packet) {
|
||||
read_header(buffer, &packet->header);
|
||||
|
||||
packet->questions = malloc(sizeof(Question) * packet->header.questions);
|
||||
for(uint16_t i = 0; i < packet->header.questions; i++) {
|
||||
read_question(buffer, &packet->questions[i]);
|
||||
}
|
||||
|
||||
packet->answers = malloc(sizeof(Record) * packet->header.answers);
|
||||
for(uint16_t i = 0; i < packet->header.answers; i++) {
|
||||
read_record(buffer, &packet->answers[i]);
|
||||
}
|
||||
|
||||
packet->authorities = malloc(sizeof(Record) * packet->header.authoritative_entries);
|
||||
for(uint16_t i = 0; i < packet->header.authoritative_entries; i++) {
|
||||
read_record(buffer, &packet->authorities[i]);
|
||||
}
|
||||
|
||||
packet->resources = malloc(sizeof(Record) * packet->header.resource_entries);
|
||||
for(uint16_t i = 0; i < packet->header.resource_entries; i++) {
|
||||
read_record(buffer, &packet->resources[i]);
|
||||
}
|
||||
}
|
||||
|
||||
void write_packet(PacketBuffer* buffer, Packet* packet) {
|
||||
write_header(buffer, &packet->header);
|
||||
|
||||
for(uint16_t i = 0; i < packet->header.questions; i++) {
|
||||
write_question(buffer, &packet->questions[i]);
|
||||
}
|
||||
|
||||
for(uint16_t i = 0; i < packet->header.answers; i++) {
|
||||
write_record(buffer, &packet->answers[i]);
|
||||
}
|
||||
|
||||
for(uint16_t i = 0; i < packet->header.authoritative_entries; i++) {
|
||||
write_record(buffer, &packet->authorities[i]);
|
||||
}
|
||||
|
||||
for(uint16_t i = 0; i < packet->header.resource_entries; i++) {
|
||||
write_record(buffer, &packet->resources[i]);
|
||||
}
|
||||
}
|
||||
|
||||
void free_packet(Packet* packet) {
|
||||
|
||||
for(uint16_t i = 0; i < packet->header.questions; i++) {
|
||||
free_question(&packet->questions[i]);
|
||||
}
|
||||
free(packet->questions);
|
||||
|
||||
for(uint16_t i = 0; i < packet->header.answers; i++) {
|
||||
free_record(&packet->answers[i]);
|
||||
}
|
||||
free(packet->answers);
|
||||
|
||||
for(uint16_t i = 0; i < packet->header.authoritative_entries; i++) {
|
||||
free_record(&packet->authorities[i]);
|
||||
}
|
||||
free(packet->authorities);
|
||||
|
||||
for(uint16_t i = 0; i < packet->header.resource_entries; i++) {
|
||||
free_record(&packet->resources[i]);
|
||||
}
|
||||
free(packet->resources);
|
||||
}
|
||||
|
||||
bool get_random_a(Packet* packet, IpAddr* addr) {
|
||||
for (uint16_t i = 0; i < packet->header.answers; i++) {
|
||||
Record record = packet->answers[i];
|
||||
if (record.type == A) {
|
||||
create_ip_addr((char*) &record.data.a.addr, addr);
|
||||
return true;
|
||||
} else if (record.type == AAAA) {
|
||||
create_ip_addr6((char*) &record.data.aaaa.addr, addr);
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
static bool ends_with(uint8_t* full, uint8_t* end) {
|
||||
uint8_t check = end[0];
|
||||
uint8_t len = full[0];
|
||||
|
||||
if (check > len) {
|
||||
return false;
|
||||
}
|
||||
|
||||
for(uint8_t i = 0; i < check; i++) {
|
||||
if (end[check - 1 - i] != full[len - 1 - i]) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool equals(uint8_t* a, uint8_t* b) {
|
||||
if (a[0] != b[0]) {
|
||||
return false;
|
||||
}
|
||||
|
||||
for(uint8_t i = 1; i < a[0] + 1; i++) {
|
||||
if(a[i] != b[i]) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool get_resolved_ns(Packet* packet, uint8_t* qname, IpAddr* addr) {
|
||||
for (uint16_t i = 0; i < packet->header.authoritative_entries; i++) {
|
||||
Record record = packet->authorities[i];
|
||||
if (record.type != NS) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if(!ends_with(qname, record.domain)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
for (uint16_t i = 0; i < packet->header.resource_entries; i++) {
|
||||
Record resource = packet->resources[i];
|
||||
if (!equals(record.data.ns.host, resource.domain)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (resource.type == A) {
|
||||
create_ip_addr((char*) &record.data.a.addr, addr);
|
||||
return true;
|
||||
} else if (resource.type == AAAA) {
|
||||
create_ip_addr6((char*) &record.data.aaaa.addr, addr);
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool get_unresoled_ns(Packet* packet, uint8_t* qname, Question* question) {
|
||||
for (uint16_t i = 0; i < packet->header.authoritative_entries; i++) {
|
||||
Record record = packet->authorities[i];
|
||||
if (record.type != NS) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if(!ends_with(qname, record.domain)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
uint8_t* host = record.data.ns.host;
|
||||
|
||||
question->qtype = NS;
|
||||
question->domain = malloc(host[0] + 1);
|
||||
memcpy(question->domain, host, host[0] + 1);
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
23
src/packet/packet.h
Normal file
23
src/packet/packet.h
Normal file
|
@ -0,0 +1,23 @@
|
|||
#pragma once
|
||||
|
||||
#include "buffer.h"
|
||||
#include "question.h"
|
||||
#include "header.h"
|
||||
#include "record.h"
|
||||
#include "../server/addr.h"
|
||||
|
||||
typedef struct {
|
||||
Header header;
|
||||
Question* questions;
|
||||
Record* answers;
|
||||
Record* authorities;
|
||||
Record* resources;
|
||||
} Packet;
|
||||
|
||||
void read_packet(PacketBuffer* buffer, Packet* packet);
|
||||
void write_packet(PacketBuffer* buffer, Packet* packet);
|
||||
void free_packet(Packet* packet);
|
||||
|
||||
bool get_random_a(Packet* packet, IpAddr* addr);
|
||||
bool get_resolved_ns(Packet* packet, uint8_t* qname, IpAddr* addr);
|
||||
bool get_unresoled_ns(Packet* packet, uint8_t* qname, Question* question);
|
94
src/packet/question.c
Normal file
94
src/packet/question.c
Normal file
|
@ -0,0 +1,94 @@
|
|||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
|
||||
#include "question.h"
|
||||
#include "buffer.h"
|
||||
#include "record.h"
|
||||
#include "../io/log.h"
|
||||
|
||||
void read_question(PacketBuffer* buffer, Question* question) {
|
||||
buffer_read_qname(buffer, &question->domain);
|
||||
|
||||
uint16_t qtype_num = buffer_read_short(buffer);
|
||||
record_from_id(qtype_num, &question->qtype);
|
||||
question->cls = buffer_read_short(buffer);
|
||||
|
||||
INIT_LOG_BUFFER(log)
|
||||
LOGONLY(print_question(question, log);)
|
||||
TRACE("Reading question: %s", log);
|
||||
}
|
||||
|
||||
void write_question(PacketBuffer* buffer, Question* question) {
|
||||
buffer_write_qname(buffer, question->domain);
|
||||
|
||||
uint16_t id = record_to_id(question->qtype);
|
||||
buffer_write_short(buffer, id);
|
||||
|
||||
buffer_write_short(buffer, question->cls);
|
||||
|
||||
INIT_LOG_BUFFER(log)
|
||||
LOGONLY(print_question(question, log);)
|
||||
TRACE("Writing question: %s", log);
|
||||
}
|
||||
|
||||
void free_question(Question* question) {
|
||||
free(question->domain);
|
||||
}
|
||||
|
||||
void print_question(Question* question, char* buffer) {
|
||||
INIT_LOG_BOUNDS
|
||||
switch (question->cls) {
|
||||
case 1:
|
||||
APPEND(buffer, "IN ");;
|
||||
break;
|
||||
case 3:
|
||||
APPEND(buffer, "CH ");
|
||||
break;
|
||||
case 4:
|
||||
APPEND(buffer, "HS ");
|
||||
break;
|
||||
default:
|
||||
APPEND(buffer, "?? ");
|
||||
break;
|
||||
}
|
||||
switch(question->qtype) {
|
||||
case UNKOWN:
|
||||
APPEND(buffer, "UNKOWN ");
|
||||
break;
|
||||
case A:
|
||||
APPEND(buffer, "A ");
|
||||
break;
|
||||
case NS:
|
||||
APPEND(buffer, "NS ");
|
||||
break;
|
||||
case CNAME:
|
||||
APPEND(buffer, "CNAME ");
|
||||
break;
|
||||
case SOA:
|
||||
APPEND(buffer, "SOA ");
|
||||
break;
|
||||
case PTR:
|
||||
APPEND(buffer, "PTR ");
|
||||
break;
|
||||
case MX:
|
||||
APPEND(buffer, "MX ");
|
||||
break;
|
||||
case TXT:
|
||||
APPEND(buffer, "TXT ");
|
||||
break;
|
||||
case AAAA:
|
||||
APPEND(buffer, "AAAA ");
|
||||
break;
|
||||
case SRV:
|
||||
APPEND(buffer, "SRV ");
|
||||
break;
|
||||
case CAA:
|
||||
APPEND(buffer, "CAA ");
|
||||
break;
|
||||
break;
|
||||
}
|
||||
APPEND(buffer, "%.*s",
|
||||
question->domain[0],
|
||||
question->domain + 1
|
||||
);
|
||||
}
|
15
src/packet/question.h
Normal file
15
src/packet/question.h
Normal file
|
@ -0,0 +1,15 @@
|
|||
#pragma once
|
||||
|
||||
#include "buffer.h"
|
||||
#include "record.h"
|
||||
|
||||
typedef struct {
|
||||
uint8_t* domain;
|
||||
RecordType qtype;
|
||||
uint16_t cls;
|
||||
} Question;
|
||||
|
||||
void read_question(PacketBuffer* buffer, Question* question);
|
||||
void write_question(PacketBuffer* buffer, Question* question);
|
||||
void free_question(Question* question);
|
||||
void print_question(Question* question, char* buffer);
|
540
src/packet/record.c
Normal file
540
src/packet/record.c
Normal file
|
@ -0,0 +1,540 @@
|
|||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
|
||||
#include "record.h"
|
||||
#include "buffer.h"
|
||||
#include "../io/log.h"
|
||||
|
||||
uint16_t record_to_id(RecordType type) {
|
||||
switch (type) {
|
||||
case A:
|
||||
return 1;
|
||||
case NS:
|
||||
return 2;
|
||||
case CNAME:
|
||||
return 5;
|
||||
case SOA:
|
||||
return 6;
|
||||
case PTR:
|
||||
return 12;
|
||||
case MX:
|
||||
return 15;
|
||||
case TXT:
|
||||
return 16;
|
||||
case AAAA:
|
||||
return 28;
|
||||
case SRV:
|
||||
return 33;
|
||||
case CAA:
|
||||
return 257;
|
||||
default:
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
void record_from_id(uint16_t i, RecordType* type) {
|
||||
switch (i) {
|
||||
case 1:
|
||||
*type = A;
|
||||
break;
|
||||
case 2:
|
||||
*type = NS;
|
||||
break;
|
||||
case 5:
|
||||
*type = CNAME;
|
||||
break;
|
||||
case 6:
|
||||
*type = SOA;
|
||||
break;
|
||||
case 12:
|
||||
*type = PTR;
|
||||
break;
|
||||
case 15:
|
||||
*type = MX;
|
||||
break;
|
||||
case 16:
|
||||
*type = TXT;
|
||||
break;
|
||||
case 28:
|
||||
*type = AAAA;
|
||||
break;
|
||||
case 33:
|
||||
*type = SRV;
|
||||
break;
|
||||
case 257:
|
||||
*type = CAA;
|
||||
break;
|
||||
default:
|
||||
*type = UNKOWN;
|
||||
}
|
||||
}
|
||||
|
||||
static void read_a_record(PacketBuffer* buffer, Record* record) {
|
||||
ARecord data;
|
||||
data.addr[0] = buffer_read(buffer);
|
||||
data.addr[1] = buffer_read(buffer);
|
||||
data.addr[2] = buffer_read(buffer);
|
||||
data.addr[3] = buffer_read(buffer);
|
||||
|
||||
record->data.a = data;
|
||||
}
|
||||
|
||||
static void read_ns_record(PacketBuffer* buffer, Record* record) {
|
||||
NSRecord data;
|
||||
buffer_read_qname(buffer, &data.host);
|
||||
|
||||
record->data.ns = data;
|
||||
}
|
||||
|
||||
static void read_cname_record(PacketBuffer* buffer, Record* record) {
|
||||
CNAMERecord data;
|
||||
buffer_read_qname(buffer, &data.host);
|
||||
|
||||
record->data.cname = data;
|
||||
}
|
||||
|
||||
static void read_soa_record(PacketBuffer* buffer, Record* record) {
|
||||
SOARecord data;
|
||||
buffer_read_qname(buffer, &data.mname);
|
||||
buffer_read_qname(buffer, &data.nname);
|
||||
data.serial = buffer_read_int(buffer);
|
||||
data.refresh = buffer_read_int(buffer);
|
||||
data.retry = buffer_read_int(buffer);
|
||||
data.expire = buffer_read_int(buffer);
|
||||
data.minimum = buffer_read_int(buffer);
|
||||
|
||||
record->data.soa = data;
|
||||
}
|
||||
|
||||
static void read_ptr_record(PacketBuffer* buffer, Record* record) {
|
||||
PTRRecord data;
|
||||
buffer_read_qname(buffer, &data.pointer);
|
||||
|
||||
record->data.ptr = data;
|
||||
}
|
||||
|
||||
static void read_mx_record(PacketBuffer* buffer, Record* record) {
|
||||
MXRecord data;
|
||||
data.priority = buffer_read_short(buffer);
|
||||
buffer_read_qname(buffer, &data.host);
|
||||
|
||||
record->data.mx = data;
|
||||
}
|
||||
|
||||
static void read_txt_record(PacketBuffer* buffer, Record* record) {
|
||||
TXTRecord data;
|
||||
data.len = 0;
|
||||
data.text = malloc(sizeof(uint8_t*) * 2);
|
||||
|
||||
uint8_t capacity = 2;
|
||||
while (1) {
|
||||
if (data.len >= capacity) {
|
||||
capacity *= 2;
|
||||
data.text = realloc(data.text, sizeof(uint8_t*) * capacity);
|
||||
}
|
||||
|
||||
buffer_read_string(buffer, &data.text[data.len]);
|
||||
if(data.text[data.len][0] == 0) break;
|
||||
data.len++;
|
||||
}
|
||||
|
||||
record->data.txt = data;
|
||||
}
|
||||
|
||||
static void read_aaaa_record(PacketBuffer* buffer, Record* record) {
|
||||
AAAARecord data;
|
||||
for (int i = 0; i < 16; i++) {
|
||||
data.addr[i] = buffer_read(buffer);
|
||||
}
|
||||
|
||||
record->data.aaaa = data;
|
||||
}
|
||||
|
||||
static void read_srv_record(PacketBuffer* buffer, Record* record) {
|
||||
SRVRecord data;
|
||||
data.priority = buffer_read_short(buffer);
|
||||
data.weight = buffer_read_short(buffer);
|
||||
data.port = buffer_read_short(buffer);
|
||||
buffer_read_qname(buffer, &data.target);
|
||||
|
||||
record->data.srv = data;
|
||||
}
|
||||
|
||||
static void read_caa_record(PacketBuffer* buffer, Record* record, int header_pos) {
|
||||
CAARecord data;
|
||||
data.flags = buffer_read(buffer);
|
||||
data.length = buffer_read(buffer);
|
||||
buffer_read_n(buffer, &data.tag, data.length);
|
||||
int value_len = ((int)record->len) + header_pos - buffer_get_index(buffer);
|
||||
buffer_read_n(buffer, &data.value, (uint8_t)value_len);
|
||||
|
||||
record->data.caa = data;
|
||||
}
|
||||
|
||||
void read_record(PacketBuffer* buffer, Record* record) {
|
||||
buffer_read_qname(buffer, &record->domain);
|
||||
|
||||
uint16_t qtype_num = buffer_read_short(buffer);
|
||||
record_from_id(qtype_num, &record->type);
|
||||
|
||||
record->cls = buffer_read_short(buffer);
|
||||
record->ttl = buffer_read_int(buffer);
|
||||
record->len = buffer_read_short(buffer);
|
||||
|
||||
int header_pos = buffer_get_index(buffer);
|
||||
|
||||
switch (record->type) {
|
||||
case A:
|
||||
read_a_record(buffer, record);
|
||||
break;
|
||||
case NS:
|
||||
read_ns_record(buffer, record);
|
||||
break;
|
||||
case CNAME:
|
||||
read_cname_record(buffer, record);
|
||||
break;
|
||||
case SOA:
|
||||
read_soa_record(buffer, record);
|
||||
break;
|
||||
case PTR:
|
||||
read_ptr_record(buffer, record);
|
||||
break;
|
||||
case MX:
|
||||
read_mx_record(buffer, record);
|
||||
break;
|
||||
case TXT:
|
||||
read_txt_record(buffer, record);
|
||||
break;
|
||||
case AAAA:
|
||||
read_aaaa_record(buffer, record);
|
||||
break;
|
||||
case SRV:
|
||||
read_srv_record(buffer, record);
|
||||
break;
|
||||
case CAA:
|
||||
read_caa_record(buffer, record, header_pos);
|
||||
break;
|
||||
default:
|
||||
buffer_step(buffer, record->len);
|
||||
return;
|
||||
}
|
||||
|
||||
INIT_LOG_BUFFER(log)
|
||||
LOGONLY(print_record(record, log);)
|
||||
TRACE("Reading record: %s", log);
|
||||
}
|
||||
|
||||
static void write_a_record(PacketBuffer* buffer, Record* record) {
|
||||
ARecord data = record->data.a;
|
||||
buffer_write_short(buffer, 4);
|
||||
buffer_write(buffer, record->data.a.addr[0]);
|
||||
buffer_write(buffer, data.addr[1]);
|
||||
buffer_write(buffer, data.addr[2]);
|
||||
buffer_write(buffer, data.addr[3]);
|
||||
}
|
||||
|
||||
static void write_ns_record(PacketBuffer* buffer, Record* record) {
|
||||
NSRecord data = record->data.ns;
|
||||
int pos = buffer_get_index(buffer);
|
||||
buffer_write_short(buffer, 0);
|
||||
|
||||
buffer_write_qname(buffer, data.host);
|
||||
|
||||
int size = buffer_get_index(buffer) - pos - 2;
|
||||
buffer_set_uint16_t(buffer, (uint16_t)size, pos);
|
||||
}
|
||||
|
||||
static void write_cname_record(PacketBuffer* buffer, Record* record) {
|
||||
CNAMERecord data = record->data.cname;
|
||||
int pos = buffer_get_index(buffer);
|
||||
buffer_write_short(buffer, 0);
|
||||
|
||||
buffer_write_qname(buffer, data.host);
|
||||
|
||||
int size = buffer_get_index(buffer) - pos - 2;
|
||||
buffer_set_uint16_t(buffer, (uint16_t)size, pos);
|
||||
}
|
||||
|
||||
static void write_soa_record(PacketBuffer* buffer, Record* record) {
|
||||
SOARecord data = record->data.soa;
|
||||
int pos = buffer_get_index(buffer);
|
||||
buffer_write_short(buffer, 0);
|
||||
|
||||
buffer_write_qname(buffer, data.mname);
|
||||
buffer_write_qname(buffer, data.nname);
|
||||
buffer_write_int(buffer, data.serial);
|
||||
buffer_write_int(buffer, data.refresh);
|
||||
buffer_write_int(buffer, data.retry);
|
||||
buffer_write_int(buffer, data.expire);
|
||||
buffer_write_int(buffer, data.minimum);
|
||||
|
||||
int size = buffer_get_index(buffer) - pos - 2;
|
||||
buffer_set_uint16_t(buffer, (uint16_t)size, pos);
|
||||
}
|
||||
|
||||
static void write_ptr_record(PacketBuffer* buffer, Record* record) {
|
||||
PTRRecord data = record->data.ptr;
|
||||
int pos = buffer_get_index(buffer);
|
||||
buffer_write_short(buffer, 0);
|
||||
|
||||
buffer_write_qname(buffer, data.pointer);
|
||||
|
||||
int size = buffer_get_index(buffer) - pos - 2;
|
||||
buffer_set_uint16_t(buffer, (uint16_t)size, pos);
|
||||
}
|
||||
|
||||
static void write_mx_record(PacketBuffer* buffer, Record* record) {
|
||||
MXRecord data = record->data.mx;
|
||||
int pos = buffer_get_index(buffer);
|
||||
buffer_write_short(buffer, 0);
|
||||
|
||||
buffer_write_short(buffer, data.priority);
|
||||
buffer_write_qname(buffer, data.host);
|
||||
|
||||
int size = buffer_get_index(buffer) - pos - 2;
|
||||
buffer_set_uint16_t(buffer, (uint16_t)size, pos);
|
||||
}
|
||||
|
||||
static void write_txt_record(PacketBuffer* buffer, Record* record) {
|
||||
TXTRecord data = record->data.txt;
|
||||
int pos = buffer_get_index(buffer);
|
||||
buffer_write_short(buffer, 0);
|
||||
|
||||
if(data.len == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
for(uint8_t i = 0; i < data.len; i++) {
|
||||
buffer_write_string(buffer, data.text[i]);
|
||||
}
|
||||
|
||||
int size = buffer_get_index(buffer) - pos - 2;
|
||||
buffer_set_uint16_t(buffer, (uint16_t)size, pos);
|
||||
}
|
||||
|
||||
static void write_aaaa_record(PacketBuffer* buffer, Record* record) {
|
||||
AAAARecord data = record->data.aaaa;
|
||||
|
||||
buffer_write_short(buffer, 16);
|
||||
|
||||
for (int i = 0; i < 16; i++) {
|
||||
buffer_write(buffer, data.addr[i]);
|
||||
}
|
||||
}
|
||||
|
||||
static void write_srv_record(PacketBuffer* buffer, Record* record) {
|
||||
SRVRecord data = record->data.srv;
|
||||
int pos = buffer_get_index(buffer);
|
||||
buffer_write_short(buffer, 0);
|
||||
|
||||
buffer_write_short(buffer, data.priority);
|
||||
buffer_write_short(buffer, data.weight);
|
||||
buffer_write_short(buffer, data.port);
|
||||
buffer_write_qname(buffer, data.target);
|
||||
|
||||
int size = buffer_get_index(buffer) - pos - 2;
|
||||
buffer_set_uint16_t(buffer, (uint16_t)size, pos);
|
||||
}
|
||||
|
||||
static void write_caa_record(PacketBuffer* buffer, Record* record) {
|
||||
CAARecord data = record->data.caa;
|
||||
int pos = buffer_get_index(buffer);
|
||||
buffer_write_short(buffer, 0);
|
||||
buffer_write(buffer, data.flags);
|
||||
buffer_write(buffer, data.length);
|
||||
buffer_write_n(buffer, data.tag + 1, data.tag[0]);
|
||||
buffer_write_n(buffer, data.value + 1, data.value[0]);
|
||||
|
||||
int size = buffer_get_index(buffer) - pos - 2;
|
||||
buffer_set_uint16_t(buffer, (uint16_t)size, pos);
|
||||
}
|
||||
|
||||
static void write_record_header(PacketBuffer* buffer, Record* record) {
|
||||
buffer_write_qname(buffer, record->domain);
|
||||
uint16_t id = record_to_id(record->type);
|
||||
buffer_write_short(buffer, id);
|
||||
buffer_write_short(buffer, record->cls);
|
||||
buffer_write_int(buffer, record->ttl);
|
||||
}
|
||||
|
||||
void write_record(PacketBuffer* buffer, Record* record) {
|
||||
switch(record->type) {
|
||||
case A:
|
||||
write_record_header(buffer, record);
|
||||
write_a_record(buffer, record);
|
||||
break;
|
||||
case NS:
|
||||
write_record_header(buffer, record);
|
||||
write_ns_record(buffer, record);
|
||||
break;
|
||||
case CNAME:
|
||||
write_record_header(buffer, record);
|
||||
write_cname_record(buffer, record);
|
||||
break;
|
||||
case SOA:
|
||||
write_record_header(buffer, record);
|
||||
write_soa_record(buffer, record);
|
||||
break;
|
||||
case PTR:
|
||||
write_record_header(buffer, record);
|
||||
write_ptr_record(buffer, record);
|
||||
break;
|
||||
case MX:
|
||||
write_record_header(buffer, record);
|
||||
write_mx_record(buffer, record);
|
||||
break;
|
||||
case TXT:
|
||||
write_record_header(buffer, record);
|
||||
write_txt_record(buffer, record);
|
||||
break;
|
||||
case AAAA:
|
||||
write_record_header(buffer, record);
|
||||
write_aaaa_record(buffer, record);
|
||||
break;
|
||||
case SRV:
|
||||
write_record_header(buffer, record);
|
||||
write_srv_record(buffer, record);
|
||||
break;
|
||||
case CAA:
|
||||
write_record_header(buffer, record);
|
||||
write_caa_record(buffer, record);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
||||
INIT_LOG_BUFFER(log)
|
||||
LOGONLY(print_record(record, log);)
|
||||
TRACE("Writing record: %s", log);
|
||||
}
|
||||
|
||||
void free_record(Record* record) {
|
||||
free(record->domain);
|
||||
switch (record->type) {
|
||||
case NS:
|
||||
free(record->data.ns.host);
|
||||
break;
|
||||
case CNAME:
|
||||
free(record->data.cname.host);
|
||||
break;
|
||||
case SOA:
|
||||
free(record->data.soa.mname);
|
||||
free(record->data.soa.nname);
|
||||
break;
|
||||
case PTR:
|
||||
free(record->data.ptr.pointer);
|
||||
break;
|
||||
case MX:
|
||||
free(record->data.mx.host);
|
||||
break;
|
||||
case TXT:
|
||||
for (uint8_t i = 0; i < record->data.txt.len; i++) {
|
||||
free(record->data.txt.text[i]);
|
||||
}
|
||||
free(record->data.txt.text);
|
||||
break;
|
||||
case SRV:
|
||||
free(record->data.srv.target);
|
||||
break;
|
||||
case CAA:
|
||||
free(record->data.caa.value);
|
||||
free(record->data.caa.tag);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
void print_record(Record* record, char* buffer) {
|
||||
INIT_LOG_BOUNDS
|
||||
switch(record->type) {
|
||||
case UNKOWN:
|
||||
APPEND(buffer, "UNKOWN");
|
||||
break;
|
||||
case A:
|
||||
APPEND(buffer, "A (%hhu.%hhu.%hhu.%hhu)",
|
||||
record->data.a.addr[0],
|
||||
record->data.a.addr[1],
|
||||
record->data.a.addr[2],
|
||||
record->data.a.addr[3]
|
||||
);
|
||||
break;
|
||||
case NS:
|
||||
APPEND(buffer, "NS (%.*s)",
|
||||
record->data.ns.host[0],
|
||||
record->data.ns.host + 1
|
||||
);
|
||||
break;
|
||||
case CNAME:
|
||||
APPEND(buffer, "CNAME (%.*s)",
|
||||
record->data.cname.host[0],
|
||||
record->data.cname.host + 1
|
||||
);
|
||||
break;
|
||||
case SOA:
|
||||
APPEND(buffer, "SOA (%.*s %.*s %u %u %u %u %u)",
|
||||
record->data.soa.mname[0],
|
||||
record->data.soa.mname + 1,
|
||||
record->data.soa.nname[0],
|
||||
record->data.soa.nname + 1,
|
||||
record->data.soa.serial,
|
||||
record->data.soa.refresh,
|
||||
record->data.soa.retry,
|
||||
record->data.soa.expire,
|
||||
record->data.soa.minimum
|
||||
);
|
||||
break;
|
||||
case PTR:
|
||||
APPEND(buffer, "PTR (%.*s)",
|
||||
record->data.ptr.pointer[0],
|
||||
record->data.ptr.pointer + 1
|
||||
);
|
||||
break;
|
||||
case MX:
|
||||
APPEND(buffer, "MX (%.*s %hu)",
|
||||
record->data.mx.host[0],
|
||||
record->data.mx.host + 1,
|
||||
record->data.mx.priority
|
||||
);
|
||||
break;
|
||||
case TXT:
|
||||
APPEND(buffer, "TXT (");
|
||||
for(uint8_t i = 0; i < record->data.txt.len; i++) {
|
||||
APPEND(buffer, "\"%.*s\"",
|
||||
record->data.txt.text[i][0],
|
||||
record->data.txt.text[i] + 1
|
||||
);
|
||||
}
|
||||
APPEND(buffer, ")");
|
||||
break;
|
||||
case AAAA:
|
||||
APPEND(buffer, "AAAA (");
|
||||
for(int i = 0; i < 8; i++) {
|
||||
APPEND(buffer, "%02hhx%02hhx:",
|
||||
record->data.a.addr[i*2 + 0],
|
||||
record->data.a.addr[i*2 + 1]
|
||||
);
|
||||
}
|
||||
APPEND(buffer, ":)");
|
||||
break;
|
||||
case SRV:
|
||||
APPEND(buffer, "SRV (%hu %hu %hu %.*s)",
|
||||
record->data.srv.priority,
|
||||
record->data.srv.weight,
|
||||
record->data.srv.port,
|
||||
record->data.srv.target[0],
|
||||
record->data.srv.target + 1
|
||||
);
|
||||
break;
|
||||
case CAA:
|
||||
APPEND(buffer, "CAA (%hhu %.*s %.*s)",
|
||||
record->data.caa.flags,
|
||||
record->data.caa.tag[0],
|
||||
record->data.caa.tag + 1,
|
||||
record->data.caa.value[0],
|
||||
record->data.caa.value + 1
|
||||
);
|
||||
break;
|
||||
}
|
||||
}
|
101
src/packet/record.h
Normal file
101
src/packet/record.h
Normal file
|
@ -0,0 +1,101 @@
|
|||
#pragma once
|
||||
|
||||
#include "buffer.h"
|
||||
|
||||
#include <string.h>
|
||||
|
||||
typedef enum {
|
||||
UNKOWN,
|
||||
A, // 1
|
||||
NS, // 2
|
||||
CNAME, // 5
|
||||
SOA, // 6
|
||||
PTR, // 12
|
||||
MX, // 15
|
||||
TXT, // 16
|
||||
AAAA, // 28
|
||||
SRV, // 33
|
||||
CAA // 257
|
||||
} RecordType;
|
||||
|
||||
uint16_t record_to_id(RecordType type);
|
||||
void record_from_id(uint16_t i, RecordType* type);
|
||||
|
||||
typedef struct {
|
||||
uint8_t addr[4];
|
||||
} ARecord;
|
||||
|
||||
typedef struct {
|
||||
uint8_t* host;
|
||||
} NSRecord;
|
||||
|
||||
typedef struct {
|
||||
uint8_t* host;
|
||||
} CNAMERecord;
|
||||
|
||||
typedef struct {
|
||||
uint8_t* mname;
|
||||
uint8_t* nname;
|
||||
uint32_t serial;
|
||||
uint32_t refresh;
|
||||
uint32_t retry;
|
||||
uint32_t expire;
|
||||
uint32_t minimum;
|
||||
} SOARecord;
|
||||
|
||||
typedef struct {
|
||||
uint8_t* pointer;
|
||||
} PTRRecord;
|
||||
|
||||
typedef struct {
|
||||
uint16_t priority;
|
||||
uint8_t* host;
|
||||
} MXRecord;
|
||||
|
||||
typedef struct TXTRecord {
|
||||
uint8_t** text;
|
||||
uint8_t len;
|
||||
} TXTRecord;
|
||||
|
||||
typedef struct {
|
||||
uint8_t addr[16];
|
||||
} AAAARecord;
|
||||
|
||||
typedef struct {
|
||||
uint16_t priority;
|
||||
uint16_t weight;
|
||||
uint16_t port;
|
||||
uint8_t* target;
|
||||
} SRVRecord;
|
||||
|
||||
typedef struct {
|
||||
uint8_t flags;
|
||||
uint8_t length;
|
||||
uint8_t* tag;
|
||||
uint8_t* value;
|
||||
} CAARecord;
|
||||
|
||||
typedef struct {
|
||||
uint32_t ttl;
|
||||
uint16_t cls;
|
||||
uint16_t len;
|
||||
uint8_t* domain;
|
||||
RecordType type;
|
||||
union data {
|
||||
ARecord a;
|
||||
NSRecord ns;
|
||||
CNAMERecord cname;
|
||||
SOARecord soa;
|
||||
PTRRecord ptr;
|
||||
MXRecord mx;
|
||||
TXTRecord txt;
|
||||
AAAARecord aaaa;
|
||||
SRVRecord srv;
|
||||
CAARecord caa;
|
||||
} data;
|
||||
} Record;
|
||||
|
||||
void read_record(PacketBuffer* buffer, Record* record);
|
||||
void write_record(PacketBuffer* buffer, Record* record);
|
||||
void free_record(Record* record);
|
||||
void print_record(Record* record, char* buffer);
|
233
src/server/addr.c
Normal file
233
src/server/addr.c
Normal file
|
@ -0,0 +1,233 @@
|
|||
#include <netinet/in.h>
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
#include <unistd.h>
|
||||
|
||||
#include "addr.h"
|
||||
#include "../io/log.h"
|
||||
|
||||
void create_ip_addr(char* domain, IpAddr* addr) {
|
||||
addr->type = V4;
|
||||
memcpy(&addr->data.v4.s_addr, domain, 4);
|
||||
}
|
||||
|
||||
void create_ip_addr6(char* domain, IpAddr* addr) {
|
||||
addr->type = V6;
|
||||
memcpy(&addr->data.v6.__in6_u.__u6_addr8, domain, 16);
|
||||
}
|
||||
|
||||
void ip_addr_any(IpAddr* addr) {
|
||||
addr->type = V4;
|
||||
addr->data.v4.s_addr = htonl(INADDR_ANY);
|
||||
}
|
||||
|
||||
void ip_addr_any6(IpAddr* addr) {
|
||||
addr->type = V6;
|
||||
addr->data.v6 = in6addr_any;
|
||||
}
|
||||
|
||||
static struct sockaddr_in create_socket_addr_v4(IpAddr addr, uint16_t port) {
|
||||
struct sockaddr_in socketaddr;
|
||||
memset(&socketaddr, 0, sizeof(socketaddr));
|
||||
socketaddr.sin_family = AF_INET;
|
||||
socketaddr.sin_port = htons(port);
|
||||
socketaddr.sin_addr = addr.data.v4;
|
||||
return socketaddr;
|
||||
}
|
||||
|
||||
static struct sockaddr_in6 create_socket_addr_v6(IpAddr addr, uint16_t port) {
|
||||
struct sockaddr_in6 socketaddr;
|
||||
memset(&socketaddr, 0, sizeof(socketaddr));
|
||||
socketaddr.sin6_family = AF_INET6;
|
||||
socketaddr.sin6_port = htons(port);
|
||||
socketaddr.sin6_addr = addr.data.v6;
|
||||
return socketaddr;
|
||||
}
|
||||
|
||||
static size_t get_addr_len(AddrType type) {
|
||||
if (type == V4) {
|
||||
return sizeof(struct sockaddr_in);
|
||||
} else if (type == V6) {
|
||||
return sizeof(struct sockaddr_in6);
|
||||
} else {
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
void create_socket_addr(uint16_t port, IpAddr addr, SocketAddr* socket) {
|
||||
socket->type = addr.type;
|
||||
if (addr.type == V4) {
|
||||
socket->data.v4 = create_socket_addr_v4(addr, port);
|
||||
} else if(addr.type == V6) {
|
||||
socket->data.v6 = create_socket_addr_v6(addr, port);
|
||||
} else {
|
||||
ERROR("Tried to create socketaddr with invalid protocol type");
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
socket->len = get_addr_len(addr.type);
|
||||
}
|
||||
|
||||
void print_socket_addr(SocketAddr* addr, char* buffer) {
|
||||
INIT_LOG_BOUNDS
|
||||
if(addr->type == V4) {
|
||||
APPEND(buffer, "%hhu.%hhu.%hhu.%hhu:%hu",
|
||||
(uint8_t) ((uint32_t)addr->data.v4.sin_addr.s_addr >> 24),
|
||||
(uint8_t) ((uint32_t)addr->data.v4.sin_addr.s_addr >> 16),
|
||||
(uint8_t) ((uint32_t)addr->data.v4.sin_addr.s_addr >> 8),
|
||||
(uint8_t) ((uint32_t)addr->data.v4.sin_addr.s_addr),
|
||||
addr->data.v4.sin_port
|
||||
);
|
||||
} else {
|
||||
for(int i = 0; i < 8; i++) {
|
||||
APPEND(buffer, "%02hhx%02hhx:",
|
||||
addr->data.v6.sin6_addr.__in6_u.__u6_addr8[i*2 + 0],
|
||||
addr->data.v6.sin6_addr.__in6_u.__u6_addr8[i*2 + 1]
|
||||
);
|
||||
}
|
||||
APPEND(buffer, ":[%hu]", addr->data.v6.sin6_port);
|
||||
}
|
||||
}
|
||||
|
||||
#define ADDR_DOMAIN(addr, var) \
|
||||
struct sockaddr* var; \
|
||||
if (addr->type == V4) { \
|
||||
var = (struct sockaddr*) &addr->data.v4; \
|
||||
} else if (addr->type == V6) { \
|
||||
var = (struct sockaddr*) &addr->data.v6; \
|
||||
} else { \
|
||||
return -1; \
|
||||
}
|
||||
|
||||
#define ADDR_AFNET(type, var) \
|
||||
int var; \
|
||||
if (type == V4) { \
|
||||
var = AF_INET; \
|
||||
} else if (type == V6) { \
|
||||
var = AF_INET6; \
|
||||
} else { \
|
||||
return -1; \
|
||||
}
|
||||
|
||||
int32_t create_udp_socket(AddrType type, UdpSocket* sock) {
|
||||
ADDR_AFNET(type, __domain)
|
||||
sock->type = type;
|
||||
sock->sockfd = socket(__domain, SOCK_DGRAM, 0);
|
||||
return sock->sockfd;
|
||||
}
|
||||
|
||||
int32_t bind_udp_socket(SocketAddr* addr, UdpSocket* sock) {
|
||||
if (addr->type == V6) {
|
||||
int v6OnlyEnabled = 0;
|
||||
int32_t res = setsockopt(
|
||||
sock->sockfd,
|
||||
IPPROTO_IPV6,
|
||||
IPV6_V6ONLY,
|
||||
&v6OnlyEnabled,
|
||||
sizeof(v6OnlyEnabled)
|
||||
);
|
||||
if (res < 0) return res;
|
||||
}
|
||||
ADDR_DOMAIN(addr, __addr)
|
||||
return bind(sock->sockfd, __addr, addr->len);
|
||||
}
|
||||
|
||||
int32_t read_udp_socket(UdpSocket* socket, void* buffer, uint16_t len, SocketAddr* clientaddr) {
|
||||
clientaddr->type = socket->type;
|
||||
clientaddr->len = get_addr_len(socket->type);
|
||||
ADDR_DOMAIN(clientaddr, __addr)
|
||||
return recvfrom(
|
||||
socket->sockfd,
|
||||
buffer,
|
||||
(size_t) len,
|
||||
MSG_WAITALL,
|
||||
__addr,
|
||||
(uint32_t*) &clientaddr->len
|
||||
);
|
||||
}
|
||||
|
||||
int32_t write_udp_socket(UdpSocket* socket, void* buffer, uint16_t len, SocketAddr* clientaddr) {
|
||||
ADDR_DOMAIN(clientaddr, __addr)
|
||||
return sendto(
|
||||
socket->sockfd,
|
||||
buffer,
|
||||
(size_t) len,
|
||||
MSG_CONFIRM,
|
||||
__addr,
|
||||
(uint32_t) clientaddr->len
|
||||
);
|
||||
}
|
||||
|
||||
int32_t close_udp_socket(UdpSocket* socket) {
|
||||
return close(socket->sockfd);
|
||||
}
|
||||
|
||||
int32_t create_tcp_socket(AddrType type, TcpSocket* sock) {
|
||||
ADDR_AFNET(type, __domain)
|
||||
sock->type = type;
|
||||
sock->sockfd = socket(__domain, SOCK_STREAM, 0);
|
||||
return sock->sockfd;
|
||||
}
|
||||
|
||||
int32_t bind_tcp_socket(SocketAddr* addr, TcpSocket* sock) {
|
||||
if (addr->type == V6) {
|
||||
int v6OnlyEnabled = 0;
|
||||
int32_t res = setsockopt(
|
||||
sock->sockfd,
|
||||
IPPROTO_IPV6,
|
||||
IPV6_V6ONLY,
|
||||
&v6OnlyEnabled,
|
||||
sizeof(v6OnlyEnabled)
|
||||
);
|
||||
if (res < 0) return res;
|
||||
}
|
||||
ADDR_DOMAIN(addr, __addr)
|
||||
return bind(sock->sockfd, __addr, addr->len);
|
||||
}
|
||||
|
||||
int32_t listen_tcp_socket(TcpSocket* socket, uint32_t max) {
|
||||
return listen(socket->sockfd, max);
|
||||
}
|
||||
|
||||
int32_t accept_tcp_socket(TcpSocket* socket, TcpStream* stream) {
|
||||
stream->clientaddr.type = socket->type;
|
||||
memset(&stream->clientaddr, 0, sizeof(SocketAddr));
|
||||
SocketAddr* addr = &stream->clientaddr;
|
||||
ADDR_DOMAIN(addr, __addr)
|
||||
stream->streamfd = accept(
|
||||
socket->sockfd,
|
||||
__addr,
|
||||
(uint32_t*) &stream->clientaddr.len
|
||||
);
|
||||
return stream->streamfd;
|
||||
}
|
||||
|
||||
int32_t close_tcp_socket(TcpSocket* socket) {
|
||||
return close(socket->sockfd);
|
||||
}
|
||||
|
||||
int32_t connect_tcp_stream(SocketAddr* servaddr, TcpStream* stream) {
|
||||
TcpSocket socket;
|
||||
int32_t res = create_tcp_socket(servaddr->type, &socket);
|
||||
if (res < 0) return res;
|
||||
stream->clientaddr = *servaddr;
|
||||
stream->streamfd = socket.sockfd;
|
||||
ADDR_DOMAIN(servaddr, __addr)
|
||||
return connect(
|
||||
socket.sockfd,
|
||||
__addr,
|
||||
servaddr->len
|
||||
);
|
||||
}
|
||||
|
||||
int32_t read_tcp_stream(TcpStream* stream, void* buffer, uint16_t len) {
|
||||
return recv(stream->streamfd, buffer, len, 0);
|
||||
}
|
||||
|
||||
int32_t write_tcp_stream(TcpStream* stream, void* buffer, uint16_t len) {
|
||||
return send(stream->streamfd, buffer, len, MSG_NOSIGNAL);
|
||||
}
|
||||
|
||||
int32_t close_tcp_stream(TcpStream* stream) {
|
||||
return close(stream->streamfd);
|
||||
}
|
69
src/server/addr.h
Normal file
69
src/server/addr.h
Normal file
|
@ -0,0 +1,69 @@
|
|||
#pragma once
|
||||
|
||||
#include "../packet/record.h"
|
||||
|
||||
#include <string.h>
|
||||
#include <netinet/in.h>
|
||||
#include <sys/socket.h>
|
||||
#include <sys/types.h>
|
||||
|
||||
typedef enum {
|
||||
V4,
|
||||
V6
|
||||
} AddrType;
|
||||
|
||||
typedef struct {
|
||||
AddrType type;
|
||||
union {
|
||||
struct in_addr v4;
|
||||
struct in6_addr v6;
|
||||
} data;
|
||||
} IpAddr;
|
||||
|
||||
void create_ip_addr(char* domain, IpAddr* addr);
|
||||
void create_ip_addr6(char* domain, IpAddr* addr);
|
||||
void ip_addr_any(IpAddr* addr);
|
||||
void ip_addr_any6(IpAddr* addr);
|
||||
|
||||
typedef struct {
|
||||
AddrType type;
|
||||
union {
|
||||
struct sockaddr_in v4;
|
||||
struct sockaddr_in6 v6;
|
||||
} data;
|
||||
size_t len;
|
||||
} SocketAddr;
|
||||
|
||||
void create_socket_addr(uint16_t port, IpAddr addr, SocketAddr* socket);
|
||||
void print_socket_addr(SocketAddr* addr, char* buffer);
|
||||
|
||||
typedef struct {
|
||||
AddrType type;
|
||||
uint32_t sockfd;
|
||||
} UdpSocket;
|
||||
|
||||
int32_t create_udp_socket(AddrType type, UdpSocket* socket);
|
||||
int32_t bind_udp_socket(SocketAddr* addr, UdpSocket* socket);
|
||||
int32_t read_udp_socket(UdpSocket* socket, void* buffer, uint16_t len, SocketAddr* clientaddr);
|
||||
int32_t write_udp_socket(UdpSocket* socket, void* buffer, uint16_t len, SocketAddr* clientaddr);
|
||||
int32_t close_udp_socket(UdpSocket* socket);
|
||||
|
||||
typedef struct {
|
||||
AddrType type;
|
||||
uint32_t sockfd;
|
||||
} TcpSocket;
|
||||
|
||||
typedef struct {
|
||||
SocketAddr clientaddr;
|
||||
uint32_t streamfd;
|
||||
} TcpStream;
|
||||
|
||||
int32_t create_tcp_socket(AddrType type, TcpSocket* socket);
|
||||
int32_t bind_tcp_socket(SocketAddr* addr, TcpSocket* socket);
|
||||
int32_t listen_tcp_socket(TcpSocket* socket, uint32_t max);
|
||||
int32_t accept_tcp_socket(TcpSocket* socket, TcpStream* stream);
|
||||
int32_t close_tcp_socket(TcpSocket* socket);
|
||||
int32_t connect_tcp_stream(SocketAddr* servaddr, TcpStream* stream);
|
||||
int32_t read_tcp_stream(TcpStream* stream, void* buffer, uint16_t len);
|
||||
int32_t write_tcp_stream(TcpStream* stream, void* buffer, uint16_t len);
|
||||
int32_t close_tcp_stream(TcpStream* stream);
|
245
src/server/binding.c
Normal file
245
src/server/binding.c
Normal file
|
@ -0,0 +1,245 @@
|
|||
#include <netinet/in.h>
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
#include <unistd.h>
|
||||
#include <errno.h>
|
||||
|
||||
#include "addr.h"
|
||||
#include "binding.h"
|
||||
#include "../io/log.h"
|
||||
|
||||
static void create_udp_binding(UdpSocket* socket, uint16_t port) {
|
||||
if (create_udp_socket(V6, socket) < 0) {
|
||||
ERROR("Failed to create UDP socket: %s", strerror(errno));
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
|
||||
IpAddr addr;
|
||||
ip_addr_any6(&addr);
|
||||
|
||||
SocketAddr socketaddr;
|
||||
create_socket_addr(port, addr, &socketaddr);
|
||||
|
||||
if (bind_udp_socket(&socketaddr, socket) < 0) {
|
||||
ERROR("Failed to bind UDP socket on port %hu: %s", port, strerror(errno));
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
}
|
||||
|
||||
static void create_tcp_binding(TcpSocket* socket, uint16_t port) {
|
||||
if (create_tcp_socket(V6, socket) < 0) {
|
||||
ERROR("Failed to create TCP socket: %s", strerror(errno));
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
|
||||
IpAddr addr;
|
||||
ip_addr_any6(&addr);
|
||||
|
||||
SocketAddr socketaddr;
|
||||
create_socket_addr(port, addr, &socketaddr);
|
||||
|
||||
if (bind_tcp_socket(&socketaddr, socket) < 0) {
|
||||
ERROR("Failed to bind TCP socket on port %hu: %s", port, strerror(errno));
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
|
||||
if (listen_tcp_socket(socket, 5) < 0) {
|
||||
ERROR("Failed to listen on TCP socket: %s", strerror(errno));
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
}
|
||||
|
||||
void create_binding(BindingType type, uint16_t port, Binding* binding) {
|
||||
binding->type = type;
|
||||
if (type == UDP) {
|
||||
create_udp_binding(&binding->sock.udp, port);
|
||||
} else if(type == TCP) {
|
||||
create_tcp_binding(&binding->sock.tcp, port);
|
||||
} else {
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
}
|
||||
|
||||
void free_binding(Binding* binding) {
|
||||
if (binding->type == UDP) {
|
||||
close_udp_socket(&binding->sock.udp);
|
||||
} else if(binding->type == TCP) {
|
||||
close_tcp_socket(&binding->sock.tcp);
|
||||
}
|
||||
}
|
||||
|
||||
bool accept_connection(Binding* binding, Connection* connection) {
|
||||
connection->type = binding->type;
|
||||
|
||||
if(binding->type == UDP) {
|
||||
connection->sock.udp.udp = binding->sock.udp;
|
||||
memset(&connection->sock.udp.clientaddr, 0, sizeof(SocketAddr));
|
||||
return true;
|
||||
}
|
||||
|
||||
if (accept_tcp_socket(&binding->sock.tcp, &connection->sock.tcp) < 0) {
|
||||
ERROR("Failed to accept TCP connection: %s", strerror(errno));
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static void read_to_packet(uint8_t* buf, uint16_t len, Packet* packet) {
|
||||
PacketBuffer* pkbuffer = buffer_create(len);
|
||||
for (int i = 0; i < len; i++) {
|
||||
buffer_write(pkbuffer, buf[i]);
|
||||
}
|
||||
buffer_seek(pkbuffer, 0);
|
||||
read_packet(pkbuffer, packet);
|
||||
buffer_free(pkbuffer);
|
||||
}
|
||||
|
||||
static bool read_udp(Connection* connection, Packet* packet) {
|
||||
uint8_t buffer[512];
|
||||
int32_t n = read_udp_socket(
|
||||
&connection->sock.udp.udp,
|
||||
buffer,
|
||||
512,
|
||||
&connection->sock.udp.clientaddr
|
||||
);
|
||||
if (n < 0) {
|
||||
return false;
|
||||
}
|
||||
read_to_packet(buffer, n, packet);
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool read_tcp(Connection* connection, Packet* packet) {
|
||||
uint16_t len;
|
||||
if ( read_tcp_stream(
|
||||
&connection->sock.tcp,
|
||||
&len,
|
||||
sizeof(uint16_t)
|
||||
) < 0) {
|
||||
return false;
|
||||
}
|
||||
|
||||
uint8_t buffer[len];
|
||||
if ( read_tcp_stream(
|
||||
&connection->sock.tcp,
|
||||
buffer,
|
||||
len
|
||||
) < 0) {
|
||||
return false;
|
||||
}
|
||||
|
||||
read_to_packet(buffer, len, packet);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool read_connection(Connection* connection, Packet* packet) {
|
||||
if (connection->type == UDP) {
|
||||
return read_udp(connection, packet);
|
||||
} else if (connection->type == TCP) {
|
||||
return read_tcp(connection, packet);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
static bool write_udp(Connection* connection, uint8_t* buf, uint16_t len) {
|
||||
//if (len > 512) {
|
||||
buf[2] = buf[2] | 0x03;
|
||||
// len = 512;
|
||||
// }
|
||||
return write_udp_socket(
|
||||
&connection->sock.udp.udp,
|
||||
buf,
|
||||
len,
|
||||
&connection->sock.udp.clientaddr
|
||||
) == len;
|
||||
}
|
||||
|
||||
static bool write_tcp(Connection* connection, uint8_t* buf, uint16_t len) {
|
||||
len = htons(len);
|
||||
if (write_tcp_stream(
|
||||
&connection->sock.tcp,
|
||||
&len,
|
||||
sizeof(uint16_t)
|
||||
) < 0) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (write_tcp_stream(
|
||||
&connection->sock.tcp,
|
||||
buf,
|
||||
len
|
||||
) < 0) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool write_connection(Connection* connection, Packet* packet) {
|
||||
PacketBuffer* pkbuffer = buffer_create(64);
|
||||
write_packet(pkbuffer, packet);
|
||||
uint16_t len = buffer_get_size(pkbuffer);
|
||||
uint8_t* buffer = buffer_get_ptr(pkbuffer);
|
||||
bool success = false;
|
||||
if(connection->type == UDP) {
|
||||
success = write_udp(connection, buffer, len);
|
||||
} else if(connection->type == TCP) {
|
||||
success = write_tcp(connection, buffer, len);
|
||||
};
|
||||
buffer_free(pkbuffer);
|
||||
return success;
|
||||
}
|
||||
|
||||
void free_connection(Connection* connection) {
|
||||
if (connection->type == TCP) {
|
||||
close_tcp_stream(&connection->sock.tcp);
|
||||
}
|
||||
}
|
||||
|
||||
static bool create_udp_request(SocketAddr* addr, Connection* request) {
|
||||
if ( create_udp_socket(addr->type, &request->sock.udp.udp) < 0) {
|
||||
ERROR("Failed to connect to UDP socket: %s", strerror(errno));
|
||||
return false;
|
||||
}
|
||||
request->sock.udp.clientaddr = *addr;
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool create_tcp_request(SocketAddr* addr, Connection* request) {
|
||||
if( connect_tcp_stream(addr, &request->sock.tcp) < 0) {
|
||||
ERROR("Failed to connect to TCP socket: %s", strerror(errno));
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool create_request(BindingType type, SocketAddr* addr, Connection* request) {
|
||||
request->type = type;
|
||||
if (type == UDP) {
|
||||
return create_udp_request(addr, request);
|
||||
} else if (type == TCP) {
|
||||
return create_tcp_request(addr, request);
|
||||
} else {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
bool request_packet(Connection* request, Packet* in, Packet* out) {
|
||||
if (!write_connection(request, in)) {
|
||||
return false;
|
||||
}
|
||||
if (!read_connection(request, out)) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void free_request(Connection* connection) {
|
||||
if (connection->type == UDP) {
|
||||
close_udp_socket(&connection->sock.udp.udp);
|
||||
} else if (connection->type == TCP) {
|
||||
close_tcp_stream(&connection->sock.tcp);
|
||||
}
|
||||
}
|
42
src/server/binding.h
Normal file
42
src/server/binding.h
Normal file
|
@ -0,0 +1,42 @@
|
|||
#pragma once
|
||||
|
||||
#include "../packet/packet.h"
|
||||
#include "addr.h"
|
||||
|
||||
#include <netinet/in.h>
|
||||
|
||||
typedef enum {
|
||||
UDP,
|
||||
TCP
|
||||
} BindingType;
|
||||
|
||||
typedef struct {
|
||||
BindingType type;
|
||||
union {
|
||||
UdpSocket udp;
|
||||
TcpSocket tcp;
|
||||
} sock;
|
||||
} Binding;
|
||||
|
||||
void create_binding(BindingType type, uint16_t port, Binding* binding);
|
||||
void free_binding(Binding* binding);
|
||||
|
||||
typedef struct {
|
||||
BindingType type;
|
||||
union {
|
||||
struct {
|
||||
UdpSocket udp;
|
||||
SocketAddr clientaddr;
|
||||
} udp;
|
||||
TcpStream tcp;
|
||||
} sock;
|
||||
} Connection;
|
||||
|
||||
bool accept_connection(Binding* binding, Connection* connection);
|
||||
bool read_connection(Connection* connection, Packet* packet);
|
||||
bool write_connection(Connection* connection, Packet* packet);
|
||||
void free_connection(Connection* connection);
|
||||
|
||||
bool create_request(BindingType type, SocketAddr* addr, Connection* request);
|
||||
bool request_packet(Connection* request, Packet* in, Packet* out);
|
||||
void free_request(Connection* connection);
|
166
src/server/resolver.c
Normal file
166
src/server/resolver.c
Normal file
|
@ -0,0 +1,166 @@
|
|||
#include <errno.h>
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
|
||||
#include "resolver.h"
|
||||
#include "addr.h"
|
||||
#include "binding.h"
|
||||
#include "../io/log.h"
|
||||
|
||||
static bool lookup(
|
||||
Question* question,
|
||||
Packet* response,
|
||||
BindingType type,
|
||||
SocketAddr addr
|
||||
) {
|
||||
INIT_LOG_BUFFER(log)
|
||||
LOGONLY(print_socket_addr(&addr, log);)
|
||||
TRACE("Attempting lookup on fallback dns %s", log);
|
||||
|
||||
Connection request;
|
||||
if (!create_request(type, &addr, &request)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
Packet req;
|
||||
memset(&req, 0, sizeof(Packet));
|
||||
req.header.id = response->header.id;
|
||||
req.header.opcode = response->header.opcode;
|
||||
req.header.questions = 1;
|
||||
req.header.recursion_desired = true;
|
||||
req.questions = malloc(sizeof(Question));
|
||||
req.questions[0] = *question;
|
||||
|
||||
if (!request_packet(&request, &req, response)) {
|
||||
free_request(&request);
|
||||
free(req.questions);
|
||||
ERROR("Failed to request fallback dns: %s", strerror(errno));
|
||||
return false;
|
||||
}
|
||||
|
||||
free_request(&request);
|
||||
free(req.questions);
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool search(Question* question, Packet* result, BindingType type) {
|
||||
IpAddr addr;
|
||||
char ip[4] = {1, 1, 1, 1};
|
||||
create_ip_addr(ip, &addr);
|
||||
|
||||
uint16_t port = 53;
|
||||
SocketAddr saddr;
|
||||
create_socket_addr(port, addr, &saddr);
|
||||
|
||||
while(1) {
|
||||
if (!lookup(question, result, type, saddr)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (result->header.answers > 0 && result->header.rescode == NOERROR) {
|
||||
return true;
|
||||
}
|
||||
|
||||
if (result->header.rescode == NXDOMAIN) {
|
||||
return true;
|
||||
}
|
||||
|
||||
if (get_resolved_ns(result, question->domain, &addr)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
Question new_question;
|
||||
if (!get_unresoled_ns(result, question->domain, &new_question)) {
|
||||
return true;
|
||||
}
|
||||
|
||||
Packet recurse;
|
||||
if (!search(&new_question, &recurse, type)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
free_question(&new_question);
|
||||
|
||||
IpAddr random;
|
||||
if (!get_random_a(&recurse, &random)) {
|
||||
free_packet(&recurse);
|
||||
return true;
|
||||
} else {
|
||||
free_packet(&recurse);
|
||||
addr = random;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void push_records(Record* from, uint8_t from_len, Record** to, uint8_t to_len) {
|
||||
if(from_len < 1) return;
|
||||
*to = realloc(*to, sizeof(Record) * (from_len + to_len));
|
||||
memcpy(*to + to_len, from, from_len * sizeof(Record));
|
||||
}
|
||||
|
||||
static void push_questions(Question* from, uint8_t from_len, Question** to, uint8_t to_len) {
|
||||
if(from_len < 1) return;
|
||||
*to = realloc(*to, sizeof(Question) * (from_len + to_len));
|
||||
memcpy(*to + to_len, from, from_len * sizeof(Question));
|
||||
}
|
||||
|
||||
void handle_query(Packet* request, Packet* response, BindingType type) {
|
||||
memset(response, 0, sizeof(Packet));
|
||||
response->header.id = request->header.id;
|
||||
response->header.opcode = request->header.opcode;
|
||||
response->header.recursion_desired = true;
|
||||
response->header.recursion_available = true;
|
||||
response->header.response = true;
|
||||
|
||||
if (request->header.questions < 1) {
|
||||
response->header.response = FORMERR;
|
||||
return;
|
||||
}
|
||||
|
||||
for (uint16_t i = 0; i < request->header.questions; i++) {
|
||||
Packet result;
|
||||
memset(&result, 0, sizeof(Packet));
|
||||
result.header.id = response->header.id;
|
||||
if (!search(&request->questions[i], &result, type)) {
|
||||
response->header.response = SERVFAIL;
|
||||
break;
|
||||
}
|
||||
|
||||
push_questions(
|
||||
result.questions,
|
||||
result.header.questions,
|
||||
&response->questions,
|
||||
response->header.questions
|
||||
);
|
||||
response->header.questions += result.header.questions;
|
||||
|
||||
push_records(
|
||||
result.answers,
|
||||
result.header.answers,
|
||||
&response->answers,
|
||||
response->header.answers
|
||||
);
|
||||
response->header.answers += result.header.answers;
|
||||
|
||||
push_records(
|
||||
result.authorities,
|
||||
result.header.authoritative_entries,
|
||||
&response->authorities,
|
||||
response->header.authoritative_entries
|
||||
);
|
||||
response->header.authoritative_entries += result.header.authoritative_entries;
|
||||
|
||||
push_records(
|
||||
result.resources,
|
||||
result.header.resource_entries,
|
||||
&response->resources,
|
||||
response->header.resource_entries
|
||||
);
|
||||
response->header.resource_entries += result.header.resource_entries;
|
||||
|
||||
free(result.questions);
|
||||
free(result.answers);
|
||||
free(result.authorities);
|
||||
free(result.resources);
|
||||
}
|
||||
}
|
6
src/server/resolver.h
Normal file
6
src/server/resolver.h
Normal file
|
@ -0,0 +1,6 @@
|
|||
#pragma once
|
||||
|
||||
#include "../packet/packet.h"
|
||||
#include "binding.h"
|
||||
|
||||
void handle_query(Packet* request, Packet* response, BindingType type);
|
100
src/server/server.c
Normal file
100
src/server/server.c
Normal file
|
@ -0,0 +1,100 @@
|
|||
#define _POSIX_SOURCE
|
||||
#include <errno.h>
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <unistd.h>
|
||||
#include <sys/wait.h>
|
||||
#include <signal.h>
|
||||
|
||||
#include "addr.h"
|
||||
#include "server.h"
|
||||
#include "resolver.h"
|
||||
#include "../io/log.h"
|
||||
|
||||
static pid_t udp, tcp;
|
||||
|
||||
void server_init(uint16_t port, Server* server) {
|
||||
INFO("Server port set to %hu", port);
|
||||
create_binding(UDP, port, &server->udp);
|
||||
create_binding(TCP, port, &server->tcp);
|
||||
}
|
||||
|
||||
static void server_listen(Binding* binding) {
|
||||
while(1) {
|
||||
|
||||
Connection connection;
|
||||
if (!accept_connection(binding, &connection)) {
|
||||
ERROR("Failed to accept connection");
|
||||
continue;
|
||||
}
|
||||
|
||||
Packet request;
|
||||
if (!read_connection(&connection, &request)) {
|
||||
ERROR("Failed to read connection");
|
||||
free_connection(&connection);
|
||||
continue;
|
||||
}
|
||||
|
||||
if(fork() != 0) {
|
||||
free_packet(&request);
|
||||
free_connection(&connection);
|
||||
continue;
|
||||
}
|
||||
|
||||
INFO("Recieved packet request ID %hu", request.header.id);
|
||||
|
||||
Packet response;
|
||||
handle_query(&request, &response, connection.type);
|
||||
|
||||
if (!write_connection(&connection, &response)) {
|
||||
ERROR("Failed to respond to connection ID %hu: %s", request.header.id, strerror(errno));
|
||||
}
|
||||
|
||||
free_packet(&request);
|
||||
free_packet(&response);
|
||||
free_connection(&connection);
|
||||
exit(EXIT_SUCCESS);
|
||||
}
|
||||
}
|
||||
|
||||
static void signal_handler() {
|
||||
printf("\n");
|
||||
kill(udp, SIGTERM);
|
||||
kill(tcp, SIGTERM);
|
||||
}
|
||||
|
||||
void server_run(Server* server) {
|
||||
if ((udp = fork()) == 0) {
|
||||
INFO("Listening for connections on UDP");
|
||||
server_listen(&server->udp);
|
||||
exit(EXIT_SUCCESS);
|
||||
}
|
||||
|
||||
if ((tcp = fork()) == 0) {
|
||||
INFO("Listening for connections on TCP");
|
||||
server_listen(&server->tcp);
|
||||
exit(EXIT_SUCCESS);
|
||||
}
|
||||
|
||||
signal(SIGINT, signal_handler);
|
||||
|
||||
int status;
|
||||
waitpid(udp, &status, 0);
|
||||
if (status == 0) {
|
||||
INFO("UDP process closed successfully");
|
||||
} else {
|
||||
ERROR("UDP process failed with error code %d", status);
|
||||
}
|
||||
|
||||
waitpid(tcp, &status, 0);
|
||||
if (status == 0) {
|
||||
INFO("TCP process closed successfully");
|
||||
} else {
|
||||
ERROR("TCP process failed with error code %d", status);
|
||||
}
|
||||
}
|
||||
|
||||
void server_free(Server* server) {
|
||||
free_binding(&server->udp);
|
||||
free_binding(&server->tcp);
|
||||
}
|
12
src/server/server.h
Normal file
12
src/server/server.h
Normal file
|
@ -0,0 +1,12 @@
|
|||
#pragma once
|
||||
|
||||
#include "binding.h"
|
||||
|
||||
typedef struct {
|
||||
Binding udp;
|
||||
Binding tcp;
|
||||
} Server;
|
||||
|
||||
void server_init(uint16_t port, Server* server);
|
||||
void server_run(Server* server);
|
||||
void server_free(Server* server);
|
156
src/web/api.rs
156
src/web/api.rs
|
@ -1,156 +0,0 @@
|
|||
use std::net::IpAddr;
|
||||
|
||||
use axum::{
|
||||
extract::Query,
|
||||
response::Response,
|
||||
routing::{get, post, put, delete},
|
||||
Extension, Router,
|
||||
};
|
||||
use moka::future::Cache;
|
||||
use rand::distributions::{Alphanumeric, DistString};
|
||||
use serde::Deserialize;
|
||||
use tower_cookies::{Cookie, Cookies};
|
||||
|
||||
use crate::{config::Config, database::Database, dns::packet::record::DnsRecord};
|
||||
|
||||
use super::{
|
||||
extract::{Authorized, Body, RequestIp},
|
||||
http::{json, text},
|
||||
};
|
||||
|
||||
pub fn router() -> Router {
|
||||
Router::new()
|
||||
.route("/login", post(login))
|
||||
.route("/domains", get(list_domains))
|
||||
.route("/domains", delete(delete_domain))
|
||||
.route("/records", get(get_domain))
|
||||
.route("/records", put(add_record))
|
||||
}
|
||||
|
||||
async fn list_domains(_: Authorized, Extension(database): Extension<Database>) -> Response {
|
||||
let domains = match database.get_domains().await {
|
||||
Ok(domains) => domains,
|
||||
Err(err) => return text(500, &format!("{err}")),
|
||||
};
|
||||
|
||||
let Ok(domains) = serde_json::to_string(&domains) else {
|
||||
return text(500, "Failed to fetch domains")
|
||||
};
|
||||
|
||||
json(200, &domains)
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct DomainRequest {
|
||||
domain: String,
|
||||
}
|
||||
|
||||
async fn get_domain(
|
||||
_: Authorized,
|
||||
Extension(database): Extension<Database>,
|
||||
Query(query): Query<DomainRequest>,
|
||||
) -> Response {
|
||||
let records = match database.get_domain(&query.domain).await {
|
||||
Ok(records) => records,
|
||||
Err(err) => return text(500, &format!("{err}")),
|
||||
};
|
||||
|
||||
let Ok(records) = serde_json::to_string(&records) else {
|
||||
return text(500, "Failed to fetch records")
|
||||
};
|
||||
|
||||
json(200, &records)
|
||||
}
|
||||
|
||||
async fn delete_domain(
|
||||
_: Authorized,
|
||||
Extension(database): Extension<Database>,
|
||||
Body(body): Body,
|
||||
) -> Response {
|
||||
|
||||
let Ok(request) = serde_json::from_str::<DomainRequest>(&body) else {
|
||||
return text(400, "Missing request parameters")
|
||||
};
|
||||
|
||||
let Ok(domains) = database.get_domains().await else {
|
||||
return text(500, "Failed to delete domain")
|
||||
};
|
||||
|
||||
if !domains.contains(&request.domain) {
|
||||
return text(400, "Domain does not exist")
|
||||
}
|
||||
|
||||
if database.delete_domain(request.domain).await.is_err() {
|
||||
return text(500, "Failed to delete domain")
|
||||
};
|
||||
|
||||
return text(204, "Successfully deleted domain")
|
||||
}
|
||||
|
||||
async fn add_record(
|
||||
_: Authorized,
|
||||
Extension(database): Extension<Database>,
|
||||
Body(body): Body,
|
||||
) -> Response {
|
||||
let Ok(record) = serde_json::from_str::<DnsRecord>(&body) else {
|
||||
return text(400, "Invalid DNS record")
|
||||
};
|
||||
|
||||
let allowed = record.get_qtype().allowed_actions();
|
||||
if !allowed.1 {
|
||||
return text(400, "Not allowed to create record")
|
||||
}
|
||||
|
||||
let Ok(records) = database.get_records(&record.get_domain(), record.get_qtype()).await else {
|
||||
return text(500, "Failed to complete record check");
|
||||
};
|
||||
|
||||
if !records.is_empty() && !allowed.0 {
|
||||
return text(400, "Not allowed to create duplicate record")
|
||||
};
|
||||
|
||||
if records.contains(&record) {
|
||||
return text(400, "Not allowed to create duplicate record")
|
||||
}
|
||||
|
||||
if let Err(err) = database.add_record(record).await {
|
||||
return text(500, &format!("{err}"));
|
||||
}
|
||||
|
||||
return text(201, "Added record to database successfully");
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct LoginRequest {
|
||||
user: String,
|
||||
pass: String,
|
||||
}
|
||||
|
||||
async fn login(
|
||||
Extension(config): Extension<Config>,
|
||||
Extension(cache): Extension<Cache<String, IpAddr>>,
|
||||
RequestIp(ip): RequestIp,
|
||||
cookies: Cookies,
|
||||
Body(body): Body,
|
||||
) -> Response {
|
||||
let Ok(request) = serde_json::from_str::<LoginRequest>(&body) else {
|
||||
return text(400, "Missing request parameters")
|
||||
};
|
||||
|
||||
if request.user != config.web_user || request.pass != config.web_pass {
|
||||
return text(400, "Invalid credentials");
|
||||
};
|
||||
|
||||
let token = Alphanumeric.sample_string(&mut rand::thread_rng(), 128);
|
||||
|
||||
cache.insert(token.clone(), ip).await;
|
||||
|
||||
let mut cookie = Cookie::new("auth", token);
|
||||
cookie.set_secure(true);
|
||||
cookie.set_http_only(true);
|
||||
cookie.set_path("/");
|
||||
|
||||
cookies.add(cookie);
|
||||
|
||||
text(200, "Successfully logged in")
|
||||
}
|
|
@ -1,139 +0,0 @@
|
|||
use std::{
|
||||
io::Read,
|
||||
net::{IpAddr, SocketAddr},
|
||||
};
|
||||
|
||||
use axum::{
|
||||
async_trait,
|
||||
body::HttpBody,
|
||||
extract::{ConnectInfo, FromRequest, FromRequestParts},
|
||||
http::{request::Parts, Request},
|
||||
response::Response,
|
||||
BoxError,
|
||||
};
|
||||
use bytes::Bytes;
|
||||
use moka::future::Cache;
|
||||
use tower_cookies::Cookies;
|
||||
|
||||
use super::http::text;
|
||||
|
||||
pub struct Authorized;
|
||||
|
||||
#[async_trait]
|
||||
impl<S> FromRequestParts<S> for Authorized
|
||||
where
|
||||
S: Send + Sync,
|
||||
{
|
||||
type Rejection = Response;
|
||||
|
||||
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
|
||||
let Ok(Some(cookies)) = Option::<Cookies>::from_request_parts(parts, state).await else {
|
||||
return Err(text(403, "No cookies provided"))
|
||||
};
|
||||
|
||||
let Some(token) = cookies.get("auth") else {
|
||||
return Err(text(403, "No auth token provided"))
|
||||
};
|
||||
|
||||
let auth_ip: IpAddr;
|
||||
{
|
||||
let Some(cache) = parts.extensions.get::<Cache<String, IpAddr>>() else {
|
||||
return Err(text(500, "Failed to load auth store"))
|
||||
};
|
||||
|
||||
let Some(ip) = cache.get(token.value()) else {
|
||||
return Err(text(401, "Unauthorized"))
|
||||
};
|
||||
|
||||
auth_ip = ip
|
||||
}
|
||||
|
||||
let Ok(Some(RequestIp(ip))) = Option::<RequestIp>::from_request_parts(parts, state).await else {
|
||||
return Err(text(403, "You have no ip"))
|
||||
};
|
||||
|
||||
if auth_ip != ip {
|
||||
return Err(text(403, "Auth token does not match current ip"));
|
||||
}
|
||||
|
||||
Ok(Self)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct RequestIp(pub IpAddr);
|
||||
|
||||
#[async_trait]
|
||||
impl<S> FromRequestParts<S> for RequestIp
|
||||
where
|
||||
S: Send + Sync,
|
||||
{
|
||||
type Rejection = Response;
|
||||
|
||||
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
|
||||
let headers = &parts.headers;
|
||||
|
||||
let forwardedfor = headers
|
||||
.get("x-forwarded-for")
|
||||
.and_then(|h| h.to_str().ok())
|
||||
.and_then(|h| {
|
||||
h.split(',')
|
||||
.rev()
|
||||
.find_map(|s| s.trim().parse::<IpAddr>().ok())
|
||||
});
|
||||
|
||||
if let Some(forwardedfor) = forwardedfor {
|
||||
return Ok(Self(forwardedfor));
|
||||
}
|
||||
|
||||
let realip = headers
|
||||
.get("x-real-ip")
|
||||
.and_then(|hv| hv.to_str().ok())
|
||||
.and_then(|s| s.parse::<IpAddr>().ok());
|
||||
|
||||
if let Some(realip) = realip {
|
||||
return Ok(Self(realip));
|
||||
}
|
||||
|
||||
let realip = headers
|
||||
.get("x-real-ip")
|
||||
.and_then(|hv| hv.to_str().ok())
|
||||
.and_then(|s| s.parse::<IpAddr>().ok());
|
||||
|
||||
if let Some(realip) = realip {
|
||||
return Ok(Self(realip));
|
||||
}
|
||||
|
||||
let info = parts.extensions.get::<ConnectInfo<SocketAddr>>();
|
||||
|
||||
if let Some(info) = info {
|
||||
return Ok(Self(info.0.ip()));
|
||||
}
|
||||
|
||||
Err(text(403, "You have no ip"))
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Body(pub String);
|
||||
|
||||
#[async_trait]
|
||||
impl<S, B> FromRequest<S, B> for Body
|
||||
where
|
||||
B: HttpBody + Sync + Send + 'static,
|
||||
B::Data: Send,
|
||||
B::Error: Into<BoxError>,
|
||||
S: Send + Sync,
|
||||
{
|
||||
type Rejection = Response;
|
||||
|
||||
async fn from_request(req: Request<B>, state: &S) -> Result<Self, Self::Rejection> {
|
||||
let Ok(bytes) = Bytes::from_request(req, state).await else {
|
||||
return Err(text(413, "Payload too large"));
|
||||
};
|
||||
|
||||
let Ok(body) = String::from_utf8(bytes.bytes().flatten().collect()) else {
|
||||
return Err(text(400, "Invalid utf8 body"))
|
||||
};
|
||||
|
||||
Ok(Self(body))
|
||||
}
|
||||
}
|
|
@ -1,31 +0,0 @@
|
|||
use axum::{extract::Path, response::Response};
|
||||
|
||||
use super::http::serve;
|
||||
|
||||
pub async fn js(Path(path): Path<String>) -> Response {
|
||||
let path = format!("/js/{path}");
|
||||
serve(&path).await
|
||||
}
|
||||
|
||||
pub async fn css(Path(path): Path<String>) -> Response {
|
||||
let path = format!("/css/{path}");
|
||||
serve(&path).await
|
||||
}
|
||||
|
||||
pub async fn fonts(Path(path): Path<String>) -> Response {
|
||||
let path = format!("/fonts/{path}");
|
||||
serve(&path).await
|
||||
}
|
||||
|
||||
pub async fn image(Path(path): Path<String>) -> Response {
|
||||
let path = format!("/image/{path}");
|
||||
serve(&path).await
|
||||
}
|
||||
|
||||
pub async fn favicon() -> Response {
|
||||
serve("/favicon.ico").await
|
||||
}
|
||||
|
||||
pub async fn robots() -> Response {
|
||||
serve("/robots.txt").await
|
||||
}
|
|
@ -1,50 +0,0 @@
|
|||
use axum::{
|
||||
body::Body,
|
||||
http::{header::HeaderName, HeaderValue, Request, StatusCode},
|
||||
response::{IntoResponse, Response},
|
||||
};
|
||||
use std::str;
|
||||
use tower::ServiceExt;
|
||||
use tower_http::services::ServeFile;
|
||||
|
||||
pub fn text(code: u16, msg: &str) -> Response {
|
||||
(status_code(code), msg.to_owned()).into_response()
|
||||
}
|
||||
|
||||
pub fn json(code: u16, json: &str) -> Response {
|
||||
let mut res = (status_code(code), json.to_owned()).into_response();
|
||||
res.headers_mut().insert(
|
||||
HeaderName::from_static("content-type"),
|
||||
HeaderValue::from_static("application/json"),
|
||||
);
|
||||
res
|
||||
}
|
||||
|
||||
pub async fn serve(path: &str) -> Response {
|
||||
if !path.chars().any(|c| c == '.') {
|
||||
return text(403, "Invalid file path");
|
||||
}
|
||||
|
||||
let path = format!("public{path}");
|
||||
let file = ServeFile::new(path);
|
||||
|
||||
let Ok(mut res) = file.oneshot(Request::new(Body::empty())).await else {
|
||||
tracing::error!("Error while fetching file");
|
||||
return text(500, "Error when fetching file")
|
||||
};
|
||||
|
||||
if res.status() != StatusCode::OK {
|
||||
return text(404, "File not found");
|
||||
}
|
||||
|
||||
res.headers_mut().insert(
|
||||
HeaderName::from_static("cache-control"),
|
||||
HeaderValue::from_static("max-age=300"),
|
||||
);
|
||||
|
||||
res.into_response()
|
||||
}
|
||||
|
||||
fn status_code(code: u16) -> StatusCode {
|
||||
StatusCode::from_u16(code).map_or(StatusCode::OK, |code| code)
|
||||
}
|
|
@ -1,82 +0,0 @@
|
|||
use std::net::{IpAddr, SocketAddr, TcpListener};
|
||||
use std::time::Duration;
|
||||
|
||||
use axum::routing::get;
|
||||
use axum::{Extension, Router};
|
||||
use moka::future::Cache;
|
||||
use tokio::task::JoinHandle;
|
||||
use tower_cookies::CookieManagerLayer;
|
||||
use tracing::{error, info};
|
||||
|
||||
use crate::config::Config;
|
||||
use crate::database::Database;
|
||||
use crate::Result;
|
||||
|
||||
mod api;
|
||||
mod extract;
|
||||
mod file;
|
||||
mod http;
|
||||
mod pages;
|
||||
|
||||
pub struct WebServer {
|
||||
config: Config,
|
||||
database: Database,
|
||||
addr: SocketAddr,
|
||||
}
|
||||
|
||||
impl WebServer {
|
||||
pub async fn new(config: Config, database: Database) -> Result<Self> {
|
||||
let addr = format!("[::]:{}", config.web_port).parse::<SocketAddr>()?;
|
||||
Ok(Self {
|
||||
config,
|
||||
database,
|
||||
addr,
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn run(&self) -> Result<JoinHandle<()>> {
|
||||
let config = self.config.clone();
|
||||
let database = self.database.clone();
|
||||
let listener = TcpListener::bind(self.addr)?;
|
||||
|
||||
info!(
|
||||
"Listening for HTTP traffic on [::]:{}",
|
||||
self.config.web_port
|
||||
);
|
||||
|
||||
let app = Self::router(config, database);
|
||||
let server = axum::Server::from_tcp(listener)?;
|
||||
|
||||
let web_handle = tokio::spawn(async move {
|
||||
if let Err(err) = server
|
||||
.serve(app.into_make_service_with_connect_info::<SocketAddr>())
|
||||
.await
|
||||
{
|
||||
error!("{err}");
|
||||
}
|
||||
});
|
||||
|
||||
Ok(web_handle)
|
||||
}
|
||||
|
||||
fn router(config: Config, database: Database) -> Router {
|
||||
let cache: Cache<String, IpAddr> = Cache::builder()
|
||||
.time_to_live(Duration::from_secs(60 * 15))
|
||||
.max_capacity(config.dns_cache_size)
|
||||
.build();
|
||||
|
||||
Router::new()
|
||||
.nest("/", pages::router())
|
||||
.nest("/api", api::router())
|
||||
.layer(Extension(config))
|
||||
.layer(Extension(cache))
|
||||
.layer(Extension(database))
|
||||
.layer(CookieManagerLayer::new())
|
||||
.route("/js/*path", get(file::js))
|
||||
.route("/css/*path", get(file::css))
|
||||
.route("/fonts/*path", get(file::fonts))
|
||||
.route("/image/*path", get(file::image))
|
||||
.route("/favicon.ico", get(file::favicon))
|
||||
.route("/robots.txt", get(file::robots))
|
||||
}
|
||||
}
|
|
@ -1,31 +0,0 @@
|
|||
use axum::{response::Response, routing::get, Router};
|
||||
|
||||
use super::{extract::Authorized, http::serve};
|
||||
|
||||
pub fn router() -> Router {
|
||||
Router::new()
|
||||
.route("/", get(root))
|
||||
.route("/login", get(login))
|
||||
.route("/home", get(home))
|
||||
.route("/domain", get(domain))
|
||||
}
|
||||
|
||||
async fn root(user: Option<Authorized>) -> Response {
|
||||
if user.is_some() {
|
||||
home().await
|
||||
} else {
|
||||
login().await
|
||||
}
|
||||
}
|
||||
|
||||
async fn login() -> Response {
|
||||
serve("/login.html").await
|
||||
}
|
||||
|
||||
async fn home() -> Response {
|
||||
serve("/home.html").await
|
||||
}
|
||||
|
||||
async fn domain() -> Response {
|
||||
serve("/domain.html").await
|
||||
}
|
Loading…
Reference in a new issue