summaryrefslogtreecommitdiff
path: root/src/dns
diff options
context:
space:
mode:
authorTyler Murphy <tylermurphy534@gmail.com>2023-03-06 18:50:08 -0500
committerTyler Murphy <tylermurphy534@gmail.com>2023-03-06 18:50:08 -0500
commitb1fb410affb7bcd2e714abac01d22c4a5332c344 (patch)
tree7ebb621ab9b73e3e1fbaeb0ef8c19abef95b7c9f /src/dns
parentfinialize initial dns + caching (diff)
downloadwrapper-b1fb410affb7bcd2e714abac01d22c4a5332c344.tar.gz
wrapper-b1fb410affb7bcd2e714abac01d22c4a5332c344.tar.bz2
wrapper-b1fb410affb7bcd2e714abac01d22c4a5332c344.zip
finish dns and start webserver
Diffstat (limited to '')
-rw-r--r--src/dns/binding.rs (renamed from src/server/binding.rs)10
-rw-r--r--src/dns/mod.rs (renamed from src/server/mod.rs)1
-rw-r--r--src/dns/packet/buffer.rs (renamed from src/packet/buffer.rs)51
-rw-r--r--src/dns/packet/header.rs (renamed from src/packet/header.rs)3
-rw-r--r--src/dns/packet/mod.rs (renamed from src/packet/mod.rs)4
-rw-r--r--src/dns/packet/query.rs (renamed from src/packet/query.rs)27
-rw-r--r--src/dns/packet/question.rs (renamed from src/packet/question.rs)0
-rw-r--r--src/dns/packet/record.rs (renamed from src/packet/record.rs)82
-rw-r--r--src/dns/packet/result.rs (renamed from src/packet/result.rs)0
-rw-r--r--src/dns/resolver.rs (renamed from src/server/resolver.rs)115
-rw-r--r--src/dns/server.rs85
11 files changed, 293 insertions, 85 deletions
diff --git a/src/server/binding.rs b/src/dns/binding.rs
index 1c69651..4c7e15f 100644
--- a/src/server/binding.rs
+++ b/src/dns/binding.rs
@@ -3,7 +3,8 @@ use std::{
sync::Arc,
};
-use crate::packet::{buffer::PacketBuffer, Packet, Result};
+use super::packet::{buffer::PacketBuffer, Packet};
+use crate::Result;
use tokio::{
io::{AsyncReadExt, AsyncWriteExt},
net::{TcpListener, TcpStream, UdpSocket},
@@ -140,11 +141,4 @@ impl Connection {
}
}
}
-
- // fn pb(buf: &[u8]) {
- // for i in 0..buf.len() {
- // print!("{:02X?} ", buf[i]);
- // }
- // println!("");
- // }
}
diff --git a/src/server/mod.rs b/src/dns/mod.rs
index 25076ef..6f1e59e 100644
--- a/src/server/mod.rs
+++ b/src/dns/mod.rs
@@ -1,3 +1,4 @@
mod binding;
+pub mod packet;
mod resolver;
pub mod server;
diff --git a/src/packet/buffer.rs b/src/dns/packet/buffer.rs
index 4ecc605..058156e 100644
--- a/src/packet/buffer.rs
+++ b/src/dns/packet/buffer.rs
@@ -1,4 +1,4 @@
-use super::Result;
+use crate::Result;
pub struct PacketBuffer {
pub buf: Vec<u8>,
@@ -9,19 +9,9 @@ pub struct PacketBuffer {
impl PacketBuffer {
pub fn new(buf: Vec<u8>) -> Self {
Self {
+ size: buf.len(),
buf,
pos: 0,
- size: 0,
- }
- }
-
- fn check(&mut self, pos: usize) {
- if self.size < pos {
- self.size = pos;
- }
-
- if self.buf.len() <= self.size {
- self.buf.resize(self.size + 1, 0x00);
}
}
@@ -42,32 +32,25 @@ impl PacketBuffer {
}
pub fn read(&mut self) -> Result<u8> {
- // if self.pos >= 512 {
- // error!("Tried to read past end of buffer");
- // return Err("End of buffer".into());
- // }
- self.check(self.pos);
+ if self.pos >= self.size {
+ return Err("Tried to read past end of buffer".into());
+ }
let res = self.buf[self.pos];
self.pos += 1;
-
Ok(res)
}
pub fn get(&mut self, pos: usize) -> Result<u8> {
- // if pos >= 512 {
- // error!("Tried to read past end of buffer");
- // return Err("End of buffer".into());
- // }
- self.check(pos);
+ if pos >= self.size {
+ return Err("Tried to read past end of buffer".into());
+ }
Ok(self.buf[pos])
}
pub fn get_range(&mut self, start: usize, len: usize) -> Result<&[u8]> {
- // if start + len >= 512 {
- // error!("Tried to read past end of buffer");
- // return Err("End of buffer".into());
- // }
- self.check(start + len);
+ if start + len >= self.size {
+ return Err("Tried to read past end of buffer".into());
+ }
Ok(&self.buf[start..start + len])
}
@@ -169,7 +152,13 @@ impl PacketBuffer {
}
pub fn write(&mut self, val: u8) -> Result<()> {
- self.check(self.pos);
+ if self.size < self.pos {
+ self.size = self.pos;
+ }
+
+ if self.buf.len() <= self.size {
+ self.buf.resize(self.size + 1, 0x00);
+ }
self.buf[self.pos] = val;
self.pos += 1;
@@ -208,7 +197,9 @@ impl PacketBuffer {
}
}
- self.write_u8(0)?;
+ if !qname.is_empty() {
+ self.write_u8(0)?;
+ }
Ok(())
}
diff --git a/src/packet/header.rs b/src/dns/packet/header.rs
index a75f6ba..2355ecb 100644
--- a/src/packet/header.rs
+++ b/src/dns/packet/header.rs
@@ -1,4 +1,5 @@
-use super::{buffer::PacketBuffer, result::ResultCode, Result};
+use super::{buffer::PacketBuffer, result::ResultCode};
+use crate::Result;
#[derive(Clone, Debug)]
pub struct DnsHeader {
diff --git a/src/packet/mod.rs b/src/dns/packet/mod.rs
index 0b7cb7b..9873b94 100644
--- a/src/packet/mod.rs
+++ b/src/dns/packet/mod.rs
@@ -4,9 +4,7 @@ use self::{
buffer::PacketBuffer, header::DnsHeader, query::QueryType, question::DnsQuestion,
record::DnsRecord,
};
-
-type Error = Box<dyn std::error::Error>;
-pub type Result<T> = std::result::Result<T, Error>;
+use crate::Result;
pub mod buffer;
pub mod header;
diff --git a/src/packet/query.rs b/src/dns/packet/query.rs
index cae6f09..732b9b2 100644
--- a/src/packet/query.rs
+++ b/src/dns/packet/query.rs
@@ -12,6 +12,8 @@ pub enum QueryType {
SRV, // 33
OPT, // 41
CAA, // 257
+ AR, // 1000
+ AAAAR, // 1001
}
impl QueryType {
@@ -29,6 +31,8 @@ impl QueryType {
Self::SRV => 33,
Self::OPT => 41,
Self::CAA => 257,
+ Self::AR => 1000,
+ Self::AAAAR => 1001,
}
}
@@ -45,7 +49,30 @@ impl QueryType {
33 => Self::SRV,
41 => Self::OPT,
257 => Self::CAA,
+ 1000 => Self::AR,
+ 1001 => Self::AAAAR,
_ => Self::UNKNOWN(num),
}
}
+
+ pub fn allowed_actions(&self) -> (bool, bool) {
+ // 0. duplicates allowed
+ // 1. allowed to be created by database
+ match self {
+ QueryType::UNKNOWN(_) => (false, false),
+ QueryType::A => (true, true),
+ QueryType::NS => (false, true),
+ QueryType::CNAME => (false, true),
+ QueryType::SOA => (false, false),
+ QueryType::PTR => (false, true),
+ QueryType::MX => (false, true),
+ QueryType::TXT => (true, true),
+ QueryType::AAAA => (true, true),
+ QueryType::SRV => (false, true),
+ QueryType::OPT => (false, false),
+ QueryType::CAA => (false, true),
+ QueryType::AR => (false, true),
+ QueryType::AAAAR => (false, true),
+ }
+ }
}
diff --git a/src/packet/question.rs b/src/dns/packet/question.rs
index 9042e1c..9042e1c 100644
--- a/src/packet/question.rs
+++ b/src/dns/packet/question.rs
diff --git a/src/packet/record.rs b/src/dns/packet/record.rs
index c29dd8f..88008f0 100644
--- a/src/packet/record.rs
+++ b/src/dns/packet/record.rs
@@ -1,11 +1,12 @@
use std::net::{Ipv4Addr, Ipv6Addr};
+use rand::RngCore;
+use serde::{Deserialize, Serialize};
use tracing::{trace, warn};
use super::{buffer::PacketBuffer, query::QueryType, Result};
-#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
-#[allow(dead_code)]
+#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)]
pub enum DnsRecord {
UNKNOWN {
domain: String,
@@ -76,10 +77,17 @@ pub enum DnsRecord {
value: String,
ttl: u32,
}, // 257
+ AR {
+ domain: String,
+ ttl: u32,
+ },
+ AAAAR {
+ domain: String,
+ ttl: u32,
+ },
}
impl DnsRecord {
-
pub fn read(buffer: &mut PacketBuffer) -> Result<Self> {
let mut domain = String::new();
buffer.read_qname(&mut domain)?;
@@ -90,10 +98,10 @@ impl DnsRecord {
let ttl = buffer.read_u32()?;
let data_len = buffer.read_u16()?;
- let header_pos = buffer.pos();
-
trace!("Reading DNS Record TYPE: {:?}", qtype);
+ let header_pos = buffer.pos();
+
match qtype {
QueryType::A => {
let raw_addr = buffer.read_u32()?;
@@ -471,6 +479,29 @@ impl DnsRecord {
let size = buffer.pos() - (pos + 2);
buffer.set_u16(pos, size as u16)?;
}
+ Self::AR { ref domain, ttl } => {
+ buffer.write_qname(domain)?;
+ buffer.write_u16(QueryType::A.to_num())?;
+ buffer.write_u16(1)?;
+ buffer.write_u32(ttl)?;
+ buffer.write_u16(4)?;
+
+ let mut rand = rand::thread_rng();
+ buffer.write_u32(rand.next_u32())?;
+ }
+ Self::AAAAR { ref domain, ttl } => {
+ buffer.write_qname(domain)?;
+ buffer.write_u16(QueryType::A.to_num())?;
+ buffer.write_u16(1)?;
+ buffer.write_u32(ttl)?;
+ buffer.write_u16(4)?;
+
+ let mut rand = rand::thread_rng();
+ buffer.write_u32(rand.next_u32())?;
+ buffer.write_u32(rand.next_u32())?;
+ buffer.write_u32(rand.next_u32())?;
+ buffer.write_u32(rand.next_u32())?;
+ }
Self::UNKNOWN { .. } => {
warn!("Skipping record: {self:?}");
}
@@ -479,20 +510,35 @@ impl DnsRecord {
Ok(buffer.pos() - start_pos)
}
+ pub fn get_domain(&self) -> String {
+ self.get_shared_domain().0
+ }
+
+ pub fn get_qtype(&self) -> QueryType {
+ self.get_shared_domain().1
+ }
+
pub fn get_ttl(&self) -> u32 {
- match *self {
- DnsRecord::UNKNOWN { .. } => 0,
- DnsRecord::AAAA { ttl, .. } => ttl,
- DnsRecord::A { ttl, .. } => ttl,
- DnsRecord::NS { ttl, .. } => ttl,
- DnsRecord::CNAME { ttl, .. } => ttl,
- DnsRecord::SOA { ttl, .. } => ttl,
- DnsRecord::PTR { ttl, .. } => ttl,
- DnsRecord::MX { ttl, .. } => ttl,
- DnsRecord::TXT { ttl, .. } => ttl,
- DnsRecord::SRV { ttl, .. } => ttl,
- DnsRecord::CAA { ttl, .. } => ttl,
+ self.get_shared_domain().2
+ }
+
+ fn get_shared_domain(&self) -> (String, QueryType, u32) {
+ match self {
+ DnsRecord::UNKNOWN {
+ domain, ttl, qtype, ..
+ } => (domain.clone(), QueryType::UNKNOWN(*qtype), *ttl),
+ DnsRecord::AAAA { domain, ttl, .. } => (domain.clone(), QueryType::AAAA, *ttl),
+ DnsRecord::A { domain, ttl, .. } => (domain.clone(), QueryType::A, *ttl),
+ DnsRecord::NS { domain, ttl, .. } => (domain.clone(), QueryType::NS, *ttl),
+ DnsRecord::CNAME { domain, ttl, .. } => (domain.clone(), QueryType::CNAME, *ttl),
+ DnsRecord::SOA { domain, ttl, .. } => (domain.clone(), QueryType::SOA, *ttl),
+ DnsRecord::PTR { domain, ttl, .. } => (domain.clone(), QueryType::PTR, *ttl),
+ DnsRecord::MX { domain, ttl, .. } => (domain.clone(), QueryType::MX, *ttl),
+ DnsRecord::TXT { domain, ttl, .. } => (domain.clone(), QueryType::TXT, *ttl),
+ DnsRecord::SRV { domain, ttl, .. } => (domain.clone(), QueryType::SRV, *ttl),
+ DnsRecord::CAA { domain, ttl, .. } => (domain.clone(), QueryType::CAA, *ttl),
+ DnsRecord::AR { domain, ttl, .. } => (domain.clone(), QueryType::AR, *ttl),
+ DnsRecord::AAAAR { domain, ttl, .. } => (domain.clone(), QueryType::AAAAR, *ttl),
}
}
-
}
diff --git a/src/packet/result.rs b/src/dns/packet/result.rs
index 41c8ba9..41c8ba9 100644
--- a/src/packet/result.rs
+++ b/src/dns/packet/result.rs
diff --git a/src/server/resolver.rs b/src/dns/resolver.rs
index 464620c..18b5bba 100644
--- a/src/server/resolver.rs
+++ b/src/dns/resolver.rs
@@ -1,11 +1,7 @@
use super::binding::Connection;
-use crate::{
- config::Config,
- packet::{
- query::QueryType, question::DnsQuestion, result::ResultCode, Packet,
- Result,
- }, get_time,
-};
+use super::packet::{query::QueryType, question::DnsQuestion, result::ResultCode, Packet};
+use crate::Result;
+use crate::{config::Config, database::Database, get_time};
use async_recursion::async_recursion;
use moka::future::Cache;
use std::{net::IpAddr, sync::Arc, time::Duration};
@@ -15,6 +11,7 @@ pub struct Resolver {
request_id: u16,
connection: Connection,
config: Arc<Config>,
+ database: Arc<Database>,
cache: Cache<DnsQuestion, (Packet, u64)>,
}
@@ -23,18 +20,59 @@ impl Resolver {
request_id: u16,
connection: Connection,
config: Arc<Config>,
+ database: Arc<Database>,
cache: Cache<DnsQuestion, (Packet, u64)>,
) -> Self {
Self {
request_id,
connection,
config,
+ database,
cache,
}
}
- async fn lookup_cache(&mut self, qname: &str, qtype: QueryType) -> Option<Packet> {
- let question = DnsQuestion::new(qname.to_string(), qtype);
+ async fn lookup_database(&self, question: &DnsQuestion) -> Option<Packet> {
+ let records = match self
+ .database
+ .get_records(&question.name, question.qtype)
+ .await
+ {
+ Ok(record) => record,
+ Err(err) => {
+ error!("{err}");
+ return None;
+ }
+ };
+
+ if records.is_empty() {
+ return None;
+ }
+
+ let mut packet = Packet::new();
+
+ packet.header.id = self.request_id;
+ packet.header.questions = 1;
+ packet.header.answers = records.len() as u16;
+ packet.header.recursion_desired = true;
+ packet
+ .questions
+ .push(DnsQuestion::new(question.name.to_string(), question.qtype));
+
+ for record in records {
+ packet.answers.push(record);
+ }
+
+ trace!(
+ "Found stored value for {:?} {}",
+ question.qtype,
+ question.name
+ );
+
+ Some(packet)
+ }
+
+ async fn lookup_cache(&self, question: &DnsQuestion) -> Option<Packet> {
let Some((packet, date)) = self.cache.get(&question) else {
return None
};
@@ -46,16 +84,20 @@ impl Resolver {
let ttl = answer.get_ttl();
if diff > ttl {
self.cache.invalidate(&question).await;
- return None
+ return None;
}
}
- trace!("Found cached value for {qtype:?} {qname}");
+ trace!(
+ "Found cached value for {:?} {}",
+ question.qtype,
+ question.name
+ );
Some(packet)
}
- async fn lookup(&mut self, qname: &str, qtype: QueryType, server: (IpAddr, u16)) -> Packet {
+ async fn lookup_fallback(&self, question: &DnsQuestion, server: (IpAddr, u16)) -> Packet {
let mut packet = Packet::new();
packet.header.id = self.request_id;
@@ -63,7 +105,7 @@ impl Resolver {
packet.header.recursion_desired = true;
packet
.questions
- .push(DnsQuestion::new(qname.to_string(), qtype));
+ .push(DnsQuestion::new(question.name.to_string(), question.qtype));
let packet = match self.connection.request_packet(packet, server).await {
Ok(packet) => packet,
@@ -78,28 +120,47 @@ impl Resolver {
packet
}
+ async fn lookup(&self, question: &DnsQuestion, server: (IpAddr, u16)) -> Packet {
+ if let Some(packet) = self.lookup_cache(question).await {
+ return packet;
+ };
+
+ if let Some(packet) = self.lookup_database(question).await {
+ return packet;
+ };
+
+ trace!(
+ "Attempting lookup of {:?} {} with ns {}",
+ question.qtype,
+ question.name,
+ server.0
+ );
+
+ self.lookup_fallback(question, server).await
+ }
+
#[async_recursion]
async fn recursive_lookup(&mut self, qname: &str, qtype: QueryType) -> Packet {
let question = DnsQuestion::new(qname.to_string(), qtype);
- let mut ns = self.config.get_fallback_ns().clone();
-
- if let Some(packet) = self.lookup_cache(qname, qtype).await { return packet }
+ let mut ns = self.config.dns_fallback.clone();
loop {
- trace!("Attempting lookup of {qtype:?} {qname} with ns {ns}");
-
let ns_copy = ns;
let server = (ns_copy, 53);
- let response = self.lookup(qname, qtype, server).await;
+ let response = self.lookup(&question, server).await;
if !response.answers.is_empty() && response.header.rescode == ResultCode::NOERROR {
- self.cache.insert(question, (response.clone(), get_time())).await;
+ self.cache
+ .insert(question, (response.clone(), get_time()))
+ .await;
return response;
}
if response.header.rescode == ResultCode::NXDOMAIN {
- self.cache.insert(question, (response.clone(), get_time())).await;
+ self.cache
+ .insert(question, (response.clone(), get_time()))
+ .await;
return response;
}
@@ -111,9 +172,11 @@ impl Resolver {
let new_ns_name = match response.get_unresolved_ns(qname) {
Some(x) => x,
None => {
- self.cache.insert(question, (response.clone(), get_time())).await;
- return response
- },
+ self.cache
+ .insert(question, (response.clone(), get_time()))
+ .await;
+ return response;
+ }
};
let recursive_response = self.recursive_lookup(new_ns_name, QueryType::A).await;
@@ -121,7 +184,9 @@ impl Resolver {
if let Some(new_ns) = recursive_response.get_random_a() {
ns = new_ns;
} else {
- self.cache.insert(question, (response.clone(), get_time())).await;
+ self.cache
+ .insert(question, (response.clone(), get_time()))
+ .await;
return response;
}
}
diff --git a/src/dns/server.rs b/src/dns/server.rs
new file mode 100644
index 0000000..65d15df
--- /dev/null
+++ b/src/dns/server.rs
@@ -0,0 +1,85 @@
+use super::{
+ binding::Binding,
+ packet::{question::DnsQuestion, Packet},
+ resolver::Resolver,
+};
+use crate::{config::Config, database::Database, Result};
+use moka::future::Cache;
+use std::{net::SocketAddr, sync::Arc, time::Duration};
+use tokio::task::JoinHandle;
+use tracing::{error, info};
+
+pub struct DnsServer {
+ addr: SocketAddr,
+ config: Arc<Config>,
+ database: Arc<Database>,
+ cache: Cache<DnsQuestion, (Packet, u64)>,
+}
+
+impl DnsServer {
+ pub async fn new(config: Config, database: Database) -> Result<Self> {
+ let addr = format!("[::]:{}", config.dns_port).parse::<SocketAddr>()?;
+ let cache = Cache::builder()
+ .time_to_live(Duration::from_secs(60 * 60))
+ .max_capacity(config.dns_cache_size)
+ .build();
+
+ info!("Created DNS cache with size of {}", config.dns_cache_size);
+
+ Ok(Self {
+ addr,
+ config: Arc::new(config),
+ database: Arc::new(database),
+ cache,
+ })
+ }
+
+ pub async fn run(&self) -> Result<(JoinHandle<()>, JoinHandle<()>)> {
+ let tcp = Binding::tcp(self.addr).await?;
+ let tcp_handle = self.listen(tcp);
+
+ let udp = Binding::udp(self.addr).await?;
+ let udp_handle = self.listen(udp);
+
+ info!(
+ "Fallback DNS Server is set to: {:?}",
+ self.config.dns_fallback
+ );
+ info!(
+ "Listening for TCP and UDP traffic on [::]:{}",
+ self.config.dns_port
+ );
+
+ Ok((udp_handle, tcp_handle))
+ }
+
+ fn listen(&self, mut binding: Binding) -> JoinHandle<()> {
+ let config = self.config.clone();
+ let database = self.database.clone();
+ let cache = self.cache.clone();
+ tokio::spawn(async move {
+ let mut id = 0;
+ loop {
+ let Ok(connection) = binding.connect().await else { continue };
+ info!("Received request on {}", binding.name());
+
+ let resolver = Resolver::new(
+ id,
+ connection,
+ config.clone(),
+ database.clone(),
+ cache.clone(),
+ );
+
+ let name = binding.name().to_string();
+ tokio::spawn(async move {
+ if let Err(err) = resolver.handle_query().await {
+ error!("{} request {} failed: {:?}", name, id, err);
+ };
+ });
+
+ id += 1;
+ }
+ })
+ }
+}