finish dns and start webserver

This commit is contained in:
Tyler Murphy 2023-03-06 18:50:08 -05:00
parent 0f40ab89e3
commit b1fb410aff
No known key found for this signature in database
GPG key ID: 04CC7A6A289B470F
42 changed files with 3093 additions and 202 deletions

1
.gitignore vendored
View file

@ -1 +1,2 @@
**/target **/target
.env

1321
Cargo.lock generated

File diff suppressed because it is too large Load diff

View file

@ -4,8 +4,34 @@ version = "0.1.0"
edition = "2021" edition = "2021"
[dependencies] [dependencies]
# Blazingly fast runtime
tokio = { version = "1", features = ["full"] } tokio = { version = "1", features = ["full"] }
async-recursion = "1"
tracing = "0.1.37"
tracing-subscriber = "0.3.16" tracing-subscriber = "0.3.16"
tracing = "0.1.37"
# Allow recursion inside tokio async
async-recursion = "1"
# DNS Caching Layer
moka = { version = "0.10.0", features = ["future"] } moka = { version = "0.10.0", features = ["future"] }
# Mongodb
mongodb = { version = "2.4", features = ["tokio-sync"] }
futures = "0.3.26"
# Convert values to json for Mongodb
serde = { version = "1.0", features = ["derive"] }
# Reading env vars from .env
dotenv = "0.15.0"
# For the meme records
rand = "0.8.5"
# For the http web frontend
axum = "0.6.4"
tower-http = { version = "0.4.0", features = ["fs"] }
tower-cookies = "0.9.0"
tower = "0.4.13"
bytes = "1.4.0"
serde_json = "1"

40
public/css/home.css Normal file
View file

@ -0,0 +1,40 @@
span {
margin-top: 5rem;
margin-bottom: 1rem;
width: 45rem;
font-size: 2em;
}
#new {
display: flex;
justify-content: center;
width: 100%;
padding-top: 2rem;
padding-bottom: 1rem;
border-bottom: solid 1px var(--gray);
}
#new input, .block {
border-radius: 1rem 0 0 1rem;
width: 40rem;
}
.block {
width: 33em;
}
#new button {
border-radius: 0 1rem 1rem 0;
}
.domain {
margin-top: 2rem;
}
.domain .delete {
border-radius: 0 1rem 1rem 0;
}
.domain .edit {
border-radius: 0;
}

18
public/css/login.css Normal file
View file

@ -0,0 +1,18 @@
#login {
margin-top: 20em;
}
#logo {
font-size: 6em;
font-weight: 750;
font-family: bold;
margin-bottom: 2rem;
}
form {
width: 30rem;
}
form input {
width: 100%;
}

119
public/css/main.css Normal file
View file

@ -0,0 +1,119 @@
:root {
--dark: #222428;
--dark-alternate: #2b2e36;
--header: #1e1e22;
--accent: #8849f5;
--accent-alternate: #6829d5;
--gray: #2f2f3f;
--main: #ffffff;
--main-alternate: #cccccc;
}
* {
padding: 0;
margin: 0;
}
@font-face {
font-family: main;
src: url("../fonts/helvetica.ttf") format("truetype");
font-display: swap;
}
@font-face {
font-family: bold;
src: url("../fonts/overpass-bold.otf") format("opentype");
font-display: swap;
}
@font-face {
font-family: bold-italic;
src: url("../fonts/overpass-bold-italic.otf") format("opentype");
font-display: swap;
}
html {
background-color: var(--dark);
font-family: main;
color: var(--main);
width: 100%;
height: 100%;
}
body {
width: 100%;
height: 100%;
display: flex;
flex-direction: column;
align-items: center;
}
.accent {
color: var(--accent);
}
.fill {
width: 100%;
height: 100%;
display: flex;
flex-direction: column;
align-items: center;
}
input, button, .block {
all: unset;
display: inline-block;
font: main;
background-color: var(--dark-alternate);
font-size: 1rem;
padding: 1rem;
border-radius: 1rem;
margin-bottom: 20px;
}
button {
background-color: var(--accent);
width: 5em;
text-align: center;
}
button:hover {
cursor: pointer;
background-color: var(--accent-alternate);
}
.delete {
background-color: #f54842;
}
.delete:hover {
cursor: pointer;
background-color: #d52822;
}
form {
display: flex;
flex-direction: column;
}
#header {
width: calc(100% - 4rem);
background-color: var(--header);
border-bottom: solid 1px var(--gray);
padding: 1rem;
padding-left: 3rem;
}
#logo {
font-size: 2em;
font-weight: 500;
font-family: bold;
}
#title {
font-size: 2em;
font-weight: 300;
font-family: sans-serif;
padding-left: 1em;
}

67
public/css/record.css Normal file
View file

@ -0,0 +1,67 @@
#buttons {
margin-top: 2rem;
width: 50rem;
}
#buttons button {
margin: 0;
margin-right: 2rem;
border-radius: 10px;
width: auto;
padding: .75rem 1rem;
}
.record {
width: 50rem;
background-color: var(--header);
padding: 1rem;
margin-top: 2rem;
}
.header {
display: flex;
align-items: center;
margin-bottom: 1rem;
}
.header span {
font-family: bold;
}
.header button {
margin: 0;
margin-left: 2rem;
padding: .5rem 1rem;
width: auto;
border-radius: 5px;
}
.type {
margin-right: 1rem;
background-color: var(--accent);
padding: .25rem .5rem;
border-radius: 5px;
}
.domain {
color: var(--main-alternate);
flex-grow: 1;
}
.properties {
display: flex;
flex-direction: column;
}
.poperty {
display: flex;
flex-direction: row;
border-bottom: solid 1px var(--gray);
margin-top: 1rem;
}
.key {
font-family: bold;
width: 5rem;
}

21
public/domain.html Normal file
View file

@ -0,0 +1,21 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Wrapper - Records</title>
<meta name="author" content="Tyler Murphy">
<meta name="description" content="wrapper records">
<meta property="og:title" content="wrapper">
<meta property="og:description" content="wrapper records">
<link rel="stylesheet" href="/css/main.css">
<link rel="stylesheet" href="/css/record.css">
<script type="module" src="/js/domain.js"></script>
</head>
<body>
</body>
</html>

Binary file not shown.

BIN
public/fonts/helvetica.ttf Normal file

Binary file not shown.

Binary file not shown.

Binary file not shown.

21
public/home.html Normal file
View file

@ -0,0 +1,21 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Wrapper - Domains</title>
<meta name="author" content="Tyler Murphy">
<meta name="description" content="wrapper domains">
<meta property="og:title" content="wrapper">
<meta property="og:description" content="wrapper domains">
<link rel="stylesheet" href="/css/main.css">
<link rel="stylesheet" href="/css/home.css">
<script type="module" src="/js/home.js"></script>
</head>
<body>
</body>
</html>

51
public/js/api.js Normal file
View file

@ -0,0 +1,51 @@
const endpoint = '/api'
const request = async (url, method, body) => {
let response;
if (method == 'GET') {
response = await fetch(endpoint + url, {
method,
headers: {
'Content-Type': 'application/json'
}
});
} else {
response = await fetch(endpoint + url, {
method,
body: JSON.stringify(body),
headers: {
'Content-Type': 'application/json'
}
});
}
if (response.status == 401) {
location.href = '/login'
}
const contentType = response.headers.get("content-type");
if (contentType && contentType.indexOf("application/json") !== -1) {
const json = await response.json()
return { status: response.status, msg: json.msg, json }
} else {
const msg = await response.text();
return { status: response.status, msg }
}
}
export const login = async (user, pass) => {
return await request('/login', 'POST', {user, pass})
}
export const domains = async () => {
return await request('/domains', 'GET')
}
export const del_domain = async (domain) => {
return await request('/domains', 'DELETE', {domain})
}
export const records = async (domain) => {
return await request(`/records?domain=${domain}`, 'GET')
}

