diff options
author | Tyler Murphy <tylermurphy534@gmail.com> | 2023-03-03 00:10:21 -0500 |
---|---|---|
committer | Tyler Murphy <tylermurphy534@gmail.com> | 2023-03-03 00:10:21 -0500 |
commit | 0f40ab89e3b523ac206077d932a0e2d40d75f7e0 (patch) | |
tree | c4914050d1bbca8af77347220c0785c8ebefa213 /src/server | |
parent | clippy my beloved (diff) | |
download | wrapper-0f40ab89e3b523ac206077d932a0e2d40d75f7e0.tar.gz wrapper-0f40ab89e3b523ac206077d932a0e2d40d75f7e0.tar.bz2 wrapper-0f40ab89e3b523ac206077d932a0e2d40d75f7e0.zip |
finialize initial dns + caching
Diffstat (limited to 'src/server')
-rw-r--r-- | src/server/binding.rs | 150 | ||||
-rw-r--r-- | src/server/mod.rs | 3 | ||||
-rw-r--r-- | src/server/resolver.rs | 165 | ||||
-rw-r--r-- | src/server/server.rs | 73 |
4 files changed, 391 insertions, 0 deletions
diff --git a/src/server/binding.rs b/src/server/binding.rs new file mode 100644 index 0000000..1c69651 --- /dev/null +++ b/src/server/binding.rs @@ -0,0 +1,150 @@ +use std::{ + net::{IpAddr, SocketAddr}, + sync::Arc, +}; + +use crate::packet::{buffer::PacketBuffer, Packet, Result}; +use tokio::{ + io::{AsyncReadExt, AsyncWriteExt}, + net::{TcpListener, TcpStream, UdpSocket}, +}; +use tracing::trace; + +pub enum Binding { + UDP(Arc<UdpSocket>), + TCP(TcpListener), +} + +impl Binding { + pub async fn udp(addr: SocketAddr) -> Result<Self> { + let socket = UdpSocket::bind(addr).await?; + Ok(Self::UDP(Arc::new(socket))) + } + + pub async fn tcp(addr: SocketAddr) -> Result<Self> { + let socket = TcpListener::bind(addr).await?; + Ok(Self::TCP(socket)) + } + + pub fn name(&self) -> &str { + match self { + Binding::UDP(_) => "UDP", + Binding::TCP(_) => "TCP", + } + } + + pub async fn connect(&mut self) -> Result<Connection> { + match self { + Self::UDP(socket) => { + let mut buf = [0; 512]; + let (_, addr) = socket.recv_from(&mut buf).await?; + Ok(Connection::UDP(socket.clone(), addr, buf)) + } + Self::TCP(socket) => { + let (stream, _) = socket.accept().await?; + Ok(Connection::TCP(stream)) + } + } + } +} + +pub enum Connection { + UDP(Arc<UdpSocket>, SocketAddr, [u8; 512]), + TCP(TcpStream), +} + +impl Connection { + pub async fn read_packet(&mut self) -> Result<Packet> { + let data = self.read().await?; + let mut packet_buffer = PacketBuffer::new(data); + + let packet = Packet::from_buffer(&mut packet_buffer)?; + Ok(packet) + } + + pub async fn write_packet(self, mut packet: Packet) -> Result<()> { + let mut packet_buffer = PacketBuffer::new(Vec::new()); + packet.write(&mut packet_buffer)?; + + self.write(packet_buffer.buf).await?; + Ok(()) + } + + pub async fn request_packet(&self, mut packet: Packet, dest: (IpAddr, u16)) -> Result<Packet> { + let mut packet_buffer = PacketBuffer::new(Vec::new()); + packet.write(&mut packet_buffer)?; + + let data = self.request(packet_buffer.buf, dest).await?; + let mut packet_buffer = PacketBuffer::new(data); + + let packet = Packet::from_buffer(&mut packet_buffer)?; + Ok(packet) + } + + async fn read(&mut self) -> Result<Vec<u8>> { + trace!("Reading DNS packet"); + match self { + Self::UDP(_, _, src) => Ok(Vec::from(*src)), + Self::TCP(stream) => { + let size = stream.read_u16().await?; + let mut buf = Vec::with_capacity(size as usize); + stream.read_buf(&mut buf).await?; + Ok(buf) + } + } + } + + async fn write(self, mut buf: Vec<u8>) -> Result<()> { + trace!("Returning DNS packet"); + match self { + Self::UDP(socket, addr, _) => { + if buf.len() > 512 { + buf[2] = buf[2] | 0x03; + socket.send_to(&buf[0..512], addr).await?; + } else { + socket.send_to(&buf, addr).await?; + } + Ok(()) + } + Self::TCP(mut stream) => { + stream.write_u16(buf.len() as u16).await?; + stream.write(&buf[0..buf.len()]).await?; + Ok(()) + } + } + } + + async fn request(&self, buf: Vec<u8>, dest: (IpAddr, u16)) -> Result<Vec<u8>> { + match self { + Self::UDP(_socket, _addr, _src) => { + let local_addr = "[::]:0".parse::<SocketAddr>()?; + let socket = UdpSocket::bind(local_addr).await?; + socket.send_to(&buf, dest).await?; + + let mut buf = [0; 512]; + socket.recv_from(&mut buf).await?; + + Ok(Vec::from(buf)) + } + Self::TCP(_stream) => { + let mut stream = TcpStream::connect(dest).await?; + stream.write_u16((buf.len()) as u16).await?; + stream.write_all(&buf[0..buf.len()]).await?; + + stream.readable().await?; + let size = stream.read_u16().await?; + let mut buf = Vec::with_capacity(size as usize); + stream.read_buf(&mut buf).await?; + + Ok(buf) + } + } + } + + // fn pb(buf: &[u8]) { + // for i in 0..buf.len() { + // print!("{:02X?} ", buf[i]); + // } + // println!(""); + // } +} diff --git a/src/server/mod.rs b/src/server/mod.rs new file mode 100644 index 0000000..25076ef --- /dev/null +++ b/src/server/mod.rs @@ -0,0 +1,3 @@ +mod binding; +mod resolver; +pub mod server; diff --git a/src/server/resolver.rs b/src/server/resolver.rs new file mode 100644 index 0000000..464620c --- /dev/null +++ b/src/server/resolver.rs @@ -0,0 +1,165 @@ +use super::binding::Connection; +use crate::{ + config::Config, + packet::{ + query::QueryType, question::DnsQuestion, result::ResultCode, Packet, + Result, + }, get_time, +}; +use async_recursion::async_recursion; +use moka::future::Cache; +use std::{net::IpAddr, sync::Arc, time::Duration}; +use tracing::{error, trace}; + +pub struct Resolver { + request_id: u16, + connection: Connection, + config: Arc<Config>, + cache: Cache<DnsQuestion, (Packet, u64)>, +} + +impl Resolver { + pub fn new( + request_id: u16, + connection: Connection, + config: Arc<Config>, + cache: Cache<DnsQuestion, (Packet, u64)>, + ) -> Self { + Self { + request_id, + connection, + config, + cache, + } + } + + async fn lookup_cache(&mut self, qname: &str, qtype: QueryType) -> Option<Packet> { + let question = DnsQuestion::new(qname.to_string(), qtype); + let Some((packet, date)) = self.cache.get(&question) else { + return None + }; + + let now = get_time(); + let diff = Duration::from_millis(now - date).as_secs() as u32; + + for answer in &packet.answers { + let ttl = answer.get_ttl(); + if diff > ttl { + self.cache.invalidate(&question).await; + return None + } + } + + trace!("Found cached value for {qtype:?} {qname}"); + + Some(packet) + } + + async fn lookup(&mut self, qname: &str, qtype: QueryType, server: (IpAddr, u16)) -> Packet { + let mut packet = Packet::new(); + + packet.header.id = self.request_id; + packet.header.questions = 1; + packet.header.recursion_desired = true; + packet + .questions + .push(DnsQuestion::new(qname.to_string(), qtype)); + + let packet = match self.connection.request_packet(packet, server).await { + Ok(packet) => packet, + Err(e) => { + error!("Failed to complete nameserver request: {e}"); + let mut packet = Packet::new(); + packet.header.rescode = ResultCode::SERVFAIL; + packet + } + }; + + packet + } + + #[async_recursion] + async fn recursive_lookup(&mut self, qname: &str, qtype: QueryType) -> Packet { + let question = DnsQuestion::new(qname.to_string(), qtype); + let mut ns = self.config.get_fallback_ns().clone(); + + if let Some(packet) = self.lookup_cache(qname, qtype).await { return packet } + + loop { + trace!("Attempting lookup of {qtype:?} {qname} with ns {ns}"); + + let ns_copy = ns; + + let server = (ns_copy, 53); + let response = self.lookup(qname, qtype, server).await; + + if !response.answers.is_empty() && response.header.rescode == ResultCode::NOERROR { + self.cache.insert(question, (response.clone(), get_time())).await; + return response; + } + + if response.header.rescode == ResultCode::NXDOMAIN { + self.cache.insert(question, (response.clone(), get_time())).await; + return response; + } + + if let Some(new_ns) = response.get_resolved_ns(qname) { + ns = new_ns; + continue; + } + + let new_ns_name = match response.get_unresolved_ns(qname) { + Some(x) => x, + None => { + self.cache.insert(question, (response.clone(), get_time())).await; + return response + }, + }; + + let recursive_response = self.recursive_lookup(new_ns_name, QueryType::A).await; + + if let Some(new_ns) = recursive_response.get_random_a() { + ns = new_ns; + } else { + self.cache.insert(question, (response.clone(), get_time())).await; + return response; + } + } + } + + pub async fn handle_query(mut self) -> Result<()> { + let mut request = self.connection.read_packet().await?; + + let mut packet = Packet::new(); + packet.header.id = request.header.id; + packet.header.recursion_desired = true; + packet.header.recursion_available = true; + packet.header.response = true; + + if let Some(question) = request.questions.pop() { + trace!("Received query: {question:?}"); + + let result = self.recursive_lookup(&question.name, question.qtype).await; + packet.questions.push(question.clone()); + packet.header.rescode = result.header.rescode; + + for rec in result.answers { + trace!("Answer: {rec:?}"); + packet.answers.push(rec); + } + for rec in result.authorities { + trace!("Authority: {rec:?}"); + packet.authorities.push(rec); + } + for rec in result.resources { + trace!("Resource: {rec:?}"); + packet.resources.push(rec); + } + } else { + packet.header.rescode = ResultCode::FORMERR; + } + + self.connection.write_packet(packet).await?; + Ok(()) + } +} diff --git a/src/server/server.rs b/src/server/server.rs new file mode 100644 index 0000000..e006bb1 --- /dev/null +++ b/src/server/server.rs @@ -0,0 +1,73 @@ +use moka::future::Cache; +use std::net::SocketAddr; +use std::sync::Arc; +use std::time::Duration; +use tokio::task::JoinHandle; +use tracing::{error, info}; + +use crate::config::Config; +use crate::packet::question::DnsQuestion; +use crate::packet::{Result, Packet}; + +use super::binding::Binding; +use super::resolver::Resolver; + +pub struct Server { + addr: SocketAddr, + config: Arc<Config>, + cache: Cache<DnsQuestion, (Packet, u64)>, +} + +impl Server { + pub async fn new(config: Config) -> Result<Self> { + let addr = format!("[::]:{}", config.get_port()).parse::<SocketAddr>()?; + let cache = Cache::builder() + .time_to_live(Duration::from_secs(60 * 60)) + .max_capacity(1_000) + .build(); + Ok(Self { + addr, + config: Arc::new(config), + cache, + }) + } + + pub async fn run(&self) -> Result<()> { + let tcp = Binding::tcp(self.addr).await?; + let tcp_handle = self.listen(tcp); + + let udp = Binding::udp(self.addr).await?; + let udp_handle = self.listen(udp); + + info!("Fallback DNS Server is set to: {:?}", self.config.get_fallback_ns()); + info!("Listening for TCP and UDP traffic on [::]:{}", self.config.get_port()); + + tokio::join!(tcp_handle) + .0 + .expect("Failed to join tcp thread"); + tokio::join!(udp_handle) + .0 + .expect("Failed to join udp thread"); + Ok(()) + } + + fn listen(&self, mut binding: Binding) -> JoinHandle<()> { + let config = self.config.clone(); + let cache = self.cache.clone(); + tokio::spawn(async move { + let mut id = 0; + loop { + let Ok(connection) = binding.connect().await else { continue }; + info!("Received request on {}", binding.name()); + + let resolver = Resolver::new(id, connection, config.clone(), cache.clone()); + + if let Err(err) = resolver.handle_query().await { + error!("{} request {} failed: {:?}", binding.name(), id, err); + }; + + id += 1; + } + }) + } +} |