finish dns and start webserver
This commit is contained in:
parent
0f40ab89e3
commit
b1fb410aff
42 changed files with 3093 additions and 202 deletions
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -1 +1,2 @@
|
|||
**/target
|
||||
.env
|
1321
Cargo.lock
generated
1321
Cargo.lock
generated
File diff suppressed because it is too large
Load diff
30
Cargo.toml
30
Cargo.toml
|
@ -4,8 +4,34 @@ version = "0.1.0"
|
|||
edition = "2021"
|
||||
|
||||
[dependencies]
|
||||
# Blazingly fast runtime
|
||||
tokio = { version = "1", features = ["full"] }
|
||||
async-recursion = "1"
|
||||
tracing = "0.1.37"
|
||||
tracing-subscriber = "0.3.16"
|
||||
tracing = "0.1.37"
|
||||
|
||||
# Allow recursion inside tokio async
|
||||
async-recursion = "1"
|
||||
|
||||
# DNS Caching Layer
|
||||
moka = { version = "0.10.0", features = ["future"] }
|
||||
|
||||
# Mongodb
|
||||
mongodb = { version = "2.4", features = ["tokio-sync"] }
|
||||
futures = "0.3.26"
|
||||
|
||||
# Convert values to json for Mongodb
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
|
||||
# Reading env vars from .env
|
||||
dotenv = "0.15.0"
|
||||
|
||||
# For the meme records
|
||||
rand = "0.8.5"
|
||||
|
||||
# For the http web frontend
|
||||
axum = "0.6.4"
|
||||
tower-http = { version = "0.4.0", features = ["fs"] }
|
||||
tower-cookies = "0.9.0"
|
||||
tower = "0.4.13"
|
||||
bytes = "1.4.0"
|
||||
serde_json = "1"
|
40
public/css/home.css
Normal file
40
public/css/home.css
Normal file
|
@ -0,0 +1,40 @@
|
|||
span {
|
||||
margin-top: 5rem;
|
||||
margin-bottom: 1rem;
|
||||
width: 45rem;
|
||||
font-size: 2em;
|
||||
}
|
||||
|
||||
#new {
|
||||
display: flex;
|
||||
justify-content: center;
|
||||
width: 100%;
|
||||
padding-top: 2rem;
|
||||
padding-bottom: 1rem;
|
||||
border-bottom: solid 1px var(--gray);
|
||||
}
|
||||
|
||||
#new input, .block {
|
||||
border-radius: 1rem 0 0 1rem;
|
||||
width: 40rem;
|
||||
}
|
||||
|
||||
.block {
|
||||
width: 33em;
|
||||
}
|
||||
|
||||
#new button {
|
||||
border-radius: 0 1rem 1rem 0;
|
||||
}
|
||||
|
||||
.domain {
|
||||
margin-top: 2rem;
|
||||
}
|
||||
|
||||
.domain .delete {
|
||||
border-radius: 0 1rem 1rem 0;
|
||||
}
|
||||
|
||||
.domain .edit {
|
||||
border-radius: 0;
|
||||
}
|
18
public/css/login.css
Normal file
18
public/css/login.css
Normal file
|
@ -0,0 +1,18 @@
|
|||
#login {
|
||||
margin-top: 20em;
|
||||
}
|
||||
|
||||
#logo {
|
||||
font-size: 6em;
|
||||
font-weight: 750;
|
||||
font-family: bold;
|
||||
margin-bottom: 2rem;
|
||||
}
|
||||
|
||||
form {
|
||||
width: 30rem;
|
||||
}
|
||||
|
||||
form input {
|
||||
width: 100%;
|
||||
}
|
119
public/css/main.css
Normal file
119
public/css/main.css
Normal file
|
@ -0,0 +1,119 @@
|
|||
:root {
|
||||
--dark: #222428;
|
||||
--dark-alternate: #2b2e36;
|
||||
--header: #1e1e22;
|
||||
|
||||
--accent: #8849f5;
|
||||
--accent-alternate: #6829d5;
|
||||
--gray: #2f2f3f;
|
||||
--main: #ffffff;
|
||||
--main-alternate: #cccccc;
|
||||
}
|
||||
|
||||
* {
|
||||
padding: 0;
|
||||
margin: 0;
|
||||
}
|
||||
|
||||
@font-face {
|
||||
font-family: main;
|
||||
src: url("../fonts/helvetica.ttf") format("truetype");
|
||||
font-display: swap;
|
||||
}
|
||||
|
||||
@font-face {
|
||||
font-family: bold;
|
||||
src: url("../fonts/overpass-bold.otf") format("opentype");
|
||||
font-display: swap;
|
||||
}
|
||||
|
||||
@font-face {
|
||||
font-family: bold-italic;
|
||||
src: url("../fonts/overpass-bold-italic.otf") format("opentype");
|
||||
font-display: swap;
|
||||
}
|
||||
|
||||
html {
|
||||
background-color: var(--dark);
|
||||
font-family: main;
|
||||
color: var(--main);
|
||||
width: 100%;
|
||||
height: 100%;
|
||||
}
|
||||
|
||||
body {
|
||||
width: 100%;
|
||||
height: 100%;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
align-items: center;
|
||||
}
|
||||
|
||||
.accent {
|
||||
color: var(--accent);
|
||||
}
|
||||
|
||||
.fill {
|
||||
width: 100%;
|
||||
height: 100%;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
align-items: center;
|
||||
}
|
||||
|
||||
input, button, .block {
|
||||
all: unset;
|
||||
display: inline-block;
|
||||
font: main;
|
||||
background-color: var(--dark-alternate);
|
||||
font-size: 1rem;
|
||||
padding: 1rem;
|
||||
border-radius: 1rem;
|
||||
margin-bottom: 20px;
|
||||
}
|
||||
|
||||
button {
|
||||
background-color: var(--accent);
|
||||
width: 5em;
|
||||
text-align: center;
|
||||
}
|
||||
|
||||
button:hover {
|
||||
cursor: pointer;
|
||||
background-color: var(--accent-alternate);
|
||||
}
|
||||
|
||||
.delete {
|
||||
background-color: #f54842;
|
||||
}
|
||||
|
||||
.delete:hover {
|
||||
cursor: pointer;
|
||||
background-color: #d52822;
|
||||
}
|
||||
|
||||
form {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
}
|
||||
|
||||
#header {
|
||||
width: calc(100% - 4rem);
|
||||
background-color: var(--header);
|
||||
border-bottom: solid 1px var(--gray);
|
||||
padding: 1rem;
|
||||
padding-left: 3rem;
|
||||
}
|
||||
|
||||
#logo {
|
||||
font-size: 2em;
|
||||
font-weight: 500;
|
||||
font-family: bold;
|
||||
}
|
||||
|
||||
#title {
|
||||
font-size: 2em;
|
||||
font-weight: 300;
|
||||
font-family: sans-serif;
|
||||
padding-left: 1em;
|
||||
}
|
67
public/css/record.css
Normal file
67
public/css/record.css
Normal file
|
@ -0,0 +1,67 @@
|
|||
#buttons {
|
||||
margin-top: 2rem;
|
||||
width: 50rem;
|
||||
}
|
||||
|
||||
#buttons button {
|
||||
margin: 0;
|
||||
margin-right: 2rem;
|
||||
border-radius: 10px;
|
||||
width: auto;
|
||||
padding: .75rem 1rem;
|
||||
}
|
||||
|
||||
.record {
|
||||
width: 50rem;
|
||||
background-color: var(--header);
|
||||
padding: 1rem;
|
||||
margin-top: 2rem;
|
||||
}
|
||||
|
||||
.header {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
margin-bottom: 1rem;
|
||||
}
|
||||
|
||||
.header span {
|
||||
font-family: bold;
|
||||
}
|
||||
|
||||
.header button {
|
||||
margin: 0;
|
||||
margin-left: 2rem;
|
||||
padding: .5rem 1rem;
|
||||
width: auto;
|
||||
border-radius: 5px;
|
||||
}
|
||||
|
||||
.type {
|
||||
margin-right: 1rem;
|
||||
background-color: var(--accent);
|
||||
padding: .25rem .5rem;
|
||||
border-radius: 5px;
|
||||
}
|
||||
|
||||
.domain {
|
||||
color: var(--main-alternate);
|
||||
flex-grow: 1;
|
||||
}
|
||||
|
||||
.properties {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
}
|
||||
|
||||
.poperty {
|
||||
display: flex;
|
||||
flex-direction: row;
|
||||
border-bottom: solid 1px var(--gray);
|
||||
margin-top: 1rem;
|
||||
}
|
||||
|
||||
.key {
|
||||
font-family: bold;
|
||||
width: 5rem;
|
||||
}
|
||||
|
21
public/domain.html
Normal file
21
public/domain.html
Normal file
|
@ -0,0 +1,21 @@
|
|||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>Wrapper - Records</title>
|
||||
|
||||
<meta name="author" content="Tyler Murphy">
|
||||
<meta name="description" content="wrapper records">
|
||||
|
||||
<meta property="og:title" content="wrapper">
|
||||
<meta property="og:description" content="wrapper records">
|
||||
|
||||
<link rel="stylesheet" href="/css/main.css">
|
||||
<link rel="stylesheet" href="/css/record.css">
|
||||
|
||||
<script type="module" src="/js/domain.js"></script>
|
||||
</head>
|
||||
<body>
|
||||
</body>
|
||||
</html>
|
BIN
public/fonts/helvetica-bold.ttf
Normal file
BIN
public/fonts/helvetica-bold.ttf
Normal file
Binary file not shown.
BIN
public/fonts/helvetica.ttf
Normal file
BIN
public/fonts/helvetica.ttf
Normal file
Binary file not shown.
BIN
public/fonts/overpass-bold-italic.otf
Normal file
BIN
public/fonts/overpass-bold-italic.otf
Normal file
Binary file not shown.
BIN
public/fonts/overpass-bold.otf
Normal file
BIN
public/fonts/overpass-bold.otf
Normal file
Binary file not shown.
21
public/home.html
Normal file
21
public/home.html
Normal file
|
@ -0,0 +1,21 @@
|
|||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>Wrapper - Domains</title>
|
||||
|
||||
<meta name="author" content="Tyler Murphy">
|
||||
<meta name="description" content="wrapper domains">
|
||||
|
||||
<meta property="og:title" content="wrapper">
|
||||
<meta property="og:description" content="wrapper domains">
|
||||
|
||||
<link rel="stylesheet" href="/css/main.css">
|
||||
<link rel="stylesheet" href="/css/home.css">
|
||||
|
||||
<script type="module" src="/js/home.js"></script>
|
||||
</head>
|
||||
<body>
|
||||
</body>
|
||||
</html>
|
51
public/js/api.js
Normal file
51
public/js/api.js
Normal file
|
@ -0,0 +1,51 @@
|
|||
const endpoint = '/api'
|
||||
|
||||
const request = async (url, method, body) => {
|
||||
|
||||
let response;
|
||||
|
||||
if (method == 'GET') {
|
||||
response = await fetch(endpoint + url, {
|
||||
method,
|
||||
headers: {
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
});
|
||||
} else {
|
||||
response = await fetch(endpoint + url, {
|
||||
method,
|
||||
body: JSON.stringify(body),
|
||||
headers: {
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
if (response.status == 401) {
|
||||
location.href = '/login'
|
||||
}
|
||||
const contentType = response.headers.get("content-type");
|
||||
if (contentType && contentType.indexOf("application/json") !== -1) {
|
||||
const json = await response.json()
|
||||
return { status: response.status, msg: json.msg, json }
|
||||
} else {
|
||||
const msg = await response.text();
|
||||
return { status: response.status, msg }
|
||||
}
|
||||
}
|
||||
|
||||
export const login = async (user, pass) => {
|
||||
return await request('/login', 'POST', {user, pass})
|
||||
}
|
||||
|
||||
export const domains = async () => {
|
||||
return await request('/domains', 'GET')
|
||||
}
|
||||
|
||||
export const del_domain = async (domain) => {
|
||||
return await request('/domains', 'DELETE', {domain})
|
||||
}
|
||||
|
||||
export const records = async (domain) => {
|
||||
return await request(`/records?domain=${domain}`, 'GET')
|
||||
}
|
12
public/js/components.js
Normal file
12
public/js/components.js
Normal file
|
@ -0,0 +1,12 @@
|
|||
import { div, parse, span } from './main.js';
|
||||
|
||||
export function header(title) {
|
||||
return div({id: 'header'},
|
||||
span({id: 'logo', class: 'accent'},
|
||||
parse("Wrapper")
|
||||
),
|
||||
span({id: 'title'},
|
||||
parse(title)
|
||||
),
|
||||
)
|
||||
}
|
95
public/js/domain.js
Normal file
95
public/js/domain.js
Normal file
|
@ -0,0 +1,95 @@
|
|||
import { del_domain, domains, records } from './api.js'
|
||||
import { header } from './components.js'
|
||||
import { body, parse, div, input, button, span, is_domain } from './main.js';
|
||||
|
||||
function render(domain, records) {
|
||||
|
||||
let divs = []
|
||||
for (const record of records) {
|
||||
divs.push(gen_record(record))
|
||||
}
|
||||
|
||||
document.body.replaceWith(
|
||||
body({},
|
||||
header(domain),
|
||||
div({id: 'buttons'},
|
||||
button({onclick: (event) => {
|
||||
location.href = '/home'
|
||||
}}, parse("Home")),
|
||||
button({}, parse("New Record")),
|
||||
),
|
||||
...divs
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
function gen_record(record) {
|
||||
let domain = record.domain
|
||||
let prefix = record.prefix
|
||||
|
||||
if (prefix.length > 0) {
|
||||
prefix = prefix + '.'
|
||||
}
|
||||
|
||||
let type = Object.keys(record.record)[0]
|
||||
let data = record.record[type]
|
||||
|
||||
let divs = []
|
||||
for (const key in data) {
|
||||
let disp_key;
|
||||
if (key == 'ttl') {
|
||||
disp_key = 'TTL'
|
||||
} else {
|
||||
disp_key = upper(key)
|
||||
}
|
||||
divs.push(
|
||||
div({class: 'poperty'},
|
||||
div({class: 'key'}, parse(disp_key)),
|
||||
div({class: 'value'}, parse(data[key])),
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
return div({class: 'record'},
|
||||
div({class: 'header'},
|
||||
span({class: 'type'}, parse(type)),
|
||||
span({class: 'prefix'}, parse(prefix)),
|
||||
span({class: 'domain'}, parse(domain)),
|
||||
button({}, parse("Edit")),
|
||||
button({class: 'delete'}, parse("Delete"))
|
||||
),
|
||||
div({class: 'properties'},
|
||||
...divs
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
function upper(string) {
|
||||
return string.charAt(0).toUpperCase() + string.slice(1);
|
||||
}
|
||||
|
||||
async function init() {
|
||||
|
||||
const params = new Proxy(new URLSearchParams(window.location.search), {
|
||||
get: (searchParams, prop) => searchParams.get(prop),
|
||||
});
|
||||
|
||||
let domain = params.domain;
|
||||
|
||||
if (!is_domain(domain)) {
|
||||
location.href = '/home'
|
||||
return
|
||||
}
|
||||
|
||||
let res = await records(domain);
|
||||
|
||||
if (res.status !== 200) {
|
||||
alert(res.msg)
|
||||
return
|
||||
}
|
||||
|
||||
render(domain, res.json)
|
||||
|
||||
}
|
||||
|
||||
init()
|
91
public/js/home.js
Normal file
91
public/js/home.js
Normal file
|
@ -0,0 +1,91 @@
|
|||
import { del_domain, domains } from './api.js'
|
||||
import { header } from './components.js'
|
||||
import { body, parse, div, input, button, span, is_domain } from './main.js';
|
||||
|
||||
function render(domains) {
|
||||
document.body.replaceWith(
|
||||
body({},
|
||||
header('domains'),
|
||||
div({id: 'new'},
|
||||
input({
|
||||
type: 'text',
|
||||
name: 'domain',
|
||||
id: 'domain',
|
||||
placeholder: 'Type domain name to create new records',
|
||||
autocomplete: "off",
|
||||
}),
|
||||
button({onclick: () => {
|
||||
let domain = document.getElementById('domain').value
|
||||
|
||||
if (!is_domain(domain)) {
|
||||
alert("Invalid domain")
|
||||
return
|
||||
}
|
||||
|
||||
location.href = '/domain?domain='+domain
|
||||
}},
|
||||
parse("Create")
|
||||
)
|
||||
),
|
||||
...domain(domains)
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
function domain(domains) {
|
||||
let divs = []
|
||||
for (const domain of domains) {
|
||||
divs.push(
|
||||
div({class: 'domain'},
|
||||
div({class: 'block'},
|
||||
parse(domain)
|
||||
),
|
||||
button({class: 'edit', onclick: (event) => {
|
||||
console.log(event.target.parentElement)
|
||||
let domain = event
|
||||
.target
|
||||
.parentElement
|
||||
.getElementsByClassName('block')[0]
|
||||
.innerText
|
||||
|
||||
if (!is_domain(domain)) {
|
||||
alert("Invalid domain")
|
||||
return
|
||||
}
|
||||
|
||||
location.href = '/domain?domain='+domain
|
||||
}},
|
||||
parse("Edit")
|
||||
),
|
||||
button({class: 'delete', onclick: async () => {
|
||||
let res = await del_domain(domain)
|
||||
|
||||
if (res.status != 204) {
|
||||
alert(res.msg)
|
||||
return
|
||||
}
|
||||
|
||||
location.reload()
|
||||
}},
|
||||
parse("Delete")
|
||||
)
|
||||
)
|
||||
)
|
||||
}
|
||||
return divs
|
||||
}
|
||||
|
||||
async function init() {
|
||||
|
||||
let res = await domains();
|
||||
|
||||
if (res.status !== 200) {
|
||||
alert(res.msg)
|
||||
return
|
||||
}
|
||||
|
||||
render(res.json)
|
||||
|
||||
}
|
||||
|
||||
init()
|
44
public/js/login.js
Normal file
44
public/js/login.js
Normal file
|
@ -0,0 +1,44 @@
|
|||
import { body, div, form, input, p, parse, span} from './main.js'
|
||||
import { login } from './api.js'
|
||||
|
||||
function render() {
|
||||
document.body.replaceWith(
|
||||
body({},
|
||||
div({id: 'login', class: 'fill'},
|
||||
span({id: 'logo'},
|
||||
span({class: 'accent'}, parse('Wrapper'))
|
||||
),
|
||||
form({autocomplete: "off"},
|
||||
input({
|
||||
type: 'text',
|
||||
name: 'user',
|
||||
id: 'user',
|
||||
placeholder: 'Username',
|
||||
autofocus: 1
|
||||
}),
|
||||
input({
|
||||
type: 'password',
|
||||
name: 'pass',
|
||||
id: 'pass',
|
||||
placeholder: 'Password',
|
||||
onkeydown: async (event) => {
|
||||
if (event.key == 'Enter') {
|
||||
event.preventDefault()
|
||||
let user = document.getElementById('user').value
|
||||
let pass = document.getElementById('pass').value
|
||||
|
||||
let res = await login(user, pass)
|
||||
|
||||
if (res.status === 200) {
|
||||
location.href = '/home'
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
render()
|
136
public/js/main.js
Normal file
136
public/js/main.js
Normal file
|
@ -0,0 +1,136 @@
|
|||
function createElement(name, attrs, ...children) {
|
||||
const el = document.createElement(name);
|
||||
|
||||
for (const attr in attrs) {
|
||||
if(attr.startsWith("on")) {
|
||||
el[attr] = attrs[attr];
|
||||
} else {
|
||||
el.setAttribute(attr, attrs[attr])
|
||||
}
|
||||
}
|
||||
|
||||
for (const child of children) {
|
||||
if (child == null) {
|
||||
continue
|
||||
}
|
||||
el.appendChild(child)
|
||||
}
|
||||
|
||||
return el
|
||||
}
|
||||
|
||||
export function createElementNS(name, attrs, ...children) {
|
||||
var svgns = "http://www.w3.org/2000/svg";
|
||||
var el = document.createElementNS(svgns, name);
|
||||
|
||||
for (const attr in attrs) {
|
||||
if(attr.startsWith("on")) {
|
||||
el[attr] = attrs[attr];
|
||||
} else {
|
||||
el.setAttribute(attr, attrs[attr])
|
||||
}
|
||||
}
|
||||
|
||||
for (const child of children) {
|
||||
if (child == null) {
|
||||
continue
|
||||
}
|
||||
el.appendChild(child)
|
||||
}
|
||||
|
||||
return el
|
||||
}
|
||||
|
||||
export function p(attrs, ...children) {
|
||||
return createElement("p", attrs, ...children)
|
||||
}
|
||||
|
||||
export function span(attrs, ...children) {
|
||||
return createElement("span", attrs, ...children)
|
||||
}
|
||||
|
||||
export function div(attrs, ...children) {
|
||||
return createElement("div", attrs, ...children)
|
||||
}
|
||||
|
||||
export function a(attrs, ...children) {
|
||||
return createElement("a", attrs, ...children)
|
||||
}
|
||||
|
||||
export function i(attrs, ...children) {
|
||||
return createElement("i", attrs, ...children)
|
||||
}
|
||||
|
||||
export function form(attrs, ...children) {
|
||||
return createElement("form", attrs, ...children)
|
||||
}
|
||||
|
||||
export function img(alt, attrs, ...children) {
|
||||
attrs['onerror'] = (event) => event.target.remove()
|
||||
attrs['alt'] = alt
|
||||
return createElement("img", attrs, ...children)
|
||||
}
|
||||
|
||||
export function input(attrs, ...children) {
|
||||
return createElement("input", attrs, ...children)
|
||||
}
|
||||
|
||||
export function button(attrs, ...children) {
|
||||
return createElement("button", attrs, ...children)
|
||||
}
|
||||
|
||||
export function path(attrs, ...children) {
|
||||
return createElementNS("path", attrs, ...children)
|
||||
}
|
||||
|
||||
export function svg(attrs, ...children) {
|
||||
return createElementNS("svg", attrs, ...children)
|
||||
}
|
||||
|
||||
export function body(attrs, ...children) {
|
||||
return createElement("body", attrs, ...children)
|
||||
}
|
||||
|
||||
export function textarea(attrs, ...children) {
|
||||
return createElement("textarea", attrs, ...children)
|
||||
}
|
||||
|
||||
export function parse(input) {
|
||||
const pattern = /^[ a-zA-Z0-9!@#$%^&*()_+\-=\[\]{};':"\\|,.<>\/?]*$/;
|
||||
|
||||
input = input + '';
|
||||
|
||||
if (!pattern.test(input)) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const sanitized = input.replace(/</g, '<').replace(/>/g, '>');
|
||||
return document.createRange().createContextualFragment(sanitized);
|
||||
}
|
||||
|
||||
export function is_domain(domain) {
|
||||
domain = domain.toLowerCase()
|
||||
|
||||
const pattern = /^[a-z0-9_\-.]*$/;
|
||||
if (!pattern.test(domain)) {
|
||||
return false
|
||||
}
|
||||
|
||||
let parts = domain.split('.').reverse()
|
||||
for (const part of parts) {
|
||||
if (part.length < 1) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
if (parts.length < 2 || parts[0].length < 2) {
|
||||
return false
|
||||
}
|
||||
|
||||
const tld_pattern = /^[a-z]*$/;
|
||||
if (!tld_pattern.test(parts[0])) {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
21
public/login.html
Normal file
21
public/login.html
Normal file
|
@ -0,0 +1,21 @@
|
|||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>Wrapper - Login</title>
|
||||
|
||||
<meta name="author" content="Tyler Murphy">
|
||||
<meta name="description" content="wrapper dns login">
|
||||
|
||||
<meta property="og:title" content="wrapper">
|
||||
<meta property="og:description" content="wrapper dns login">
|
||||
|
||||
<link rel="stylesheet" href="/css/main.css">
|
||||
<link rel="stylesheet" href="/css/login.css">
|
||||
|
||||
<script type="module" src="/js/login.js"></script>
|
||||
</head>
|
||||
<body>
|
||||
</body>
|
||||
</html>
|
9
public/robots.txt
Normal file
9
public/robots.txt
Normal file
|
@ -0,0 +1,9 @@
|
|||
User-agent: Googlebot
|
||||
Disallow: /api
|
||||
|
||||
User-agent: Googlebot
|
||||
User-agent: AdsBot-Google
|
||||
Disallow: /api
|
||||
|
||||
User-agent: *
|
||||
Disallow: /api
|
|
@ -1,35 +1,57 @@
|
|||
use std::net::IpAddr;
|
||||
use std::{env, net::IpAddr, str::FromStr, fmt::Display};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Config {
|
||||
fallback: IpAddr,
|
||||
port: u16,
|
||||
pub dns_fallback: IpAddr,
|
||||
pub dns_port: u16,
|
||||
pub dns_cache_size: u64,
|
||||
|
||||
pub db_host: String,
|
||||
pub db_port: u16,
|
||||
pub db_user: String,
|
||||
pub db_pass: String,
|
||||
|
||||
pub web_user: String,
|
||||
pub web_pass: String,
|
||||
pub web_port: u16,
|
||||
}
|
||||
|
||||
impl Config {
|
||||
pub fn new() -> Self {
|
||||
let fallback = "9.9.9.9"
|
||||
.parse::<IpAddr>()
|
||||
.expect("Failed to create default ns fallback");
|
||||
let dns_port = Self::get_var::<u16>("WRAPPER_DNS_PORT", 53);
|
||||
let dns_fallback = Self::get_var::<IpAddr>("WRAPPER_FALLBACK_DNS", [9, 9, 9, 9].into());
|
||||
let dns_cache_size = Self::get_var::<u64>("WRAPPER_CACHE_SIZE", 1000);
|
||||
|
||||
let db_host = Self::get_var::<String>("WRAPPER_DB_HOST", String::from("localhost"));
|
||||
let db_port = Self::get_var::<u16>("WRAPPER_DB_PORT", 27017);
|
||||
let db_user = Self::get_var::<String>("WRAPPER_DB_USER", String::from("root"));
|
||||
let db_pass = Self::get_var::<String>("WRAPPER_DB_PASS", String::from(""));
|
||||
|
||||
let web_user = Self::get_var::<String>("WRAPPER_WEB_USER", String::from("admin"));
|
||||
let web_pass = Self::get_var::<String>("WRAPPER_WEB_PASS", String::from("wrapper"));
|
||||
let web_port = Self::get_var::<u16>("WRAPPER_WEB_PORT", 80);
|
||||
|
||||
Self {
|
||||
fallback,
|
||||
port: 2000,
|
||||
dns_fallback,
|
||||
dns_port,
|
||||
dns_cache_size,
|
||||
|
||||
db_host,
|
||||
db_port,
|
||||
db_user,
|
||||
db_pass,
|
||||
|
||||
web_user,
|
||||
web_pass,
|
||||
web_port,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_fallback_ns(&self) -> &IpAddr {
|
||||
&self.fallback
|
||||
}
|
||||
|
||||
pub fn get_port(&self) -> u16 {
|
||||
self.port
|
||||
}
|
||||
|
||||
pub fn set_fallback_ns(&mut self, addr: &IpAddr) {
|
||||
self.fallback = *addr;
|
||||
}
|
||||
|
||||
pub fn set_port(&mut self, port: u16) {
|
||||
self.port = port;
|
||||
fn get_var<T>(name: &str, default: T) -> T
|
||||
where
|
||||
T: FromStr + Display,
|
||||
{
|
||||
let env = env::var(name).unwrap_or(format!("{default}"));
|
||||
env.parse::<T>().unwrap_or(default)
|
||||
}
|
||||
}
|
||||
|
|
146
src/database/mod.rs
Normal file
146
src/database/mod.rs
Normal file
|
@ -0,0 +1,146 @@
|
|||
use futures::TryStreamExt;
|
||||
use mongodb::{
|
||||
bson::doc,
|
||||
options::{ClientOptions, Credential, ServerAddress},
|
||||
Client,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tracing::info;
|
||||
|
||||
use crate::{
|
||||
config::Config,
|
||||
dns::packet::{query::QueryType, record::DnsRecord},
|
||||
};
|
||||
|
||||
use crate::Result;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Database {
|
||||
client: Client,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct StoredRecord {
|
||||
record: DnsRecord,
|
||||
domain: String,
|
||||
prefix: String,
|
||||
}
|
||||
|
||||
impl StoredRecord {
|
||||
fn get_domain_parts(domain: &str) -> (String, String) {
|
||||
let parts: Vec<&str> = domain.split(".").collect();
|
||||
let len = parts.len();
|
||||
if len == 1 {
|
||||
(String::new(), String::from(parts[0]))
|
||||
} else if len == 2 {
|
||||
(String::new(), String::from(parts.join(".")))
|
||||
} else {
|
||||
(
|
||||
String::from(parts[0..len - 2].join(".")),
|
||||
String::from(parts[len - 2..len].join(".")),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<DnsRecord> for StoredRecord {
|
||||
fn from(record: DnsRecord) -> Self {
|
||||
let (prefix, domain) = Self::get_domain_parts(&record.get_domain());
|
||||
Self {
|
||||
record,
|
||||
domain,
|
||||
prefix,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Into<DnsRecord> for StoredRecord {
|
||||
fn into(self) -> DnsRecord {
|
||||
self.record
|
||||
}
|
||||
}
|
||||
|
||||
impl Database {
|
||||
pub async fn new(config: Config) -> Result<Self> {
|
||||
let options = ClientOptions::builder()
|
||||
.hosts(vec![ServerAddress::Tcp {
|
||||
host: config.db_host,
|
||||
port: Some(config.db_port),
|
||||
}])
|
||||
.credential(
|
||||
Credential::builder()
|
||||
.username(config.db_user)
|
||||
.password(config.db_pass)
|
||||
.build(),
|
||||
)
|
||||
.max_pool_size(100)
|
||||
.app_name(String::from("wrapper"))
|
||||
.build();
|
||||
|
||||
let client = Client::with_options(options)?;
|
||||
|
||||
client
|
||||
.database("wrapper")
|
||||
.run_command(doc! {"ping": 1}, None)
|
||||
.await?;
|
||||
|
||||
info!("Connection to mongodb successfully");
|
||||
|
||||
Ok(Database { client })
|
||||
}
|
||||
|
||||
pub async fn get_records(&self, domain: &str, qtype: QueryType) -> Result<Vec<DnsRecord>> {
|
||||
let (prefix, domain) = StoredRecord::get_domain_parts(domain);
|
||||
Ok(self
|
||||
.get_domain(&domain)
|
||||
.await?
|
||||
.into_iter()
|
||||
.filter(|r| r.prefix == prefix)
|
||||
.filter(|r| {
|
||||
let rqtype = r.record.get_qtype();
|
||||
if qtype == QueryType::A {
|
||||
return rqtype == QueryType::A || rqtype == QueryType::AR;
|
||||
} else if qtype == QueryType::AAAA {
|
||||
return rqtype == QueryType::AAAA || rqtype == QueryType::AAAAR;
|
||||
} else {
|
||||
r.record.get_qtype() == qtype
|
||||
}
|
||||
})
|
||||
.map(|r| r.into())
|
||||
.collect())
|
||||
}
|
||||
|
||||
pub async fn get_domain(&self, domain: &str) -> Result<Vec<StoredRecord>> {
|
||||
let db = self.client.database("wrapper");
|
||||
let col = db.collection::<StoredRecord>(domain);
|
||||
|
||||
let filter = doc! { "domain": domain };
|
||||
let mut cursor = col.find(filter, None).await?;
|
||||
|
||||
let mut records = Vec::new();
|
||||
while let Some(record) = cursor.try_next().await? {
|
||||
records.push(record);
|
||||
}
|
||||
|
||||
Ok(records)
|
||||
}
|
||||
|
||||
pub async fn add_record(&self, record: DnsRecord) -> Result<()> {
|
||||
let record = StoredRecord::from(record);
|
||||
let db = self.client.database("wrapper");
|
||||
let col = db.collection::<StoredRecord>(&record.domain);
|
||||
col.insert_one(record, None).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn get_domains(&self) -> Result<Vec<String>> {
|
||||
let db = self.client.database("wrapper");
|
||||
Ok(db.list_collection_names(None).await?)
|
||||
}
|
||||
|
||||
pub async fn delete_domain(&self, domain: String) -> Result<()> {
|
||||
let db = self.client.database("wrapper");
|
||||
let col = db.collection::<StoredRecord>(&domain);
|
||||
Ok(col.drop(None).await?)
|
||||
}
|
||||
}
|
|
@ -3,7 +3,8 @@ use std::{
|
|||
sync::Arc,
|
||||
};
|
||||
|
||||
use crate::packet::{buffer::PacketBuffer, Packet, Result};
|
||||
use super::packet::{buffer::PacketBuffer, Packet};
|
||||
use crate::Result;
|
||||
use tokio::{
|
||||
io::{AsyncReadExt, AsyncWriteExt},
|
||||
net::{TcpListener, TcpStream, UdpSocket},
|
||||
|
@ -140,11 +141,4 @@ impl Connection {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
// fn pb(buf: &[u8]) {
|
||||
// for i in 0..buf.len() {
|
||||
// print!("{:02X?} ", buf[i]);
|
||||
// }
|
||||
// println!("");
|
||||
// }
|
||||
}
|
|
@ -1,3 +1,4 @@
|
|||
mod binding;
|
||||
pub mod packet;
|
||||
mod resolver;
|
||||
pub mod server;
|
|
@ -1,4 +1,4 @@
|
|||
use super::Result;
|
||||
use crate::Result;
|
||||
|
||||
pub struct PacketBuffer {
|
||||
pub buf: Vec<u8>,
|
||||
|
@ -9,19 +9,9 @@ pub struct PacketBuffer {
|
|||
impl PacketBuffer {
|
||||
pub fn new(buf: Vec<u8>) -> Self {
|
||||
Self {
|
||||
size: buf.len(),
|
||||
buf,
|
||||
pos: 0,
|
||||
size: 0,
|
||||
}
|
||||
}
|
||||
|
||||
fn check(&mut self, pos: usize) {
|
||||
if self.size < pos {
|
||||
self.size = pos;
|
||||
}
|
||||
|
||||
if self.buf.len() <= self.size {
|
||||
self.buf.resize(self.size + 1, 0x00);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -42,32 +32,25 @@ impl PacketBuffer {
|
|||
}
|
||||
|
||||
pub fn read(&mut self) -> Result<u8> {
|
||||
// if self.pos >= 512 {
|
||||
// error!("Tried to read past end of buffer");
|
||||
// return Err("End of buffer".into());
|
||||
// }
|
||||
self.check(self.pos);
|
||||
if self.pos >= self.size {
|
||||
return Err("Tried to read past end of buffer".into());
|
||||
}
|
||||
let res = self.buf[self.pos];
|
||||
self.pos += 1;
|
||||
|
||||
Ok(res)
|
||||
}
|
||||
|
||||
pub fn get(&mut self, pos: usize) -> Result<u8> {
|
||||
// if pos >= 512 {
|
||||
// error!("Tried to read past end of buffer");
|
||||
// return Err("End of buffer".into());
|
||||
// }
|
||||
self.check(pos);
|
||||
if pos >= self.size {
|
||||
return Err("Tried to read past end of buffer".into());
|
||||
}
|
||||
Ok(self.buf[pos])
|
||||
}
|
||||
|
||||
pub fn get_range(&mut self, start: usize, len: usize) -> Result<&[u8]> {
|
||||
// if start + len >= 512 {
|
||||
// error!("Tried to read past end of buffer");
|
||||
// return Err("End of buffer".into());
|
||||
// }
|
||||
self.check(start + len);
|
||||
if start + len >= self.size {
|
||||
return Err("Tried to read past end of buffer".into());
|
||||
}
|
||||
Ok(&self.buf[start..start + len])
|
||||
}
|
||||
|
||||
|
@ -169,7 +152,13 @@ impl PacketBuffer {
|
|||
}
|
||||
|
||||
pub fn write(&mut self, val: u8) -> Result<()> {
|
||||
self.check(self.pos);
|
||||
if self.size < self.pos {
|
||||
self.size = self.pos;
|
||||
}
|
||||
|
||||
if self.buf.len() <= self.size {
|
||||
self.buf.resize(self.size + 1, 0x00);
|
||||
}
|
||||
|
||||
self.buf[self.pos] = val;
|
||||
self.pos += 1;
|
||||
|
@ -208,7 +197,9 @@ impl PacketBuffer {
|
|||
}
|
||||
}
|
||||
|
||||
self.write_u8(0)?;
|
||||
if !qname.is_empty() {
|
||||
self.write_u8(0)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
|
@ -1,4 +1,5 @@
|
|||
use super::{buffer::PacketBuffer, result::ResultCode, Result};
|
||||
use super::{buffer::PacketBuffer, result::ResultCode};
|
||||
use crate::Result;
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct DnsHeader {
|
|
@ -4,9 +4,7 @@ use self::{
|
|||
buffer::PacketBuffer, header::DnsHeader, query::QueryType, question::DnsQuestion,
|
||||
record::DnsRecord,
|
||||
};
|
||||
|
||||
type Error = Box<dyn std::error::Error>;
|
||||
pub type Result<T> = std::result::Result<T, Error>;
|
||||
use crate::Result;
|
||||
|
||||
pub mod buffer;
|
||||
pub mod header;
|
|
@ -12,6 +12,8 @@ pub enum QueryType {
|
|||
SRV, // 33
|
||||
OPT, // 41
|
||||
CAA, // 257
|
||||
AR, // 1000
|
||||
AAAAR, // 1001
|
||||
}
|
||||
|
||||
impl QueryType {
|
||||
|
@ -29,6 +31,8 @@ impl QueryType {
|
|||
Self::SRV => 33,
|
||||
Self::OPT => 41,
|
||||
Self::CAA => 257,
|
||||
Self::AR => 1000,
|
||||
Self::AAAAR => 1001,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -45,7 +49,30 @@ impl QueryType {
|
|||
33 => Self::SRV,
|
||||
41 => Self::OPT,
|
||||
257 => Self::CAA,
|
||||
1000 => Self::AR,
|
||||
1001 => Self::AAAAR,
|
||||
_ => Self::UNKNOWN(num),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn allowed_actions(&self) -> (bool, bool) {
|
||||
// 0. duplicates allowed
|
||||
// 1. allowed to be created by database
|
||||
match self {
|
||||
QueryType::UNKNOWN(_) => (false, false),
|
||||
QueryType::A => (true, true),
|
||||
QueryType::NS => (false, true),
|
||||
QueryType::CNAME => (false, true),
|
||||
QueryType::SOA => (false, false),
|
||||
QueryType::PTR => (false, true),
|
||||
QueryType::MX => (false, true),
|
||||
QueryType::TXT => (true, true),
|
||||
QueryType::AAAA => (true, true),
|
||||
QueryType::SRV => (false, true),
|
||||
QueryType::OPT => (false, false),
|
||||
QueryType::CAA => (false, true),
|
||||
QueryType::AR => (false, true),
|
||||
QueryType::AAAAR => (false, true),
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,11 +1,12 @@
|
|||
use std::net::{Ipv4Addr, Ipv6Addr};
|
||||
|
||||
use rand::RngCore;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tracing::{trace, warn};
|
||||
|
||||
use super::{buffer::PacketBuffer, query::QueryType, Result};
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
|
||||
#[allow(dead_code)]
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)]
|
||||
pub enum DnsRecord {
|
||||
UNKNOWN {
|
||||
domain: String,
|
||||
|
@ -76,10 +77,17 @@ pub enum DnsRecord {
|
|||
value: String,
|
||||
ttl: u32,
|
||||
}, // 257
|
||||
AR {
|
||||
domain: String,
|
||||
ttl: u32,
|
||||
},
|
||||
AAAAR {
|
||||
domain: String,
|
||||
ttl: u32,
|
||||
},
|
||||
}
|
||||
|
||||
impl DnsRecord {
|
||||
|
||||
pub fn read(buffer: &mut PacketBuffer) -> Result<Self> {
|
||||
let mut domain = String::new();
|
||||
buffer.read_qname(&mut domain)?;
|
||||
|
@ -90,10 +98,10 @@ impl DnsRecord {
|
|||
let ttl = buffer.read_u32()?;
|
||||
let data_len = buffer.read_u16()?;
|
||||
|
||||
let header_pos = buffer.pos();
|
||||
|
||||
trace!("Reading DNS Record TYPE: {:?}", qtype);
|
||||
|
||||
let header_pos = buffer.pos();
|
||||
|
||||
match qtype {
|
||||
QueryType::A => {
|
||||
let raw_addr = buffer.read_u32()?;
|
||||
|
@ -471,6 +479,29 @@ impl DnsRecord {
|
|||
let size = buffer.pos() - (pos + 2);
|
||||
buffer.set_u16(pos, size as u16)?;
|
||||
}
|
||||
Self::AR { ref domain, ttl } => {
|
||||
buffer.write_qname(domain)?;
|
||||
buffer.write_u16(QueryType::A.to_num())?;
|
||||
buffer.write_u16(1)?;
|
||||
buffer.write_u32(ttl)?;
|
||||
buffer.write_u16(4)?;
|
||||
|
||||
let mut rand = rand::thread_rng();
|
||||
buffer.write_u32(rand.next_u32())?;
|
||||
}
|
||||
Self::AAAAR { ref domain, ttl } => {
|
||||
buffer.write_qname(domain)?;
|
||||
buffer.write_u16(QueryType::A.to_num())?;
|
||||
buffer.write_u16(1)?;
|
||||
buffer.write_u32(ttl)?;
|
||||
buffer.write_u16(4)?;
|
||||
|
||||
let mut rand = rand::thread_rng();
|
||||
buffer.write_u32(rand.next_u32())?;
|
||||
buffer.write_u32(rand.next_u32())?;
|
||||
buffer.write_u32(rand.next_u32())?;
|
||||
buffer.write_u32(rand.next_u32())?;
|
||||
}
|
||||
Self::UNKNOWN { .. } => {
|
||||
warn!("Skipping record: {self:?}");
|
||||
}
|
||||
|
@ -479,20 +510,35 @@ impl DnsRecord {
|
|||
Ok(buffer.pos() - start_pos)
|
||||
}
|
||||
|
||||
pub fn get_ttl(&self) -> u32 {
|
||||
match *self {
|
||||
DnsRecord::UNKNOWN { .. } => 0,
|
||||
DnsRecord::AAAA { ttl, .. } => ttl,
|
||||
DnsRecord::A { ttl, .. } => ttl,
|
||||
DnsRecord::NS { ttl, .. } => ttl,
|
||||
DnsRecord::CNAME { ttl, .. } => ttl,
|
||||
DnsRecord::SOA { ttl, .. } => ttl,
|
||||
DnsRecord::PTR { ttl, .. } => ttl,
|
||||
DnsRecord::MX { ttl, .. } => ttl,
|
||||
DnsRecord::TXT { ttl, .. } => ttl,
|
||||
DnsRecord::SRV { ttl, .. } => ttl,
|
||||
DnsRecord::CAA { ttl, .. } => ttl,
|
||||
}
|
||||
pub fn get_domain(&self) -> String {
|
||||
self.get_shared_domain().0
|
||||
}
|
||||
|
||||
pub fn get_qtype(&self) -> QueryType {
|
||||
self.get_shared_domain().1
|
||||
}
|
||||
|
||||
pub fn get_ttl(&self) -> u32 {
|
||||
self.get_shared_domain().2
|
||||
}
|
||||
|
||||
fn get_shared_domain(&self) -> (String, QueryType, u32) {
|
||||
match self {
|
||||
DnsRecord::UNKNOWN {
|
||||
domain, ttl, qtype, ..
|
||||
} => (domain.clone(), QueryType::UNKNOWN(*qtype), *ttl),
|
||||
DnsRecord::AAAA { domain, ttl, .. } => (domain.clone(), QueryType::AAAA, *ttl),
|
||||
DnsRecord::A { domain, ttl, .. } => (domain.clone(), QueryType::A, *ttl),
|
||||
DnsRecord::NS { domain, ttl, .. } => (domain.clone(), QueryType::NS, *ttl),
|
||||
DnsRecord::CNAME { domain, ttl, .. } => (domain.clone(), QueryType::CNAME, *ttl),
|
||||
DnsRecord::SOA { domain, ttl, .. } => (domain.clone(), QueryType::SOA, *ttl),
|
||||
DnsRecord::PTR { domain, ttl, .. } => (domain.clone(), QueryType::PTR, *ttl),
|
||||
DnsRecord::MX { domain, ttl, .. } => (domain.clone(), QueryType::MX, *ttl),
|
||||
DnsRecord::TXT { domain, ttl, .. } => (domain.clone(), QueryType::TXT, *ttl),
|
||||
DnsRecord::SRV { domain, ttl, .. } => (domain.clone(), QueryType::SRV, *ttl),
|
||||
DnsRecord::CAA { domain, ttl, .. } => (domain.clone(), QueryType::CAA, *ttl),
|
||||
DnsRecord::AR { domain, ttl, .. } => (domain.clone(), QueryType::AR, *ttl),
|
||||
DnsRecord::AAAAR { domain, ttl, .. } => (domain.clone(), QueryType::AAAAR, *ttl),
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,11 +1,7 @@
|
|||
use super::binding::Connection;
|
||||
use crate::{
|
||||
config::Config,
|
||||
packet::{
|
||||
query::QueryType, question::DnsQuestion, result::ResultCode, Packet,
|
||||
Result,
|
||||
}, get_time,
|
||||
};
|
||||
use super::packet::{query::QueryType, question::DnsQuestion, result::ResultCode, Packet};
|
||||
use crate::Result;
|
||||
use crate::{config::Config, database::Database, get_time};
|
||||
use async_recursion::async_recursion;
|
||||
use moka::future::Cache;
|
||||
use std::{net::IpAddr, sync::Arc, time::Duration};
|
||||
|
@ -15,6 +11,7 @@ pub struct Resolver {
|
|||
request_id: u16,
|
||||
connection: Connection,
|
||||
config: Arc<Config>,
|
||||
database: Arc<Database>,
|
||||
cache: Cache<DnsQuestion, (Packet, u64)>,
|
||||
}
|
||||
|
||||
|
@ -23,18 +20,59 @@ impl Resolver {
|
|||
request_id: u16,
|
||||
connection: Connection,
|
||||
config: Arc<Config>,
|
||||
database: Arc<Database>,
|
||||
cache: Cache<DnsQuestion, (Packet, u64)>,
|
||||
) -> Self {
|
||||
Self {
|
||||
request_id,
|
||||
connection,
|
||||
config,
|
||||
database,
|
||||
cache,
|
||||
}
|
||||
}
|
||||
|
||||
async fn lookup_cache(&mut self, qname: &str, qtype: QueryType) -> Option<Packet> {
|
||||
let question = DnsQuestion::new(qname.to_string(), qtype);
|
||||
async fn lookup_database(&self, question: &DnsQuestion) -> Option<Packet> {
|
||||
let records = match self
|
||||
.database
|
||||
.get_records(&question.name, question.qtype)
|
||||
.await
|
||||
{
|
||||
Ok(record) => record,
|
||||
Err(err) => {
|
||||
error!("{err}");
|
||||
return None;
|
||||
}
|
||||
};
|
||||
|
||||
if records.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let mut packet = Packet::new();
|
||||
|
||||
packet.header.id = self.request_id;
|
||||
packet.header.questions = 1;
|
||||
packet.header.answers = records.len() as u16;
|
||||
packet.header.recursion_desired = true;
|
||||
packet
|
||||
.questions
|
||||
.push(DnsQuestion::new(question.name.to_string(), question.qtype));
|
||||
|
||||
for record in records {
|
||||
packet.answers.push(record);
|
||||
}
|
||||
|
||||
trace!(
|
||||
"Found stored value for {:?} {}",
|
||||
question.qtype,
|
||||
question.name
|
||||
);
|
||||
|
||||
Some(packet)
|
||||
}
|
||||
|
||||
async fn lookup_cache(&self, question: &DnsQuestion) -> Option<Packet> {
|
||||
let Some((packet, date)) = self.cache.get(&question) else {
|
||||
return None
|
||||
};
|
||||
|
@ -46,16 +84,20 @@ impl Resolver {
|
|||
let ttl = answer.get_ttl();
|
||||
if diff > ttl {
|
||||
self.cache.invalidate(&question).await;
|
||||
return None
|
||||
return None;
|
||||
}
|
||||
}
|
||||
|
||||
trace!("Found cached value for {qtype:?} {qname}");
|
||||
trace!(
|
||||
"Found cached value for {:?} {}",
|
||||
question.qtype,
|
||||
question.name
|
||||
);
|
||||
|
||||
Some(packet)
|
||||
}
|
||||
|
||||
async fn lookup(&mut self, qname: &str, qtype: QueryType, server: (IpAddr, u16)) -> Packet {
|
||||
async fn lookup_fallback(&self, question: &DnsQuestion, server: (IpAddr, u16)) -> Packet {
|
||||
let mut packet = Packet::new();
|
||||
|
||||
packet.header.id = self.request_id;
|
||||
|
@ -63,7 +105,7 @@ impl Resolver {
|
|||
packet.header.recursion_desired = true;
|
||||
packet
|
||||
.questions
|
||||
.push(DnsQuestion::new(qname.to_string(), qtype));
|
||||
.push(DnsQuestion::new(question.name.to_string(), question.qtype));
|
||||
|
||||
let packet = match self.connection.request_packet(packet, server).await {
|
||||
Ok(packet) => packet,
|
||||
|
@ -78,28 +120,47 @@ impl Resolver {
|
|||
packet
|
||||
}
|
||||
|
||||
async fn lookup(&self, question: &DnsQuestion, server: (IpAddr, u16)) -> Packet {
|
||||
if let Some(packet) = self.lookup_cache(question).await {
|
||||
return packet;
|
||||
};
|
||||
|
||||
if let Some(packet) = self.lookup_database(question).await {
|
||||
return packet;
|
||||
};
|
||||
|
||||
trace!(
|
||||
"Attempting lookup of {:?} {} with ns {}",
|
||||
question.qtype,
|
||||
question.name,
|
||||
server.0
|
||||
);
|
||||
|
||||
self.lookup_fallback(question, server).await
|
||||
}
|
||||
|
||||
#[async_recursion]
|
||||
async fn recursive_lookup(&mut self, qname: &str, qtype: QueryType) -> Packet {
|
||||
let question = DnsQuestion::new(qname.to_string(), qtype);
|
||||
let mut ns = self.config.get_fallback_ns().clone();
|
||||
|
||||
if let Some(packet) = self.lookup_cache(qname, qtype).await { return packet }
|
||||
let mut ns = self.config.dns_fallback.clone();
|
||||
|
||||
loop {
|
||||
trace!("Attempting lookup of {qtype:?} {qname} with ns {ns}");
|
||||
|
||||
let ns_copy = ns;
|
||||
|
||||
let server = (ns_copy, 53);
|
||||
let response = self.lookup(qname, qtype, server).await;
|
||||
let response = self.lookup(&question, server).await;
|
||||
|
||||
if !response.answers.is_empty() && response.header.rescode == ResultCode::NOERROR {
|
||||
self.cache.insert(question, (response.clone(), get_time())).await;
|
||||
self.cache
|
||||
.insert(question, (response.clone(), get_time()))
|
||||
.await;
|
||||
return response;
|
||||
}
|
||||
|
||||
if response.header.rescode == ResultCode::NXDOMAIN {
|
||||
self.cache.insert(question, (response.clone(), get_time())).await;
|
||||
self.cache
|
||||
.insert(question, (response.clone(), get_time()))
|
||||
.await;
|
||||
return response;
|
||||
}
|
||||
|
||||
|
@ -111,9 +172,11 @@ impl Resolver {
|
|||
let new_ns_name = match response.get_unresolved_ns(qname) {
|
||||
Some(x) => x,
|
||||
None => {
|
||||
self.cache.insert(question, (response.clone(), get_time())).await;
|
||||
return response
|
||||
},
|
||||
self.cache
|
||||
.insert(question, (response.clone(), get_time()))
|
||||
.await;
|
||||
return response;
|
||||
}
|
||||
};
|
||||
|
||||
let recursive_response = self.recursive_lookup(new_ns_name, QueryType::A).await;
|
||||
|
@ -121,7 +184,9 @@ impl Resolver {
|
|||
if let Some(new_ns) = recursive_response.get_random_a() {
|
||||
ns = new_ns;
|
||||
} else {
|
||||
self.cache.insert(question, (response.clone(), get_time())).await;
|
||||
self.cache
|
||||
.insert(question, (response.clone(), get_time()))
|
||||
.await;
|
||||
return response;
|
||||
}
|
||||
}
|
85
src/dns/server.rs
Normal file
85
src/dns/server.rs
Normal file
|
@ -0,0 +1,85 @@
|
|||
use super::{
|
||||
binding::Binding,
|
||||
packet::{question::DnsQuestion, Packet},
|
||||
resolver::Resolver,
|
||||
};
|
||||
use crate::{config::Config, database::Database, Result};
|
||||
use moka::future::Cache;
|
||||
use std::{net::SocketAddr, sync::Arc, time::Duration};
|
||||
use tokio::task::JoinHandle;
|
||||
use tracing::{error, info};
|
||||
|
||||
pub struct DnsServer {
|
||||
addr: SocketAddr,
|
||||
config: Arc<Config>,
|
||||
database: Arc<Database>,
|
||||
cache: Cache<DnsQuestion, (Packet, u64)>,
|
||||
}
|
||||
|
||||
impl DnsServer {
|
||||
pub async fn new(config: Config, database: Database) -> Result<Self> {
|
||||
let addr = format!("[::]:{}", config.dns_port).parse::<SocketAddr>()?;
|
||||
let cache = Cache::builder()
|
||||
.time_to_live(Duration::from_secs(60 * 60))
|
||||
.max_capacity(config.dns_cache_size)
|
||||
.build();
|
||||
|
||||
info!("Created DNS cache with size of {}", config.dns_cache_size);
|
||||
|
||||
Ok(Self {
|
||||
addr,
|
||||
config: Arc::new(config),
|
||||
database: Arc::new(database),
|
||||
cache,
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn run(&self) -> Result<(JoinHandle<()>, JoinHandle<()>)> {
|
||||
let tcp = Binding::tcp(self.addr).await?;
|
||||
let tcp_handle = self.listen(tcp);
|
||||
|
||||
let udp = Binding::udp(self.addr).await?;
|
||||
let udp_handle = self.listen(udp);
|
||||
|
||||
info!(
|
||||
"Fallback DNS Server is set to: {:?}",
|
||||
self.config.dns_fallback
|
||||
);
|
||||
info!(
|
||||
"Listening for TCP and UDP traffic on [::]:{}",
|
||||
self.config.dns_port
|
||||
);
|
||||
|
||||
Ok((udp_handle, tcp_handle))
|
||||
}
|
||||
|
||||
fn listen(&self, mut binding: Binding) -> JoinHandle<()> {
|
||||
let config = self.config.clone();
|
||||
let database = self.database.clone();
|
||||
let cache = self.cache.clone();
|
||||
tokio::spawn(async move {
|
||||
let mut id = 0;
|
||||
loop {
|
||||
let Ok(connection) = binding.connect().await else { continue };
|
||||
info!("Received request on {}", binding.name());
|
||||
|
||||
let resolver = Resolver::new(
|
||||
id,
|
||||
connection,
|
||||
config.clone(),
|
||||
database.clone(),
|
||||
cache.clone(),
|
||||
);
|
||||
|
||||
let name = binding.name().to_string();
|
||||
tokio::spawn(async move {
|
||||
if let Err(err) = resolver.handle_query().await {
|
||||
error!("{} request {} failed: {:?}", name, id, err);
|
||||
};
|
||||
});
|
||||
|
||||
id += 1;
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
44
src/main.rs
44
src/main.rs
|
@ -1,19 +1,34 @@
|
|||
use std::{time::{UNIX_EPOCH, SystemTime}, env, net::IpAddr};
|
||||
use std::time::{SystemTime, UNIX_EPOCH};
|
||||
|
||||
use config::Config;
|
||||
|
||||
use server::server::Server;
|
||||
use tracing::metadata::LevelFilter;
|
||||
use database::Database;
|
||||
use dotenv::dotenv;
|
||||
use dns::server::DnsServer;
|
||||
use tracing::{error, metadata::LevelFilter};
|
||||
use tracing_subscriber::{
|
||||
filter::filter_fn, prelude::__tracing_subscriber_SubscriberExt, util::SubscriberInitExt, Layer,
|
||||
};
|
||||
use web::WebServer;
|
||||
|
||||
mod config;
|
||||
mod packet;
|
||||
mod server;
|
||||
mod database;
|
||||
mod dns;
|
||||
mod web;
|
||||
|
||||
type Error = Box<dyn std::error::Error>;
|
||||
pub type Result<T> = std::result::Result<T, Error>;
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() {
|
||||
if let Err(err) = run().await {
|
||||
error!("{err}")
|
||||
};
|
||||
}
|
||||
|
||||
async fn run() -> Result<()> {
|
||||
dotenv().ok();
|
||||
|
||||
tracing_subscriber::registry()
|
||||
.with(
|
||||
tracing_subscriber::fmt::layer()
|
||||
|
@ -24,19 +39,20 @@ async fn main() {
|
|||
)
|
||||
.init();
|
||||
|
||||
let mut config = Config::new();
|
||||
let config = Config::new();
|
||||
let database = Database::new(config.clone()).await?;
|
||||
|
||||
if let Ok(port) = env::var("PORT").unwrap_or(String::new()).parse::<u16>() {
|
||||
config.set_port(port);
|
||||
}
|
||||
let dns_server = DnsServer::new(config.clone(), database.clone()).await?;
|
||||
let (udp, tcp) = dns_server.run().await?;
|
||||
|
||||
if let Ok(fallback) = env::var("FALLBACK_DNS").unwrap_or(String::new()).parse::<IpAddr>() {
|
||||
config.set_fallback_ns(&fallback);
|
||||
}
|
||||
let web_server = WebServer::new(config, database).await?;
|
||||
let web = web_server.run().await?;
|
||||
|
||||
let server = Server::new(config).await.expect("Failed to bind server");
|
||||
tokio::join!(udp).0?;
|
||||
tokio::join!(tcp).0?;
|
||||
tokio::join!(web).0?;
|
||||
|
||||
server.run().await.unwrap();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn get_time() -> u64 {
|
||||
|
|
|
@ -1,73 +0,0 @@
|
|||
use moka::future::Cache;
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tokio::task::JoinHandle;
|
||||
use tracing::{error, info};
|
||||
|
||||
use crate::config::Config;
|
||||
use crate::packet::question::DnsQuestion;
|
||||
use crate::packet::{Result, Packet};
|
||||
|
||||
use super::binding::Binding;
|
||||
use super::resolver::Resolver;
|
||||
|
||||
pub struct Server {
|
||||
addr: SocketAddr,
|
||||
config: Arc<Config>,
|
||||
cache: Cache<DnsQuestion, (Packet, u64)>,
|
||||
}
|
||||
|
||||
impl Server {
|
||||
pub async fn new(config: Config) -> Result<Self> {
|
||||
let addr = format!("[::]:{}", config.get_port()).parse::<SocketAddr>()?;
|
||||
let cache = Cache::builder()
|
||||
.time_to_live(Duration::from_secs(60 * 60))
|
||||
.max_capacity(1_000)
|
||||
.build();
|
||||
Ok(Self {
|
||||
addr,
|
||||
config: Arc::new(config),
|
||||
cache,
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn run(&self) -> Result<()> {
|
||||
let tcp = Binding::tcp(self.addr).await?;
|
||||
let tcp_handle = self.listen(tcp);
|
||||
|
||||
let udp = Binding::udp(self.addr).await?;
|
||||
let udp_handle = self.listen(udp);
|
||||
|
||||
info!("Fallback DNS Server is set to: {:?}", self.config.get_fallback_ns());
|
||||
info!("Listening for TCP and UDP traffic on [::]:{}", self.config.get_port());
|
||||
|
||||
tokio::join!(tcp_handle)
|
||||
.0
|
||||
.expect("Failed to join tcp thread");
|
||||
tokio::join!(udp_handle)
|
||||
.0
|
||||
.expect("Failed to join udp thread");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn listen(&self, mut binding: Binding) -> JoinHandle<()> {
|
||||
let config = self.config.clone();
|
||||
let cache = self.cache.clone();
|
||||
tokio::spawn(async move {
|
||||
let mut id = 0;
|
||||
loop {
|
||||
let Ok(connection) = binding.connect().await else { continue };
|
||||
info!("Received request on {}", binding.name());
|
||||
|
||||
let resolver = Resolver::new(id, connection, config.clone(), cache.clone());
|
||||
|
||||
if let Err(err) = resolver.handle_query().await {
|
||||
error!("{} request {} failed: {:?}", binding.name(), id, err);
|
||||
};
|
||||
|
||||
id += 1;
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
156
src/web/api.rs
Normal file
156
src/web/api.rs
Normal file
|
@ -0,0 +1,156 @@
|
|||
use std::net::IpAddr;
|
||||
|
||||
use axum::{
|
||||
extract::Query,
|
||||
response::Response,
|
||||
routing::{get, post, put, delete},
|
||||
Extension, Router,
|
||||
};
|
||||
use moka::future::Cache;
|
||||
use rand::distributions::{Alphanumeric, DistString};
|
||||
use serde::Deserialize;
|
||||
use tower_cookies::{Cookie, Cookies};
|
||||
|
||||
use crate::{config::Config, database::Database, dns::packet::record::DnsRecord};
|
||||
|
||||
use super::{
|
||||
extract::{Authorized, Body, RequestIp},
|
||||
http::{json, text},
|
||||
};
|
||||
|
||||
pub fn router() -> Router {
|
||||
Router::new()
|
||||
.route("/login", post(login))
|
||||
.route("/domains", get(list_domains))
|
||||
.route("/domains", delete(delete_domain))
|
||||
.route("/records", get(get_domain))
|
||||
.route("/records", put(add_record))
|
||||
}
|
||||
|
||||
async fn list_domains(_: Authorized, Extension(database): Extension<Database>) -> Response {
|
||||
let domains = match database.get_domains().await {
|
||||
Ok(domains) => domains,
|
||||
Err(err) => return text(500, &format!("{err}")),
|
||||
};
|
||||
|
||||
let Ok(domains) = serde_json::to_string(&domains) else {
|
||||
return text(500, "Failed to fetch domains")
|
||||
};
|
||||
|
||||
json(200, &domains)
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct DomainRequest {
|
||||
domain: String,
|
||||
}
|
||||
|
||||
async fn get_domain(
|
||||
_: Authorized,
|
||||
Extension(database): Extension<Database>,
|
||||
Query(query): Query<DomainRequest>,
|
||||
) -> Response {
|
||||
let records = match database.get_domain(&query.domain).await {
|
||||
Ok(records) => records,
|
||||
Err(err) => return text(500, &format!("{err}")),
|
||||
};
|
||||
|
||||
let Ok(records) = serde_json::to_string(&records) else {
|
||||
return text(500, "Failed to fetch records")
|
||||
};
|
||||
|
||||
json(200, &records)
|
||||
}
|
||||
|
||||
async fn delete_domain(
|
||||
_: Authorized,
|
||||
Extension(database): Extension<Database>,
|
||||
Body(body): Body,
|
||||
) -> Response {
|
||||
|
||||
let Ok(request) = serde_json::from_str::<DomainRequest>(&body) else {
|
||||
return text(400, "Missing request parameters")
|
||||
};
|
||||
|
||||
let Ok(domains) = database.get_domains().await else {
|
||||
return text(500, "Failed to delete domain")
|
||||
};
|
||||
|
||||
if !domains.contains(&request.domain) {
|
||||
return text(400, "Domain does not exist")
|
||||
}
|
||||
|
||||
if database.delete_domain(request.domain).await.is_err() {
|
||||
return text(500, "Failed to delete domain")
|
||||
};
|
||||
|
||||
return text(204, "Successfully deleted domain")
|
||||
}
|
||||
|
||||
async fn add_record(
|
||||
_: Authorized,
|
||||
Extension(database): Extension<Database>,
|
||||
Body(body): Body,
|
||||
) -> Response {
|
||||
let Ok(record) = serde_json::from_str::<DnsRecord>(&body) else {
|
||||
return text(400, "Invalid DNS record")
|
||||
};
|
||||
|
||||
let allowed = record.get_qtype().allowed_actions();
|
||||
if !allowed.1 {
|
||||
return text(400, "Not allowed to create record")
|
||||
}
|
||||
|
||||
let Ok(records) = database.get_records(&record.get_domain(), record.get_qtype()).await else {
|
||||
return text(500, "Failed to complete record check");
|
||||
};
|
||||
|
||||
if !records.is_empty() && !allowed.0 {
|
||||
return text(400, "Not allowed to create duplicate record")
|
||||
};
|
||||
|
||||
if records.contains(&record) {
|
||||
return text(400, "Not allowed to create duplicate record")
|
||||
}
|
||||
|
||||
if let Err(err) = database.add_record(record).await {
|
||||
return text(500, &format!("{err}"));
|
||||
}
|
||||
|
||||
return text(201, "Added record to database successfully");
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct LoginRequest {
|
||||
user: String,
|
||||
pass: String,
|
||||
}
|
||||
|
||||
async fn login(
|
||||
Extension(config): Extension<Config>,
|
||||
Extension(cache): Extension<Cache<String, IpAddr>>,
|
||||
RequestIp(ip): RequestIp,
|
||||
cookies: Cookies,
|
||||
Body(body): Body,
|
||||
) -> Response {
|
||||
let Ok(request) = serde_json::from_str::<LoginRequest>(&body) else {
|
||||
return text(400, "Missing request parameters")
|
||||
};
|
||||
|
||||
if request.user != config.web_user || request.pass != config.web_pass {
|
||||
return text(400, "Invalid credentials");
|
||||
};
|
||||
|
||||
let token = Alphanumeric.sample_string(&mut rand::thread_rng(), 128);
|
||||
|
||||
cache.insert(token.clone(), ip).await;
|
||||
|
||||
let mut cookie = Cookie::new("auth", token);
|
||||
cookie.set_secure(true);
|
||||
cookie.set_http_only(true);
|
||||
cookie.set_path("/");
|
||||
|
||||
cookies.add(cookie);
|
||||
|
||||
text(200, "Successfully logged in")
|
||||
}
|
139
src/web/extract.rs
Normal file
139
src/web/extract.rs
Normal file
|
@ -0,0 +1,139 @@
|
|||
use std::{
|
||||
io::Read,
|
||||
net::{IpAddr, SocketAddr},
|
||||
};
|
||||
|
||||
use axum::{
|
||||
async_trait,
|
||||
body::HttpBody,
|
||||
extract::{ConnectInfo, FromRequest, FromRequestParts},
|
||||
http::{request::Parts, Request},
|
||||
response::Response,
|
||||
BoxError,
|
||||
};
|
||||
use bytes::Bytes;
|
||||
use moka::future::Cache;
|
||||
use tower_cookies::Cookies;
|
||||
|
||||
use super::http::text;
|
||||
|
||||
pub struct Authorized;
|
||||
|
||||
#[async_trait]
|
||||
impl<S> FromRequestParts<S> for Authorized
|
||||
where
|
||||
S: Send + Sync,
|
||||
{
|
||||
type Rejection = Response;
|
||||
|
||||
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
|
||||
let Ok(Some(cookies)) = Option::<Cookies>::from_request_parts(parts, state).await else {
|
||||
return Err(text(403, "No cookies provided"))
|
||||
};
|
||||
|
||||
let Some(token) = cookies.get("auth") else {
|
||||
return Err(text(403, "No auth token provided"))
|
||||
};
|
||||
|
||||
let auth_ip: IpAddr;
|
||||
{
|
||||
let Some(cache) = parts.extensions.get::<Cache<String, IpAddr>>() else {
|
||||
return Err(text(500, "Failed to load auth store"))
|
||||
};
|
||||
|
||||
let Some(ip) = cache.get(token.value()) else {
|
||||
return Err(text(401, "Unauthorized"))
|
||||
};
|
||||
|
||||
auth_ip = ip
|
||||
}
|
||||
|
||||
let Ok(Some(RequestIp(ip))) = Option::<RequestIp>::from_request_parts(parts, state).await else {
|
||||
return Err(text(403, "You have no ip"))
|
||||
};
|
||||
|
||||
if auth_ip != ip {
|
||||
return Err(text(403, "Auth token does not match current ip"));
|
||||
}
|
||||
|
||||
Ok(Self)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct RequestIp(pub IpAddr);
|
||||
|
||||
#[async_trait]
|
||||
impl<S> FromRequestParts<S> for RequestIp
|
||||
where
|
||||
S: Send + Sync,
|
||||
{
|
||||
type Rejection = Response;
|
||||
|
||||
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
|
||||
let headers = &parts.headers;
|
||||
|
||||
let forwardedfor = headers
|
||||
.get("x-forwarded-for")
|
||||
.and_then(|h| h.to_str().ok())
|
||||
.and_then(|h| {
|
||||
h.split(',')
|
||||
.rev()
|
||||
.find_map(|s| s.trim().parse::<IpAddr>().ok())
|
||||
});
|
||||
|
||||
if let Some(forwardedfor) = forwardedfor {
|
||||
return Ok(Self(forwardedfor));
|
||||
}
|
||||
|
||||
let realip = headers
|
||||
.get("x-real-ip")
|
||||
.and_then(|hv| hv.to_str().ok())
|
||||
.and_then(|s| s.parse::<IpAddr>().ok());
|
||||
|
||||
if let Some(realip) = realip {
|
||||
return Ok(Self(realip));
|
||||
}
|
||||
|
||||
let realip = headers
|
||||
.get("x-real-ip")
|
||||
.and_then(|hv| hv.to_str().ok())
|
||||
.and_then(|s| s.parse::<IpAddr>().ok());
|
||||
|
||||
if let Some(realip) = realip {
|
||||
return Ok(Self(realip));
|
||||
}
|
||||
|
||||
let info = parts.extensions.get::<ConnectInfo<SocketAddr>>();
|
||||
|
||||
if let Some(info) = info {
|
||||
return Ok(Self(info.0.ip()));
|
||||
}
|
||||
|
||||
Err(text(403, "You have no ip"))
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Body(pub String);
|
||||
|
||||
#[async_trait]
|
||||
impl<S, B> FromRequest<S, B> for Body
|
||||
where
|
||||
B: HttpBody + Sync + Send + 'static,
|
||||
B::Data: Send,
|
||||
B::Error: Into<BoxError>,
|
||||
S: Send + Sync,
|
||||
{
|
||||
type Rejection = Response;
|
||||
|
||||
async fn from_request(req: Request<B>, state: &S) -> Result<Self, Self::Rejection> {
|
||||
let Ok(bytes) = Bytes::from_request(req, state).await else {
|
||||
return Err(text(413, "Payload too large"));
|
||||
};
|
||||
|
||||
let Ok(body) = String::from_utf8(bytes.bytes().flatten().collect()) else {
|
||||
return Err(text(400, "Invalid utf8 body"))
|
||||
};
|
||||
|
||||
Ok(Self(body))
|
||||
}
|
||||
}
|
31
src/web/file.rs
Normal file
31
src/web/file.rs
Normal file
|
@ -0,0 +1,31 @@
|
|||
use axum::{extract::Path, response::Response};
|
||||
|
||||
use super::http::serve;
|
||||
|
||||
pub async fn js(Path(path): Path<String>) -> Response {
|
||||
let path = format!("/js/{path}");
|
||||
serve(&path).await
|
||||
}
|
||||
|
||||
pub async fn css(Path(path): Path<String>) -> Response {
|
||||
let path = format!("/css/{path}");
|
||||
serve(&path).await
|
||||
}
|
||||
|
||||
pub async fn fonts(Path(path): Path<String>) -> Response {
|
||||
let path = format!("/fonts/{path}");
|
||||
serve(&path).await
|
||||
}
|
||||
|
||||
pub async fn image(Path(path): Path<String>) -> Response {
|
||||
let path = format!("/image/{path}");
|
||||
serve(&path).await
|
||||
}
|
||||
|
||||
pub async fn favicon() -> Response {
|
||||
serve("/favicon.ico").await
|
||||
}
|
||||
|
||||
pub async fn robots() -> Response {
|
||||
serve("/robots.txt").await
|
||||
}
|
50
src/web/http.rs
Normal file
50
src/web/http.rs
Normal file
|
@ -0,0 +1,50 @@
|
|||
use axum::{
|
||||
body::Body,
|
||||
http::{header::HeaderName, HeaderValue, Request, StatusCode},
|
||||
response::{IntoResponse, Response},
|
||||
};
|
||||
use std::str;
|
||||
use tower::ServiceExt;
|
||||
use tower_http::services::ServeFile;
|
||||
|
||||
pub fn text(code: u16, msg: &str) -> Response {
|
||||
(status_code(code), msg.to_owned()).into_response()
|
||||
}
|
||||
|
||||
pub fn json(code: u16, json: &str) -> Response {
|
||||
let mut res = (status_code(code), json.to_owned()).into_response();
|
||||
res.headers_mut().insert(
|
||||
HeaderName::from_static("content-type"),
|
||||
HeaderValue::from_static("application/json"),
|
||||
);
|
||||
res
|
||||
}
|
||||
|
||||
pub async fn serve(path: &str) -> Response {
|
||||
if !path.chars().any(|c| c == '.') {
|
||||
return text(403, "Invalid file path");
|
||||
}
|
||||
|
||||
let path = format!("public{path}");
|
||||
let file = ServeFile::new(path);
|
||||
|
||||
let Ok(mut res) = file.oneshot(Request::new(Body::empty())).await else {
|
||||
tracing::error!("Error while fetching file");
|
||||
return text(500, "Error when fetching file")
|
||||
};
|
||||
|
||||
if res.status() != StatusCode::OK {
|
||||
return text(404, "File not found");
|
||||
}
|
||||
|
||||
res.headers_mut().insert(
|
||||
HeaderName::from_static("cache-control"),
|
||||
HeaderValue::from_static("max-age=300"),
|
||||
);
|
||||
|
||||
res.into_response()
|
||||
}
|
||||
|
||||
fn status_code(code: u16) -> StatusCode {
|
||||
StatusCode::from_u16(code).map_or(StatusCode::OK, |code| code)
|
||||
}
|
82
src/web/mod.rs
Normal file
82
src/web/mod.rs
Normal file
|
@ -0,0 +1,82 @@
|
|||
use std::net::{IpAddr, SocketAddr, TcpListener};
|
||||
use std::time::Duration;
|
||||
|
||||
use axum::routing::get;
|
||||
use axum::{Extension, Router};
|
||||
use moka::future::Cache;
|
||||
use tokio::task::JoinHandle;
|
||||
use tower_cookies::CookieManagerLayer;
|
||||
use tracing::{error, info};
|
||||
|
||||
use crate::config::Config;
|
||||
use crate::database::Database;
|
||||
use crate::Result;
|
||||
|
||||
mod api;
|
||||
mod extract;
|
||||
mod file;
|
||||
mod http;
|
||||
mod pages;
|
||||
|
||||
pub struct WebServer {
|
||||
config: Config,
|
||||
database: Database,
|
||||
addr: SocketAddr,
|
||||
}
|
||||
|
||||
impl WebServer {
|
||||
pub async fn new(config: Config, database: Database) -> Result<Self> {
|
||||
let addr = format!("[::]:{}", config.web_port).parse::<SocketAddr>()?;
|
||||
Ok(Self {
|
||||
config,
|
||||
database,
|
||||
addr,
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn run(&self) -> Result<JoinHandle<()>> {
|
||||
let config = self.config.clone();
|
||||
let database = self.database.clone();
|
||||
let listener = TcpListener::bind(self.addr)?;
|
||||
|
||||
info!(
|
||||
"Listening for HTTP traffic on [::]:{}",
|
||||
self.config.web_port
|
||||
);
|
||||
|
||||
let app = Self::router(config, database);
|
||||
let server = axum::Server::from_tcp(listener)?;
|
||||
|
||||
let web_handle = tokio::spawn(async move {
|
||||
if let Err(err) = server
|
||||
.serve(app.into_make_service_with_connect_info::<SocketAddr>())
|
||||
.await
|
||||
{
|
||||
error!("{err}");
|
||||
}
|
||||
});
|
||||
|
||||
Ok(web_handle)
|
||||
}
|
||||
|
||||
fn router(config: Config, database: Database) -> Router {
|
||||
let cache: Cache<String, IpAddr> = Cache::builder()
|
||||
.time_to_live(Duration::from_secs(60 * 15))
|
||||
.max_capacity(config.dns_cache_size)
|
||||
.build();
|
||||
|
||||
Router::new()
|
||||
.nest("/", pages::router())
|
||||
.nest("/api", api::router())
|
||||
.layer(Extension(config))
|
||||
.layer(Extension(cache))
|
||||
.layer(Extension(database))
|
||||
.layer(CookieManagerLayer::new())
|
||||
.route("/js/*path", get(file::js))
|
||||
.route("/css/*path", get(file::css))
|
||||
.route("/fonts/*path", get(file::fonts))
|
||||
.route("/image/*path", get(file::image))
|
||||
.route("/favicon.ico", get(file::favicon))
|
||||
.route("/robots.txt", get(file::robots))
|
||||
}
|
||||
}
|
31
src/web/pages.rs
Normal file
31
src/web/pages.rs
Normal file
|
@ -0,0 +1,31 @@
|
|||
use axum::{response::Response, routing::get, Router};
|
||||
|
||||
use super::{extract::Authorized, http::serve};
|
||||
|
||||
pub fn router() -> Router {
|
||||
Router::new()
|
||||
.route("/", get(root))
|
||||
.route("/login", get(login))
|
||||
.route("/home", get(home))
|
||||
.route("/domain", get(domain))
|
||||
}
|
||||
|
||||
async fn root(user: Option<Authorized>) -> Response {
|
||||
if user.is_some() {
|
||||
home().await
|
||||
} else {
|
||||
login().await
|
||||
}
|
||||
}
|
||||
|
||||
async fn login() -> Response {
|
||||
serve("/login.html").await
|
||||
}
|
||||
|
||||
async fn home() -> Response {
|
||||
serve("/home.html").await
|
||||
}
|
||||
|
||||
async fn domain() -> Response {
|
||||
serve("/domain.html").await
|
||||
}
|
Loading…
Reference in a new issue