12
public/js/components.js Normal file
View file

@ -0,0 +1,12 @@
import { div, parse, span } from './main.js';
export function header(title) {
return div({id: 'header'},
span({id: 'logo', class: 'accent'},
parse("Wrapper")
),
span({id: 'title'},
parse(title)
),
)
}

95
public/js/domain.js Normal file
View file

@ -0,0 +1,95 @@
import { del_domain, domains, records } from './api.js'
import { header } from './components.js'
import { body, parse, div, input, button, span, is_domain } from './main.js';
function render(domain, records) {
let divs = []
for (const record of records) {
divs.push(gen_record(record))
}
document.body.replaceWith(
body({},
header(domain),
div({id: 'buttons'},
button({onclick: (event) => {
location.href = '/home'
}}, parse("Home")),
button({}, parse("New Record")),
),
...divs
)
)
}
function gen_record(record) {
let domain = record.domain
let prefix = record.prefix
if (prefix.length > 0) {
prefix = prefix + '.'
}
let type = Object.keys(record.record)[0]
let data = record.record[type]
let divs = []
for (const key in data) {
let disp_key;
if (key == 'ttl') {
disp_key = 'TTL'
} else {
disp_key = upper(key)
}
divs.push(
div({class: 'poperty'},
div({class: 'key'}, parse(disp_key)),
div({class: 'value'}, parse(data[key])),
)
)
}
return div({class: 'record'},
div({class: 'header'},
span({class: 'type'}, parse(type)),
span({class: 'prefix'}, parse(prefix)),
span({class: 'domain'}, parse(domain)),
button({}, parse("Edit")),
button({class: 'delete'}, parse("Delete"))
),
div({class: 'properties'},
...divs
)
)
}
function upper(string) {
return string.charAt(0).toUpperCase() + string.slice(1);
}
async function init() {
const params = new Proxy(new URLSearchParams(window.location.search), {
get: (searchParams, prop) => searchParams.get(prop),
});
let domain = params.domain;
if (!is_domain(domain)) {
location.href = '/home'
return
}
let res = await records(domain);
if (res.status !== 200) {
alert(res.msg)
return
}
render(domain, res.json)
}
init()

91
public/js/home.js Normal file
View file

@ -0,0 +1,91 @@
import { del_domain, domains } from './api.js'
import { header } from './components.js'
import { body, parse, div, input, button, span, is_domain } from './main.js';
function render(domains) {
document.body.replaceWith(
body({},
header('domains'),
div({id: 'new'},
input({
type: 'text',
name: 'domain',
id: 'domain',
placeholder: 'Type domain name to create new records',
autocomplete: "off",
}),
button({onclick: () => {
let domain = document.getElementById('domain').value
if (!is_domain(domain)) {
alert("Invalid domain")
return
}
location.href = '/domain?domain='+domain
}},
parse("Create")
)
),
...domain(domains)
)
)
}
function domain(domains) {
let divs = []
for (const domain of domains) {
divs.push(
div({class: 'domain'},
div({class: 'block'},
parse(domain)
),
button({class: 'edit', onclick: (event) => {
console.log(event.target.parentElement)
let domain = event
.target
.parentElement
.getElementsByClassName('block')[0]
.innerText
if (!is_domain(domain)) {
alert("Invalid domain")
return
}
location.href = '/domain?domain='+domain
}},
parse("Edit")
),
button({class: 'delete', onclick: async () => {
let res = await del_domain(domain)
if (res.status != 204) {
alert(res.msg)
return
}
location.reload()
}},
parse("Delete")
)
)
)
}
return divs
}
async function init() {
let res = await domains();
if (res.status !== 200) {
alert(res.msg)
return
}
render(res.json)
}
init()

44
public/js/login.js Normal file
View file

@ -0,0 +1,44 @@
import { body, div, form, input, p, parse, span} from './main.js'
import { login } from './api.js'
function render() {
document.body.replaceWith(
body({},
div({id: 'login', class: 'fill'},
span({id: 'logo'},
span({class: 'accent'}, parse('Wrapper'))
),
form({autocomplete: "off"},
input({
type: 'text',
name: 'user',
id: 'user',
placeholder: 'Username',
autofocus: 1
}),
input({
type: 'password',
name: 'pass',
id: 'pass',
placeholder: 'Password',
onkeydown: async (event) => {
if (event.key == 'Enter') {
event.preventDefault()
let user = document.getElementById('user').value
let pass = document.getElementById('pass').value
let res = await login(user, pass)
if (res.status === 200) {
location.href = '/home'
}
}
}
})
)
)
)
)
}
render()

136
public/js/main.js Normal file
View file

