diff options
author | Tyler Murphy <tylermurphy534@gmail.com> | 2023-03-06 18:50:08 -0500 |
---|---|---|
committer | Tyler Murphy <tylermurphy534@gmail.com> | 2023-03-06 18:50:08 -0500 |
commit | b1fb410affb7bcd2e714abac01d22c4a5332c344 (patch) | |
tree | 7ebb621ab9b73e3e1fbaeb0ef8c19abef95b7c9f /src | |
parent | finialize initial dns + caching (diff) | |
download | wrapper-b1fb410affb7bcd2e714abac01d22c4a5332c344.tar.gz wrapper-b1fb410affb7bcd2e714abac01d22c4a5332c344.tar.bz2 wrapper-b1fb410affb7bcd2e714abac01d22c4a5332c344.zip |
finish dns and start webserver
Diffstat (limited to '')
-rw-r--r-- | src/config.rs | 64 | ||||
-rw-r--r-- | src/database/mod.rs | 146 | ||||
-rw-r--r-- | src/dns/binding.rs (renamed from src/server/binding.rs) | 10 | ||||
-rw-r--r-- | src/dns/mod.rs (renamed from src/server/mod.rs) | 1 | ||||
-rw-r--r-- | src/dns/packet/buffer.rs (renamed from src/packet/buffer.rs) | 51 | ||||
-rw-r--r-- | src/dns/packet/header.rs (renamed from src/packet/header.rs) | 3 | ||||
-rw-r--r-- | src/dns/packet/mod.rs (renamed from src/packet/mod.rs) | 4 | ||||
-rw-r--r-- | src/dns/packet/query.rs (renamed from src/packet/query.rs) | 27 | ||||
-rw-r--r-- | src/dns/packet/question.rs (renamed from src/packet/question.rs) | 0 | ||||
-rw-r--r-- | src/dns/packet/record.rs (renamed from src/packet/record.rs) | 82 | ||||
-rw-r--r-- | src/dns/packet/result.rs (renamed from src/packet/result.rs) | 0 | ||||
-rw-r--r-- | src/dns/resolver.rs (renamed from src/server/resolver.rs) | 115 | ||||
-rw-r--r-- | src/dns/server.rs | 85 | ||||
-rw-r--r-- | src/main.rs | 44 | ||||
-rw-r--r-- | src/server/server.rs | 73 | ||||
-rw-r--r-- | src/web/api.rs | 156 | ||||
-rw-r--r-- | src/web/extract.rs | 139 | ||||
-rw-r--r-- | src/web/file.rs | 31 | ||||
-rw-r--r-- | src/web/http.rs | 50 | ||||
-rw-r--r-- | src/web/mod.rs | 82 | ||||
-rw-r--r-- | src/web/pages.rs | 31 |
21 files changed, 1001 insertions, 193 deletions
diff --git a/src/config.rs b/src/config.rs index 9350adf..547e853 100644 --- a/src/config.rs +++ b/src/config.rs @@ -1,35 +1,57 @@ -use std::net::IpAddr; +use std::{env, net::IpAddr, str::FromStr, fmt::Display}; #[derive(Clone)] pub struct Config { - fallback: IpAddr, - port: u16, + pub dns_fallback: IpAddr, + pub dns_port: u16, + pub dns_cache_size: u64, + + pub db_host: String, + pub db_port: u16, + pub db_user: String, + pub db_pass: String, + + pub web_user: String, + pub web_pass: String, + pub web_port: u16, } impl Config { pub fn new() -> Self { - let fallback = "9.9.9.9" - .parse::<IpAddr>() - .expect("Failed to create default ns fallback"); - Self { - fallback, - port: 2000, - } - } + let dns_port = Self::get_var::<u16>("WRAPPER_DNS_PORT", 53); + let dns_fallback = Self::get_var::<IpAddr>("WRAPPER_FALLBACK_DNS", [9, 9, 9, 9].into()); + let dns_cache_size = Self::get_var::<u64>("WRAPPER_CACHE_SIZE", 1000); - pub fn get_fallback_ns(&self) -> &IpAddr { - &self.fallback - } + let db_host = Self::get_var::<String>("WRAPPER_DB_HOST", String::from("localhost")); + let db_port = Self::get_var::<u16>("WRAPPER_DB_PORT", 27017); + let db_user = Self::get_var::<String>("WRAPPER_DB_USER", String::from("root")); + let db_pass = Self::get_var::<String>("WRAPPER_DB_PASS", String::from("")); - pub fn get_port(&self) -> u16 { - self.port - } + let web_user = Self::get_var::<String>("WRAPPER_WEB_USER", String::from("admin")); + let web_pass = Self::get_var::<String>("WRAPPER_WEB_PASS", String::from("wrapper")); + let web_port = Self::get_var::<u16>("WRAPPER_WEB_PORT", 80); + + Self { + dns_fallback, + dns_port, + dns_cache_size, - pub fn set_fallback_ns(&mut self, addr: &IpAddr) { - self.fallback = *addr; + db_host, + db_port, + db_user, + db_pass, + + web_user, + web_pass, + web_port, + } } - pub fn set_port(&mut self, port: u16) { - self.port = port; + fn get_var<T>(name: &str, default: T) -> T + where + T: FromStr + Display, + { + let env = env::var(name).unwrap_or(format!("{default}")); + env.parse::<T>().unwrap_or(default) } } diff --git a/src/database/mod.rs b/src/database/mod.rs new file mode 100644 index 0000000..0d81dc3 --- /dev/null +++ b/src/database/mod.rs @@ -0,0 +1,146 @@ +use futures::TryStreamExt; +use mongodb::{ + bson::doc, + options::{ClientOptions, Credential, ServerAddress}, + Client, +}; +use serde::{Deserialize, Serialize}; +use tracing::info; + +use crate::{ + config::Config, + dns::packet::{query::QueryType, record::DnsRecord}, +}; + +use crate::Result; + +#[derive(Clone)] +pub struct Database { + client: Client, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct StoredRecord { + record: DnsRecord, + domain: String, + prefix: String, +} + +impl StoredRecord { + fn get_domain_parts(domain: &str) -> (String, String) { + let parts: Vec<&str> = domain.split(".").collect(); + let len = parts.len(); + if len == 1 { + (String::new(), String::from(parts[0])) + } else if len == 2 { + (String::new(), String::from(parts.join("."))) + } else { + ( + String::from(parts[0..len - 2].join(".")), + String::from(parts[len - 2..len].join(".")), + ) + } + } +} + +impl From<DnsRecord> for StoredRecord { + fn from(record: DnsRecord) -> Self { + let (prefix, domain) = Self::get_domain_parts(&record.get_domain()); + Self { + record, + domain, + prefix, + } + } +} + +impl Into<DnsRecord> for StoredRecord { + fn into(self) -> DnsRecord { + self.record + } +} + +impl Database { + pub async fn new(config: Config) -> Result<Self> { + let options = ClientOptions::builder() + .hosts(vec![ServerAddress::Tcp { + host: config.db_host, + port: Some(config.db_port), + }]) + .credential( + Credential::builder() + .username(config.db_user) + .password(config.db_pass) + .build(), + ) + .max_pool_size(100) + .app_name(String::from("wrapper")) + .build(); + + let client = Client::with_options(options)?; + + client + .database("wrapper") + .run_command(doc! {"ping": 1}, None) + .await?; + + info!("Connection to mongodb successfully"); + + Ok(Database { client }) + } + + pub async fn get_records(&self, domain: &str, qtype: QueryType) -> Result<Vec<DnsRecord>> { + let (prefix, domain) = StoredRecord::get_domain_parts(domain); + Ok(self + .get_domain(&domain) + .await? + .into_iter() + .filter(|r| r.prefix == prefix) + .filter(|r| { + let rqtype = r.record.get_qtype(); + if qtype == QueryType::A { + return rqtype == QueryType::A || rqtype == QueryType::AR; + } else if qtype == QueryType::AAAA { + return rqtype == QueryType::AAAA || rqtype == QueryType::AAAAR; + } else { + r.record.get_qtype() == qtype + } + }) + .map(|r| r.into()) + .collect()) + } + + pub async fn get_domain(&self, domain: &str) -> Result<Vec<StoredRecord>> { + let db = self.client.database("wrapper"); + let col = db.collection::<StoredRecord>(domain); + + let filter = doc! { "domain": domain }; + let mut cursor = col.find(filter, None).await?; + + let mut records = Vec::new(); + while let Some(record) = cursor.try_next().await? { + records.push(record); + } + + Ok(records) + } + + pub async fn add_record(&self, record: DnsRecord) -> Result<()> { + let record = StoredRecord::from(record); + let db = self.client.database("wrapper"); + let col = db.collection::<StoredRecord>(&record.domain); + col.insert_one(record, None).await?; + Ok(()) + } + + pub async fn get_domains(&self) -> Result<Vec<String>> { + let db = self.client.database("wrapper"); + Ok(db.list_collection_names(None).await?) + } + + pub async fn delete_domain(&self, domain: String) -> Result<()> { + let db = self.client.database("wrapper"); + let col = db.collection::<StoredRecord>(&domain); + Ok(col.drop(None).await?) + } +} diff --git a/src/server/binding.rs b/src/dns/binding.rs index 1c69651..4c7e15f 100644 --- a/src/server/binding.rs +++ b/src/dns/binding.rs @@ -3,7 +3,8 @@ use std::{ sync::Arc, }; -use crate::packet::{buffer::PacketBuffer, Packet, Result}; +use super::packet::{buffer::PacketBuffer, Packet}; +use crate::Result; use tokio::{ io::{AsyncReadExt, AsyncWriteExt}, net::{TcpListener, TcpStream, UdpSocket}, @@ -140,11 +141,4 @@ impl Connection { } } } - - // fn pb(buf: &[u8]) { - // for i in 0..buf.len() { - // print!("{:02X?} ", buf[i]); - // } - // println!(""); - // } } diff --git a/src/server/mod.rs b/src/dns/mod.rs index 25076ef..6f1e59e 100644 --- a/src/server/mod.rs +++ b/src/dns/mod.rs @@ -1,3 +1,4 @@ mod binding; +pub mod packet; mod resolver; pub mod server; diff --git a/src/packet/buffer.rs b/src/dns/packet/buffer.rs index 4ecc605..058156e 100644 --- a/src/packet/buffer.rs +++ b/src/dns/packet/buffer.rs @@ -1,4 +1,4 @@ -use super::Result; +use crate::Result; pub struct PacketBuffer { pub buf: Vec<u8>, @@ -9,19 +9,9 @@ pub struct PacketBuffer { impl PacketBuffer { pub fn new(buf: Vec<u8>) -> Self { Self { + size: buf.len(), buf, pos: 0, - size: 0, - } - } - - fn check(&mut self, pos: usize) { - if self.size < pos { - self.size = pos; - } - - if self.buf.len() <= self.size { - self.buf.resize(self.size + 1, 0x00); } } @@ -42,32 +32,25 @@ impl PacketBuffer { } pub fn read(&mut self) -> Result<u8> { - // if self.pos >= 512 { - // error!("Tried to read past end of buffer"); - // return Err("End of buffer".into()); - // } - self.check(self.pos); + if self.pos >= self.size { + return Err("Tried to read past end of buffer".into()); + } let res = self.buf[self.pos]; self.pos += 1; - Ok(res) } pub fn get(&mut self, pos: usize) -> Result<u8> { - // if pos >= 512 { - // error!("Tried to read past end of buffer"); - // return Err("End of buffer".into()); - // } - self.check(pos); + if pos >= self.size { + return Err("Tried to read past end of buffer".into()); + } Ok(self.buf[pos]) } pub fn get_range(&mut self, start: usize, len: usize) -> Result<&[u8]> { - // if start + len >= 512 { - // error!("Tried to read past end of buffer"); - // return Err("End of buffer".into()); - // } - self.check(start + len); + if start + len >= self.size { + return Err("Tried to read past end of buffer".into()); + } Ok(&self.buf[start..start + len]) } @@ -169,7 +152,13 @@ impl PacketBuffer { } pub fn write(&mut self, val: u8) -> Result<()> { - self.check(self.pos); + if self.size < self.pos { + self.size = self.pos; + } + + if self.buf.len() <= self.size { + self.buf.resize(self.size + 1, 0x00); + } self.buf[self.pos] = val; self.pos += 1; @@ -208,7 +197,9 @@ impl PacketBuffer { } } - self.write_u8(0)?; + if !qname.is_empty() { + self.write_u8(0)?; + } Ok(()) } diff --git a/src/packet/header.rs b/src/dns/packet/header.rs index a75f6ba..2355ecb 100644 --- a/src/packet/header.rs +++ b/src/dns/packet/header.rs @@ -1,4 +1,5 @@ -use super::{buffer::PacketBuffer, result::ResultCode, Result}; +use super::{buffer::PacketBuffer, result::ResultCode}; +use crate::Result; #[derive(Clone, Debug)] pub struct DnsHeader { diff --git a/src/packet/mod.rs b/src/dns/packet/mod.rs index 0b7cb7b..9873b94 100644 --- a/src/packet/mod.rs +++ b/src/dns/packet/mod.rs @@ -4,9 +4,7 @@ use self::{ buffer::PacketBuffer, header::DnsHeader, query::QueryType, question::DnsQuestion, record::DnsRecord, }; - -type Error = Box<dyn std::error::Error>; -pub type Result<T> = std::result::Result<T, Error>; +use crate::Result; pub mod buffer; pub mod header; diff --git a/src/packet/query.rs b/src/dns/packet/query.rs index cae6f09..732b9b2 100644 --- a/src/packet/query.rs +++ b/src/dns/packet/query.rs @@ -12,6 +12,8 @@ pub enum QueryType { SRV, // 33 OPT, // 41 CAA, // 257 + AR, // 1000 + AAAAR, // 1001 } impl QueryType { @@ -29,6 +31,8 @@ impl QueryType { Self::SRV => 33, Self::OPT => 41, Self::CAA => 257, + Self::AR => 1000, + Self::AAAAR => 1001, } } @@ -45,7 +49,30 @@ impl QueryType { 33 => Self::SRV, 41 => Self::OPT, 257 => Self::CAA, + 1000 => Self::AR, + 1001 => Self::AAAAR, _ => Self::UNKNOWN(num), } } + + pub fn allowed_actions(&self) -> (bool, bool) { + // 0. duplicates allowed + // 1. allowed to be created by database + match self { + QueryType::UNKNOWN(_) => (false, false), + QueryType::A => (true, true), + QueryType::NS => (false, true), + QueryType::CNAME => (false, true), + QueryType::SOA => (false, false), + QueryType::PTR => (false, true), + QueryType::MX => (false, true), + QueryType::TXT => (true, true), + QueryType::AAAA => (true, true), + QueryType::SRV => (false, true), + QueryType::OPT => (false, false), + QueryType::CAA => (false, true), + QueryType::AR => (false, true), + QueryType::AAAAR => (false, true), + } + } } diff --git a/src/packet/question.rs b/src/dns/packet/question.rs index 9042e1c..9042e1c 100644 --- a/src/packet/question.rs +++ b/src/dns/packet/question.rs diff --git a/src/packet/record.rs b/src/dns/packet/record.rs index c29dd8f..88008f0 100644 --- a/src/packet/record.rs +++ b/src/dns/packet/record.rs @@ -1,11 +1,12 @@ use std::net::{Ipv4Addr, Ipv6Addr}; +use rand::RngCore; +use serde::{Deserialize, Serialize}; use tracing::{trace, warn}; use super::{buffer::PacketBuffer, query::QueryType, Result}; -#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] -#[allow(dead_code)] +#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)] pub enum DnsRecord { UNKNOWN { domain: String, @@ -76,10 +77,17 @@ pub enum DnsRecord { value: String, ttl: u32, }, // 257 + AR { + domain: String, + ttl: u32, + }, + AAAAR { + domain: String, + ttl: u32, + }, } impl DnsRecord { - pub fn read(buffer: &mut PacketBuffer) -> Result<Self> { let mut domain = String::new(); buffer.read_qname(&mut domain)?; @@ -90,10 +98,10 @@ impl DnsRecord { let ttl = buffer.read_u32()?; let data_len = buffer.read_u16()?; - let header_pos = buffer.pos(); - trace!("Reading DNS Record TYPE: {:?}", qtype); + let header_pos = buffer.pos(); + match qtype { QueryType::A => { let raw_addr = buffer.read_u32()?; @@ -471,6 +479,29 @@ impl DnsRecord { let size = buffer.pos() - (pos + 2); buffer.set_u16(pos, size as u16)?; } + Self::AR { ref domain, ttl } => { + buffer.write_qname(domain)?; + buffer.write_u16(QueryType::A.to_num())?; + buffer.write_u16(1)?; + buffer.write_u32(ttl)?; + buffer.write_u16(4)?; + + let mut rand = rand::thread_rng(); + buffer.write_u32(rand.next_u32())?; + } + Self::AAAAR { ref domain, ttl } => { + buffer.write_qname(domain)?; + buffer.write_u16(QueryType::A.to_num())?; + buffer.write_u16(1)?; + buffer.write_u32(ttl)?; + buffer.write_u16(4)?; + + let mut rand = rand::thread_rng(); + buffer.write_u32(rand.next_u32())?; + buffer.write_u32(rand.next_u32())?; + buffer.write_u32(rand.next_u32())?; + buffer.write_u32(rand.next_u32())?; + } Self::UNKNOWN { .. } => { warn!("Skipping record: {self:?}"); } @@ -479,20 +510,35 @@ impl DnsRecord { Ok(buffer.pos() - start_pos) } + pub fn get_domain(&self) -> String { + self.get_shared_domain().0 + } + + pub fn get_qtype(&self) -> QueryType { + self.get_shared_domain().1 + } + pub fn get_ttl(&self) -> u32 { - match *self { - DnsRecord::UNKNOWN { .. } => 0, - DnsRecord::AAAA { ttl, .. } => ttl, - DnsRecord::A { ttl, .. } => ttl, - DnsRecord::NS { ttl, .. } => ttl, - DnsRecord::CNAME { ttl, .. } => ttl, - DnsRecord::SOA { ttl, .. } => ttl, - DnsRecord::PTR { ttl, .. } => ttl, - DnsRecord::MX { ttl, .. } => ttl, - DnsRecord::TXT { ttl, .. } => ttl, - DnsRecord::SRV { ttl, .. } => ttl, - DnsRecord::CAA { ttl, .. } => ttl, + self.get_shared_domain().2 + } + + fn get_shared_domain(&self) -> (String, QueryType, u32) { + match self { + DnsRecord::UNKNOWN { + domain, ttl, qtype, .. + } => (domain.clone(), QueryType::UNKNOWN(*qtype), *ttl), + DnsRecord::AAAA { domain, ttl, .. } => (domain.clone(), QueryType::AAAA, *ttl), + DnsRecord::A { domain, ttl, .. } => (domain.clone(), QueryType::A, *ttl), + DnsRecord::NS { domain, ttl, .. } => (domain.clone(), QueryType::NS, *ttl), + DnsRecord::CNAME { domain, ttl, .. } => (domain.clone(), QueryType::CNAME, *ttl), + DnsRecord::SOA { domain, ttl, .. } => (domain.clone(), QueryType::SOA, *ttl), + DnsRecord::PTR { domain, ttl, .. } => (domain.clone(), QueryType::PTR, *ttl), + DnsRecord::MX { domain, ttl, .. } => (domain.clone(), QueryType::MX, *ttl), + DnsRecord::TXT { domain, ttl, .. } => (domain.clone(), QueryType::TXT, *ttl), + DnsRecord::SRV { domain, ttl, .. } => (domain.clone(), QueryType::SRV, *ttl), + DnsRecord::CAA { domain, ttl, .. } => (domain.clone(), QueryType::CAA, *ttl), + DnsRecord::AR { domain, ttl, .. } => (domain.clone(), QueryType::AR, *ttl), + DnsRecord::AAAAR { domain, ttl, .. } => (domain.clone(), QueryType::AAAAR, *ttl), } } - } diff --git a/src/packet/result.rs b/src/dns/packet/result.rs index 41c8ba9..41c8ba9 100644 --- a/src/packet/result.rs +++ b/src/dns/packet/result.rs diff --git a/src/server/resolver.rs b/src/dns/resolver.rs index 464620c..18b5bba 100644 --- a/src/server/resolver.rs +++ b/src/dns/resolver.rs @@ -1,11 +1,7 @@ use super::binding::Connection; -use crate::{ - config::Config, - packet::{ - query::QueryType, question::DnsQuestion, result::ResultCode, Packet, - Result, - }, get_time, -}; +use super::packet::{query::QueryType, question::DnsQuestion, result::ResultCode, Packet}; +use crate::Result; +use crate::{config::Config, database::Database, get_time}; use async_recursion::async_recursion; use moka::future::Cache; use std::{net::IpAddr, sync::Arc, time::Duration}; @@ -15,6 +11,7 @@ pub struct Resolver { request_id: u16, connection: Connection, config: Arc<Config>, + database: Arc<Database>, cache: Cache<DnsQuestion, (Packet, u64)>, } @@ -23,18 +20,59 @@ impl Resolver { request_id: u16, connection: Connection, config: Arc<Config>, + database: Arc<Database>, cache: Cache<DnsQuestion, (Packet, u64)>, ) -> Self { Self { request_id, connection, config, + database, cache, } } - async fn lookup_cache(&mut self, qname: &str, qtype: QueryType) -> Option<Packet> { - let question = DnsQuestion::new(qname.to_string(), qtype); + async fn lookup_database(&self, question: &DnsQuestion) -> Option<Packet> { + let records = match self + .database + .get_records(&question.name, question.qtype) + .await + { + Ok(record) => record, + Err(err) => { + error!("{err}"); + return None; + } + }; + + if records.is_empty() { + return None; + } + + let mut packet = Packet::new(); + + packet.header.id = self.request_id; + packet.header.questions = 1; + packet.header.answers = records.len() as u16; + packet.header.recursion_desired = true; + packet + .questions + .push(DnsQuestion::new(question.name.to_string(), question.qtype)); + + for record in records { + packet.answers.push(record); + } + + trace!( + "Found stored value for {:?} {}", + question.qtype, + question.name + ); + + Some(packet) + } + + async fn lookup_cache(&self, question: &DnsQuestion) -> Option<Packet> { let Some((packet, date)) = self.cache.get(&question) else { return None }; @@ -46,16 +84,20 @@ impl Resolver { let ttl = answer.get_ttl(); if diff > ttl { self.cache.invalidate(&question).await; - return None + return None; } } - trace!("Found cached value for {qtype:?} {qname}"); + trace!( + "Found cached value for {:?} {}", + question.qtype, + question.name + ); Some(packet) } - async fn lookup(&mut self, qname: &str, qtype: QueryType, server: (IpAddr, u16)) -> Packet { + async fn lookup_fallback(&self, question: &DnsQuestion, server: (IpAddr, u16)) -> Packet { let mut packet = Packet::new(); packet.header.id = self.request_id; @@ -63,7 +105,7 @@ impl Resolver { packet.header.recursion_desired = true; packet .questions - .push(DnsQuestion::new(qname.to_string(), qtype)); + .push(DnsQuestion::new(question.name.to_string(), question.qtype)); let packet = match self.connection.request_packet(packet, server).await { Ok(packet) => packet, @@ -78,28 +120,47 @@ impl Resolver { packet } + async fn lookup(&self, question: &DnsQuestion, server: (IpAddr, u16)) -> Packet { + if let Some(packet) = self.lookup_cache(question).await { + return packet; + }; + + if let Some(packet) = self.lookup_database(question).await { + return packet; + }; + + trace!( + "Attempting lookup of {:?} {} with ns {}", + question.qtype, + question.name, + server.0 + ); + + self.lookup_fallback(question, server).await + } + #[async_recursion] async fn recursive_lookup(&mut self, qname: &str, qtype: QueryType) -> Packet { let question = DnsQuestion::new(qname.to_string(), qtype); - let mut ns = self.config.get_fallback_ns().clone(); - - if let Some(packet) = self.lookup_cache(qname, qtype).await { return packet } + let mut ns = self.config.dns_fallback.clone(); loop { - trace!("Attempting lookup of {qtype:?} {qname} with ns {ns}"); - let ns_copy = ns; let server = (ns_copy, 53); - let response = self.lookup(qname, qtype, server).await; + let response = self.lookup(&question, server).await; if !response.answers.is_empty() && response.header.rescode == ResultCode::NOERROR { - self.cache.insert(question, (response.clone(), get_time())).await; + self.cache + .insert(question, (response.clone(), get_time())) + .await; return response; } if response.header.rescode == ResultCode::NXDOMAIN { - self.cache.insert(question, (response.clone(), get_time())).await; + self.cache + .insert(question, (response.clone(), get_time())) + .await; return response; } @@ -111,9 +172,11 @@ impl Resolver { let new_ns_name = match response.get_unresolved_ns(qname) { Some(x) => x, None => { - self.cache.insert(question, (response.clone(), get_time())).await; - return response - }, + self.cache + .insert(question, (response.clone(), get_time())) + .await; + return response; + } }; let recursive_response = self.recursive_lookup(new_ns_name, QueryType::A).await; @@ -121,7 +184,9 @@ impl Resolver { if let Some(new_ns) = recursive_response.get_random_a() { ns = new_ns; } else { - self.cache.insert(question, (response.clone(), get_time())).await; + self.cache + .insert(question, (response.clone(), get_time())) + .await; return response; } } diff --git a/src/dns/server.rs b/src/dns/server.rs new file mode 100644 index 0000000..65d15df --- /dev/null +++ b/src/dns/server.rs @@ -0,0 +1,85 @@ +use super::{ + binding::Binding, + packet::{question::DnsQuestion, Packet}, + resolver::Resolver, +}; +use crate::{config::Config, database::Database, Result}; +use moka::future::Cache; +use std::{net::SocketAddr, sync::Arc, time::Duration}; +use tokio::task::JoinHandle; +use tracing::{error, info}; + +pub struct DnsServer { + addr: SocketAddr, + config: Arc<Config>, + database: Arc<Database>, + cache: Cache<DnsQuestion, (Packet, u64)>, +} + +impl DnsServer { + pub async fn new(config: Config, database: Database) -> Result<Self> { + let addr = format!("[::]:{}", config.dns_port).parse::<SocketAddr>()?; + let cache = Cache::builder() + .time_to_live(Duration::from_secs(60 * 60)) + .max_capacity(config.dns_cache_size) + .build(); + + info!("Created DNS cache with size of {}", config.dns_cache_size); + + Ok(Self { + addr, + config: Arc::new(config), + database: Arc::new(database), + cache, + }) + } + + pub async fn run(&self) -> Result<(JoinHandle<()>, JoinHandle<()>)> { + let tcp = Binding::tcp(self.addr).await?; + let tcp_handle = self.listen(tcp); + + let udp = Binding::udp(self.addr).await?; + let udp_handle = self.listen(udp); + + info!( + "Fallback DNS Server is set to: {:?}", + self.config.dns_fallback + ); + info!( + "Listening for TCP and UDP traffic on [::]:{}", + self.config.dns_port + ); + + Ok((udp_handle, tcp_handle)) + } + + fn listen(&self, mut binding: Binding) -> JoinHandle<()> { + let config = self.config.clone(); + let database = self.database.clone(); + let cache = self.cache.clone(); + tokio::spawn(async move { + let mut id = 0; + loop { + let Ok(connection) = binding.connect().await else { continue }; + info!("Received request on {}", binding.name()); + + let resolver = Resolver::new( + id, + connection, + config.clone(), + database.clone(), + cache.clone(), + ); + + let name = binding.name().to_string(); + tokio::spawn(async move { + if let Err(err) = resolver.handle_query().await { + error!("{} request {} failed: {:?}", name, id, err); + }; + }); + + id += 1; + } + }) + } +} diff --git a/src/main.rs b/src/main.rs index c891d50..679e87b 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,19 +1,34 @@ -use std::{time::{UNIX_EPOCH, SystemTime}, env, net::IpAddr}; +use std::time::{SystemTime, UNIX_EPOCH}; use config::Config; -use server::server::Server; -use tracing::metadata::LevelFilter; +use database::Database; +use dotenv::dotenv; +use dns::server::DnsServer; +use tracing::{error, metadata::LevelFilter}; use tracing_subscriber::{ filter::filter_fn, prelude::__tracing_subscriber_SubscriberExt, util::SubscriberInitExt, Layer, }; +use web::WebServer; mod config; -mod packet; -mod server; +mod database; +mod dns; +mod web; + +type Error = Box<dyn std::error::Error>; +pub type Result<T> = std::result::Result<T, Error>; #[tokio::main] async fn main() { + if let Err(err) = run().await { + error!("{err}") + }; +} + +async fn run() -> Result<()> { + dotenv().ok(); + tracing_subscriber::registry() .with( tracing_subscriber::fmt::layer() @@ -24,19 +39,20 @@ async fn main() { ) .init(); - let mut config = Config::new(); + let config = Config::new(); + let database = Database::new(config.clone()).await?; - if let Ok(port) = env::var("PORT").unwrap_or(String::new()).parse::<u16>() { - config.set_port(port); - } + let dns_server = DnsServer::new(config.clone(), database.clone()).await?; + let (udp, tcp) = dns_server.run().await?; - if let Ok(fallback) = env::var("FALLBACK_DNS").unwrap_or(String::new()).parse::<IpAddr>() { - config.set_fallback_ns(&fallback); - } + let web_server = WebServer::new(config, database).await?; + let web = web_server.run().await?; - let server = Server::new(config).await.expect("Failed to bind server"); + tokio::join!(udp).0?; + tokio::join!(tcp).0?; + tokio::join!(web).0?; - server.run().await.unwrap(); + Ok(()) } pub fn get_time() -> u64 { diff --git a/src/server/server.rs b/src/server/server.rs deleted file mode 100644 index e006bb1..0000000 --- a/src/server/server.rs +++ /dev/null @@ -1,73 +0,0 @@ -use moka::future::Cache; -use std::net::SocketAddr; -use std::sync::Arc; -use std::time::Duration; -use tokio::task::JoinHandle; -use tracing::{error, info}; - -use crate::config::Config; -use crate::packet::question::DnsQuestion; -use crate::packet::{Result, Packet}; - -use super::binding::Binding; -use super::resolver::Resolver; - -pub struct Server { - addr: SocketAddr, - config: Arc<Config>, - cache: Cache<DnsQuestion, (Packet, u64)>, -} - -impl Server { - pub async fn new(config: Config) -> Result<Self> { - let addr = format!("[::]:{}", config.get_port()).parse::<SocketAddr>()?; - let cache = Cache::builder() - .time_to_live(Duration::from_secs(60 * 60)) - .max_capacity(1_000) - .build(); - Ok(Self { - addr, - config: Arc::new(config), - cache, - }) - } - - pub async fn run(&self) -> Result<()> { - let tcp = Binding::tcp(self.addr).await?; - let tcp_handle = self.listen(tcp); - - let udp = Binding::udp(self.addr).await?; - let udp_handle = self.listen(udp); - - info!("Fallback DNS Server is set to: {:?}", self.config.get_fallback_ns()); - info!("Listening for TCP and UDP traffic on [::]:{}", self.config.get_port()); - - tokio::join!(tcp_handle) - .0 - .expect("Failed to join tcp thread"); - tokio::join!(udp_handle) - .0 - .expect("Failed to join udp thread"); - Ok(()) - } - - fn listen(&self, mut binding: Binding) -> JoinHandle<()> { - let config = self.config.clone(); - let cache = self.cache.clone(); - tokio::spawn(async move { - let mut id = 0; - loop { - let Ok(connection) = binding.connect().await else { continue }; - info!("Received request on {}", binding.name()); - - let resolver = Resolver::new(id, connection, config.clone(), cache.clone()); - - if let Err(err) = resolver.handle_query().await { - error!("{} request {} failed: {:?}", binding.name(), id, err); - }; - - id += 1; - } - }) - } -} diff --git a/src/web/api.rs b/src/web/api.rs new file mode 100644 index 0000000..1fddb5f --- /dev/null +++ b/src/web/api.rs @@ -0,0 +1,156 @@ +use std::net::IpAddr; + +use axum::{ + extract::Query, + response::Response, + routing::{get, post, put, delete}, + Extension, Router, +}; +use moka::future::Cache; +use rand::distributions::{Alphanumeric, DistString}; +use serde::Deserialize; +use tower_cookies::{Cookie, Cookies}; + +use crate::{config::Config, database::Database, dns::packet::record::DnsRecord}; + +use super::{ + extract::{Authorized, Body, RequestIp}, + http::{json, text}, +}; + +pub fn router() -> Router { + Router::new() + .route("/login", post(login)) + .route("/domains", get(list_domains)) + .route("/domains", delete(delete_domain)) + .route("/records", get(get_domain)) + .route("/records", put(add_record)) +} + +async fn list_domains(_: Authorized, Extension(database): Extension<Database>) -> Response { + let domains = match database.get_domains().await { + Ok(domains) => domains, + Err(err) => return text(500, &format!("{err}")), + }; + + let Ok(domains) = serde_json::to_string(&domains) else { + return text(500, "Failed to fetch domains") + }; + + json(200, &domains) +} + +#[derive(Deserialize)] +struct DomainRequest { + domain: String, +} + +async fn get_domain( + _: Authorized, + Extension(database): Extension<Database>, + Query(query): Query<DomainRequest>, +) -> Response { + let records = match database.get_domain(&query.domain).await { + Ok(records) => records, + Err(err) => return text(500, &format!("{err}")), + }; + + let Ok(records) = serde_json::to_string(&records) else { + return text(500, "Failed to fetch records") + }; + + json(200, &records) +} + +async fn delete_domain( + _: Authorized, + Extension(database): Extension<Database>, + Body(body): Body, +) -> Response { + + let Ok(request) = serde_json::from_str::<DomainRequest>(&body) else { + return text(400, "Missing request parameters") + }; + + let Ok(domains) = database.get_domains().await else { + return text(500, "Failed to delete domain") + }; + + if !domains.contains(&request.domain) { + return text(400, "Domain does not exist") + } + + if database.delete_domain(request.domain).await.is_err() { + return text(500, "Failed to delete domain") + }; + + return text(204, "Successfully deleted domain") +} + +async fn add_record( + _: Authorized, + Extension(database): Extension<Database>, + Body(body): Body, +) -> Response { + let Ok(record) = serde_json::from_str::<DnsRecord>(&body) else { + return text(400, "Invalid DNS record") + }; + + let allowed = record.get_qtype().allowed_actions(); + if !allowed.1 { + return text(400, "Not allowed to create record") + } + + let Ok(records) = database.get_records(&record.get_domain(), record.get_qtype()).await else { + return text(500, "Failed to complete record check"); + }; + + if !records.is_empty() && !allowed.0 { + return text(400, "Not allowed to create duplicate record") + }; + + if records.contains(&record) { + return text(400, "Not allowed to create duplicate record") + } + + if let Err(err) = database.add_record(record).await { + return text(500, &format!("{err}")); + } + + return text(201, "Added record to database successfully"); +} + +#[derive(Deserialize)] +struct LoginRequest { + user: String, + pass: String, +} + +async fn login( + Extension(config): Extension<Config>, + Extension(cache): Extension<Cache<String, IpAddr>>, + RequestIp(ip): RequestIp, + cookies: Cookies, + Body(body): Body, +) -> Response { + let Ok(request) = serde_json::from_str::<LoginRequest>(&body) else { + return text(400, "Missing request parameters") + }; + + if request.user != config.web_user || request.pass != config.web_pass { + return text(400, "Invalid credentials"); + }; + + let token = Alphanumeric.sample_string(&mut rand::thread_rng(), 128); + + cache.insert(token.clone(), ip).await; + + let mut cookie = Cookie::new("auth", token); + cookie.set_secure(true); + cookie.set_http_only(true); + cookie.set_path("/"); + + cookies.add(cookie); + + text(200, "Successfully logged in") +} diff --git a/src/web/extract.rs b/src/web/extract.rs new file mode 100644 index 0000000..4b6cd7c --- /dev/null +++ b/src/web/extract.rs @@ -0,0 +1,139 @@ +use std::{ + io::Read, + net::{IpAddr, SocketAddr}, +}; + +use axum::{ + async_trait, + body::HttpBody, + extract::{ConnectInfo, FromRequest, FromRequestParts}, + http::{request::Parts, Request}, + response::Response, + BoxError, +}; +use bytes::Bytes; +use moka::future::Cache; +use tower_cookies::Cookies; + +use super::http::text; + +pub struct Authorized; + +#[async_trait] +impl<S> FromRequestParts<S> for Authorized +where + S: Send + Sync, +{ + type Rejection = Response; + + async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> { + let Ok(Some(cookies)) = Option::<Cookies>::from_request_parts(parts, state).await else { + return Err(text(403, "No cookies provided")) + }; + + let Some(token) = cookies.get("auth") else { + return Err(text(403, "No auth token provided")) + }; + + let auth_ip: IpAddr; + { + let Some(cache) = parts.extensions.get::<Cache<String, IpAddr>>() else { + return Err(text(500, "Failed to load auth store")) + }; + + let Some(ip) = cache.get(token.value()) else { + return Err(text(401, "Unauthorized")) + }; + + auth_ip = ip + } + + let Ok(Some(RequestIp(ip))) = Option::<RequestIp>::from_request_parts(parts, state).await else { + return Err(text(403, "You have no ip")) + }; + + if auth_ip != ip { + return Err(text(403, "Auth token does not match current ip")); + } + + Ok(Self) + } +} + +pub struct RequestIp(pub IpAddr); + +#[async_trait] +impl<S> FromRequestParts<S> for RequestIp +where + S: Send + Sync, +{ + type Rejection = Response; + + async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> { + let headers = &parts.headers; + + let forwardedfor = headers + .get("x-forwarded-for") + .and_then(|h| h.to_str().ok()) + .and_then(|h| { + h.split(',') + .rev() + .find_map(|s| s.trim().parse::<IpAddr>().ok()) + }); + + if let Some(forwardedfor) = forwardedfor { + return Ok(Self(forwardedfor)); + } + + let realip = headers + .get("x-real-ip") + .and_then(|hv| hv.to_str().ok()) + .and_then(|s| s.parse::<IpAddr>().ok()); + + if let Some(realip) = realip { + return Ok(Self(realip)); + } + + let realip = headers + .get("x-real-ip") + .and_then(|hv| hv.to_str().ok()) + .and_then(|s| s.parse::<IpAddr>().ok()); + + if let Some(realip) = realip { + return Ok(Self(realip)); + } + + let info = parts.extensions.get::<ConnectInfo<SocketAddr>>(); + + if let Some(info) = info { + return Ok(Self(info.0.ip())); + } + + Err(text(403, "You have no ip")) + } +} + +pub struct Body(pub String); + +#[async_trait] +impl<S, B> FromRequest<S, B> for Body +where + B: HttpBody + Sync + Send + 'static, + B::Data: Send, + B::Error: Into<BoxError>, + S: Send + Sync, +{ + type Rejection = Response; + + async fn from_request(req: Request<B>, state: &S) -> Result<Self, Self::Rejection> { + let Ok(bytes) = Bytes::from_request(req, state).await else { + return Err(text(413, "Payload too large")); + }; + + let Ok(body) = String::from_utf8(bytes.bytes().flatten().collect()) else { + return Err(text(400, "Invalid utf8 body")) + }; + + Ok(Self(body)) + } +} diff --git a/src/web/file.rs b/src/web/file.rs new file mode 100644 index 0000000..73ecdc9 --- /dev/null +++ b/src/web/file.rs @@ -0,0 +1,31 @@ +use axum::{extract::Path, response::Response}; + +use super::http::serve; + +pub async fn js(Path(path): Path<String>) -> Response { + let path = format!("/js/{path}"); + serve(&path).await +} + +pub async fn css(Path(path): Path<String>) -> Response { + let path = format!("/css/{path}"); + serve(&path).await +} + +pub async fn fonts(Path(path): Path<String>) -> Response { + let path = format!("/fonts/{path}"); + serve(&path).await +} + +pub async fn image(Path(path): Path<String>) -> Response { + let path = format!("/image/{path}"); + serve(&path).await +} + +pub async fn favicon() -> Response { + serve("/favicon.ico").await +} + +pub async fn robots() -> Response { + serve("/robots.txt").await +} diff --git a/src/web/http.rs b/src/web/http.rs new file mode 100644 index 0000000..7ab1b11 --- /dev/null +++ b/src/web/http.rs @@ -0,0 +1,50 @@ +use axum::{ + body::Body, + http::{header::HeaderName, HeaderValue, Request, StatusCode}, + response::{IntoResponse, Response}, +}; +use std::str; +use tower::ServiceExt; +use tower_http::services::ServeFile; + +pub fn text(code: u16, msg: &str) -> Response { + (status_code(code), msg.to_owned()).into_response() +} + +pub fn json(code: u16, json: &str) -> Response { + let mut res = (status_code(code), json.to_owned()).into_response(); + res.headers_mut().insert( + HeaderName::from_static("content-type"), + HeaderValue::from_static("application/json"), + ); + res +} + +pub async fn serve(path: &str) -> Response { + if !path.chars().any(|c| c == '.') { + return text(403, "Invalid file path"); + } + + let path = format!("public{path}"); + let file = ServeFile::new(path); + + let Ok(mut res) = file.oneshot(Request::new(Body::empty())).await else { + tracing::error!("Error while fetching file"); + return text(500, "Error when fetching file") + }; + + if res.status() != StatusCode::OK { + return text(404, "File not found"); + } + + res.headers_mut().insert( + HeaderName::from_static("cache-control"), + HeaderValue::from_static("max-age=300"), + ); + + res.into_response() +} + +fn status_code(code: u16) -> StatusCode { + StatusCode::from_u16(code).map_or(StatusCode::OK, |code| code) +} diff --git a/src/web/mod.rs b/src/web/mod.rs new file mode 100644 index 0000000..530a3f9 --- /dev/null +++ b/src/web/mod.rs @@ -0,0 +1,82 @@ +use std::net::{IpAddr, SocketAddr, TcpListener}; +use std::time::Duration; + +use axum::routing::get; +use axum::{Extension, Router}; +use moka::future::Cache; +use tokio::task::JoinHandle; +use tower_cookies::CookieManagerLayer; +use tracing::{error, info}; + +use crate::config::Config; +use crate::database::Database; +use crate::Result; + +mod api; +mod extract; +mod file; +mod http; +mod pages; + +pub struct WebServer { + config: Config, + database: Database, + addr: SocketAddr, +} + +impl WebServer { + pub async fn new(config: Config, database: Database) -> Result<Self> { + let addr = format!("[::]:{}", config.web_port).parse::<SocketAddr>()?; + Ok(Self { + config, + database, + addr, + }) + } + + pub async fn run(&self) -> Result<JoinHandle<()>> { + let config = self.config.clone(); + let database = self.database.clone(); + let listener = TcpListener::bind(self.addr)?; + + info!( + "Listening for HTTP traffic on [::]:{}", + self.config.web_port + ); + + let app = Self::router(config, database); + let server = axum::Server::from_tcp(listener)?; + + let web_handle = tokio::spawn(async move { + if let Err(err) = server + .serve(app.into_make_service_with_connect_info::<SocketAddr>()) + .await + { + error!("{err}"); + } + }); + + Ok(web_handle) + } + + fn router(config: Config, database: Database) -> Router { + let cache: Cache<String, IpAddr> = Cache::builder() + .time_to_live(Duration::from_secs(60 * 15)) + .max_capacity(config.dns_cache_size) + .build(); + + Router::new() + .nest("/", pages::router()) + .nest("/api", api::router()) + .layer(Extension(config)) + .layer(Extension(cache)) + .layer(Extension(database)) + .layer(CookieManagerLayer::new()) + .route("/js/*path", get(file::js)) + .route("/css/*path", get(file::css)) + .route("/fonts/*path", get(file::fonts)) + .route("/image/*path", get(file::image)) + .route("/favicon.ico", get(file::favicon)) + .route("/robots.txt", get(file::robots)) + } +} diff --git a/src/web/pages.rs b/src/web/pages.rs new file mode 100644 index 0000000..a8605ef --- /dev/null +++ b/src/web/pages.rs @@ -0,0 +1,31 @@ +use axum::{response::Response, routing::get, Router}; + +use super::{extract::Authorized, http::serve}; + +pub fn router() -> Router { + Router::new() + .route("/", get(root)) + .route("/login", get(login)) + .route("/home", get(home)) + .route("/domain", get(domain)) +} + +async fn root(user: Option<Authorized>) -> Response { + if user.is_some() { + home().await + } else { + login().await + } +} + +async fn login() -> Response { + serve("/login.html").await +} + +async fn home() -> Response { + serve("/home.html").await +} + +async fn domain() -> Response { + serve("/domain.html").await +} |