summaryrefslogtreecommitdiff
path: root/src/server/resolver.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/server/resolver.rs')
-rw-r--r--src/server/resolver.rs165
1 files changed, 165 insertions, 0 deletions
diff --git a/src/server/resolver.rs b/src/server/resolver.rs
new file mode 100644
index 0000000..464620c
--- /dev/null
+++ b/src/server/resolver.rs
@@ -0,0 +1,165 @@
+use super::binding::Connection;
+use crate::{
+ config::Config,
+ packet::{
+ query::QueryType, question::DnsQuestion, result::ResultCode, Packet,
+ Result,
+ }, get_time,
+};
+use async_recursion::async_recursion;
+use moka::future::Cache;
+use std::{net::IpAddr, sync::Arc, time::Duration};
+use tracing::{error, trace};
+
+pub struct Resolver {
+ request_id: u16,
+ connection: Connection,
+ config: Arc<Config>,
+ cache: Cache<DnsQuestion, (Packet, u64)>,
+}
+
+impl Resolver {
+ pub fn new(
+ request_id: u16,
+ connection: Connection,
+ config: Arc<Config>,
+ cache: Cache<DnsQuestion, (Packet, u64)>,
+ ) -> Self {
+ Self {
+ request_id,
+ connection,
+ config,
+ cache,
+ }
+ }
+
+ async fn lookup_cache(&mut self, qname: &str, qtype: QueryType) -> Option<Packet> {
+ let question = DnsQuestion::new(qname.to_string(), qtype);
+ let Some((packet, date)) = self.cache.get(&question) else {
+ return None
+ };
+
+ let now = get_time();
+ let diff = Duration::from_millis(now - date).as_secs() as u32;
+
+ for answer in &packet.answers {
+ let ttl = answer.get_ttl();
+ if diff > ttl {
+ self.cache.invalidate(&question).await;
+ return None
+ }
+ }
+
+ trace!("Found cached value for {qtype:?} {qname}");
+
+ Some(packet)
+ }
+
+ async fn lookup(&mut self, qname: &str, qtype: QueryType, server: (IpAddr, u16)) -> Packet {
+ let mut packet = Packet::new();
+
+ packet.header.id = self.request_id;
+ packet.header.questions = 1;
+ packet.header.recursion_desired = true;
+ packet
+ .questions
+ .push(DnsQuestion::new(qname.to_string(), qtype));
+
+ let packet = match self.connection.request_packet(packet, server).await {
+ Ok(packet) => packet,
+ Err(e) => {
+ error!("Failed to complete nameserver request: {e}");
+ let mut packet = Packet::new();
+ packet.header.rescode = ResultCode::SERVFAIL;
+ packet
+ }
+ };
+
+ packet
+ }
+
+ #[async_recursion]
+ async fn recursive_lookup(&mut self, qname: &str, qtype: QueryType) -> Packet {
+ let question = DnsQuestion::new(qname.to_string(), qtype);
+ let mut ns = self.config.get_fallback_ns().clone();
+
+ if let Some(packet) = self.lookup_cache(qname, qtype).await { return packet }
+
+ loop {
+ trace!("Attempting lookup of {qtype:?} {qname} with ns {ns}");
+
+ let ns_copy = ns;
+
+ let server = (ns_copy, 53);
+ let response = self.lookup(qname, qtype, server).await;
+
+ if !response.answers.is_empty() && response.header.rescode == ResultCode::NOERROR {
+ self.cache.insert(question, (response.clone(), get_time())).await;
+ return response;
+ }
+
+ if response.header.rescode == ResultCode::NXDOMAIN {
+ self.cache.insert(question, (response.clone(), get_time())).await;
+ return response;
+ }
+
+ if let Some(new_ns) = response.get_resolved_ns(qname) {
+ ns = new_ns;
+ continue;
+ }
+
+ let new_ns_name = match response.get_unresolved_ns(qname) {
+ Some(x) => x,
+ None => {
+ self.cache.insert(question, (response.clone(), get_time())).await;
+ return response
+ },
+ };
+
+ let recursive_response = self.recursive_lookup(new_ns_name, QueryType::A).await;
+
+ if let Some(new_ns) = recursive_response.get_random_a() {
+ ns = new_ns;
+ } else {
+ self.cache.insert(question, (response.clone(), get_time())).await;
+ return response;
+ }
+ }
+ }
+
+ pub async fn handle_query(mut self) -> Result<()> {
+ let mut request = self.connection.read_packet().await?;
+
+ let mut packet = Packet::new();
+ packet.header.id = request.header.id;
+ packet.header.recursion_desired = true;
+ packet.header.recursion_available = true;
+ packet.header.response = true;
+
+ if let Some(question) = request.questions.pop() {
+ trace!("Received query: {question:?}");
+
+ let result = self.recursive_lookup(&question.name, question.qtype).await;
+ packet.questions.push(question.clone());
+ packet.header.rescode = result.header.rescode;
+
+ for rec in result.answers {
+ trace!("Answer: {rec:?}");
+ packet.answers.push(rec);
+ }
+ for rec in result.authorities {
+ trace!("Authority: {rec:?}");
+ packet.authorities.push(rec);
+ }
+ for rec in result.resources {
+ trace!("Resource: {rec:?}");
+ packet.resources.push(rec);
+ }
+ } else {
+ packet.header.rescode = ResultCode::FORMERR;
+ }
+
+ self.connection.write_packet(packet).await?;
+ Ok(())
+ }
+}