summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorTyler Murphy <tylermurphy534@gmail.com>2023-03-06 18:50:08 -0500
committerTyler Murphy <tylermurphy534@gmail.com>2023-03-06 18:50:08 -0500
commitb1fb410affb7bcd2e714abac01d22c4a5332c344 (patch)
tree7ebb621ab9b73e3e1fbaeb0ef8c19abef95b7c9f /src
parentfinialize initial dns + caching (diff)
downloadwrapper-b1fb410affb7bcd2e714abac01d22c4a5332c344.tar.gz
wrapper-b1fb410affb7bcd2e714abac01d22c4a5332c344.tar.bz2
wrapper-b1fb410affb7bcd2e714abac01d22c4a5332c344.zip
finish dns and start webserver
Diffstat (limited to '')
-rw-r--r--src/config.rs64
-rw-r--r--src/database/mod.rs146
-rw-r--r--src/dns/binding.rs (renamed from src/server/binding.rs)10
-rw-r--r--src/dns/mod.rs (renamed from src/server/mod.rs)1
-rw-r--r--src/dns/packet/buffer.rs (renamed from src/packet/buffer.rs)51
-rw-r--r--src/dns/packet/header.rs (renamed from src/packet/header.rs)3
-rw-r--r--src/dns/packet/mod.rs (renamed from src/packet/mod.rs)4
-rw-r--r--src/dns/packet/query.rs (renamed from src/packet/query.rs)27
-rw-r--r--src/dns/packet/question.rs (renamed from src/packet/question.rs)0
-rw-r--r--src/dns/packet/record.rs (renamed from src/packet/record.rs)82
-rw-r--r--src/dns/packet/result.rs (renamed from src/packet/result.rs)0
-rw-r--r--src/dns/resolver.rs (renamed from src/server/resolver.rs)115
-rw-r--r--src/dns/server.rs85
-rw-r--r--src/main.rs44
-rw-r--r--src/server/server.rs73
-rw-r--r--src/web/api.rs156
-rw-r--r--src/web/extract.rs139
-rw-r--r--src/web/file.rs31
-rw-r--r--src/web/http.rs50
-rw-r--r--src/web/mod.rs82
-rw-r--r--src/web/pages.rs31
21 files changed, 1001 insertions, 193 deletions
diff --git a/src/config.rs b/src/config.rs
index 9350adf..547e853 100644
--- a/src/config.rs
+++ b/src/config.rs
@@ -1,35 +1,57 @@
-use std::net::IpAddr;
+use std::{env, net::IpAddr, str::FromStr, fmt::Display};
#[derive(Clone)]
pub struct Config {
- fallback: IpAddr,
- port: u16,
+ pub dns_fallback: IpAddr,
+ pub dns_port: u16,
+ pub dns_cache_size: u64,
+
+ pub db_host: String,
+ pub db_port: u16,
+ pub db_user: String,
+ pub db_pass: String,
+
+ pub web_user: String,
+ pub web_pass: String,
+ pub web_port: u16,
}
impl Config {
pub fn new() -> Self {
- let fallback = "9.9.9.9"
- .parse::<IpAddr>()
- .expect("Failed to create default ns fallback");
- Self {
- fallback,
- port: 2000,
- }
- }
+ let dns_port = Self::get_var::<u16>("WRAPPER_DNS_PORT", 53);
+ let dns_fallback = Self::get_var::<IpAddr>("WRAPPER_FALLBACK_DNS", [9, 9, 9, 9].into());
+ let dns_cache_size = Self::get_var::<u64>("WRAPPER_CACHE_SIZE", 1000);
- pub fn get_fallback_ns(&self) -> &IpAddr {
- &self.fallback
- }
+ let db_host = Self::get_var::<String>("WRAPPER_DB_HOST", String::from("localhost"));
+ let db_port = Self::get_var::<u16>("WRAPPER_DB_PORT", 27017);
+ let db_user = Self::get_var::<String>("WRAPPER_DB_USER", String::from("root"));
+ let db_pass = Self::get_var::<String>("WRAPPER_DB_PASS", String::from(""));
- pub fn get_port(&self) -> u16 {
- self.port
- }
+ let web_user = Self::get_var::<String>("WRAPPER_WEB_USER", String::from("admin"));
+ let web_pass = Self::get_var::<String>("WRAPPER_WEB_PASS", String::from("wrapper"));
+ let web_port = Self::get_var::<u16>("WRAPPER_WEB_PORT", 80);
+
+ Self {
+ dns_fallback,
+ dns_port,
+ dns_cache_size,
- pub fn set_fallback_ns(&mut self, addr: &IpAddr) {
- self.fallback = *addr;
+ db_host,
+ db_port,
+ db_user,
+ db_pass,
+
+ web_user,
+ web_pass,
+ web_port,
+ }
}
- pub fn set_port(&mut self, port: u16) {
- self.port = port;
+ fn get_var<T>(name: &str, default: T) -> T
+ where
+ T: FromStr + Display,
+ {
+ let env = env::var(name).unwrap_or(format!("{default}"));
+ env.parse::<T>().unwrap_or(default)
}
}
diff --git a/src/database/mod.rs b/src/database/mod.rs
new file mode 100644
index 0000000..0d81dc3
--- /dev/null
+++ b/src/database/mod.rs
@@ -0,0 +1,146 @@
+use futures::TryStreamExt;
+use mongodb::{
+ bson::doc,
+ options::{ClientOptions, Credential, ServerAddress},
+ Client,
+};
+use serde::{Deserialize, Serialize};
+use tracing::info;
+
+use crate::{
+ config::Config,
+ dns::packet::{query::QueryType, record::DnsRecord},
+};
+
+use crate::Result;
+
+#[derive(Clone)]
+pub struct Database {
+ client: Client,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+pub struct StoredRecord {
+ record: DnsRecord,
+ domain: String,
+ prefix: String,
+}
+
+impl StoredRecord {
+ fn get_domain_parts(domain: &str) -> (String, String) {
+ let parts: Vec<&str> = domain.split(".").collect();
+ let len = parts.len();
+ if len == 1 {
+ (String::new(), String::from(parts[0]))
+ } else if len == 2 {
+ (String::new(), String::from(parts.join(".")))
+ } else {
+ (
+ String::from(parts[0..len - 2].join(".")),
+ String::from(parts[len - 2..len].join(".")),
+ )
+ }
+ }
+}
+
+impl From<DnsRecord> for StoredRecord {
+ fn from(record: DnsRecord) -> Self {
+ let (prefix, domain) = Self::get_domain_parts(&record.get_domain());
+ Self {
+ record,
+ domain,
+ prefix,
+ }
+ }
+}
+
+impl Into<DnsRecord> for StoredRecord {
+ fn into(self) -> DnsRecord {
+ self.record
+ }
+}
+
+impl Database {
+ pub async fn new(config: Config) -> Result<Self> {
+ let options = ClientOptions::builder()
+ .hosts(vec![ServerAddress::Tcp {
+ host: config.db_host,
+ port: Some(config.db_port),
+ }])
+ .credential(
+ Credential::builder()
+ .username(config.db_user)
+ .password(config.db_pass)
+ .build(),
+ )
+ .max_pool_size(100)
+ .app_name(String::from("wrapper"))
+ .build();
+
+ let client = Client::with_options(options)?;
+
+ client
+ .database("wrapper")
+ .run_command(doc! {"ping": 1}, None)
+ .await?;
+
+ info!("Connection to mongodb successfully");
+
+ Ok(Database { client })
+ }
+
+ pub async fn get_records(&self, domain: &str, qtype: QueryType) -> Result<Vec<DnsRecord>> {
+ let (prefix, domain) = StoredRecord::get_domain_parts(domain);
+ Ok(self
+ .get_domain(&domain)
+ .await?
+ .into_iter()
+ .filter(|r| r.prefix == prefix)
+ .filter(|r| {
+ let rqtype = r.record.get_qtype();
+ if qtype == QueryType::A {
+ return rqtype == QueryType::A || rqtype == QueryType::AR;
+ } else if qtype == QueryType::AAAA {
+ return rqtype == QueryType::AAAA || rqtype == QueryType::AAAAR;
+ } else {
+ r.record.get_qtype() == qtype
+ }
+ })
+ .map(|r| r.into())
+ .collect())
+ }
+
+ pub async fn get_domain(&self, domain: &str) -> Result<Vec<StoredRecord>> {
+ let db = self.client.database("wrapper");
+ let col = db.collection::<StoredRecord>(domain);
+
+ let filter = doc! { "domain": domain };
+ let mut cursor = col.find(filter, None).await?;
+
+ let mut records = Vec::new();
+ while let Some(record) = cursor.try_next().await? {
+ records.push(record);
+ }
+
+ Ok(records)
+ }
+
+ pub async fn add_record(&self, record: DnsRecord) -> Result<()> {
+ let record = StoredRecord::from(record);
+ let db = self.client.database("wrapper");
+ let col = db.collection::<StoredRecord>(&record.domain);
+ col.insert_one(record, None).await?;
+ Ok(())
+ }
+
+ pub async fn get_domains(&self) -> Result<Vec<String>> {
+ let db = self.client.database("wrapper");
+ Ok(db.list_collection_names(None).await?)
+ }
+
+ pub async fn delete_domain(&self, domain: String) -> Result<()> {
+ let db = self.client.database("wrapper");
+ let col = db.collection::<StoredRecord>(&domain);
+ Ok(col.drop(None).await?)
+ }
+}
diff --git a/src/server/binding.rs b/src/dns/binding.rs
index 1c69651..4c7e15f 100644
--- a/src/server/binding.rs
+++ b/src/dns/binding.rs
@@ -3,7 +3,8 @@ use std::{
sync::Arc,
};
-use crate::packet::{buffer::PacketBuffer, Packet, Result};
+use super::packet::{buffer::PacketBuffer, Packet};
+use crate::Result;
use tokio::{
io::{AsyncReadExt, AsyncWriteExt},
net::{TcpListener, TcpStream, UdpSocket},
@@ -140,11 +141,4 @@ impl Connection {
}
}
}
-
- // fn pb(buf: &[u8]) {
- // for i in 0..buf.len() {
- // print!("{:02X?} ", buf[i]);
- // }
- // println!("");
- // }
}
diff --git a/src/server/mod.rs b/src/dns/mod.rs
index 25076ef..6f1e59e 100644
--- a/src/server/mod.rs
+++ b/src/dns/mod.rs
@@ -1,3 +1,4 @@
mod binding;
+pub mod packet;
mod resolver;
pub mod server;
diff --git a/src/packet/buffer.rs b/src/dns/packet/buffer.rs
index 4ecc605..058156e 100644
--- a/src/packet/buffer.rs
+++ b/src/dns/packet/buffer.rs
@@ -1,4 +1,4 @@
-use super::Result;
+use crate::Result;
pub struct PacketBuffer {
pub buf: Vec<u8>,
@@ -9,19 +9,9 @@ pub struct PacketBuffer {
impl PacketBuffer {
pub fn new(buf: Vec<u8>) -> Self {
Self {
+ size: buf.len(),
buf,
pos: 0,
- size: 0,
- }
- }
-
- fn check(&mut self, pos: usize) {
- if self.size < pos {
- self.size = pos;
- }
-
- if self.buf.len() <= self.size {
- self.buf.resize(self.size + 1, 0x00);
}
}
@@ -42,32 +32,25 @@ impl PacketBuffer {
}
pub fn read(&mut self) -> Result<u8> {
- // if self.pos >= 512 {
- // error!("Tried to read past end of buffer");
- // return Err("End of buffer".into());
- // }
- self.check(self.pos);
+ if self.pos >= self.size {
+ return Err("Tried to read past end of buffer".into());
+ }
let res = self.buf[self.pos];
self.pos += 1;
-
Ok(res)
}
pub fn get(&mut self, pos: usize) -> Result<u8> {
- // if pos >= 512 {
- // error!("Tried to read past end of buffer");
- // return Err("End of buffer".into());
- // }
- self.check(pos);
+ if pos >= self.size {
+ return Err("Tried to read past end of buffer".into());
+ }
Ok(self.buf[pos])
}
pub fn get_range(&mut self, start: usize, len: usize) -> Result<&[u8]> {
- // if start + len >= 512 {
- // error!("Tried to read past end of buffer");
- // return Err("End of buffer".into());
- // }
- self.check(start + len);
+ if start + len >= self.size {
+ return Err("Tried to read past end of buffer".into());
+ }
Ok(&self.buf[start..start + len])
}
@@ -169,7 +152,13 @@ impl PacketBuffer {
}
pub fn write(&mut self, val: u8) -> Result<()> {
- self.check(self.pos);
+ if self.size < self.pos {
+ self.size = self.pos;
+ }
+
+ if self.buf.len() <= self.size {
+ self.buf.resize(self.size + 1, 0x00);
+ }
self.buf[self.pos] = val;
self.pos += 1;
@@ -208,7 +197,9 @@ impl PacketBuffer {
}
}
- self.write_u8(0)?;
+ if !qname.is_empty() {
+ self.write_u8(0)?;
+ }
Ok(())
}
diff --git a/src/packet/header.rs b/src/dns/packet/header.rs
index a75f6ba..2355ecb 100644
--- a/src/packet/header.rs
+++ b/src/dns/packet/header.rs
@@ -1,4 +1,5 @@
-use super::{buffer::PacketBuffer, result::ResultCode, Result};
+use super::{buffer::PacketBuffer, result::ResultCode};
+use crate::Result;
#[derive(Clone, Debug)]
pub struct DnsHeader {
diff --git a/src/packet/mod.rs b/src/dns/packet/mod.rs
index 0b7cb7b..9873b94 100644
--- a/src/packet/mod.rs
+++ b/src/dns/packet/mod.rs
@@ -4,9 +4,7 @@ use self::{
buffer::PacketBuffer, header::DnsHeader, query::QueryType, question::DnsQuestion,
record::DnsRecord,
};
-
-type Error = Box<dyn std::error::Error>;
-pub type Result<T> = std::result::Result<T, Error>;
+use crate::Result;
pub mod buffer;
pub mod header;
diff --git a/src/packet/query.rs b/src/dns/packet/query.rs
index cae6f09..732b9b2 100644
--- a/src/packet/query.rs
+++ b/src/dns/packet/query.rs
@@ -12,6 +12,8 @@ pub enum QueryType {
SRV, // 33
OPT, // 41
CAA, // 257
+ AR, // 1000
+ AAAAR, // 1001
}
impl QueryType {
@@ -29,6 +31,8 @@ impl QueryType {
Self::SRV => 33,
Self::OPT => 41,
Self::CAA => 257,
+ Self::AR => 1000,
+ Self::AAAAR => 1001,
}
}
@@ -45,7 +49,30 @@ impl QueryType {
33 => Self::SRV,
41 => Self::OPT,
257 => Self::CAA,
+ 1000 => Self::AR,
+ 1001 => Self::AAAAR,
_ => Self::UNKNOWN(num),
}
}
+
+ pub fn allowed_actions(&self) -> (bool, bool) {
+ // 0. duplicates allowed
+ // 1. allowed to be created by database
+ match self {
+ QueryType::UNKNOWN(_) => (false, false),
+ QueryType::A => (true, true),
+ QueryType::NS => (false, true),
+ QueryType::CNAME => (false, true),
+ QueryType::SOA => (false, false),
+ QueryType::PTR => (false, true),
+ QueryType::MX => (false, true),
+ QueryType::TXT => (true, true),
+ QueryType::AAAA => (true, true),
+ QueryType::SRV => (false, true),
+ QueryType::OPT => (false, false),
+ QueryType::CAA => (false, true),
+ QueryType::AR => (false, true),
+ QueryType::AAAAR => (false, true),
+ }
+ }
}
diff --git a/src/packet/question.rs b/src/dns/packet/question.rs
index 9042e1c..9042e1c 100644
--- a/src/packet/question.rs
+++ b/src/dns/packet/question.rs
diff --git a/src/packet/record.rs b/src/dns/packet/record.rs
index c29dd8f..88008f0 100644
--- a/src/packet/record.rs
+++ b/src/dns/packet/record.rs
@@ -1,11 +1,12 @@
use std::net::{Ipv4Addr, Ipv6Addr};
+use rand::RngCore;
+use serde::{Deserialize, Serialize};
use tracing::{trace, warn};
use super::{buffer::PacketBuffer, query::QueryType, Result};
-#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
-#[allow(dead_code)]
+#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)]
pub enum DnsRecord {
UNKNOWN {
domain: String,
@@ -76,10 +77,17 @@ pub enum DnsRecord {
value: String,
ttl: u32,
}, // 257
+ AR {
+ domain: String,
+ ttl: u32,
+ },
+ AAAAR {
+ domain: String,
+ ttl: u32,
+ },
}
impl DnsRecord {
-
pub fn read(buffer: &mut PacketBuffer) -> Result<Self> {
let mut domain = String::new();
buffer.read_qname(&mut domain)?;
@@ -90,10 +98,10 @@ impl DnsRecord {
let ttl = buffer.read_u32()?;
let data_len = buffer.read_u16()?;
- let header_pos = buffer.pos();
-
trace!("Reading DNS Record TYPE: {:?}", qtype);
+ let header_pos = buffer.pos();
+
match qtype {
QueryType::A => {
let raw_addr = buffer.read_u32()?;
@@ -471,6 +479,29 @@ impl DnsRecord {
let size = buffer.pos() - (pos + 2);
buffer.set_u16(pos, size as u16)?;
}
+ Self::AR { ref domain, ttl } => {
+ buffer.write_qname(domain)?;
+ buffer.write_u16(QueryType::A.to_num())?;
+ buffer.write_u16(1)?;
+ buffer.write_u32(ttl)?;
+ buffer.write_u16(4)?;
+
+ let mut rand = rand::thread_rng();
+ buffer.write_u32(rand.next_u32())?;
+ }
+ Self::AAAAR { ref domain, ttl } => {
+ buffer.write_qname(domain)?;
+ buffer.write_u16(QueryType::A.to_num())?;
+ buffer.write_u16(1)?;
+ buffer.write_u32(ttl)?;
+ buffer.write_u16(4)?;
+
+ let mut rand = rand::thread_rng();
+ buffer.write_u32(rand.next_u32())?;
+ buffer.write_u32(rand.next_u32())?;
+ buffer.write_u32(rand.next_u32())?;
+ buffer.write_u32(rand.next_u32())?;
+ }
Self::UNKNOWN { .. } => {
warn!("Skipping record: {self:?}");
}
@@ -479,20 +510,35 @@ impl DnsRecord {
Ok(buffer.pos() - start_pos)
}
+ pub fn get_domain(&self) -> String {
+ self.get_shared_domain().0
+ }
+
+ pub fn get_qtype(&self) -> QueryType {
+ self.get_shared_domain().1
+ }
+
pub fn get_ttl(&self) -> u32 {
- match *self {
- DnsRecord::UNKNOWN { .. } => 0,
- DnsRecord::AAAA { ttl, .. } => ttl,
- DnsRecord::A { ttl, .. } => ttl,
- DnsRecord::NS { ttl, .. } => ttl,
- DnsRecord::CNAME { ttl, .. } => ttl,
- DnsRecord::SOA { ttl, .. } => ttl,
- DnsRecord::PTR { ttl, .. } => ttl,
- DnsRecord::MX { ttl, .. } => ttl,
- DnsRecord::TXT { ttl, .. } => ttl,
- DnsRecord::SRV { ttl, .. } => ttl,
- DnsRecord::CAA { ttl, .. } => ttl,
+ self.get_shared_domain().2
+ }
+
+ fn get_shared_domain(&self) -> (String, QueryType, u32) {
+ match self {
+ DnsRecord::UNKNOWN {
+ domain, ttl, qtype, ..
+ } => (domain.clone(), QueryType::UNKNOWN(*qtype), *ttl),
+ DnsRecord::AAAA { domain, ttl, .. } => (domain.clone(), QueryType::AAAA, *ttl),
+ DnsRecord::A { domain, ttl, .. } => (domain.clone(), QueryType::A, *ttl),
+ DnsRecord::NS { domain, ttl, .. } => (domain.clone(), QueryType::NS, *ttl),
+ DnsRecord::CNAME { domain, ttl, .. } => (domain.clone(), QueryType::CNAME, *ttl),
+ DnsRecord::SOA { domain, ttl, .. } => (domain.clone(), QueryType::SOA, *ttl),
+ DnsRecord::PTR { domain, ttl, .. } => (domain.clone(), QueryType::PTR, *ttl),
+ DnsRecord::MX { domain, ttl, .. } => (domain.clone(), QueryType::MX, *ttl),
+ DnsRecord::TXT { domain, ttl, .. } => (domain.clone(), QueryType::TXT, *ttl),
+ DnsRecord::SRV { domain, ttl, .. } => (domain.clone(), QueryType::SRV, *ttl),
+ DnsRecord::CAA { domain, ttl, .. } => (domain.clone(), QueryType::CAA, *ttl),
+ DnsRecord::AR { domain, ttl, .. } => (domain.clone(), QueryType::AR, *ttl),
+ DnsRecord::AAAAR { domain, ttl, .. } => (domain.clone(), QueryType::AAAAR, *ttl),
}
}
-
}
diff --git a/src/packet/result.rs b/src/dns/packet/result.rs
index 41c8ba9..41c8ba9 100644
--- a/src/packet/result.rs
+++ b/src/dns/packet/result.rs
diff --git a/src/server/resolver.rs b/src/dns/resolver.rs
index 464620c..18b5bba 100644
--- a/src/server/resolver.rs
+++ b/src/dns/resolver.rs
@@ -1,11 +1,7 @@
use super::binding::Connection;
-use crate::{
- config::Config,
- packet::{
- query::QueryType, question::DnsQuestion, result::ResultCode, Packet,
- Result,
- }, get_time,
-};
+use super::packet::{query::QueryType, question::DnsQuestion, result::ResultCode, Packet};
+use crate::Result;
+use crate::{config::Config, database::Database, get_time};
use async_recursion::async_recursion;
use moka::future::Cache;
use std::{net::IpAddr, sync::Arc, time::Duration};
@@ -15,6 +11,7 @@ pub struct Resolver {
request_id: u16,
connection: Connection,
config: Arc<Config>,
+ database: Arc<Database>,
cache: Cache<DnsQuestion, (Packet, u64)>,
}
@@ -23,18 +20,59 @@ impl Resolver {
request_id: u16,
connection: Connection,
config: Arc<Config>,
+ database: Arc<Database>,
cache: Cache<DnsQuestion, (Packet, u64)>,
) -> Self {
Self {
request_id,
connection,
config,
+ database,
cache,
}
}
- async fn lookup_cache(&mut self, qname: &str, qtype: QueryType) -> Option<Packet> {
- let question = DnsQuestion::new(qname.to_string(), qtype);
+ async fn lookup_database(&self, question: &DnsQuestion) -> Option<Packet> {
+ let records = match self
+ .database
+ .get_records(&question.name, question.qtype)
+ .await
+ {
+ Ok(record) => record,
+ Err(err) => {
+ error!("{err}");
+ return None;
+ }
+ };
+
+ if records.is_empty() {
+ return None;
+ }
+
+ let mut packet = Packet::new();
+
+ packet.header.id = self.request_id;
+ packet.header.questions = 1;
+ packet.header.answers = records.len() as u16;
+ packet.header.recursion_desired = true;
+ packet
+ .questions
+ .push(DnsQuestion::new(question.name.to_string(), question.qtype));
+
+ for record in records {
+ packet.answers.push(record);
+ }
+
+ trace!(
+ "Found stored value for {:?} {}",
+ question.qtype,
+ question.name
+ );
+
+ Some(packet)
+ }
+
+ async fn lookup_cache(&self, question: &DnsQuestion) -> Option<Packet> {
let Some((packet, date)) = self.cache.get(&question) else {
return None
};
@@ -46,16 +84,20 @@ impl Resolver {
let ttl = answer.get_ttl();
if diff > ttl {
self.cache.invalidate(&question).await;
- return None
+ return None;
}
}
- trace!("Found cached value for {qtype:?} {qname}");
+ trace!(
+ "Found cached value for {:?} {}",
+ question.qtype,
+ question.name
+ );
Some(packet)
}
- async fn lookup(&mut self, qname: &str, qtype: QueryType, server: (IpAddr, u16)) -> Packet {
+ async fn lookup_fallback(&self, question: &DnsQuestion, server: (IpAddr, u16)) -> Packet {
let mut packet = Packet::new();
packet.header.id = self.request_id;
@@ -63,7 +105,7 @@ impl Resolver {
packet.header.recursion_desired = true;
packet
.questions
- .push(DnsQuestion::new(qname.to_string(), qtype));
+ .push(DnsQuestion::new(question.name.to_string(), question.qtype));
let packet = match self.connection.request_packet(packet, server).await {
Ok(packet) => packet,
@@ -78,28 +120,47 @@ impl Resolver {
packet
}
+ async fn lookup(&self, question: &DnsQuestion, server: (IpAddr, u16)) -> Packet {
+ if let Some(packet) = self.lookup_cache(question).await {
+ return packet;
+ };
+
+ if let Some(packet) = self.lookup_database(question).await {
+ return packet;
+ };
+
+ trace!(
+ "Attempting lookup of {:?} {} with ns {}",
+ question.qtype,
+ question.name,
+ server.0
+ );
+
+ self.lookup_fallback(question, server).await
+ }
+
#[async_recursion]
async fn recursive_lookup(&mut self, qname: &str, qtype: QueryType) -> Packet {
let question = DnsQuestion::new(qname.to_string(), qtype);
- let mut ns = self.config.get_fallback_ns().clone();
-
- if let Some(packet) = self.lookup_cache(qname, qtype).await { return packet }
+ let mut ns = self.config.dns_fallback.clone();
loop {
- trace!("Attempting lookup of {qtype:?} {qname} with ns {ns}");
-
let ns_copy = ns;
let server = (ns_copy, 53);
- let response = self.lookup(qname, qtype, server).await;
+ let response = self.lookup(&question, server).await;
if !response.answers.is_empty() && response.header.rescode == ResultCode::NOERROR {
- self.cache.insert(question, (response.clone(), get_time())).await;
+ self.cache
+ .insert(question, (response.clone(), get_time()))
+ .await;
return response;
}
if response.header.rescode == ResultCode::NXDOMAIN {
- self.cache.insert(question, (response.clone(), get_time())).await;
+ self.cache
+ .insert(question, (response.clone(), get_time()))
+ .await;
return response;
}
@@ -111,9 +172,11 @@ impl Resolver {
let new_ns_name = match response.get_unresolved_ns(qname) {
Some(x) => x,
None => {
- self.cache.insert(question, (response.clone(), get_time())).await;
- return response
- },
+ self.cache
+ .insert(question, (response.clone(), get_time()))
+ .await;
+ return response;
+ }
};
let recursive_response = self.recursive_lookup(new_ns_name, QueryType::A).await;
@@ -121,7 +184,9 @@ impl Resolver {
if let Some(new_ns) = recursive_response.get_random_a() {
ns = new_ns;
} else {
- self.cache.insert(question, (response.clone(), get_time())).await;
+ self.cache
+ .insert(question, (response.clone(), get_time()))
+ .await;
return response;
}
}
diff --git a/src/dns/server.rs b/src/dns/server.rs
new file mode 100644
index 0000000..65d15df
--- /dev/null
+++ b/src/dns/server.rs
@@ -0,0 +1,85 @@
+use super::{
+ binding::Binding,
+ packet::{question::DnsQuestion, Packet},
+ resolver::Resolver,
+};
+use crate::{config::Config, database::Database, Result};
+use moka::future::Cache;
+use std::{net::SocketAddr, sync::Arc, time::Duration};
+use tokio::task::JoinHandle;
+use tracing::{error, info};
+
+pub struct DnsServer {
+ addr: SocketAddr,
+ config: Arc<Config>,
+ database: Arc<Database>,
+ cache: Cache<DnsQuestion, (Packet, u64)>,
+}
+
+impl DnsServer {
+ pub async fn new(config: Config, database: Database) -> Result<Self> {
+ let addr = format!("[::]:{}", config.dns_port).parse::<SocketAddr>()?;
+ let cache = Cache::builder()
+ .time_to_live(Duration::from_secs(60 * 60))
+ .max_capacity(config.dns_cache_size)
+ .build();
+
+ info!("Created DNS cache with size of {}", config.dns_cache_size);
+
+ Ok(Self {
+ addr,
+ config: Arc::new(config),
+ database: Arc::new(database),
+ cache,
+ })
+ }
+
+ pub async fn run(&self) -> Result<(JoinHandle<()>, JoinHandle<()>)> {
+ let tcp = Binding::tcp(self.addr).await?;
+ let tcp_handle = self.listen(tcp);
+
+ let udp = Binding::udp(self.addr).await?;
+ let udp_handle = self.listen(udp);
+
+ info!(
+ "Fallback DNS Server is set to: {:?}",
+ self.config.dns_fallback
+ );
+ info!(
+ "Listening for TCP and UDP traffic on [::]:{}",
+ self.config.dns_port
+ );
+
+ Ok((udp_handle, tcp_handle))
+ }
+
+ fn listen(&self, mut binding: Binding) -> JoinHandle<()> {
+ let config = self.config.clone();
+ let database = self.database.clone();
+ let cache = self.cache.clone();
+ tokio::spawn(async move {
+ let mut id = 0;
+ loop {
+ let Ok(connection) = binding.connect().await else { continue };
+ info!("Received request on {}", binding.name());
+
+ let resolver = Resolver::new(
+ id,
+ connection,
+ config.clone(),
+ database.clone(),
+ cache.clone(),
+ );
+
+ let name = binding.name().to_string();
+ tokio::spawn(async move {
+ if let Err(err) = resolver.handle_query().await {
+ error!("{} request {} failed: {:?}", name, id, err);
+ };
+ });
+
+ id += 1;
+ }
+ })
+ }
+}
diff --git a/src/main.rs b/src/main.rs
index c891d50..679e87b 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -1,19 +1,34 @@
-use std::{time::{UNIX_EPOCH, SystemTime}, env, net::IpAddr};
+use std::time::{SystemTime, UNIX_EPOCH};
use config::Config;
-use server::server::Server;
-use tracing::metadata::LevelFilter;
+use database::Database;
+use dotenv::dotenv;
+use dns::server::DnsServer;
+use tracing::{error, metadata::LevelFilter};
use tracing_subscriber::{
filter::filter_fn, prelude::__tracing_subscriber_SubscriberExt, util::SubscriberInitExt, Layer,
};
+use web::WebServer;
mod config;
-mod packet;
-mod server;
+mod database;
+mod dns;
+mod web;
+
+type Error = Box<dyn std::error::Error>;
+pub type Result<T> = std::result::Result<T, Error>;
#[tokio::main]
async fn main() {
+ if let Err(err) = run().await {
+ error!("{err}")
+ };
+}
+
+async fn run() -> Result<()> {
+ dotenv().ok();
+
tracing_subscriber::registry()
.with(
tracing_subscriber::fmt::layer()
@@ -24,19 +39,20 @@ async fn main() {
)
.init();
- let mut config = Config::new();
+ let config = Config::new();
+ let database = Database::new(config.clone()).await?;
- if let Ok(port) = env::var("PORT").unwrap_or(String::new()).parse::<u16>() {
- config.set_port(port);
- }
+ let dns_server = DnsServer::new(config.clone(), database.clone()).await?;
+ let (udp, tcp) = dns_server.run().await?;
- if let Ok(fallback) = env::var("FALLBACK_DNS").unwrap_or(String::new()).parse::<IpAddr>() {
- config.set_fallback_ns(&fallback);
- }
+ let web_server = WebServer::new(config, database).await?;
+ let web = web_server.run().await?;
- let server = Server::new(config).await.expect("Failed to bind server");
+ tokio::join!(udp).0?;
+ tokio::join!(tcp).0?;
+ tokio::join!(web).0?;
- server.run().await.unwrap();
+ Ok(())
}
pub fn get_time() -> u64 {
diff --git a/src/server/server.rs b/src/server/server.rs
deleted file mode 100644
index e006bb1..0000000
--- a/src/server/server.rs
+++ /dev/null
@@ -1,73 +0,0 @@
-use moka::future::Cache;
-use std::net::SocketAddr;
-use std::sync::Arc;
-use std::time::Duration;
-use tokio::task::JoinHandle;
-use tracing::{error, info};
-
-use crate::config::Config;
-use crate::packet::question::DnsQuestion;
-use crate::packet::{Result, Packet};
-
-use super::binding::Binding;
-use super::resolver::Resolver;
-
-pub struct Server {
- addr: SocketAddr,
- config: Arc<Config>,
- cache: Cache<DnsQuestion, (Packet, u64)>,
-}
-
-impl Server {
- pub async fn new(config: Config) -> Result<Self> {
- let addr = format!("[::]:{}", config.get_port()).parse::<SocketAddr>()?;
- let cache = Cache::builder()
- .time_to_live(Duration::from_secs(60 * 60))
- .max_capacity(1_000)
- .build();
- Ok(Self {
- addr,
- config: Arc::new(config),
- cache,
- })
- }
-
- pub async fn run(&self) -> Result<()> {
- let tcp = Binding::tcp(self.addr).await?;
- let tcp_handle = self.listen(tcp);
-
- let udp = Binding::udp(self.addr).await?;
- let udp_handle = self.listen(udp);
-
- info!("Fallback DNS Server is set to: {:?}", self.config.get_fallback_ns());
- info!("Listening for TCP and UDP traffic on [::]:{}", self.config.get_port());
-
- tokio::join!(tcp_handle)
- .0
- .expect("Failed to join tcp thread");
- tokio::join!(udp_handle)
- .0
- .expect("Failed to join udp thread");
- Ok(())
- }
-
- fn listen(&self, mut binding: Binding) -> JoinHandle<()> {
- let config = self.config.clone();
- let cache = self.cache.clone();
- tokio::spawn(async move {
- let mut id = 0;
- loop {
- let Ok(connection) = binding.connect().await else { continue };
- info!("Received request on {}", binding.name());
-
- let resolver = Resolver::new(id, connection, config.clone(), cache.clone());
-
- if let Err(err) = resolver.handle_query().await {
- error!("{} request {} failed: {:?}", binding.name(), id, err);
- };
-
- id += 1;
- }
- })
- }
-}
diff --git a/src/web/api.rs b/src/web/api.rs
new file mode 100644
index 0000000..1fddb5f
--- /dev/null
+++ b/src/web/api.rs
@@ -0,0 +1,156 @@
+use std::net::IpAddr;
+
+use axum::{
+ extract::Query,
+ response::Response,
+ routing::{get, post, put, delete},
+ Extension, Router,
+};
+use moka::future::Cache;
+use rand::distributions::{Alphanumeric, DistString};
+use serde::Deserialize;
+use tower_cookies::{Cookie, Cookies};
+
+use crate::{config::Config, database::Database, dns::packet::record::DnsRecord};
+
+use super::{
+ extract::{Authorized, Body, RequestIp},
+ http::{json, text},
+};
+
+pub fn router() -> Router {
+ Router::new()
+ .route("/login", post(login))
+ .route("/domains", get(list_domains))
+ .route("/domains", delete(delete_domain))
+ .route("/records", get(get_domain))
+ .route("/records", put(add_record))
+}
+
+async fn list_domains(_: Authorized, Extension(database): Extension<Database>) -> Response {
+ let domains = match database.get_domains().await {
+ Ok(domains) => domains,
+ Err(err) => return text(500, &format!("{err}")),
+ };
+
+ let Ok(domains) = serde_json::to_string(&domains) else {
+ return text(500, "Failed to fetch domains")
+ };
+
+ json(200, &domains)
+}
+
+#[derive(Deserialize)]
+struct DomainRequest {
+ domain: String,
+}
+
+async fn get_domain(
+ _: Authorized,
+ Extension(database): Extension<Database>,
+ Query(query): Query<DomainRequest>,
+) -> Response {
+ let records = match database.get_domain(&query.domain).await {
+ Ok(records) => records,
+ Err(err) => return text(500, &format!("{err}")),
+ };
+
+ let Ok(records) = serde_json::to_string(&records) else {
+ return text(500, "Failed to fetch records")
+ };
+
+ json(200, &records)
+}
+
+async fn delete_domain(
+ _: Authorized,
+ Extension(database): Extension<Database>,
+ Body(body): Body,
+) -> Response {
+
+ let Ok(request) = serde_json::from_str::<DomainRequest>(&body) else {
+ return text(400, "Missing request parameters")
+ };
+
+ let Ok(domains) = database.get_domains().await else {
+ return text(500, "Failed to delete domain")
+ };
+
+ if !domains.contains(&request.domain) {
+ return text(400, "Domain does not exist")
+ }
+
+ if database.delete_domain(request.domain).await.is_err() {
+ return text(500, "Failed to delete domain")
+ };
+
+ return text(204, "Successfully deleted domain")
+}
+
+async fn add_record(
+ _: Authorized,
+ Extension(database): Extension<Database>,
+ Body(body): Body,
+) -> Response {
+ let Ok(record) = serde_json::from_str::<DnsRecord>(&body) else {
+ return text(400, "Invalid DNS record")
+ };
+
+ let allowed = record.get_qtype().allowed_actions();
+ if !allowed.1 {
+ return text(400, "Not allowed to create record")
+ }
+
+ let Ok(records) = database.get_records(&record.get_domain(), record.get_qtype()).await else {
+ return text(500, "Failed to complete record check");
+ };
+
+ if !records.is_empty() && !allowed.0 {
+ return text(400, "Not allowed to create duplicate record")
+ };
+
+ if records.contains(&record) {
+ return text(400, "Not allowed to create duplicate record")
+ }
+
+ if let Err(err) = database.add_record(record).await {
+ return text(500, &format!("{err}"));
+ }
+
+ return text(201, "Added record to database successfully");
+}
+
+#[derive(Deserialize)]
+struct LoginRequest {
+ user: String,
+ pass: String,
+}
+
+async fn login(
+ Extension(config): Extension<Config>,
+ Extension(cache): Extension<Cache<String, IpAddr>>,
+ RequestIp(ip): RequestIp,
+ cookies: Cookies,
+ Body(body): Body,
+) -> Response {
+ let Ok(request) = serde_json::from_str::<LoginRequest>(&body) else {
+ return text(400, "Missing request parameters")
+ };
+
+ if request.user != config.web_user || request.pass != config.web_pass {
+ return text(400, "Invalid credentials");
+ };
+
+ let token = Alphanumeric.sample_string(&mut rand::thread_rng(), 128);
+
+ cache.insert(token.clone(), ip).await;
+
+ let mut cookie = Cookie::new("auth", token);
+ cookie.set_secure(true);
+ cookie.set_http_only(true);
+ cookie.set_path("/");
+
+ cookies.add(cookie);
+
+ text(200, "Successfully logged in")
+}
diff --git a/src/web/extract.rs b/src/web/extract.rs
new file mode 100644
index 0000000..4b6cd7c
--- /dev/null
+++ b/src/web/extract.rs
@@ -0,0 +1,139 @@
+use std::{
+ io::Read,
+ net::{IpAddr, SocketAddr},
+};
+
+use axum::{
+ async_trait,
+ body::HttpBody,
+ extract::{ConnectInfo, FromRequest, FromRequestParts},
+ http::{request::Parts, Request},
+ response::Response,
+ BoxError,
+};
+use bytes::Bytes;
+use moka::future::Cache;
+use tower_cookies::Cookies;
+
+use super::http::text;
+
+pub struct Authorized;
+
+#[async_trait]
+impl<S> FromRequestParts<S> for Authorized
+where
+ S: Send + Sync,
+{
+ type Rejection = Response;
+
+ async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
+ let Ok(Some(cookies)) = Option::<Cookies>::from_request_parts(parts, state).await else {
+ return Err(text(403, "No cookies provided"))
+ };
+
+ let Some(token) = cookies.get("auth") else {
+ return Err(text(403, "No auth token provided"))
+ };
+
+ let auth_ip: IpAddr;
+ {
+ let Some(cache) = parts.extensions.get::<Cache<String, IpAddr>>() else {
+ return Err(text(500, "Failed to load auth store"))
+ };
+
+ let Some(ip) = cache.get(token.value()) else {
+ return Err(text(401, "Unauthorized"))
+ };
+
+ auth_ip = ip
+ }
+
+ let Ok(Some(RequestIp(ip))) = Option::<RequestIp>::from_request_parts(parts, state).await else {
+ return Err(text(403, "You have no ip"))
+ };
+
+ if auth_ip != ip {
+ return Err(text(403, "Auth token does not match current ip"));
+ }
+
+ Ok(Self)
+ }
+}
+
+pub struct RequestIp(pub IpAddr);
+
+#[async_trait]
+impl<S> FromRequestParts<S> for RequestIp
+where
+ S: Send + Sync,
+{
+ type Rejection = Response;
+
+ async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
+ let headers = &parts.headers;
+
+ let forwardedfor = headers
+ .get("x-forwarded-for")
+ .and_then(|h| h.to_str().ok())
+ .and_then(|h| {
+ h.split(',')
+ .rev()
+ .find_map(|s| s.trim().parse::<IpAddr>().ok())
+ });
+
+ if let Some(forwardedfor) = forwardedfor {
+ return Ok(Self(forwardedfor));
+ }
+
+ let realip = headers
+ .get("x-real-ip")
+ .and_then(|hv| hv.to_str().ok())
+ .and_then(|s| s.parse::<IpAddr>().ok());
+
+ if let Some(realip) = realip {
+ return Ok(Self(realip));
+ }
+
+ let realip = headers
+ .get("x-real-ip")
+ .and_then(|hv| hv.to_str().ok())
+ .and_then(|s| s.parse::<IpAddr>().ok());
+
+ if let Some(realip) = realip {
+ return Ok(Self(realip));
+ }
+
+ let info = parts.extensions.get::<ConnectInfo<SocketAddr>>();
+
+ if let Some(info) = info {
+ return Ok(Self(info.0.ip()));
+ }
+
+ Err(text(403, "You have no ip"))
+ }
+}
+
+pub struct Body(pub String);
+
+#[async_trait]
+impl<S, B> FromRequest<S, B> for Body
+where
+ B: HttpBody + Sync + Send + 'static,
+ B::Data: Send,
+ B::Error: Into<BoxError>,
+ S: Send + Sync,
+{
+ type Rejection = Response;
+
+ async fn from_request(req: Request<B>, state: &S) -> Result<Self, Self::Rejection> {
+ let Ok(bytes) = Bytes::from_request(req, state).await else {
+ return Err(text(413, "Payload too large"));
+ };
+
+ let Ok(body) = String::from_utf8(bytes.bytes().flatten().collect()) else {
+ return Err(text(400, "Invalid utf8 body"))
+ };
+
+ Ok(Self(body))
+ }
+}
diff --git a/src/web/file.rs b/src/web/file.rs
new file mode 100644
index 0000000..73ecdc9
--- /dev/null
+++ b/src/web/file.rs
@@ -0,0 +1,31 @@
+use axum::{extract::Path, response::Response};
+
+use super::http::serve;
+
+pub async fn js(Path(path): Path<String>) -> Response {
+ let path = format!("/js/{path}");
+ serve(&path).await
+}
+
+pub async fn css(Path(path): Path<String>) -> Response {
+ let path = format!("/css/{path}");
+ serve(&path).await
+}
+
+pub async fn fonts(Path(path): Path<String>) -> Response {
+ let path = format!("/fonts/{path}");
+ serve(&path).await
+}
+
+pub async fn image(Path(path): Path<String>) -> Response {
+ let path = format!("/image/{path}");
+ serve(&path).await
+}
+
+pub async fn favicon() -> Response {
+ serve("/favicon.ico").await
+}
+
+pub async fn robots() -> Response {
+ serve("/robots.txt").await
+}
diff --git a/src/web/http.rs b/src/web/http.rs
new file mode 100644
index 0000000..7ab1b11
--- /dev/null
+++ b/src/web/http.rs
@@ -0,0 +1,50 @@
+use axum::{
+ body::Body,
+ http::{header::HeaderName, HeaderValue, Request, StatusCode},
+ response::{IntoResponse, Response},
+};
+use std::str;
+use tower::ServiceExt;
+use tower_http::services::ServeFile;
+
+pub fn text(code: u16, msg: &str) -> Response {
+ (status_code(code), msg.to_owned()).into_response()
+}
+
+pub fn json(code: u16, json: &str) -> Response {
+ let mut res = (status_code(code), json.to_owned()).into_response();
+ res.headers_mut().insert(
+ HeaderName::from_static("content-type"),
+ HeaderValue::from_static("application/json"),
+ );
+ res
+}
+
+pub async fn serve(path: &str) -> Response {
+ if !path.chars().any(|c| c == '.') {
+ return text(403, "Invalid file path");
+ }
+
+ let path = format!("public{path}");
+ let file = ServeFile::new(path);
+
+ let Ok(mut res) = file.oneshot(Request::new(Body::empty())).await else {
+ tracing::error!("Error while fetching file");
+ return text(500, "Error when fetching file")
+ };
+
+ if res.status() != StatusCode::OK {
+ return text(404, "File not found");
+ }
+
+ res.headers_mut().insert(
+ HeaderName::from_static("cache-control"),
+ HeaderValue::from_static("max-age=300"),
+ );
+
+ res.into_response()
+}
+
+fn status_code(code: u16) -> StatusCode {
+ StatusCode::from_u16(code).map_or(StatusCode::OK, |code| code)
+}
diff --git a/src/web/mod.rs b/src/web/mod.rs
new file mode 100644
index 0000000..530a3f9
--- /dev/null
+++ b/src/web/mod.rs
@@ -0,0 +1,82 @@
+use std::net::{IpAddr, SocketAddr, TcpListener};
+use std::time::Duration;
+
+use axum::routing::get;
+use axum::{Extension, Router};
+use moka::future::Cache;
+use tokio::task::JoinHandle;
+use tower_cookies::CookieManagerLayer;
+use tracing::{error, info};
+
+use crate::config::Config;
+use crate::database::Database;
+use crate::Result;
+
+mod api;
+mod extract;
+mod file;
+mod http;
+mod pages;
+
+pub struct WebServer {
+ config: Config,
+ database: Database,
+ addr: SocketAddr,
+}
+
+impl WebServer {
+ pub async fn new(config: Config, database: Database) -> Result<Self> {
+ let addr = format!("[::]:{}", config.web_port).parse::<SocketAddr>()?;
+ Ok(Self {
+ config,
+ database,
+ addr,
+ })
+ }
+
+ pub async fn run(&self) -> Result<JoinHandle<()>> {
+ let config = self.config.clone();
+ let database = self.database.clone();
+ let listener = TcpListener::bind(self.addr)?;
+
+ info!(
+ "Listening for HTTP traffic on [::]:{}",
+ self.config.web_port
+ );
+
+ let app = Self::router(config, database);
+ let server = axum::Server::from_tcp(listener)?;
+
+ let web_handle = tokio::spawn(async move {
+ if let Err(err) = server
+ .serve(app.into_make_service_with_connect_info::<SocketAddr>())
+ .await
+ {
+ error!("{err}");
+ }
+ });
+
+ Ok(web_handle)
+ }
+
+ fn router(config: Config, database: Database) -> Router {
+ let cache: Cache<String, IpAddr> = Cache::builder()
+ .time_to_live(Duration::from_secs(60 * 15))
+ .max_capacity(config.dns_cache_size)
+ .build();
+
+ Router::new()
+ .nest("/", pages::router())
+ .nest("/api", api::router())
+ .layer(Extension(config))
+ .layer(Extension(cache))
+ .layer(Extension(database))
+ .layer(CookieManagerLayer::new())
+ .route("/js/*path", get(file::js))
+ .route("/css/*path", get(file::css))
+ .route("/fonts/*path", get(file::fonts))
+ .route("/image/*path", get(file::image))
+ .route("/favicon.ico", get(file::favicon))
+ .route("/robots.txt", get(file::robots))
+ }
+}
diff --git a/src/web/pages.rs b/src/web/pages.rs
new file mode 100644
index 0000000..a8605ef
--- /dev/null
+++ b/src/web/pages.rs
@@ -0,0 +1,31 @@
+use axum::{response::Response, routing::get, Router};
+
+use super::{extract::Authorized, http::serve};
+
+pub fn router() -> Router {
+ Router::new()
+ .route("/", get(root))
+ .route("/login", get(login))
+ .route("/home", get(home))
+ .route("/domain", get(domain))
+}
+
+async fn root(user: Option<Authorized>) -> Response {
+ if user.is_some() {
+ home().await
+ } else {
+ login().await
+ }
+}
+
+async fn login() -> Response {
+ serve("/login.html").await
+}
+
+async fn home() -> Response {
+ serve("/home.html").await
+}
+
+async fn domain() -> Response {
+ serve("/domain.html").await
+}