diff options
Diffstat (limited to 'src/web')
-rw-r--r-- | src/web/api.rs | 156 | ||||
-rw-r--r-- | src/web/extract.rs | 139 | ||||
-rw-r--r-- | src/web/file.rs | 31 | ||||
-rw-r--r-- | src/web/http.rs | 50 | ||||
-rw-r--r-- | src/web/mod.rs | 82 | ||||
-rw-r--r-- | src/web/pages.rs | 31 |
6 files changed, 489 insertions, 0 deletions
diff --git a/src/web/api.rs b/src/web/api.rs new file mode 100644 index 0000000..1fddb5f --- /dev/null +++ b/src/web/api.rs @@ -0,0 +1,156 @@ +use std::net::IpAddr; + +use axum::{ + extract::Query, + response::Response, + routing::{get, post, put, delete}, + Extension, Router, +}; +use moka::future::Cache; +use rand::distributions::{Alphanumeric, DistString}; +use serde::Deserialize; +use tower_cookies::{Cookie, Cookies}; + +use crate::{config::Config, database::Database, dns::packet::record::DnsRecord}; + +use super::{ + extract::{Authorized, Body, RequestIp}, + http::{json, text}, +}; + +pub fn router() -> Router { + Router::new() + .route("/login", post(login)) + .route("/domains", get(list_domains)) + .route("/domains", delete(delete_domain)) + .route("/records", get(get_domain)) + .route("/records", put(add_record)) +} + +async fn list_domains(_: Authorized, Extension(database): Extension<Database>) -> Response { + let domains = match database.get_domains().await { + Ok(domains) => domains, + Err(err) => return text(500, &format!("{err}")), + }; + + let Ok(domains) = serde_json::to_string(&domains) else { + return text(500, "Failed to fetch domains") + }; + + json(200, &domains) +} + +#[derive(Deserialize)] +struct DomainRequest { + domain: String, +} + +async fn get_domain( + _: Authorized, + Extension(database): Extension<Database>, + Query(query): Query<DomainRequest>, +) -> Response { + let records = match database.get_domain(&query.domain).await { + Ok(records) => records, + Err(err) => return text(500, &format!("{err}")), + }; + + let Ok(records) = serde_json::to_string(&records) else { + return text(500, "Failed to fetch records") + }; + + json(200, &records) +} + +async fn delete_domain( + _: Authorized, + Extension(database): Extension<Database>, + Body(body): Body, +) -> Response { + + let Ok(request) = serde_json::from_str::<DomainRequest>(&body) else { + return text(400, "Missing request parameters") + }; + + let Ok(domains) = database.get_domains().await else { + return text(500, "Failed to delete domain") + }; + + if !domains.contains(&request.domain) { + return text(400, "Domain does not exist") + } + + if database.delete_domain(request.domain).await.is_err() { + return text(500, "Failed to delete domain") + }; + + return text(204, "Successfully deleted domain") +} + +async fn add_record( + _: Authorized, + Extension(database): Extension<Database>, + Body(body): Body, +) -> Response { + let Ok(record) = serde_json::from_str::<DnsRecord>(&body) else { + return text(400, "Invalid DNS record") + }; + + let allowed = record.get_qtype().allowed_actions(); + if !allowed.1 { + return text(400, "Not allowed to create record") + } + + let Ok(records) = database.get_records(&record.get_domain(), record.get_qtype()).await else { + return text(500, "Failed to complete record check"); + }; + + if !records.is_empty() && !allowed.0 { + return text(400, "Not allowed to create duplicate record") + }; + + if records.contains(&record) { + return text(400, "Not allowed to create duplicate record") + } + + if let Err(err) = database.add_record(record).await { + return text(500, &format!("{err}")); + } + + return text(201, "Added record to database successfully"); +} + +#[derive(Deserialize)] +struct LoginRequest { + user: String, + pass: String, +} + +async fn login( + Extension(config): Extension<Config>, + Extension(cache): Extension<Cache<String, IpAddr>>, + RequestIp(ip): RequestIp, + cookies: Cookies, + Body(body): Body, +) -> Response { + let Ok(request) = serde_json::from_str::<LoginRequest>(&body) else { + return text(400, "Missing request parameters") + }; + + if request.user != config.web_user || request.pass != config.web_pass { + return text(400, "Invalid credentials"); + }; + + let token = Alphanumeric.sample_string(&mut rand::thread_rng(), 128); + + cache.insert(token.clone(), ip).await; + + let mut cookie = Cookie::new("auth", token); + cookie.set_secure(true); + cookie.set_http_only(true); + cookie.set_path("/"); + + cookies.add(cookie); + + text(200, "Successfully logged in") +} diff --git a/src/web/extract.rs b/src/web/extract.rs new file mode 100644 index 0000000..4b6cd7c --- /dev/null +++ b/src/web/extract.rs @@ -0,0 +1,139 @@ +use std::{ + io::Read, + net::{IpAddr, SocketAddr}, +}; + +use axum::{ + async_trait, + body::HttpBody, + extract::{ConnectInfo, FromRequest, FromRequestParts}, + http::{request::Parts, Request}, + response::Response, + BoxError, +}; +use bytes::Bytes; +use moka::future::Cache; +use tower_cookies::Cookies; + +use super::http::text; + +pub struct Authorized; + +#[async_trait] +impl<S> FromRequestParts<S> for Authorized +where + S: Send + Sync, +{ + type Rejection = Response; + + async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> { + let Ok(Some(cookies)) = Option::<Cookies>::from_request_parts(parts, state).await else { + return Err(text(403, "No cookies provided")) + }; + + let Some(token) = cookies.get("auth") else { + return Err(text(403, "No auth token provided")) + }; + + let auth_ip: IpAddr; + { + let Some(cache) = parts.extensions.get::<Cache<String, IpAddr>>() else { + return Err(text(500, "Failed to load auth store")) + }; + + let Some(ip) = cache.get(token.value()) else { + return Err(text(401, "Unauthorized")) + }; + + auth_ip = ip + } + + let Ok(Some(RequestIp(ip))) = Option::<RequestIp>::from_request_parts(parts, state).await else { + return Err(text(403, "You have no ip")) + }; + + if auth_ip != ip { + return Err(text(403, "Auth token does not match current ip")); + } + + Ok(Self) + } +} + +pub struct RequestIp(pub IpAddr); + +#[async_trait] +impl<S> FromRequestParts<S> for RequestIp +where + S: Send + Sync, +{ + type Rejection = Response; + + async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> { + let headers = &parts.headers; + + let forwardedfor = headers + .get("x-forwarded-for") + .and_then(|h| h.to_str().ok()) + .and_then(|h| { + h.split(',') + .rev() + .find_map(|s| s.trim().parse::<IpAddr>().ok()) + }); + + if let Some(forwardedfor) = forwardedfor { + return Ok(Self(forwardedfor)); + } + + let realip = headers + .get("x-real-ip") + .and_then(|hv| hv.to_str().ok()) + .and_then(|s| s.parse::<IpAddr>().ok()); + + if let Some(realip) = realip { + return Ok(Self(realip)); + } + + let realip = headers + .get("x-real-ip") + .and_then(|hv| hv.to_str().ok()) + .and_then(|s| s.parse::<IpAddr>().ok()); + + if let Some(realip) = realip { + return Ok(Self(realip)); + } + + let info = parts.extensions.get::<ConnectInfo<SocketAddr>>(); + + if let Some(info) = info { + return Ok(Self(info.0.ip())); + } + + Err(text(403, "You have no ip")) + } +} + +pub struct Body(pub String); + +#[async_trait] +impl<S, B> FromRequest<S, B> for Body +where + B: HttpBody + Sync + Send + 'static, + B::Data: Send, + B::Error: Into<BoxError>, + S: Send + Sync, +{ + type Rejection = Response; + + async fn from_request(req: Request<B>, state: &S) -> Result<Self, Self::Rejection> { + let Ok(bytes) = Bytes::from_request(req, state).await else { + return Err(text(413, "Payload too large")); + }; + + let Ok(body) = String::from_utf8(bytes.bytes().flatten().collect()) else { + return Err(text(400, "Invalid utf8 body")) + }; + + Ok(Self(body)) + } +} diff --git a/src/web/file.rs b/src/web/file.rs new file mode 100644 index 0000000..73ecdc9 --- /dev/null +++ b/src/web/file.rs @@ -0,0 +1,31 @@ +use axum::{extract::Path, response::Response}; + +use super::http::serve; + +pub async fn js(Path(path): Path<String>) -> Response { + let path = format!("/js/{path}"); + serve(&path).await +} + +pub async fn css(Path(path): Path<String>) -> Response { + let path = format!("/css/{path}"); + serve(&path).await +} + +pub async fn fonts(Path(path): Path<String>) -> Response { + let path = format!("/fonts/{path}"); + serve(&path).await +} + +pub async fn image(Path(path): Path<String>) -> Response { + let path = format!("/image/{path}"); + serve(&path).await +} + +pub async fn favicon() -> Response { + serve("/favicon.ico").await +} + +pub async fn robots() -> Response { + serve("/robots.txt").await +} diff --git a/src/web/http.rs b/src/web/http.rs new file mode 100644 index 0000000..7ab1b11 --- /dev/null +++ b/src/web/http.rs @@ -0,0 +1,50 @@ +use axum::{ + body::Body, + http::{header::HeaderName, HeaderValue, Request, StatusCode}, + response::{IntoResponse, Response}, +}; +use std::str; +use tower::ServiceExt; +use tower_http::services::ServeFile; + +pub fn text(code: u16, msg: &str) -> Response { + (status_code(code), msg.to_owned()).into_response() +} + +pub fn json(code: u16, json: &str) -> Response { + let mut res = (status_code(code), json.to_owned()).into_response(); + res.headers_mut().insert( + HeaderName::from_static("content-type"), + HeaderValue::from_static("application/json"), + ); + res +} + +pub async fn serve(path: &str) -> Response { + if !path.chars().any(|c| c == '.') { + return text(403, "Invalid file path"); + } + + let path = format!("public{path}"); + let file = ServeFile::new(path); + + let Ok(mut res) = file.oneshot(Request::new(Body::empty())).await else { + tracing::error!("Error while fetching file"); + return text(500, "Error when fetching file") + }; + + if res.status() != StatusCode::OK { + return text(404, "File not found"); + } + + res.headers_mut().insert( + HeaderName::from_static("cache-control"), + HeaderValue::from_static("max-age=300"), + ); + + res.into_response() +} + +fn status_code(code: u16) -> StatusCode { + StatusCode::from_u16(code).map_or(StatusCode::OK, |code| code) +} diff --git a/src/web/mod.rs b/src/web/mod.rs new file mode 100644 index 0000000..530a3f9 --- /dev/null +++ b/src/web/mod.rs @@ -0,0 +1,82 @@ +use std::net::{IpAddr, SocketAddr, TcpListener}; +use std::time::Duration; + +use axum::routing::get; +use axum::{Extension, Router}; +use moka::future::Cache; +use tokio::task::JoinHandle; +use tower_cookies::CookieManagerLayer; +use tracing::{error, info}; + +use crate::config::Config; +use crate::database::Database; +use crate::Result; + +mod api; +mod extract; +mod file; +mod http; +mod pages; + +pub struct WebServer { + config: Config, + database: Database, + addr: SocketAddr, +} + +impl WebServer { + pub async fn new(config: Config, database: Database) -> Result<Self> { + let addr = format!("[::]:{}", config.web_port).parse::<SocketAddr>()?; + Ok(Self { + config, + database, + addr, + }) + } + + pub async fn run(&self) -> Result<JoinHandle<()>> { + let config = self.config.clone(); + let database = self.database.clone(); + let listener = TcpListener::bind(self.addr)?; + + info!( + "Listening for HTTP traffic on [::]:{}", + self.config.web_port + ); + + let app = Self::router(config, database); + let server = axum::Server::from_tcp(listener)?; + + let web_handle = tokio::spawn(async move { + if let Err(err) = server + .serve(app.into_make_service_with_connect_info::<SocketAddr>()) + .await + { + error!("{err}"); + } + }); + + Ok(web_handle) + } + + fn router(config: Config, database: Database) -> Router { + let cache: Cache<String, IpAddr> = Cache::builder() + .time_to_live(Duration::from_secs(60 * 15)) + .max_capacity(config.dns_cache_size) + .build(); + + Router::new() + .nest("/", pages::router()) + .nest("/api", api::router()) + .layer(Extension(config)) + .layer(Extension(cache)) + .layer(Extension(database)) + .layer(CookieManagerLayer::new()) + .route("/js/*path", get(file::js)) + .route("/css/*path", get(file::css)) + .route("/fonts/*path", get(file::fonts)) + .route("/image/*path", get(file::image)) + .route("/favicon.ico", get(file::favicon)) + .route("/robots.txt", get(file::robots)) + } +} diff --git a/src/web/pages.rs b/src/web/pages.rs new file mode 100644 index 0000000..a8605ef --- /dev/null +++ b/src/web/pages.rs @@ -0,0 +1,31 @@ +use axum::{response::Response, routing::get, Router}; + +use super::{extract::Authorized, http::serve}; + +pub fn router() -> Router { + Router::new() + .route("/", get(root)) + .route("/login", get(login)) + .route("/home", get(home)) + .route("/domain", get(domain)) +} + +async fn root(user: Option<Authorized>) -> Response { + if user.is_some() { + home().await + } else { + login().await + } +} + +async fn login() -> Response { + serve("/login.html").await +} + +async fn home() -> Response { + serve("/home.html").await +} + +async fn domain() -> Response { + serve("/domain.html").await +} |