diff options
Diffstat (limited to 'src/server')
-rw-r--r-- | src/server/binding.rs | 150 | ||||
-rw-r--r-- | src/server/mod.rs | 3 | ||||
-rw-r--r-- | src/server/resolver.rs | 165 | ||||
-rw-r--r-- | src/server/server.rs | 73 |
4 files changed, 0 insertions, 391 deletions
diff --git a/src/server/binding.rs b/src/server/binding.rs deleted file mode 100644 index 1c69651..0000000 --- a/src/server/binding.rs +++ /dev/null @@ -1,150 +0,0 @@ -use std::{ - net::{IpAddr, SocketAddr}, - sync::Arc, -}; - -use crate::packet::{buffer::PacketBuffer, Packet, Result}; -use tokio::{ - io::{AsyncReadExt, AsyncWriteExt}, - net::{TcpListener, TcpStream, UdpSocket}, -}; -use tracing::trace; - -pub enum Binding { - UDP(Arc<UdpSocket>), - TCP(TcpListener), -} - -impl Binding { - pub async fn udp(addr: SocketAddr) -> Result<Self> { - let socket = UdpSocket::bind(addr).await?; - Ok(Self::UDP(Arc::new(socket))) - } - - pub async fn tcp(addr: SocketAddr) -> Result<Self> { - let socket = TcpListener::bind(addr).await?; - Ok(Self::TCP(socket)) - } - - pub fn name(&self) -> &str { - match self { - Binding::UDP(_) => "UDP", - Binding::TCP(_) => "TCP", - } - } - - pub async fn connect(&mut self) -> Result<Connection> { - match self { - Self::UDP(socket) => { - let mut buf = [0; 512]; - let (_, addr) = socket.recv_from(&mut buf).await?; - Ok(Connection::UDP(socket.clone(), addr, buf)) - } - Self::TCP(socket) => { - let (stream, _) = socket.accept().await?; - Ok(Connection::TCP(stream)) - } - } - } -} - -pub enum Connection { - UDP(Arc<UdpSocket>, SocketAddr, [u8; 512]), - TCP(TcpStream), -} - -impl Connection { - pub async fn read_packet(&mut self) -> Result<Packet> { - let data = self.read().await?; - let mut packet_buffer = PacketBuffer::new(data); - - let packet = Packet::from_buffer(&mut packet_buffer)?; - Ok(packet) - } - - pub async fn write_packet(self, mut packet: Packet) -> Result<()> { - let mut packet_buffer = PacketBuffer::new(Vec::new()); - packet.write(&mut packet_buffer)?; - - self.write(packet_buffer.buf).await?; - Ok(()) - } - - pub async fn request_packet(&self, mut packet: Packet, dest: (IpAddr, u16)) -> Result<Packet> { - let mut packet_buffer = PacketBuffer::new(Vec::new()); - packet.write(&mut packet_buffer)?; - - let data = self.request(packet_buffer.buf, dest).await?; - let mut packet_buffer = PacketBuffer::new(data); - - let packet = Packet::from_buffer(&mut packet_buffer)?; - Ok(packet) - } - - async fn read(&mut self) -> Result<Vec<u8>> { - trace!("Reading DNS packet"); - match self { - Self::UDP(_, _, src) => Ok(Vec::from(*src)), - Self::TCP(stream) => { - let size = stream.read_u16().await?; - let mut buf = Vec::with_capacity(size as usize); - stream.read_buf(&mut buf).await?; - Ok(buf) - } - } - } - - async fn write(self, mut buf: Vec<u8>) -> Result<()> { - trace!("Returning DNS packet"); - match self { - Self::UDP(socket, addr, _) => { - if buf.len() > 512 { - buf[2] = buf[2] | 0x03; - socket.send_to(&buf[0..512], addr).await?; - } else { - socket.send_to(&buf, addr).await?; - } - Ok(()) - } - Self::TCP(mut stream) => { - stream.write_u16(buf.len() as u16).await?; - stream.write(&buf[0..buf.len()]).await?; - Ok(()) - } - } - } - - async fn request(&self, buf: Vec<u8>, dest: (IpAddr, u16)) -> Result<Vec<u8>> { - match self { - Self::UDP(_socket, _addr, _src) => { - let local_addr = "[::]:0".parse::<SocketAddr>()?; - let socket = UdpSocket::bind(local_addr).await?; - socket.send_to(&buf, dest).await?; - - let mut buf = [0; 512]; - socket.recv_from(&mut buf).await?; - - Ok(Vec::from(buf)) - } - Self::TCP(_stream) => { - let mut stream = TcpStream::connect(dest).await?; - stream.write_u16((buf.len()) as u16).await?; - stream.write_all(&buf[0..buf.len()]).await?; - - stream.readable().await?; - let size = stream.read_u16().await?; - let mut buf = Vec::with_capacity(size as usize); - stream.read_buf(&mut buf).await?; - - Ok(buf) - } - } - } - - // fn pb(buf: &[u8]) { - // for i in 0..buf.len() { - // print!("{:02X?} ", buf[i]); - // } - // println!(""); - // } -} diff --git a/src/server/mod.rs b/src/server/mod.rs deleted file mode 100644 index 25076ef..0000000 --- a/src/server/mod.rs +++ /dev/null @@ -1,3 +0,0 @@ -mod binding; -mod resolver; -pub mod server; diff --git a/src/server/resolver.rs b/src/server/resolver.rs deleted file mode 100644 index 464620c..0000000 --- a/src/server/resolver.rs +++ /dev/null @@ -1,165 +0,0 @@ -use super::binding::Connection; -use crate::{ - config::Config, - packet::{ - query::QueryType, question::DnsQuestion, result::ResultCode, Packet, - Result, - }, get_time, -}; -use async_recursion::async_recursion; -use moka::future::Cache; -use std::{net::IpAddr, sync::Arc, time::Duration}; -use tracing::{error, trace}; - -pub struct Resolver { - request_id: u16, - connection: Connection, - config: Arc<Config>, - cache: Cache<DnsQuestion, (Packet, u64)>, -} - -impl Resolver { - pub fn new( - request_id: u16, - connection: Connection, - config: Arc<Config>, - cache: Cache<DnsQuestion, (Packet, u64)>, - ) -> Self { - Self { - request_id, - connection, - config, - cache, - } - } - - async fn lookup_cache(&mut self, qname: &str, qtype: QueryType) -> Option<Packet> { - let question = DnsQuestion::new(qname.to_string(), qtype); - let Some((packet, date)) = self.cache.get(&question) else { - return None - }; - - let now = get_time(); - let diff = Duration::from_millis(now - date).as_secs() as u32; - - for answer in &packet.answers { - let ttl = answer.get_ttl(); - if diff > ttl { - self.cache.invalidate(&question).await; - return None - } - } - - trace!("Found cached value for {qtype:?} {qname}"); - - Some(packet) - } - - async fn lookup(&mut self, qname: &str, qtype: QueryType, server: (IpAddr, u16)) -> Packet { - let mut packet = Packet::new(); - - packet.header.id = self.request_id; - packet.header.questions = 1; - packet.header.recursion_desired = true; - packet - .questions - .push(DnsQuestion::new(qname.to_string(), qtype)); - - let packet = match self.connection.request_packet(packet, server).await { - Ok(packet) => packet, - Err(e) => { - error!("Failed to complete nameserver request: {e}"); - let mut packet = Packet::new(); - packet.header.rescode = ResultCode::SERVFAIL; - packet - } - }; - - packet - } - - #[async_recursion] - async fn recursive_lookup(&mut self, qname: &str, qtype: QueryType) -> Packet { - let question = DnsQuestion::new(qname.to_string(), qtype); - let mut ns = self.config.get_fallback_ns().clone(); - - if let Some(packet) = self.lookup_cache(qname, qtype).await { return packet } - - loop { - trace!("Attempting lookup of {qtype:?} {qname} with ns {ns}"); - - let ns_copy = ns; - - let server = (ns_copy, 53); - let response = self.lookup(qname, qtype, server).await; - - if !response.answers.is_empty() && response.header.rescode == ResultCode::NOERROR { - self.cache.insert(question, (response.clone(), get_time())).await; - return response; - } - - if response.header.rescode == ResultCode::NXDOMAIN { - self.cache.insert(question, (response.clone(), get_time())).await; - return response; - } - - if let Some(new_ns) = response.get_resolved_ns(qname) { - ns = new_ns; - continue; - } - - let new_ns_name = match response.get_unresolved_ns(qname) { - Some(x) => x, - None => { - self.cache.insert(question, (response.clone(), get_time())).await; - return response - }, - }; - - let recursive_response = self.recursive_lookup(new_ns_name, QueryType::A).await; - - if let Some(new_ns) = recursive_response.get_random_a() { - ns = new_ns; - } else { - self.cache.insert(question, (response.clone(), get_time())).await; - return response; - } - } - } - - pub async fn handle_query(mut self) -> Result<()> { - let mut request = self.connection.read_packet().await?; - - let mut packet = Packet::new(); - packet.header.id = request.header.id; - packet.header.recursion_desired = true; - packet.header.recursion_available = true; - packet.header.response = true; - - if let Some(question) = request.questions.pop() { - trace!("Received query: {question:?}"); - - let result = self.recursive_lookup(&question.name, question.qtype).await; - packet.questions.push(question.clone()); - packet.header.rescode = result.header.rescode; - - for rec in result.answers { - trace!("Answer: {rec:?}"); - packet.answers.push(rec); - } - for rec in result.authorities { - trace!("Authority: {rec:?}"); - packet.authorities.push(rec); - } - for rec in result.resources { - trace!("Resource: {rec:?}"); - packet.resources.push(rec); - } - } else { - packet.header.rescode = ResultCode::FORMERR; - } - - self.connection.write_packet(packet).await?; - Ok(()) - } -} diff --git a/src/server/server.rs b/src/server/server.rs deleted file mode 100644 index e006bb1..0000000 --- a/src/server/server.rs +++ /dev/null @@ -1,73 +0,0 @@ -use moka::future::Cache; -use std::net::SocketAddr; -use std::sync::Arc; -use std::time::Duration; -use tokio::task::JoinHandle; -use tracing::{error, info}; - -use crate::config::Config; -use crate::packet::question::DnsQuestion; -use crate::packet::{Result, Packet}; - -use super::binding::Binding; -use super::resolver::Resolver; - -pub struct Server { - addr: SocketAddr, - config: Arc<Config>, - cache: Cache<DnsQuestion, (Packet, u64)>, -} - -impl Server { - pub async fn new(config: Config) -> Result<Self> { - let addr = format!("[::]:{}", config.get_port()).parse::<SocketAddr>()?; - let cache = Cache::builder() - .time_to_live(Duration::from_secs(60 * 60)) - .max_capacity(1_000) - .build(); - Ok(Self { - addr, - config: Arc::new(config), - cache, - }) - } - - pub async fn run(&self) -> Result<()> { - let tcp = Binding::tcp(self.addr).await?; - let tcp_handle = self.listen(tcp); - - let udp = Binding::udp(self.addr).await?; - let udp_handle = self.listen(udp); - - info!("Fallback DNS Server is set to: {:?}", self.config.get_fallback_ns()); - info!("Listening for TCP and UDP traffic on [::]:{}", self.config.get_port()); - - tokio::join!(tcp_handle) - .0 - .expect("Failed to join tcp thread"); - tokio::join!(udp_handle) - .0 - .expect("Failed to join udp thread"); - Ok(()) - } - - fn listen(&self, mut binding: Binding) -> JoinHandle<()> { - let config = self.config.clone(); - let cache = self.cache.clone(); - tokio::spawn(async move { - let mut id = 0; - loop { - let Ok(connection) = binding.connect().await else { continue }; - info!("Received request on {}", binding.name()); - - let resolver = Resolver::new(id, connection, config.clone(), cache.clone()); - - if let Err(err) = resolver.handle_query().await { - error!("{} request {} failed: {:?}", binding.name(), id, err); - }; - - id += 1; - } - }) - } -} |