@ -0,0 +1,136 @@
function createElement(name, attrs, ...children) {
const el = document.createElement(name);
for (const attr in attrs) {
if(attr.startsWith("on")) {
el[attr] = attrs[attr];
} else {
el.setAttribute(attr, attrs[attr])
}
}
for (const child of children) {
if (child == null) {
continue
}
el.appendChild(child)
}
return el
}
export function createElementNS(name, attrs, ...children) {
var svgns = "http://www.w3.org/2000/svg";
var el = document.createElementNS(svgns, name);
for (const attr in attrs) {
if(attr.startsWith("on")) {
el[attr] = attrs[attr];
} else {
el.setAttribute(attr, attrs[attr])
}
}
for (const child of children) {
if (child == null) {
continue
}
el.appendChild(child)
}
return el
}
export function p(attrs, ...children) {
return createElement("p", attrs, ...children)
}
export function span(attrs, ...children) {
return createElement("span", attrs, ...children)
}
export function div(attrs, ...children) {
return createElement("div", attrs, ...children)
}
export function a(attrs, ...children) {
return createElement("a", attrs, ...children)
}
export function i(attrs, ...children) {
return createElement("i", attrs, ...children)
}
export function form(attrs, ...children) {
return createElement("form", attrs, ...children)
}
export function img(alt, attrs, ...children) {
attrs['onerror'] = (event) => event.target.remove()
attrs['alt'] = alt
return createElement("img", attrs, ...children)
}
export function input(attrs, ...children) {
return createElement("input", attrs, ...children)
}
export function button(attrs, ...children) {
return createElement("button", attrs, ...children)
}
export function path(attrs, ...children) {
return createElementNS("path", attrs, ...children)
}
export function svg(attrs, ...children) {
return createElementNS("svg", attrs, ...children)
}
export function body(attrs, ...children) {
return createElement("body", attrs, ...children)
}
export function textarea(attrs, ...children) {
return createElement("textarea", attrs, ...children)
}
export function parse(input) {
const pattern = /^[ a-zA-Z0-9!@#$%^&*()_+\-=\[\]{};':"\\|,.<>\/?]*$/;
input = input + '';
if (!pattern.test(input)) {
return null;
}
const sanitized = input.replace(/</g, '&lt;').replace(/>/g, '&gt;');
return document.createRange().createContextualFragment(sanitized);
}
export function is_domain(domain) {
domain = domain.toLowerCase()
const pattern = /^[a-z0-9_\-.]*$/;
if (!pattern.test(domain)) {
return false
}
let parts = domain.split('.').reverse()
for (const part of parts) {
if (part.length < 1) {
return false
}
}
if (parts.length < 2 || parts[0].length < 2) {
return false
}
const tld_pattern = /^[a-z]*$/;
if (!tld_pattern.test(parts[0])) {
return false
}
return true
}

21
public/login.html Normal file
View file

@ -0,0 +1,21 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Wrapper - Login</title>
<meta name="author" content="Tyler Murphy">
<meta name="description" content="wrapper dns login">
<meta property="og:title" content="wrapper">
<meta property="og:description" content="wrapper dns login">
<link rel="stylesheet" href="/css/main.css">
<link rel="stylesheet" href="/css/login.css">
<script type="module" src="/js/login.js"></script>
</head>
<body>
</body>
</html>

9
public/robots.txt Normal file
View file

@ -0,0 +1,9 @@
User-agent: Googlebot
Disallow: /api
User-agent: Googlebot
User-agent: AdsBot-Google
Disallow: /api
User-agent: *
Disallow: /api

View file

@ -1,35 +1,57 @@
use std::net::IpAddr; use std::{env, net::IpAddr, str::FromStr, fmt::Display};
#[derive(Clone)] #[derive(Clone)]
pub struct Config { pub struct Config {
fallback: IpAddr, pub dns_fallback: IpAddr,
port: u16, pub dns_port: u16,
pub dns_cache_size: u64,
pub db_host: String,
pub db_port: u16,
pub db_user: String,
pub db_pass: String,
pub web_user: String,
pub web_pass: String,
pub web_port: u16,
} }
impl Config { impl Config {
pub fn new() -> Self { pub fn new() -> Self {
let fallback = "9.9.9.9" let dns_port = Self::get_var::<u16>("WRAPPER_DNS_PORT", 53);
.parse::<IpAddr>() let dns_fallback = Self::get_var::<IpAddr>("WRAPPER_FALLBACK_DNS", [9, 9, 9, 9].into());
.expect("Failed to create default ns fallback"); let dns_cache_size = Self::get_var::<u64>("WRAPPER_CACHE_SIZE", 1000);
let db_host = Self::get_var::<String>("WRAPPER_DB_HOST", String::from("localhost"));
let db_port = Self::get_var::<u16>("WRAPPER_DB_PORT", 27017);
let db_user = Self::get_var::<String>("WRAPPER_DB_USER", String::from("root"));
let db_pass = Self::get_var::<String>("WRAPPER_DB_PASS", String::from(""));
let web_user = Self::get_var::<String>("WRAPPER_WEB_USER", String::from("admin"));
let web_pass = Self::get_var::<String>("WRAPPER_WEB_PASS", String::from("wrapper"));
let web_port = Self::get_var::<u16>("WRAPPER_WEB_PORT", 80);
Self { Self {
fallback, dns_fallback,
port: 2000, dns_port,
dns_cache_size,
db_host,
db_port,
db_user,
db_pass,
web_user,
web_pass,
web_port,
} }
} }
pub fn get_fallback_ns(&self) -> &IpAddr { fn get_var<T>(name: &str, default: T) -> T
&self.fallback where
} T: FromStr + Display,
{
pub fn get_port(&self) -> u16 { let env = env::var(name).unwrap_or(format!("{default}"));
self.port env.parse::<T>().unwrap_or(default)
}
pub fn set_fallback_ns(&mut self, addr: &IpAddr) {
self.fallback = *addr;
}
pub fn set_port(&mut self, port: u16) {
self.port = port;
} }
} }

146
src/database/mod.rs Normal file
View file

@ -0,0 +1,146 @@
use futures::TryStreamExt;
use mongodb::{
bson::doc,
options::{ClientOptions, Credential, ServerAddress},
Client,
};
use serde::{Deserialize, Serialize};
use tracing::info;
use crate::{
config::Config,
dns::packet::{query::QueryType, record::DnsRecord},
};
use crate::Result;
#[derive(Clone)]
pub struct Database {
client: Client,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct StoredRecord {
record: DnsRecord,
domain: String,
prefix: String,
}
impl StoredRecord {
fn get_domain_parts(domain: &str) -> (String, String) {
let parts: Vec<&str> = domain.split(".").collect();
let len = parts.len();
if len == 1 {
(String::new(), String::from(parts[0]))
} else if len == 2 {
(String::new(), String::from(parts.join(".")))
} else {
(
String::from(parts[0..len - 2].join(".")),
String::from(parts[len - 2..len].join(".")),
)
}
}
}
impl From<DnsRecord> for StoredRecord {
fn from(record: DnsRecord) -> Self {
let (prefix, domain) = Self::get_domain_parts(&record.get_domain());
Self {
record,
domain,
prefix,
}
}
}
impl Into<DnsRecord> for StoredRecord {
fn into(self) -> DnsRecord {
self.record
}
}
impl Database {
pub async fn new(config: Config) -> Result<Self> {
let options = ClientOptions::builder()
.hosts(vec![ServerAddress::Tcp {
host: config.db_host,
port: Some(config.db_port),
}])
.credential(
Credential::builder()
.username(config.db_user)
.password(config.db_pass)
.build(),
)
.max_pool_size(100)
.app_name(String::from("wrapper"))
.build();
let client = Client::with_options(options)?;
client
.database("wrapper")
.run_command(doc! {"ping": 1}, None)
.await?;
info!("Connection to mongodb successfully");
Ok(Database { client })
}
pub async fn get_records(&self, domain: &str, qtype: QueryType) -> Result<Vec<DnsRecord>> {
let (prefix, domain) = StoredRecord::get_domain_parts(domain);
Ok(self
.get_domain(&domain)
.await?
.into_iter()
.filter(|r| r.prefix == prefix)
.filter(|r| {
let rqtype = r.record.get_qtype();
if qtype == QueryType::A {
return rqtype == QueryType::A || rqtype == QueryType::AR;
} else if qtype == QueryType::AAAA {
return rqtype == QueryType::AAAA || rqtype == QueryType::AAAAR;
} else {
r.record.get_qtype() == qtype
}
})
.map(|r| r.into())
.collect())
}
pub async fn get_domain(&self, domain: &str) -> Result<Vec<StoredRecord>> {
let db = self.client.database("wrapper");
let col = db.collection::<StoredRecord>(domain);
let filter = doc! { "domain": domain };
let mut cursor = col.find(filter, None).await?;
let mut records = Vec::new();
while let Some(record) = cursor.try_next().await? {
records.push(record);
}
Ok(records)
}
pub async fn add_record(&self, record: DnsRecord) -> Result<()> {
let record = StoredRecord::from(record);
let db = self.client.database("wrapper");
let col = db.collection::<StoredRecord>(&record.domain);
col.insert_one(record, None).await?;
Ok(())
}
pub async fn get_domains(&self) -> Result<Vec<String>> {
let db = self.client.database("wrapper");
Ok(db.list_collection_names(None).await?)
}
pub async fn delete_domain(&self, domain: String) -> Result<()> {
let db = self.client.database("wrapper");
let col = db.collection::<StoredRecord>(&domain);
Ok(col.drop(None).await?)
}
}

View file

@ -3,7 +3,8 @@ use std::{
sync::Arc, sync::Arc,
}; };
use crate::packet::{buffer::PacketBuffer, Packet, Result}; use super::packet::{buffer::PacketBuffer, Packet};
use crate::Result;
use tokio::{ use tokio::{
io::{AsyncReadExt, AsyncWriteExt}, io::{AsyncReadExt, AsyncWriteExt},
net::{TcpListener, TcpStream, UdpSocket}, net::{TcpListener, TcpStream, UdpSocket},
@ -140,11 +141,4 @@ impl Connection {
} }
} }
} }
// fn pb(buf: &[u8]) {
// for i in 0..buf.len() {
// print!("{:02X?} ", buf[i]);
// }
// println!("");
// }
} }

