use std::net::IpAddr; use axum::{ extract::Query, response::Response, routing::{get, post, put, delete}, Extension, Router, }; use moka::future::Cache; use rand::distributions::{Alphanumeric, DistString}; use serde::Deserialize; use tower_cookies::{Cookie, Cookies}; use crate::{config::Config, database::Database, dns::packet::record::DnsRecord}; use super::{ extract::{Authorized, Body, RequestIp}, http::{json, text}, }; pub fn router() -> Router { Router::new() .route("/login", post(login)) .route("/domains", get(list_domains)) .route("/domains", delete(delete_domain)) .route("/records", get(get_domain)) .route("/records", put(add_record)) } async fn list_domains(_: Authorized, Extension(database): Extension) -> Response { let domains = match database.get_domains().await { Ok(domains) => domains, Err(err) => return text(500, &format!("{err}")), }; let Ok(domains) = serde_json::to_string(&domains) else { return text(500, "Failed to fetch domains") }; json(200, &domains) } #[derive(Deserialize)] struct DomainRequest { domain: String, } async fn get_domain( _: Authorized, Extension(database): Extension, Query(query): Query, ) -> Response { let records = match database.get_domain(&query.domain).await { Ok(records) => records, Err(err) => return text(500, &format!("{err}")), }; let Ok(records) = serde_json::to_string(&records) else { return text(500, "Failed to fetch records") }; json(200, &records) } async fn delete_domain( _: Authorized, Extension(database): Extension, Body(body): Body, ) -> Response { let Ok(request) = serde_json::from_str::(&body) else { return text(400, "Missing request parameters") }; let Ok(domains) = database.get_domains().await else { return text(500, "Failed to delete domain") }; if !domains.contains(&request.domain) { return text(400, "Domain does not exist") } if database.delete_domain(request.domain).await.is_err() { return text(500, "Failed to delete domain") }; return text(204, "Successfully deleted domain") } async fn add_record( _: Authorized, Extension(database): Extension, Body(body): Body, ) -> Response { let Ok(record) = serde_json::from_str::(&body) else { return text(400, "Invalid DNS record") }; let allowed = record.get_qtype().allowed_actions(); if !allowed.1 { return text(400, "Not allowed to create record") } let Ok(records) = database.get_records(&record.get_domain(), record.get_qtype()).await else { return text(500, "Failed to complete record check"); }; if !records.is_empty() && !allowed.0 { return text(400, "Not allowed to create duplicate record") }; if records.contains(&record) { return text(400, "Not allowed to create duplicate record") } if let Err(err) = database.add_record(record).await { return text(500, &format!("{err}")); } return text(201, "Added record to database successfully"); } #[derive(Deserialize)] struct LoginRequest { user: String, pass: String, } async fn login( Extension(config): Extension, Extension(cache): Extension>, RequestIp(ip): RequestIp, cookies: Cookies, Body(body): Body, ) -> Response { let Ok(request) = serde_json::from_str::(&body) else { return text(400, "Missing request parameters") }; if request.user != config.web_user || request.pass != config.web_pass { return text(400, "Invalid credentials"); }; let token = Alphanumeric.sample_string(&mut rand::thread_rng(), 128); cache.insert(token.clone(), ip).await; let mut cookie = Cookie::new("auth", token); cookie.set_secure(true); cookie.set_http_only(true); cookie.set_path("/"); cookies.add(cookie); text(200, "Successfully logged in") }