diff options
author | Tyler Murphy <tylermurphy534@gmail.com> | 2023-03-06 18:50:08 -0500 |
---|---|---|
committer | Tyler Murphy <tylermurphy534@gmail.com> | 2023-03-06 18:50:08 -0500 |
commit | b1fb410affb7bcd2e714abac01d22c4a5332c344 (patch) | |
tree | 7ebb621ab9b73e3e1fbaeb0ef8c19abef95b7c9f /src/dns/resolver.rs | |
parent | finialize initial dns + caching (diff) | |
download | wrapper-b1fb410affb7bcd2e714abac01d22c4a5332c344.tar.gz wrapper-b1fb410affb7bcd2e714abac01d22c4a5332c344.tar.bz2 wrapper-b1fb410affb7bcd2e714abac01d22c4a5332c344.zip |
finish dns and start webserver
Diffstat (limited to 'src/dns/resolver.rs')
-rw-r--r-- | src/dns/resolver.rs | 230 |
1 files changed, 230 insertions, 0 deletions
diff --git a/src/dns/resolver.rs b/src/dns/resolver.rs new file mode 100644 index 0000000..18b5bba --- /dev/null +++ b/src/dns/resolver.rs @@ -0,0 +1,230 @@ +use super::binding::Connection; +use super::packet::{query::QueryType, question::DnsQuestion, result::ResultCode, Packet}; +use crate::Result; +use crate::{config::Config, database::Database, get_time}; +use async_recursion::async_recursion; +use moka::future::Cache; +use std::{net::IpAddr, sync::Arc, time::Duration}; +use tracing::{error, trace}; + +pub struct Resolver { + request_id: u16, + connection: Connection, + config: Arc<Config>, + database: Arc<Database>, + cache: Cache<DnsQuestion, (Packet, u64)>, +} + +impl Resolver { + pub fn new( + request_id: u16, + connection: Connection, + config: Arc<Config>, + database: Arc<Database>, + cache: Cache<DnsQuestion, (Packet, u64)>, + ) -> Self { + Self { + request_id, + connection, + config, + database, + cache, + } + } + + async fn lookup_database(&self, question: &DnsQuestion) -> Option<Packet> { + let records = match self + .database + .get_records(&question.name, question.qtype) + .await + { + Ok(record) => record, + Err(err) => { + error!("{err}"); + return None; + } + }; + + if records.is_empty() { + return None; + } + + let mut packet = Packet::new(); + + packet.header.id = self.request_id; + packet.header.questions = 1; + packet.header.answers = records.len() as u16; + packet.header.recursion_desired = true; + packet + .questions + .push(DnsQuestion::new(question.name.to_string(), question.qtype)); + + for record in records { + packet.answers.push(record); + } + + trace!( + "Found stored value for {:?} {}", + question.qtype, + question.name + ); + + Some(packet) + } + + async fn lookup_cache(&self, question: &DnsQuestion) -> Option<Packet> { + let Some((packet, date)) = self.cache.get(&question) else { + return None + }; + + let now = get_time(); + let diff = Duration::from_millis(now - date).as_secs() as u32; + + for answer in &packet.answers { + let ttl = answer.get_ttl(); + if diff > ttl { + self.cache.invalidate(&question).await; + return None; + } + } + + trace!( + "Found cached value for {:?} {}", + question.qtype, + question.name + ); + + Some(packet) + } + + async fn lookup_fallback(&self, question: &DnsQuestion, server: (IpAddr, u16)) -> Packet { + let mut packet = Packet::new(); + + packet.header.id = self.request_id; + packet.header.questions = 1; + packet.header.recursion_desired = true; + packet + .questions + .push(DnsQuestion::new(question.name.to_string(), question.qtype)); + + let packet = match self.connection.request_packet(packet, server).await { + Ok(packet) => packet, + Err(e) => { + error!("Failed to complete nameserver request: {e}"); + let mut packet = Packet::new(); + packet.header.rescode = ResultCode::SERVFAIL; + packet + } + }; + + packet + } + + async fn lookup(&self, question: &DnsQuestion, server: (IpAddr, u16)) -> Packet { + if let Some(packet) = self.lookup_cache(question).await { + return packet; + }; + + if let Some(packet) = self.lookup_database(question).await { + return packet; + }; + + trace!( + "Attempting lookup of {:?} {} with ns {}", + question.qtype, + question.name, + server.0 + ); + + self.lookup_fallback(question, server).await + } + + #[async_recursion] + async fn recursive_lookup(&mut self, qname: &str, qtype: QueryType) -> Packet { + let question = DnsQuestion::new(qname.to_string(), qtype); + let mut ns = self.config.dns_fallback.clone(); + + loop { + let ns_copy = ns; + + let server = (ns_copy, 53); + let response = self.lookup(&question, server).await; + + if !response.answers.is_empty() && response.header.rescode == ResultCode::NOERROR { + self.cache + .insert(question, (response.clone(), get_time())) + .await; + return response; + } + + if response.header.rescode == ResultCode::NXDOMAIN { + self.cache + .insert(question, (response.clone(), get_time())) + .await; + return response; + } + + if let Some(new_ns) = response.get_resolved_ns(qname) { + ns = new_ns; + continue; + } + + let new_ns_name = match response.get_unresolved_ns(qname) { + Some(x) => x, + None => { + self.cache + .insert(question, (response.clone(), get_time())) + .await; + return response; + } + }; + + let recursive_response = self.recursive_lookup(new_ns_name, QueryType::A).await; + + if let Some(new_ns) = recursive_response.get_random_a() { + ns = new_ns; + } else { + self.cache + .insert(question, (response.clone(), get_time())) + .await; + return response; + } + } + } + + pub async fn handle_query(mut self) -> Result<()> { + let mut request = self.connection.read_packet().await?; + + let mut packet = Packet::new(); + packet.header.id = request.header.id; + packet.header.recursion_desired = true; + packet.header.recursion_available = true; + packet.header.response = true; + + if let Some(question) = request.questions.pop() { + trace!("Received query: {question:?}"); + + let result = self.recursive_lookup(&question.name, question.qtype).await; + packet.questions.push(question.clone()); + packet.header.rescode = result.header.rescode; + + for rec in result.answers { + trace!("Answer: {rec:?}"); + packet.answers.push(rec); + } + for rec in result.authorities { + trace!("Authority: {rec:?}"); + packet.authorities.push(rec); + } + for rec in result.resources { + trace!("Resource: {rec:?}"); + packet.resources.push(rec); + } + } else { + packet.header.rescode = ResultCode::FORMERR; + } + + self.connection.write_packet(packet).await?; + Ok(()) + } +} |