View file

@ -1,3 +1,4 @@
mod binding; mod binding;
pub mod packet;
mod resolver; mod resolver;
pub mod server; pub mod server;

View file

@ -1,4 +1,4 @@
use super::Result; use crate::Result;
pub struct PacketBuffer { pub struct PacketBuffer {
pub buf: Vec<u8>, pub buf: Vec<u8>,
@ -9,19 +9,9 @@ pub struct PacketBuffer {
impl PacketBuffer { impl PacketBuffer {
pub fn new(buf: Vec<u8>) -> Self { pub fn new(buf: Vec<u8>) -> Self {
Self { Self {
size: buf.len(),
buf, buf,
pos: 0, pos: 0,
size: 0,
}
}
fn check(&mut self, pos: usize) {
if self.size < pos {
self.size = pos;
}
if self.buf.len() <= self.size {
self.buf.resize(self.size + 1, 0x00);
} }
} }
@ -42,32 +32,25 @@ impl PacketBuffer {
} }
pub fn read(&mut self) -> Result<u8> { pub fn read(&mut self) -> Result<u8> {
// if self.pos >= 512 { if self.pos >= self.size {
// error!("Tried to read past end of buffer"); return Err("Tried to read past end of buffer".into());
// return Err("End of buffer".into()); }
// }
self.check(self.pos);
let res = self.buf[self.pos]; let res = self.buf[self.pos];
self.pos += 1; self.pos += 1;
Ok(res) Ok(res)
} }
pub fn get(&mut self, pos: usize) -> Result<u8> { pub fn get(&mut self, pos: usize) -> Result<u8> {
// if pos >= 512 { if pos >= self.size {
// error!("Tried to read past end of buffer"); return Err("Tried to read past end of buffer".into());
// return Err("End of buffer".into()); }
// }
self.check(pos);
Ok(self.buf[pos]) Ok(self.buf[pos])
} }
pub fn get_range(&mut self, start: usize, len: usize) -> Result<&[u8]> { pub fn get_range(&mut self, start: usize, len: usize) -> Result<&[u8]> {
// if start + len >= 512 { if start + len >= self.size {
// error!("Tried to read past end of buffer"); return Err("Tried to read past end of buffer".into());
// return Err("End of buffer".into()); }
// }
self.check(start + len);
Ok(&self.buf[start..start + len]) Ok(&self.buf[start..start + len])
} }
@ -169,7 +152,13 @@ impl PacketBuffer {
} }
pub fn write(&mut self, val: u8) -> Result<()> { pub fn write(&mut self, val: u8) -> Result<()> {
self.check(self.pos); if self.size < self.pos {
self.size = self.pos;
}
if self.buf.len() <= self.size {
self.buf.resize(self.size + 1, 0x00);
}
self.buf[self.pos] = val; self.buf[self.pos] = val;
self.pos += 1; self.pos += 1;
@ -208,7 +197,9 @@ impl PacketBuffer {
} }
} }
self.write_u8(0)?; if !qname.is_empty() {
self.write_u8(0)?;
}
Ok(()) Ok(())
} }

View file

@ -1,4 +1,5 @@
use super::{buffer::PacketBuffer, result::ResultCode, Result}; use super::{buffer::PacketBuffer, result::ResultCode};
use crate::Result;
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct DnsHeader { pub struct DnsHeader {

View file

@ -4,9 +4,7 @@ use self::{
buffer::PacketBuffer, header::DnsHeader, query::QueryType, question::DnsQuestion, buffer::PacketBuffer, header::DnsHeader, query::QueryType, question::DnsQuestion,
record::DnsRecord, record::DnsRecord,
}; };
use crate::Result;
type Error = Box<dyn std::error::Error>;
pub type Result<T> = std::result::Result<T, Error>;
pub mod buffer; pub mod buffer;
pub mod header; pub mod header;

View file

@ -12,6 +12,8 @@ pub enum QueryType {
SRV, // 33 SRV, // 33
OPT, // 41 OPT, // 41
CAA, // 257 CAA, // 257
AR, // 1000
AAAAR, // 1001
} }
impl QueryType { impl QueryType {
@ -29,6 +31,8 @@ impl QueryType {
Self::SRV => 33, Self::SRV => 33,
Self::OPT => 41, Self::OPT => 41,
Self::CAA => 257, Self::CAA => 257,
Self::AR => 1000,
Self::AAAAR => 1001,
} }
} }
@ -45,7 +49,30 @@ impl QueryType {
33 => Self::SRV, 33 => Self::SRV,
41 => Self::OPT, 41 => Self::OPT,
257 => Self::CAA, 257 => Self::CAA,
1000 => Self::AR,
1001 => Self::AAAAR,
_ => Self::UNKNOWN(num), _ => Self::UNKNOWN(num),
} }
} }
pub fn allowed_actions(&self) -> (bool, bool) {
// 0. duplicates allowed
// 1. allowed to be created by database
match self {
QueryType::UNKNOWN(_) => (false, false),
QueryType::A => (true, true),
QueryType::NS => (false, true),
QueryType::CNAME => (false, true),
QueryType::SOA => (false, false),
QueryType::PTR => (false, true),
QueryType::MX => (false, true),
QueryType::TXT => (true, true),
QueryType::AAAA => (true, true),
QueryType::SRV => (false, true),
QueryType::OPT => (false, false),
QueryType::CAA => (false, true),
QueryType::AR => (false, true),
QueryType::AAAAR => (false, true),
}
}
} }

View file

