summaryrefslogtreecommitdiff
path: root/src/dns/resolver.rs
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--src/dns/resolver.rs (renamed from src/server/resolver.rs)115
1 files changed, 90 insertions, 25 deletions
diff --git a/src/server/resolver.rs b/src/dns/resolver.rs
index 464620c..18b5bba 100644
--- a/src/server/resolver.rs
+++ b/src/dns/resolver.rs
@@ -1,11 +1,7 @@
use super::binding::Connection;
-use crate::{
- config::Config,
- packet::{
- query::QueryType, question::DnsQuestion, result::ResultCode, Packet,
- Result,
- }, get_time,
-};
+use super::packet::{query::QueryType, question::DnsQuestion, result::ResultCode, Packet};
+use crate::Result;
+use crate::{config::Config, database::Database, get_time};
use async_recursion::async_recursion;
use moka::future::Cache;
use std::{net::IpAddr, sync::Arc, time::Duration};
@@ -15,6 +11,7 @@ pub struct Resolver {
request_id: u16,
connection: Connection,
config: Arc<Config>,
+ database: Arc<Database>,
cache: Cache<DnsQuestion, (Packet, u64)>,
}
@@ -23,18 +20,59 @@ impl Resolver {
request_id: u16,
connection: Connection,
config: Arc<Config>,
+ database: Arc<Database>,
cache: Cache<DnsQuestion, (Packet, u64)>,
) -> Self {
Self {
request_id,
connection,
config,
+ database,
cache,
}
}
- async fn lookup_cache(&mut self, qname: &str, qtype: QueryType) -> Option<Packet> {
- let question = DnsQuestion::new(qname.to_string(), qtype);
+ async fn lookup_database(&self, question: &DnsQuestion) -> Option<Packet> {
+ let records = match self
+ .database
+ .get_records(&question.name, question.qtype)
+ .await
+ {
+ Ok(record) => record,
+ Err(err) => {
+ error!("{err}");
+ return None;
+ }
+ };
+
+ if records.is_empty() {
+ return None;
+ }
+
+ let mut packet = Packet::new();
+
+ packet.header.id = self.request_id;
+ packet.header.questions = 1;
+ packet.header.answers = records.len() as u16;
+ packet.header.recursion_desired = true;
+ packet
+ .questions
+ .push(DnsQuestion::new(question.name.to_string(), question.qtype));
+
+ for record in records {
+ packet.answers.push(record);
+ }
+
+ trace!(
+ "Found stored value for {:?} {}",
+ question.qtype,
+ question.name
+ );
+
+ Some(packet)
+ }
+
+ async fn lookup_cache(&self, question: &DnsQuestion) -> Option<Packet> {
let Some((packet, date)) = self.cache.get(&question) else {
return None
};
@@ -46,16 +84,20 @@ impl Resolver {
let ttl = answer.get_ttl();
if diff > ttl {
self.cache.invalidate(&question).await;
- return None
+ return None;
}
}
- trace!("Found cached value for {qtype:?} {qname}");
+ trace!(
+ "Found cached value for {:?} {}",
+ question.qtype,
+ question.name
+ );
Some(packet)
}
- async fn lookup(&mut self, qname: &str, qtype: QueryType, server: (IpAddr, u16)) -> Packet {
+ async fn lookup_fallback(&self, question: &DnsQuestion, server: (IpAddr, u16)) -> Packet {
let mut packet = Packet::new();
packet.header.id = self.request_id;
@@ -63,7 +105,7 @@ impl Resolver {
packet.header.recursion_desired = true;
packet
.questions
- .push(DnsQuestion::new(qname.to_string(), qtype));
+ .push(DnsQuestion::new(question.name.to_string(), question.qtype));
let packet = match self.connection.request_packet(packet, server).await {
Ok(packet) => packet,
@@ -78,28 +120,47 @@ impl Resolver {
packet
}
+ async fn lookup(&self, question: &DnsQuestion, server: (IpAddr, u16)) -> Packet {
+ if let Some(packet) = self.lookup_cache(question).await {
+ return packet;
+ };
+
+ if let Some(packet) = self.lookup_database(question).await {
+ return packet;
+ };
+
+ trace!(
+ "Attempting lookup of {:?} {} with ns {}",
+ question.qtype,
+ question.name,
+ server.0
+ );
+
+ self.lookup_fallback(question, server).await
+ }
+
#[async_recursion]
async fn recursive_lookup(&mut self, qname: &str, qtype: QueryType) -> Packet {
let question = DnsQuestion::new(qname.to_string(), qtype);
- let mut ns = self.config.get_fallback_ns().clone();
-
- if let Some(packet) = self.lookup_cache(qname, qtype).await { return packet }
+ let mut ns = self.config.dns_fallback.clone();
loop {
- trace!("Attempting lookup of {qtype:?} {qname} with ns {ns}");
-
let ns_copy = ns;
let server = (ns_copy, 53);
- let response = self.lookup(qname, qtype, server).await;
+ let response = self.lookup(&question, server).await;
if !response.answers.is_empty() && response.header.rescode == ResultCode::NOERROR {
- self.cache.insert(question, (response.clone(), get_time())).await;
+ self.cache
+ .insert(question, (response.clone(), get_time()))
+ .await;
return response;
}
if response.header.rescode == ResultCode::NXDOMAIN {
- self.cache.insert(question, (response.clone(), get_time())).await;
+ self.cache
+ .insert(question, (response.clone(), get_time()))
+ .await;
return response;
}
@@ -111,9 +172,11 @@ impl Resolver {
let new_ns_name = match response.get_unresolved_ns(qname) {
Some(x) => x,
None => {
- self.cache.insert(question, (response.clone(), get_time())).await;
- return response
- },
+ self.cache
+ .insert(question, (response.clone(), get_time()))
+ .await;
+ return response;
+ }
};
let recursive_response = self.recursive_lookup(new_ns_name, QueryType::A).await;
@@ -121,7 +184,9 @@ impl Resolver {
if let Some(new_ns) = recursive_response.get_random_a() {
ns = new_ns;
} else {
- self.cache.insert(question, (response.clone(), get_time())).await;
+ self.cache
+ .insert(question, (response.clone(), get_time()))
+ .await;
return response;
}
}