@ -1,11 +1,12 @@
use std::net::{Ipv4Addr, Ipv6Addr}; use std::net::{Ipv4Addr, Ipv6Addr};
use rand::RngCore;
use serde::{Deserialize, Serialize};
use tracing::{trace, warn}; use tracing::{trace, warn};
use super::{buffer::PacketBuffer, query::QueryType, Result}; use super::{buffer::PacketBuffer, query::QueryType, Result};
#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] #[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)]
#[allow(dead_code)]
pub enum DnsRecord { pub enum DnsRecord {
UNKNOWN { UNKNOWN {
domain: String, domain: String,
@ -76,10 +77,17 @@ pub enum DnsRecord {
value: String, value: String,
ttl: u32, ttl: u32,
}, // 257 }, // 257
AR {
domain: String,
ttl: u32,
},
AAAAR {
domain: String,
ttl: u32,
},
} }
impl DnsRecord { impl DnsRecord {
pub fn read(buffer: &mut PacketBuffer) -> Result<Self> { pub fn read(buffer: &mut PacketBuffer) -> Result<Self> {
let mut domain = String::new(); let mut domain = String::new();
buffer.read_qname(&mut domain)?; buffer.read_qname(&mut domain)?;
@ -90,10 +98,10 @@ impl DnsRecord {
let ttl = buffer.read_u32()?; let ttl = buffer.read_u32()?;
let data_len = buffer.read_u16()?; let data_len = buffer.read_u16()?;
let header_pos = buffer.pos();
trace!("Reading DNS Record TYPE: {:?}", qtype); trace!("Reading DNS Record TYPE: {:?}", qtype);
let header_pos = buffer.pos();
match qtype { match qtype {
QueryType::A => { QueryType::A => {
let raw_addr = buffer.read_u32()?; let raw_addr = buffer.read_u32()?;
@ -471,6 +479,29 @@ impl DnsRecord {
let size = buffer.pos() - (pos + 2); let size = buffer.pos() - (pos + 2);
buffer.set_u16(pos, size as u16)?; buffer.set_u16(pos, size as u16)?;
} }
Self::AR { ref domain, ttl } => {
buffer.write_qname(domain)?;
buffer.write_u16(QueryType::A.to_num())?;
buffer.write_u16(1)?;
buffer.write_u32(ttl)?;
buffer.write_u16(4)?;
let mut rand = rand::thread_rng();
buffer.write_u32(rand.next_u32())?;
}
Self::AAAAR { ref domain, ttl } => {
buffer.write_qname(domain)?;
buffer.write_u16(QueryType::A.to_num())?;
buffer.write_u16(1)?;
buffer.write_u32(ttl)?;
buffer.write_u16(4)?;
let mut rand = rand::thread_rng();
buffer.write_u32(rand.next_u32())?;
buffer.write_u32(rand.next_u32())?;
buffer.write_u32(rand.next_u32())?;
buffer.write_u32(rand.next_u32())?;
}
Self::UNKNOWN { .. } => { Self::UNKNOWN { .. } => {
warn!("Skipping record: {self:?}"); warn!("Skipping record: {self:?}");
} }
@ -479,20 +510,35 @@ impl DnsRecord {
Ok(buffer.pos() - start_pos) Ok(buffer.pos() - start_pos)
} }
pub fn get_ttl(&self) -> u32 { pub fn get_domain(&self) -> String {
match *self { self.get_shared_domain().0
DnsRecord::UNKNOWN { .. } => 0,
DnsRecord::AAAA { ttl, .. } => ttl,
DnsRecord::A { ttl, .. } => ttl,
DnsRecord::NS { ttl, .. } => ttl,
DnsRecord::CNAME { ttl, .. } => ttl,
DnsRecord::SOA { ttl, .. } => ttl,
DnsRecord::PTR { ttl, .. } => ttl,
DnsRecord::MX { ttl, .. } => ttl,
DnsRecord::TXT { ttl, .. } => ttl,
DnsRecord::SRV { ttl, .. } => ttl,
DnsRecord::CAA { ttl, .. } => ttl,
}
} }
pub fn get_qtype(&self) -> QueryType {
self.get_shared_domain().1
}
pub fn get_ttl(&self) -> u32 {
self.get_shared_domain().2
}
fn get_shared_domain(&self) -> (String, QueryType, u32) {
match self {
DnsRecord::UNKNOWN {
domain, ttl, qtype, ..
} => (domain.clone(), QueryType::UNKNOWN(*qtype), *ttl),
DnsRecord::AAAA { domain, ttl, .. } => (domain.clone(), QueryType::AAAA, *ttl),
DnsRecord::A { domain, ttl, .. } => (domain.clone(), QueryType::A, *ttl),
DnsRecord::NS { domain, ttl, .. } => (domain.clone(), QueryType::NS, *ttl),
DnsRecord::CNAME { domain, ttl, .. } => (domain.clone(), QueryType::CNAME, *ttl),
DnsRecord::SOA { domain, ttl, .. } => (domain.clone(), QueryType::SOA, *ttl),
DnsRecord::PTR { domain, ttl, .. } => (domain.clone(), QueryType::PTR, *ttl),
DnsRecord::MX { domain, ttl, .. } => (domain.clone(), QueryType::MX, *ttl),
DnsRecord::TXT { domain, ttl, .. } => (domain.clone(), QueryType::TXT, *ttl),
DnsRecord::SRV { domain, ttl, .. } => (domain.clone(), QueryType::SRV, *ttl),
DnsRecord::CAA { domain, ttl, .. } => (domain.clone(), QueryType::CAA, *ttl),
DnsRecord::AR { domain, ttl, .. } => (domain.clone(), QueryType::AR, *ttl),
DnsRecord::AAAAR { domain, ttl, .. } => (domain.clone(), QueryType::AAAAR, *ttl),
}
}
} }

View file

@ -1,11 +1,7 @@
use super::binding::Connection; use super::binding::Connection;
use crate::{ use super::packet::{query::QueryType, question::DnsQuestion, result::ResultCode, Packet};
config::Config, use crate::Result;
packet::{ use crate::{config::Config, database::Database, get_time};
query::QueryType, question::DnsQuestion, result::ResultCode, Packet,
Result,
}, get_time,
};
use async_recursion::async_recursion; use async_recursion::async_recursion;
use moka::future::Cache; use moka::future::Cache;
use std::{net::IpAddr, sync::Arc, time::Duration}; use std::{net::IpAddr, sync::Arc, time::Duration};
@ -15,6 +11,7 @@ pub struct Resolver {
request_id: u16, request_id: u16,
connection: Connection, connection: Connection,
config: Arc<Config>, config: Arc<Config>,
database: Arc<Database>,
cache: Cache<DnsQuestion, (Packet, u64)>, cache: Cache<DnsQuestion, (Packet, u64)>,
} }
@ -23,18 +20,59 @@ impl Resolver {
request_id: u16, request_id: u16,
connection: Connection, connection: Connection,
config: Arc<Config>, config: Arc<Config>,
database: Arc<Database>,
cache: Cache<DnsQuestion, (Packet, u64)>, cache: Cache<DnsQuestion, (Packet, u64)>,
) -> Self { ) -> Self {
Self { Self {
request_id, request_id,
connection, connection,
config, config,
database,
cache, cache,
} }
} }
async fn lookup_cache(&mut self, qname: &str, qtype: QueryType) -> Option<Packet> { async fn lookup_database(&self, question: &DnsQuestion) -> Option<Packet> {
let question = DnsQuestion::new(qname.to_string(), qtype); let records = match self
.database
.get_records(&question.name, question.qtype)
.await
{
Ok(record) => record,
Err(err) => {
error!("{err}");
return None;
}
};
if records.is_empty() {
return None;
}
let mut packet = Packet::new();
packet.header.id = self.request_id;
packet.header.questions = 1;
packet.header.answers = records.len() as u16;
packet.header.recursion_desired = true;
packet
.questions
.push(DnsQuestion::new(question.name.to_string(), question.qtype));
for record in records {
packet.answers.push(record);
}
trace!(
"Found stored value for {:?} {}",
question.qtype,
question.name
);
Some(packet)
}
async fn lookup_cache(&self, question: &DnsQuestion) -> Option<Packet> {
let Some((packet, date)) = self.cache.get(&question) else { let Some((packet, date)) = self.cache.get(&question) else {
return None return None
}; };
@ -46,16 +84,20 @@ impl Resolver {
let ttl = answer.get_ttl(); let ttl = answer.get_ttl();
if diff > ttl { if diff > ttl {
self.cache.invalidate(&question).await; self.cache.invalidate(&question).await;
return None return None;
} }
} }
trace!("Found cached value for {qtype:?} {qname}"); trace!(
"Found cached value for {:?} {}",
question.qtype,
question.name
);
Some(packet) Some(packet)
} }
async fn lookup(&mut self, qname: &str, qtype: QueryType, server: (IpAddr, u16)) -> Packet { async fn lookup_fallback(&self, question: &DnsQuestion, server: (IpAddr, u16)) -> Packet {
let mut packet = Packet::new(); let mut packet = Packet::new();
packet.header.id = self.request_id; packet.header.id = self.request_id;
@ -63,7 +105,7 @@ impl Resolver {
packet.header.recursion_desired = true; packet.header.recursion_desired = true;
packet packet
.questions .questions
.push(DnsQuestion::new(qname.to_string(), qtype)); .push(DnsQuestion::new(question.name.to_string(), question.qtype));
let packet = match self.connection.request_packet(packet, server).await { let packet = match self.connection.request_packet(packet, server).await {
Ok(packet) => packet, Ok(packet) => packet,
@ -78,28 +120,47 @@ impl Resolver {
packet packet
} }
async fn lookup(&self, question: &DnsQuestion, server: (IpAddr, u16)) -> Packet {
if let Some(packet) = self.lookup_cache(question).await {
return packet;
};
if let Some(packet) = self.lookup_database(question).await {
return packet;
};
trace!(
"Attempting lookup of {:?} {} with ns {}",
question.qtype,
question.name,
server.0
);
self.lookup_fallback(question, server).await
}
#[async_recursion] #[async_recursion]
async fn recursive_lookup(&mut self, qname: &str, qtype: QueryType) -> Packet { async fn recursive_lookup(&mut self, qname: &str, qtype: QueryType) -> Packet {
let question = DnsQuestion::new(qname.to_string(), qtype); let question = DnsQuestion::new(qname.to_string(), qtype);
let mut ns = self.config.get_fallback_ns().clone(); let mut ns = self.config.dns_fallback.clone();
if let Some(packet) = self.lookup_cache(qname, qtype).await { return packet }
loop { loop {
trace!("Attempting lookup of {qtype:?} {qname} with ns {ns}");
let ns_copy = ns; let ns_copy = ns;
let server = (ns_copy, 53); let server = (ns_copy, 53);
let response = self.lookup(qname, qtype, server).await; let response = self.lookup(&question, server).await;
if !response.answers.is_empty() && response.header.rescode == ResultCode::NOERROR { if !response.answers.is_empty() && response.header.rescode == ResultCode::NOERROR {
self.cache.insert(question, (response.clone(), get_time())).await; self.cache
.insert(question, (response.clone(), get_time()))
.await;
return response; return response;
} }
if response.header.rescode == ResultCode::NXDOMAIN { if response.header.rescode == ResultCode::NXDOMAIN {
self.cache.insert(question, (response.clone(), get_time())).await; self.cache
.insert(question, (response.clone(), get_time()))
.await;
return response; return response;
} }
@ -111,9 +172,11 @@ impl Resolver {
let new_ns_name = match response.get_unresolved_ns(qname) { let new_ns_name = match response.get_unresolved_ns(qname) {
Some(x) => x, Some(x) => x,
None => { None => {
self.cache.insert(question, (response.clone(), get_time())).await; self.cache
return response .insert(question, (response.clone(), get_time()))
}, .await;
return response;
}
}; };
let recursive_response = self.recursive_lookup(new_ns_name, QueryType::A).await; let recursive_response = self.recursive_lookup(new_ns_name, QueryType::A).await;
@ -121,7 +184,9 @@ impl Resolver {
if let Some(new_ns) = recursive_response.get_random_a() { if let Some(new_ns) = recursive_response.get_random_a() {
ns = new_ns; ns = new_ns;
} else { } else {
self.cache.insert(question, (response.clone(), get_time())).await; self.cache
.insert(question, (response.clone(), get_time()))
.await;
return response; return response;
} }
} }

85
src/dns/server.rs Normal file
View file

@ -0,0 +1,85 @@
use super::{
binding::Binding,
packet::{question::DnsQuestion, Packet},
resolver::Resolver,
};
use crate::{config::Config, database::Database, Result};
use moka::future::Cache;
use std::{net::SocketAddr, sync::Arc, time::Duration};
use tokio::task::JoinHandle;
use tracing::{error, info};
pub struct DnsServer {
addr: SocketAddr,
config: Arc<Config>,
database: Arc<Database>,
cache: Cache<DnsQuestion, (Packet, u64)>,
}
impl DnsServer {
pub async fn new(config: Config, database: Database) -> Result<Self> {
let addr = format!("[::]:{}", config.dns_port).parse::<SocketAddr>()?;
let cache = Cache::builder()
.time_to_live(Duration::from_secs(60 * 60))
.max_capacity(config.dns_cache_size)
.build();
info!("Created DNS cache with size of {}", config.dns_cache_size);
Ok(Self {
addr,
config: Arc::new(config),
database: Arc::new(database),
cache,
})
}
pub async fn run(&self) -> Result<(JoinHandle<()>, JoinHandle<()>)> {
let tcp = Binding::tcp(self.addr).await?;
let tcp_handle = self.listen(tcp);
let udp = Binding::udp(self.addr).await?;
let udp_handle = self.listen(udp);
info!(
"Fallback DNS Server is set to: {:?}",
self.config.dns_fallback
);
info!(
"Listening for TCP and UDP traffic on [::]:{}",
self.config.dns_port
);
Ok((udp_handle, tcp_handle))
}
fn listen(&self, mut binding: Binding) -> JoinHandle<()> {
let config = self.config.clone();
let database = self.database.clone();
let cache = self.cache.clone();
tokio::spawn(async move {
let mut id = 0;
loop {
let Ok(connection) = binding.connect().await else { continue };
info!("Received request on {}", binding.name());
let resolver = Resolver::new(
id,
connection,
config.clone(),
database.clone(),
cache.clone(),
);
let name = binding.name().to_string();
tokio::spawn(async move {
if let Err(err) = resolver.handle_query().await {
error!("{} request {} failed: {:?}", name, id, err);
};
});
id += 1;
}
})
}
}

View file

@ -1,19 +1,34 @@
use std::{time::{UNIX_EPOCH, SystemTime}, env, net::IpAddr}; use std::time::{SystemTime, UNIX_EPOCH};
use config::Config; use config::Config;
use server::server::Server; use database::Database;
use tracing::metadata::LevelFilter; use dotenv::dotenv;
use dns::server::DnsServer;
use tracing::{error, metadata::LevelFilter};
use tracing_subscriber::{ use tracing_subscriber::{
filter::filter_fn, prelude::__tracing_subscriber_SubscriberExt, util::SubscriberInitExt, Layer, filter::filter_fn, prelude::__tracing_subscriber_SubscriberExt, util::SubscriberInitExt, Layer,
}; };
use web::WebServer;
mod config; mod config;
mod packet; mod database;
mod server; mod dns;
mod web;
type Error = Box<dyn std::error::Error>;
pub type Result<T> = std::result::Result<T, Error>;
#[tokio::main] #[tokio::main]
async fn main() { async fn main() {
if let Err(err) = run().await {
error!("{err}")
};
}
async fn run() -> Result<()> {
dotenv().ok();
tracing_subscriber::registry() tracing_subscriber::registry()
.with( .with(
tracing_subscriber::fmt::layer() tracing_subscriber::fmt::layer()
@ -24,19 +39,20 @@ async fn main() {
) )
.init(); .init();
let mut config = Config::new(); let config = Config::new();
let database = Database::new(config.clone()).await?;
if let Ok(port) = env::var("PORT").unwrap_or(String::new()).parse::<u16>() { let dns_server = DnsServer::new(config.clone(), database.clone()).await?;
config.set_port(port); let (udp, tcp) = dns_server.run().await?;
}
if let Ok(fallback) = env::var("FALLBACK_DNS").unwrap_or(String::new()).parse::<IpAddr>() { let web_server = WebServer::new(config, database).await?;
config.set_fallback_ns(&fallback); let web = web_server.run().await?;
}
let server = Server::new(config).await.expect("Failed to bind server"); tokio::join!(udp).0?;
tokio::join!(tcp).0?;
tokio::join!(web).0?;
server.run().await.unwrap(); Ok(())
} }
pub fn get_time() -> u64 { pub fn get_time() -> u64 {

View file

@ -1,73 +0,0 @@
use moka::future::Cache;
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Duration;
use tokio::task::JoinHandle;
use tracing::{error, info};
use crate::config::Config;
use crate::packet::question::DnsQuestion;
use crate::packet::{Result, Packet};
use super::binding::Binding;
use super::resolver::Resolver;
pub struct Server {
addr: SocketAddr,
config: Arc<Config>,
cache: Cache<DnsQuestion, (Packet, u64)>,
}
impl Server {
pub async fn new(config: Config) -> Result<Self> {
let addr = format!("[::]:{}", config.get_port()).parse::<SocketAddr>()?;
let cache = Cache::builder()
.time_to_live(Duration::from_secs(60 * 60))
.max_capacity(1_000)
.build();
Ok(Self {
addr,
config: Arc::new(config),
cache,
})
}
pub async fn run(&self) -> Result<()> {
let tcp = Binding::tcp(self.addr).await?;
let tcp_handle = self.listen(tcp);
let udp = Binding::udp(self.addr).await?;
let udp_handle = self.listen(udp);
info!("Fallback DNS Server is set to: {:?}", self.config.get_fallback_ns());
info!("Listening for TCP and UDP traffic on [::]:{}", self.config.get_port());
tokio::join!(tcp_handle)
.0
.expect("Failed to join tcp thread");
tokio::join!(udp_handle)
.0
.expect("Failed to join udp thread");
Ok(())
}
fn listen(&self, mut binding: Binding) -> JoinHandle<()> {
let config = self.config.clone();
let cache = self.cache.clone();
tokio::spawn(async move {
let mut id = 0;
loop {
let Ok(connection) = binding.connect().await else { continue };
info!("Received request on {}", binding.name());
let resolver = Resolver::new(id, connection, config.clone(), cache.clone());
if let Err(err) = resolver.handle_query().await {
error!("{} request {} failed: {:?}", binding.name(), id, err);
};
id += 1;
}
})
}
}

156
src/web/api.rs Normal file
View file

@ -0,0 +1,156 @@
use std::net::IpAddr;
use axum::{
extract::Query,
response::Response,
routing::{get, post, put, delete},
Extension, Router,
};
use moka::future::Cache;
use rand::distributions::{Alphanumeric, DistString};
use serde::Deserialize;
use tower_cookies::{Cookie, Cookies};
use crate::{config::Config, database::Database, dns::packet::record::DnsRecord};
use super::{
extract::{Authorized, Body, RequestIp},
http::{json, text},
};
pub fn router() -> Router {
Router::new()
.route("/login", post(login))
.route("/domains", get(list_domains))
.route("/domains", delete(delete_domain))
.route("/records", get(get_domain))
.route("/records", put(add_record))
}
async fn list_domains(_: Authorized, Extension(database): Extension<Database>) -> Response {
let domains = match database.get_domains().await {
Ok(domains) => domains,
Err(err) => return text(500, &format!("{err}")),
};
let Ok(domains) = serde_json::to_string(&domains) else {
return text(500, "Failed to fetch domains")
};
json(200, &domains)
}
#[derive(Deserialize)]
struct DomainRequest {
domain: String,
}
async fn get_domain(
_: Authorized,
Extension(database): Extension<Database>,
Query(query): Query<DomainRequest>,
) -> Response {
let records = match database.get_domain(&query.domain).await {
Ok(records) => records,
Err(err) => return text(500, &format!("{err}")),
};
let Ok(records) = serde_json::to_string(&records) else {
return text(500, "Failed to fetch records")
};
json(200, &records)
}
async fn delete_domain(
_: Authorized,
Extension(database): Extension<Database>,
Body(body): Body,
) -> Response {
let Ok(request) = serde_json::from_str::<DomainRequest>(&body) else {
return text(400, "Missing request parameters")
};
let Ok(domains) = database.get_domains().await else {
return text(500, "Failed to delete domain")
};
if !domains.contains(&request.domain) {
return text(400, "Domain does not exist")
}
if database.delete_domain(request.domain).await.is_err() {
return text(500, "Failed to delete domain")
};
return text(204, "Successfully deleted domain")
}
async fn add_record(
_: Authorized,
Extension(database): Extension<Database>,
Body(body): Body,
) -> Response {
let Ok(record) = serde_json::from_str::<DnsRecord>(&body) else {
return text(400, "Invalid DNS record")
};
let allowed = record.get_qtype().allowed_actions();
if !allowed.1 {
return text(400, "Not allowed to create record")
}
let Ok(records) = database.get_records(&record.get_domain(), record.get_qtype()).await else {
return text(500, "Failed to complete record check");
};
if !records.is_empty() && !allowed.0 {
return text(400, "Not allowed to create duplicate record")
};
if records.contains(&record) {
return text(400, "Not allowed to create duplicate record")
}
if let Err(err) = database.add_record(record).await {
return text(500, &format!("{err}"));
}
return text(201, "Added record to database successfully");
}
#[derive(Deserialize)]
struct LoginRequest {
user: String,
pass: String,
}
async fn login(
Extension(config): Extension<Config>,
Extension(cache): Extension<Cache<String, IpAddr>>,
RequestIp(ip): RequestIp,
cookies: Cookies,
Body(body): Body,
) -> Response {
let Ok(request) = serde_json::from_str::<LoginRequest>(&body) else {
return text(400, "Missing request parameters")
};
if request.user != config.web_user || request.pass != config.web_pass {
return text(400, "Invalid credentials");
};
let token = Alphanumeric.sample_string(&mut rand::thread_rng(), 128);
cache.insert(token.clone(), ip).await;
let mut cookie = Cookie::new("auth", token);
cookie.set_secure(true);
cookie.set_http_only(true);
cookie.set_path("/");
cookies.add(cookie);
text(200, "Successfully logged in")
}

139
src/web/extract.rs Normal file
View file

@ -0,0 +1,139 @@
use std::{
io::Read,
net::{IpAddr, SocketAddr},
};
use axum::{
async_trait,
body::HttpBody,
extract::{ConnectInfo, FromRequest, FromRequestParts},
http::{request::Parts, Request},
response::Response,
BoxError,
};
use bytes::Bytes;
use moka::future::Cache;
use tower_cookies::Cookies;
use super::http::text;
pub struct Authorized;
#[async_trait]
impl<S> FromRequestParts<S> for Authorized
where
S: Send + Sync,
{
type Rejection = Response;
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
let Ok(Some(cookies)) = Option::<Cookies>::from_request_parts(parts, state).await else {
return Err(text(403, "No cookies provided"))
};
let Some(token) = cookies.get("auth") else {
return Err(text(403, "No auth token provided"))
};
let auth_ip: IpAddr;
{
let Some(cache) = parts.extensions.get::<Cache<String, IpAddr>>() else {
return Err(text(500, "Failed to load auth store"))
};
let Some(ip) = cache.get(token.value()) else {
return Err(text(401, "Unauthorized"))
};
auth_ip = ip
}
let Ok(Some(RequestIp(ip))) = Option::<RequestIp>::from_request_parts(parts, state).await else {
return Err(text(403, "You have no ip"))
};
if auth_ip != ip {
return Err(text(403, "Auth token does not match current ip"));
}
Ok(Self)
}
}
pub struct RequestIp(pub IpAddr);
#[async_trait]
impl<S> FromRequestParts<S> for RequestIp
where
S: Send + Sync,
{
type Rejection = Response;
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
let headers = &parts.headers;
let forwardedfor = headers
.get("x-forwarded-for")
.and_then(|h| h.to_str().ok())
.and_then(|h| {
h.split(',')
.rev()
.find_map(|s| s.trim().parse::<IpAddr>().ok())
});
if let Some(forwardedfor) = forwardedfor {
return Ok(Self(forwardedfor));
}
let realip = headers
.get("x-real-ip")
.and_then(|hv| hv.to_str().ok())
.and_then(|s| s.parse::<IpAddr>().ok());
if let Some(realip) = realip {
return Ok(Self(realip));
}
let realip = headers
.get("x-real-ip")
.and_then(|hv| hv.to_str().ok())
.and_then(|s| s.parse::<IpAddr>().ok());
if let Some(realip) = realip {
return Ok(Self(realip));
}
let info = parts.extensions.get::<ConnectInfo<SocketAddr>>();
if let Some(info) = info {
return Ok(Self(info.0.ip()));
}
Err(text(403, "You have no ip"))
}
}
pub struct Body(pub String);
#[async_trait]
impl<S, B> FromRequest<S, B> for Body
where
B: HttpBody + Sync + Send + 'static,
B::Data: Send,
B::Error: Into<BoxError>,
S: Send + Sync,
{
type Rejection = Response;
async fn from_request(req: Request<B>, state: &S) -> Result<Self, Self::Rejection> {
let Ok(bytes) = Bytes::from_request(req, state).await else {
return Err(text(413, "Payload too large"));
};
let Ok(body) = String::from_utf8(bytes.bytes().flatten().collect()) else {
return Err(text(400, "Invalid utf8 body"))
};
Ok(Self(body))
}
}

31
src/web/file.rs Normal file
View file

@ -0,0 +1,31 @@
use axum::{extract::Path, response::Response};
use super::http::serve;
pub async fn js(Path(path): Path<String>) -> Response {
let path = format!("/js/{path}");
serve(&path).await
}
pub async fn css(Path(path): Path<String>) -> Response {
let path = format!("/css/{path}");
serve(&path).await
}
pub async fn fonts(Path(path): Path<String>) -> Response {
let path = format!("/fonts/{path}");
serve(&path).await
}
pub async fn image(Path(path): Path<String>) -> Response {
let path = format!("/image/{path}");
serve(&path).await
}
pub async fn favicon() -> Response {
serve("/favicon.ico").await
}
pub async fn robots() -> Response {
serve("/robots.txt").await
}

50
src/web/http.rs Normal file
View file

@ -0,0 +1,50 @@
use axum::{
body::Body,
http::{header::HeaderName, HeaderValue, Request, StatusCode},
response::{IntoResponse, Response},
};
use std::str;
use tower::ServiceExt;
use tower_http::services::ServeFile;
pub fn text(code: u16, msg: &str) -> Response {
(status_code(code), msg.to_owned()).into_response()
}
pub fn json(code: u16, json: &str) -> Response {
let mut res = (status_code(code), json.to_owned()).into_response();
res.headers_mut().insert(
HeaderName::from_static("content-type"),
HeaderValue::from_static("application/json"),
);
res
}
pub async fn serve(path: &str) -> Response {
if !path.chars().any(|c| c == '.') {
return text(403, "Invalid file path");
}
let path = format!("public{path}");
let file = ServeFile::new(path);
let Ok(mut res) = file.oneshot(Request::new(Body::empty())).await else {
tracing::error!("Error while fetching file");
return text(500, "Error when fetching file")
};
if res.status() != StatusCode::OK {
return text(404, "File not found");
}
res.headers_mut().insert(
HeaderName::from_static("cache-control"),
HeaderValue::from_static("max-age=300"),
);
res.into_response()
}
fn status_code(code: u16) -> StatusCode {
StatusCode::from_u16(code).map_or(StatusCode::OK, |code| code)
}

82
src/web/mod.rs Normal file
View file

@ -0,0 +1,82 @@
use std::net::{IpAddr, SocketAddr, TcpListener};
use std::time::Duration;
use axum::routing::get;
use axum::{Extension, Router};
use moka::future::Cache;
use tokio::task::JoinHandle;
use tower_cookies::CookieManagerLayer;
use tracing::{error, info};
use crate::config::Config;
use crate::database::Database;
use crate::Result;
mod api;
mod extract;
mod file;
mod http;
mod pages;
pub struct WebServer {
config: Config,
database: Database,
addr: SocketAddr,
}
impl WebServer {
pub async fn new(config: Config, database: Database) -> Result<Self> {
let addr = format!("[::]:{}", config.web_port).parse::<SocketAddr>()?;
Ok(Self {
config,
database,
addr,
})
}
pub async fn run(&self) -> Result<JoinHandle<()>> {
let config = self.config.clone();
let database = self.database.clone();
let listener = TcpListener::bind(self.addr)?;
info!(
"Listening for HTTP traffic on [::]:{}",
self.config.web_port
);
let app = Self::router(config, database);
let server = axum::Server::from_tcp(listener)?;
let web_handle = tokio::spawn(async move {
if let Err(err) = server
.serve(app.into_make_service_with_connect_info::<SocketAddr>())
.await
{
error!("{err}");
}
});
Ok(web_handle)
}
fn router(config: Config, database: Database) -> Router {
let cache: Cache<String, IpAddr> = Cache::builder()
.time_to_live(Duration::from_secs(60 * 15))
.max_capacity(config.dns_cache_size)
.build();
Router::new()
.nest("/", pages::router())
.nest("/api", api::router())
.layer(Extension(config))
.layer(Extension(cache))
.layer(Extension(database))
.layer(CookieManagerLayer::new())
.route("/js/*path", get(file::js))
.route("/css/*path", get(file::css))
.route("/fonts/*path", get(file::fonts))
.route("/image/*path", get(file::image))
.route("/favicon.ico", get(file::favicon))
.route("/robots.txt", get(file::robots))
}
}

31
src/web/pages.rs Normal file
View file

@ -0,0 +1,31 @@
use axum::{response::Response, routing::get, Router};
use super::{extract::Authorized, http::serve};
pub fn router() -> Router {
Router::new()
.route("/", get(root))
.route("/login", get(login))
.route("/home", get(home))
.route("/domain", get(domain))
}
async fn root(user: Option<Authorized>) -> Response {
if user.is_some() {
home().await
} else {
login().await
}
}
async fn login() -> Response {
serve("/login.html").await
}
async fn home() -> Response {
serve("/home.html").await
}
async fn domain() -> Response {
serve("/domain.html").await
}