summaryrefslogtreecommitdiff
path: root/src/dns
diff options
context:
space:
mode:
authorTyler Murphy <tylermurphy534@gmail.com>2023-03-06 18:50:08 -0500
committerTyler Murphy <tylermurphy534@gmail.com>2023-03-06 18:50:08 -0500
commitb1fb410affb7bcd2e714abac01d22c4a5332c344 (patch)
tree7ebb621ab9b73e3e1fbaeb0ef8c19abef95b7c9f /src/dns
parentfinialize initial dns + caching (diff)
downloadwrapper-b1fb410affb7bcd2e714abac01d22c4a5332c344.tar.gz
wrapper-b1fb410affb7bcd2e714abac01d22c4a5332c344.tar.bz2
wrapper-b1fb410affb7bcd2e714abac01d22c4a5332c344.zip
finish dns and start webserver
Diffstat (limited to 'src/dns')
-rw-r--r--src/dns/binding.rs144
-rw-r--r--src/dns/mod.rs4
-rw-r--r--src/dns/packet/buffer.rs227
-rw-r--r--src/dns/packet/header.rs102
-rw-r--r--src/dns/packet/mod.rs128
-rw-r--r--src/dns/packet/query.rs78
-rw-r--r--src/dns/packet/question.rs31
-rw-r--r--src/dns/packet/record.rs544
-rw-r--r--src/dns/packet/result.rs22
-rw-r--r--src/dns/resolver.rs230
-rw-r--r--src/dns/server.rs85
11 files changed, 1595 insertions, 0 deletions
diff --git a/src/dns/binding.rs b/src/dns/binding.rs
new file mode 100644
index 0000000..4c7e15f
--- /dev/null
+++ b/src/dns/binding.rs
@@ -0,0 +1,144 @@
+use std::{
+ net::{IpAddr, SocketAddr},
+ sync::Arc,
+};
+
+use super::packet::{buffer::PacketBuffer, Packet};
+use crate::Result;
+use tokio::{
+ io::{AsyncReadExt, AsyncWriteExt},
+ net::{TcpListener, TcpStream, UdpSocket},
+};
+use tracing::trace;
+
+pub enum Binding {
+ UDP(Arc<UdpSocket>),
+ TCP(TcpListener),
+}
+
+impl Binding {
+ pub async fn udp(addr: SocketAddr) -> Result<Self> {
+ let socket = UdpSocket::bind(addr).await?;
+ Ok(Self::UDP(Arc::new(socket)))
+ }
+
+ pub async fn tcp(addr: SocketAddr) -> Result<Self> {
+ let socket = TcpListener::bind(addr).await?;
+ Ok(Self::TCP(socket))
+ }
+
+ pub fn name(&self) -> &str {
+ match self {
+ Binding::UDP(_) => "UDP",
+ Binding::TCP(_) => "TCP",
+ }
+ }
+
+ pub async fn connect(&mut self) -> Result<Connection> {
+ match self {
+ Self::UDP(socket) => {
+ let mut buf = [0; 512];
+ let (_, addr) = socket.recv_from(&mut buf).await?;
+ Ok(Connection::UDP(socket.clone(), addr, buf))
+ }
+ Self::TCP(socket) => {
+ let (stream, _) = socket.accept().await?;
+ Ok(Connection::TCP(stream))
+ }
+ }
+ }
+}
+
+pub enum Connection {
+ UDP(Arc<UdpSocket>, SocketAddr, [u8; 512]),
+ TCP(TcpStream),
+}
+
+impl Connection {
+ pub async fn read_packet(&mut self) -> Result<Packet> {
+ let data = self.read().await?;
+ let mut packet_buffer = PacketBuffer::new(data);
+
+ let packet = Packet::from_buffer(&mut packet_buffer)?;
+ Ok(packet)
+ }
+
+ pub async fn write_packet(self, mut packet: Packet) -> Result<()> {
+ let mut packet_buffer = PacketBuffer::new(Vec::new());
+ packet.write(&mut packet_buffer)?;
+
+ self.write(packet_buffer.buf).await?;
+ Ok(())
+ }
+
+ pub async fn request_packet(&self, mut packet: Packet, dest: (IpAddr, u16)) -> Result<Packet> {
+ let mut packet_buffer = PacketBuffer::new(Vec::new());
+ packet.write(&mut packet_buffer)?;
+
+ let data = self.request(packet_buffer.buf, dest).await?;
+ let mut packet_buffer = PacketBuffer::new(data);
+
+ let packet = Packet::from_buffer(&mut packet_buffer)?;
+ Ok(packet)
+ }
+
+ async fn read(&mut self) -> Result<Vec<u8>> {
+ trace!("Reading DNS packet");
+ match self {
+ Self::UDP(_, _, src) => Ok(Vec::from(*src)),
+ Self::TCP(stream) => {
+ let size = stream.read_u16().await?;
+ let mut buf = Vec::with_capacity(size as usize);
+ stream.read_buf(&mut buf).await?;
+ Ok(buf)
+ }
+ }
+ }
+
+ async fn write(self, mut buf: Vec<u8>) -> Result<()> {
+ trace!("Returning DNS packet");
+ match self {
+ Self::UDP(socket, addr, _) => {
+ if buf.len() > 512 {
+ buf[2] = buf[2] | 0x03;
+ socket.send_to(&buf[0..512], addr).await?;
+ } else {
+ socket.send_to(&buf, addr).await?;
+ }
+ Ok(())
+ }
+ Self::TCP(mut stream) => {
+ stream.write_u16(buf.len() as u16).await?;
+ stream.write(&buf[0..buf.len()]).await?;
+ Ok(())
+ }
+ }
+ }
+
+ async fn request(&self, buf: Vec<u8>, dest: (IpAddr, u16)) -> Result<Vec<u8>> {
+ match self {
+ Self::UDP(_socket, _addr, _src) => {
+ let local_addr = "[::]:0".parse::<SocketAddr>()?;
+ let socket = UdpSocket::bind(local_addr).await?;
+ socket.send_to(&buf, dest).await?;
+
+ let mut buf = [0; 512];
+ socket.recv_from(&mut buf).await?;
+
+ Ok(Vec::from(buf))
+ }
+ Self::TCP(_stream) => {
+ let mut stream = TcpStream::connect(dest).await?;
+ stream.write_u16((buf.len()) as u16).await?;
+ stream.write_all(&buf[0..buf.len()]).await?;
+
+ stream.readable().await?;
+ let size = stream.read_u16().await?;
+ let mut buf = Vec::with_capacity(size as usize);
+ stream.read_buf(&mut buf).await?;
+
+ Ok(buf)
+ }
+ }
+ }
+}
diff --git a/src/dns/mod.rs b/src/dns/mod.rs
new file mode 100644
index 0000000..6f1e59e
--- /dev/null
+++ b/src/dns/mod.rs
@@ -0,0 +1,4 @@
+mod binding;
+pub mod packet;
+mod resolver;
+pub mod server;
diff --git a/src/dns/packet/buffer.rs b/src/dns/packet/buffer.rs
new file mode 100644
index 0000000..058156e
--- /dev/null
+++ b/src/dns/packet/buffer.rs
@@ -0,0 +1,227 @@
+use crate::Result;
+
+pub struct PacketBuffer {
+ pub buf: Vec<u8>,
+ pub pos: usize,
+ pub size: usize,
+}
+
+impl PacketBuffer {
+ pub fn new(buf: Vec<u8>) -> Self {
+ Self {
+ size: buf.len(),
+ buf,
+ pos: 0,
+ }
+ }
+
+ pub fn pos(&self) -> usize {
+ self.pos
+ }
+
+ pub fn step(&mut self, steps: usize) -> Result<()> {
+ self.pos += steps;
+
+ Ok(())
+ }
+
+ pub fn seek(&mut self, pos: usize) -> Result<()> {
+ self.pos = pos;
+
+ Ok(())
+ }
+
+ pub fn read(&mut self) -> Result<u8> {
+ if self.pos >= self.size {
+ return Err("Tried to read past end of buffer".into());
+ }
+ let res = self.buf[self.pos];
+ self.pos += 1;
+ Ok(res)
+ }
+
+ pub fn get(&mut self, pos: usize) -> Result<u8> {
+ if pos >= self.size {
+ return Err("Tried to read past end of buffer".into());
+ }
+ Ok(self.buf[pos])
+ }
+
+ pub fn get_range(&mut self, start: usize, len: usize) -> Result<&[u8]> {
+ if start + len >= self.size {
+ return Err("Tried to read past end of buffer".into());
+ }
+ Ok(&self.buf[start..start + len])
+ }
+
+ pub fn read_u16(&mut self) -> Result<u16> {
+ let res = ((self.read()? as u16) << 8) | (self.read()? as u16);
+
+ Ok(res)
+ }
+
+ pub fn read_u32(&mut self) -> Result<u32> {
+ let res = ((self.read()? as u32) << 24)
+ | ((self.read()? as u32) << 16)
+ | ((self.read()? as u32) << 8)
+ | (self.read()? as u32);
+
+ Ok(res)
+ }
+
+ pub fn read_qname(&mut self, outstr: &mut String) -> Result<()> {
+ let mut pos = self.pos();
+ let mut jumped = false;
+
+ let mut delim = "";
+ let max_jumps = 5;
+ let mut jumps_performed = 0;
+ loop {
+ // Dns Packets are untrusted data, so we need to be paranoid. Someone
+ // can craft a packet with a cycle in the jump instructions. This guards
+ // against such packets.
+ if jumps_performed > max_jumps {
+ return Err(format!("Limit of {max_jumps} jumps exceeded").into());
+ }
+
+ let len = self.get(pos)?;
+
+ if (len & 0xC0) == 0xC0 {
+ if !jumped {
+ self.seek(pos + 2)?;
+ }
+
+ let b2 = self.get(pos + 1)? as u16;
+ let offset = (((len as u16) ^ 0xC0) << 8) | b2;
+ pos = offset as usize;
+ jumped = true;
+ jumps_performed += 1;
+ continue;
+ }
+
+ pos += 1;
+
+ if len == 0 {
+ break;
+ }
+
+ outstr.push_str(delim);
+
+ let str_buffer = self.get_range(pos, len as usize)?;
+ outstr.push_str(&String::from_utf8_lossy(str_buffer).to_lowercase());
+
+ delim = ".";
+
+ pos += len as usize;
+ }
+
+ if !jumped {
+ self.seek(pos)?;
+ }
+
+ Ok(())
+ }
+
+ pub fn read_string(&mut self, outstr: &mut String) -> Result<()> {
+ let len = self.read()?;
+
+ self.read_string_n(outstr, len)?;
+
+ Ok(())
+ }
+
+ pub fn read_string_n(&mut self, outstr: &mut String, len: u8) -> Result<()> {
+ let mut pos = self.pos;
+
+ let str_buffer = self.get_range(pos, len as usize)?;
+
+ let mut i = 0;
+ for b in str_buffer {
+ let c = *b as char;
+ if c == '\0' {
+ break;
+ }
+ outstr.push(c);
+ i += 1;
+ }
+
+ pos += i;
+ self.seek(pos)?;
+
+ Ok(())
+ }
+
+ pub fn write(&mut self, val: u8) -> Result<()> {
+ if self.size < self.pos {
+ self.size = self.pos;
+ }
+
+ if self.buf.len() <= self.size {
+ self.buf.resize(self.size + 1, 0x00);
+ }
+
+ self.buf[self.pos] = val;
+ self.pos += 1;
+ Ok(())
+ }
+
+ pub fn write_u8(&mut self, val: u8) -> Result<()> {
+ self.write(val)?;
+
+ Ok(())
+ }
+
+ pub fn write_u16(&mut self, val: u16) -> Result<()> {
+ self.write((val >> 8) as u8)?;
+ self.write((val & 0xFF) as u8)?;
+
+ Ok(())
+ }
+
+ pub fn write_u32(&mut self, val: u32) -> Result<()> {
+ self.write(((val >> 24) & 0xFF) as u8)?;
+ self.write(((val >> 16) & 0xFF) as u8)?;
+ self.write(((val >> 8) & 0xFF) as u8)?;
+ self.write((val & 0xFF) as u8)?;
+
+ Ok(())
+ }
+
+ pub fn write_qname(&mut self, qname: &str) -> Result<()> {
+ for label in qname.split('.') {
+ let len = label.len();
+
+ self.write_u8(len as u8)?;
+ for b in label.as_bytes() {
+ self.write_u8(*b)?;
+ }
+ }
+
+ if !qname.is_empty() {
+ self.write_u8(0)?;
+ }
+
+ Ok(())
+ }
+
+ pub fn write_string(&mut self, text: &str) -> Result<()> {
+ for b in text.as_bytes() {
+ self.write_u8(*b)?;
+ }
+
+ Ok(())
+ }
+
+ pub fn set(&mut self, pos: usize, val: u8) -> Result<()> {
+ self.buf[pos] = val;
+
+ Ok(())
+ }
+
+ pub fn set_u16(&mut self, pos: usize, val: u16) -> Result<()> {
+ self.set(pos, (val >> 8) as u8)?;
+ self.set(pos + 1, (val & 0xFF) as u8)?;
+
+ Ok(())
+ }
+}
diff --git a/src/dns/packet/header.rs b/src/dns/packet/header.rs
new file mode 100644
index 0000000..2355ecb
--- /dev/null
+++ b/src/dns/packet/header.rs
@@ -0,0 +1,102 @@
+use super::{buffer::PacketBuffer, result::ResultCode};
+use crate::Result;
+
+#[derive(Clone, Debug)]
+pub struct DnsHeader {
+ pub id: u16, // 16 bits
+
+ pub recursion_desired: bool, // 1 bit
+ pub truncated_message: bool, // 1 bit
+ pub authoritative_answer: bool, // 1 bit
+ pub opcode: u8, // 4 bits
+ pub response: bool, // 1 bit
+
+ pub rescode: ResultCode, // 4 bits
+ pub checking_disabled: bool, // 1 bit
+ pub authed_data: bool, // 1 bit
+ pub z: bool, // 1 bit
+ pub recursion_available: bool, // 1 bit
+
+ pub questions: u16, // 16 bits
+ pub answers: u16, // 16 bits
+ pub authoritative_entries: u16, // 16 bits
+ pub resource_entries: u16, // 16 bits
+}
+
+impl DnsHeader {
+ pub fn new() -> Self {
+ Self {
+ id: 0,
+
+ recursion_desired: false,
+ truncated_message: false,
+ authoritative_answer: false,
+ opcode: 0,
+ response: false,
+
+ rescode: ResultCode::NOERROR,
+ checking_disabled: false,
+ authed_data: false,
+ z: false,
+ recursion_available: false,
+
+ questions: 0,
+ answers: 0,
+ authoritative_entries: 0,
+ resource_entries: 0,
+ }
+ }
+
+ pub fn read(&mut self, buffer: &mut PacketBuffer) -> Result<()> {
+ self.id = buffer.read_u16()?;
+ let flags = buffer.read_u16()?;
+ let a = (flags >> 8) as u8;
+ let b = (flags & 0xFF) as u8;
+ self.recursion_desired = (a & (1 << 0)) > 0;
+ self.truncated_message = (a & (1 << 1)) > 0;
+ self.authoritative_answer = (a & (1 << 2)) > 0;
+ self.opcode = (a >> 3) & 0x0F;
+ self.response = (a & (1 << 7)) > 0;
+
+ self.rescode = ResultCode::from_num(b & 0x0F);
+ self.checking_disabled = (b & (1 << 4)) > 0;
+ self.authed_data = (b & (1 << 5)) > 0;
+ self.z = (b & (1 << 6)) > 0;
+ self.recursion_available = (b & (1 << 7)) > 0;
+
+ self.questions = buffer.read_u16()?;
+ self.answers = buffer.read_u16()?;
+ self.authoritative_entries = buffer.read_u16()?;
+ self.resource_entries = buffer.read_u16()?;
+
+ // Return the constant header size
+ Ok(())
+ }
+
+ pub fn write(&self, buffer: &mut PacketBuffer) -> Result<()> {
+ buffer.write_u16(self.id)?;
+
+ buffer.write_u8(
+ (self.recursion_desired as u8)
+ | ((self.truncated_message as u8) << 1)
+ | ((self.authoritative_answer as u8) << 2)
+ | (self.opcode << 3)
+ | ((self.response as u8) << 7),
+ )?;
+
+ buffer.write_u8(
+ (self.rescode as u8)
+ | ((self.checking_disabled as u8) << 4)
+ | ((self.authed_data as u8) << 5)
+ | ((self.z as u8) << 6)
+ | ((self.recursion_available as u8) << 7),
+ )?;
+
+ buffer.write_u16(self.questions)?;
+ buffer.write_u16(self.answers)?;
+ buffer.write_u16(self.authoritative_entries)?;
+ buffer.write_u16(self.resource_entries)?;
+
+ Ok(())
+ }
+}
diff --git a/src/dns/packet/mod.rs b/src/dns/packet/mod.rs
new file mode 100644
index 0000000..9873b94
--- /dev/null
+++ b/src/dns/packet/mod.rs
@@ -0,0 +1,128 @@
+use std::net::IpAddr;
+
+use self::{
+ buffer::PacketBuffer, header::DnsHeader, query::QueryType, question::DnsQuestion,
+ record::DnsRecord,
+};
+use crate::Result;
+
+pub mod buffer;
+pub mod header;
+pub mod query;
+pub mod question;
+pub mod record;
+pub mod result;
+
+#[derive(Clone, Debug)]
+pub struct Packet {
+ pub header: DnsHeader,
+ pub questions: Vec<DnsQuestion>,
+ pub answers: Vec<DnsRecord>,
+ pub authorities: Vec<DnsRecord>,
+ pub resources: Vec<DnsRecord>,
+}
+
+impl Packet {
+ pub fn new() -> Self {
+ Self {
+ header: DnsHeader::new(),
+ questions: Vec::new(),
+ answers: Vec::new(),
+ authorities: Vec::new(),
+ resources: Vec::new(),
+ }
+ }
+
+ pub fn from_buffer(buffer: &mut PacketBuffer) -> Result<Self> {
+ let mut result = Self::new();
+ result.header.read(buffer)?;
+
+ for _ in 0..result.header.questions {
+ let mut question = DnsQuestion::new("".to_string(), QueryType::UNKNOWN(0));
+ question.read(buffer)?;
+ result.questions.push(question);
+ }
+
+ for _ in 0..result.header.answers {
+ let rec = DnsRecord::read(buffer)?;
+ result.answers.push(rec);
+ }
+ for _ in 0..result.header.authoritative_entries {
+ let rec = DnsRecord::read(buffer)?;
+ result.authorities.push(rec);
+ }
+ for _ in 0..result.header.resource_entries {
+ let rec = DnsRecord::read(buffer)?;
+ result.resources.push(rec);
+ }
+
+ Ok(result)
+ }
+
+ pub fn write(&mut self, buffer: &mut PacketBuffer) -> Result<()> {
+ self.header.questions = self.questions.len() as u16;
+ self.header.answers = self.answers.len() as u16;
+ self.header.authoritative_entries = self.authorities.len() as u16;
+ self.header.resource_entries = self.resources.len() as u16;
+
+ self.header.write(buffer)?;
+
+ for question in &self.questions {
+ question.write(buffer)?;
+ }
+ for rec in &self.answers {
+ rec.write(buffer)?;
+ }
+ for rec in &self.authorities {
+ rec.write(buffer)?;
+ }
+ for rec in &self.resources {
+ rec.write(buffer)?;
+ }
+
+ Ok(())
+ }
+
+ pub fn get_random_a(&self) -> Option<IpAddr> {
+ self.answers
+ .iter()
+ .filter_map(|record| match record {
+ DnsRecord::A { addr, .. } => Some(IpAddr::V4(*addr)),
+ DnsRecord::AAAA { addr, .. } => Some(IpAddr::V6(*addr)),
+ _ => None,
+ })
+ .next()
+ }
+
+ fn get_ns<'a>(&'a self, qname: &'a str) -> impl Iterator<Item = (&'a str, &'a str)> {
+ self.authorities
+ .iter()
+ .filter_map(|record| match record {
+ DnsRecord::NS { domain, host, .. } => Some((domain.as_str(), host.as_str())),
+ _ => None,
+ })
+ .filter(move |(domain, _)| qname.ends_with(*domain))
+ }
+
+ pub fn get_resolved_ns(&self, qname: &str) -> Option<IpAddr> {
+ self.get_ns(qname)
+ .flat_map(|(_, host)| {
+ self.resources
+ .iter()
+ .filter_map(move |record| match record {
+ DnsRecord::A { domain, addr, .. } if domain == host => {
+ Some(IpAddr::V4(*addr))
+ }
+ DnsRecord::AAAA { domain, addr, .. } if domain == host => {
+ Some(IpAddr::V6(*addr))
+ }
+ _ => None,
+ })
+ })
+ .next()
+ }
+
+ pub fn get_unresolved_ns<'a>(&'a self, qname: &'a str) -> Option<&'a str> {
+ self.get_ns(qname).map(|(_, host)| host).next()
+ }
+}
diff --git a/src/dns/packet/query.rs b/src/dns/packet/query.rs
new file mode 100644
index 0000000..732b9b2
--- /dev/null
+++ b/src/dns/packet/query.rs
@@ -0,0 +1,78 @@
+#[derive(PartialEq, Eq, Debug, Clone, Hash, Copy)]
+pub enum QueryType {
+ UNKNOWN(u16),
+ A, // 1
+ NS, // 2
+ CNAME, // 5
+ SOA, // 6
+ PTR, // 12
+ MX, // 15
+ TXT, // 16
+ AAAA, // 28
+ SRV, // 33
+ OPT, // 41
+ CAA, // 257
+ AR, // 1000
+ AAAAR, // 1001
+}
+
+impl QueryType {
+ pub fn to_num(&self) -> u16 {
+ match *self {
+ Self::UNKNOWN(x) => x,
+ Self::A => 1,
+ Self::NS => 2,
+ Self::CNAME => 5,
+ Self::SOA => 6,
+ Self::PTR => 12,
+ Self::MX => 15,
+ Self::TXT => 16,
+ Self::AAAA => 28,
+ Self::SRV => 33,
+ Self::OPT => 41,
+ Self::CAA => 257,
+ Self::AR => 1000,
+ Self::AAAAR => 1001,
+ }
+ }
+
+ pub fn from_num(num: u16) -> Self {
+ match num {
+ 1 => Self::A,
+ 2 => Self::NS,
+ 5 => Self::CNAME,
+ 6 => Self::SOA,
+ 12 => Self::PTR,
+ 15 => Self::MX,
+ 16 => Self::TXT,
+ 28 => Self::AAAA,
+ 33 => Self::SRV,
+ 41 => Self::OPT,
+ 257 => Self::CAA,
+ 1000 => Self::AR,
+ 1001 => Self::AAAAR,
+ _ => Self::UNKNOWN(num),
+ }
+ }
+
+ pub fn allowed_actions(&self) -> (bool, bool) {
+ // 0. duplicates allowed
+ // 1. allowed to be created by database
+ match self {
+ QueryType::UNKNOWN(_) => (false, false),
+ QueryType::A => (true, true),
+ QueryType::NS => (false, true),
+ QueryType::CNAME => (false, true),
+ QueryType::SOA => (false, false),
+ QueryType::PTR => (false, true),
+ QueryType::MX => (false, true),
+ QueryType::TXT => (true, true),
+ QueryType::AAAA => (true, true),
+ QueryType::SRV => (false, true),
+ QueryType::OPT => (false, false),
+ QueryType::CAA => (false, true),
+ QueryType::AR => (false, true),
+ QueryType::AAAAR => (false, true),
+ }
+ }
+}
diff --git a/src/dns/packet/question.rs b/src/dns/packet/question.rs
new file mode 100644
index 0000000..9042e1c
--- /dev/null
+++ b/src/dns/packet/question.rs
@@ -0,0 +1,31 @@
+use super::{buffer::PacketBuffer, query::QueryType, Result};
+
+#[derive(Debug, Clone, PartialEq, Eq, Hash)]
+pub struct DnsQuestion {
+ pub name: String,
+ pub qtype: QueryType,
+}
+
+impl DnsQuestion {
+ pub fn new(name: String, qtype: QueryType) -> Self {
+ Self { name, qtype }
+ }
+
+ pub fn read(&mut self, buffer: &mut PacketBuffer) -> Result<()> {
+ buffer.read_qname(&mut self.name)?;
+ self.qtype = QueryType::from_num(buffer.read_u16()?); // qtype
+ let _ = buffer.read_u16()?; // class
+
+ Ok(())
+ }
+
+ pub fn write(&self, buffer: &mut PacketBuffer) -> Result<()> {
+ buffer.write_qname(&self.name)?;
+
+ let typenum = self.qtype.to_num();
+ buffer.write_u16(typenum)?;
+ buffer.write_u16(1)?;
+
+ Ok(())
+ }
+}
diff --git a/src/dns/packet/record.rs b/src/dns/packet/record.rs
new file mode 100644
index 0000000..88008f0
--- /dev/null
+++ b/src/dns/packet/record.rs
@@ -0,0 +1,544 @@
+use std::net::{Ipv4Addr, Ipv6Addr};
+
+use rand::RngCore;
+use serde::{Deserialize, Serialize};
+use tracing::{trace, warn};
+
+use super::{buffer::PacketBuffer, query::QueryType, Result};
+
+#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)]
+pub enum DnsRecord {
+ UNKNOWN {
+ domain: String,
+ qtype: u16,
+ data_len: u16,
+ ttl: u32,
+ }, // 0
+ A {
+ domain: String,
+ addr: Ipv4Addr,
+ ttl: u32,
+ }, // 1
+ NS {
+ domain: String,
+ host: String,
+ ttl: u32,
+ }, // 2
+ CNAME {
+ domain: String,
+ host: String,
+ ttl: u32,
+ }, // 5
+ SOA {
+ domain: String,
+ mname: String,
+ nname: String,
+ serial: u32,
+ refresh: u32,
+ retry: u32,
+ expire: u32,
+ minimum: u32,
+ ttl: u32,
+ }, // 6
+ PTR {
+ domain: String,
+ pointer: String,
+ ttl: u32,
+ }, // 12
+ MX {
+ domain: String,
+ priority: u16,
+ host: String,
+ ttl: u32,
+ }, // 15
+ TXT {
+ domain: String,
+ text: Vec<String>,
+ ttl: u32,
+ }, //16
+ AAAA {
+ domain: String,
+ addr: Ipv6Addr,
+ ttl: u32,
+ }, // 28
+ SRV {
+ domain: String,
+ priority: u16,
+ weight: u16,
+ port: u16,
+ target: String,
+ ttl: u32,
+ }, // 33
+ CAA {
+ domain: String,
+ flags: u8,
+ length: u8,
+ tag: String,
+ value: String,
+ ttl: u32,
+ }, // 257
+ AR {
+ domain: String,
+ ttl: u32,
+ },
+ AAAAR {
+ domain: String,
+ ttl: u32,
+ },
+}
+
+impl DnsRecord {
+ pub fn read(buffer: &mut PacketBuffer) -> Result<Self> {
+ let mut domain = String::new();
+ buffer.read_qname(&mut domain)?;
+
+ let qtype_num = buffer.read_u16()?;
+ let qtype = QueryType::from_num(qtype_num);
+ let _ = buffer.read_u16()?;
+ let ttl = buffer.read_u32()?;
+ let data_len = buffer.read_u16()?;
+
+ trace!("Reading DNS Record TYPE: {:?}", qtype);
+
+ let header_pos = buffer.pos();
+
+ match qtype {
+ QueryType::A => {
+ let raw_addr = buffer.read_u32()?;
+ let addr = Ipv4Addr::new(
+ ((raw_addr >> 24) & 0xFF) as u8,
+ ((raw_addr >> 16) & 0xFF) as u8,
+ ((raw_addr >> 8) & 0xFF) as u8,
+ (raw_addr & 0xFF) as u8,
+ );
+
+ Ok(Self::A { domain, addr, ttl })
+ }
+ QueryType::AAAA => {
+ let raw_addr1 = buffer.read_u32()?;
+ let raw_addr2 = buffer.read_u32()?;
+ let raw_addr3 = buffer.read_u32()?;
+ let raw_addr4 = buffer.read_u32()?;
+ let addr = Ipv6Addr::new(
+ ((raw_addr1 >> 16) & 0xFFFF) as u16,
+ (raw_addr1 & 0xFFFF) as u16,
+ ((raw_addr2 >> 16) & 0xFFFF) as u16,
+ (raw_addr2 & 0xFFFF) as u16,
+ ((raw_addr3 >> 16) & 0xFFFF) as u16,
+ (raw_addr3 & 0xFFFF) as u16,
+ ((raw_addr4 >> 16) & 0xFFFF) as u16,
+ (raw_addr4 & 0xFFFF) as u16,
+ );
+
+ Ok(Self::AAAA { domain, addr, ttl })
+ }
+ QueryType::NS => {
+ let mut ns = String::new();
+ buffer.read_qname(&mut ns)?;
+
+ Ok(Self::NS {
+ domain,
+ host: ns,
+ ttl,
+ })
+ }
+ QueryType::CNAME => {
+ let mut cname = String::new();
+ buffer.read_qname(&mut cname)?;
+
+ Ok(Self::CNAME {
+ domain,
+ host: cname,
+ ttl,
+ })
+ }
+ QueryType::SOA => {
+ let mut mname = String::new();
+ buffer.read_qname(&mut mname)?;
+
+ let mut nname = String::new();
+ buffer.read_qname(&mut nname)?;
+
+ let serial = buffer.read_u32()?;
+ let refresh = buffer.read_u32()?;
+ let retry = buffer.read_u32()?;
+ let expire = buffer.read_u32()?;
+ let minimum = buffer.read_u32()?;
+
+ Ok(Self::SOA {
+ domain,
+ mname,
+ nname,
+ serial,
+ refresh,
+ retry,
+ expire,
+ minimum,
+ ttl,
+ })
+ }
+ QueryType::PTR => {
+ let mut pointer = String::new();
+ buffer.read_qname(&mut pointer)?;
+
+ Ok(Self::PTR {
+ domain,
+ pointer,
+ ttl,
+ })
+ }
+ QueryType::MX => {
+ let priority = buffer.read_u16()?;
+ let mut mx = String::new();
+ buffer.read_qname(&mut mx)?;
+
+ Ok(Self::MX {
+ domain,
+ priority,
+ host: mx,
+ ttl,
+ })
+ }
+ QueryType::TXT => {
+ let mut text = Vec::new();
+
+ loop {
+ let mut s = String::new();
+ buffer.read_string(&mut s)?;
+
+ if s.len() == 0 {
+ break;
+ } else {
+ text.push(s);
+ }
+ }
+
+ Ok(Self::TXT { domain, text, ttl })
+ }
+ QueryType::SRV => {
+ let priority = buffer.read_u16()?;
+ let weight = buffer.read_u16()?;
+ let port = buffer.read_u16()?;
+
+ let mut target = String::new();
+ buffer.read_qname(&mut target)?;
+
+ Ok(Self::SRV {
+ domain,
+ priority,
+ weight,
+ port,
+ target,
+ ttl,
+ })
+ }
+ QueryType::CAA => {
+ let flags = buffer.read()?;
+ let length = buffer.read()?;
+
+ let mut tag = String::new();
+ buffer.read_string_n(&mut tag, length)?;
+
+ let value_len = (data_len as usize) + header_pos - buffer.pos;
+ let mut value = String::new();
+ buffer.read_string_n(&mut value, value_len as u8)?;
+
+ Ok(Self::CAA {
+ domain,
+ flags,
+ length,
+ tag,
+ value,
+ ttl,
+ })
+ }
+ QueryType::UNKNOWN(_) | _ => {
+ buffer.step(data_len as usize)?;
+
+ Ok(Self::UNKNOWN {
+ domain,
+ qtype: qtype_num,
+ data_len,
+ ttl,
+ })
+ }
+ }
+ }
+
+ pub fn write(&self, buffer: &mut PacketBuffer) -> Result<usize> {
+ let start_pos = buffer.pos();
+
+ trace!("Writing DNS Record {:?}", self);
+
+ match *self {
+ Self::A {
+ ref domain,
+ ref addr,
+ ttl,
+ } => {
+ buffer.write_qname(domain)?;
+ buffer.write_u16(QueryType::A.to_num())?;
+ buffer.write_u16(1)?;
+ buffer.write_u32(ttl)?;
+ buffer.write_u16(4)?;
+
+ let octets = addr.octets();
+ buffer.write_u8(octets[0])?;
+ buffer.write_u8(octets[1])?;
+ buffer.write_u8(octets[2])?;
+ buffer.write_u8(octets[3])?;
+ }
+ Self::NS {
+ ref domain,
+ ref host,
+ ttl,
+ } => {
+ buffer.write_qname(domain)?;
+ buffer.write_u16(QueryType::NS.to_num())?;
+ buffer.write_u16(1)?;
+ buffer.write_u32(ttl)?;
+
+ let pos = buffer.pos();
+ buffer.write_u16(0)?;
+
+ buffer.write_qname(host)?;
+
+ let size = buffer.pos() - (pos + 2);
+ buffer.set_u16(pos, size as u16)?;
+ }
+ Self::CNAME {
+ ref domain,
+ ref host,
+ ttl,
+ } => {
+ buffer.write_qname(domain)?;
+ buffer.write_u16(QueryType::CNAME.to_num())?;
+ buffer.write_u16(1)?;
+ buffer.write_u32(ttl)?;
+
+ let pos = buffer.pos();
+ buffer.write_u16(0)?;
+
+ buffer.write_qname(host)?;
+
+ let size = buffer.pos() - (pos + 2);
+ buffer.set_u16(pos, size as u16)?;
+ }
+ Self::SOA {
+ ref domain,
+ ref mname,
+ ref nname,
+ serial,
+ refresh,
+ retry,
+ expire,
+ minimum,
+ ttl,
+ } => {
+ buffer.write_qname(domain)?;
+ buffer.write_u16(QueryType::SOA.to_num())?;
+ buffer.write_u16(1)?;
+ buffer.write_u32(ttl)?;
+
+ let pos = buffer.pos();
+ buffer.write_u16(0)?;
+
+ buffer.write_qname(mname)?;
+ buffer.write_qname(nname)?;
+ buffer.write_u32(serial)?;
+ buffer.write_u32(refresh)?;
+ buffer.write_u32(retry)?;
+ buffer.write_u32(expire)?;
+ buffer.write_u32(minimum)?;
+
+ let size = buffer.pos() - (pos + 2);
+ buffer.set_u16(pos, size as u16)?;
+ }
+ Self::PTR {
+ ref domain,
+ ref pointer,
+ ttl,
+ } => {
+ buffer.write_qname(domain)?;
+ buffer.write_u16(QueryType::NS.to_num())?;
+ buffer.write_u16(1)?;
+ buffer.write_u32(ttl)?;
+
+ let pos = buffer.pos();
+ buffer.write_u16(0)?;
+
+ buffer.write_qname(&pointer)?;
+
+ let size = buffer.pos() - (pos + 2);
+ buffer.set_u16(pos, size as u16)?;
+ }
+ Self::MX {
+ ref domain,
+ priority,
+ ref host,
+ ttl,
+ } => {
+ buffer.write_qname(domain)?;
+ buffer.write_u16(QueryType::MX.to_num())?;
+ buffer.write_u16(1)?;
+ buffer.write_u32(ttl)?;
+
+ let pos = buffer.pos();
+ buffer.write_u16(0)?;
+
+ buffer.write_u16(priority)?;
+ buffer.write_qname(host)?;
+
+ let size = buffer.pos() - (pos + 2);
+ buffer.set_u16(pos, size as u16)?;
+ }
+ Self::TXT {
+ ref domain,
+ ref text,
+ ttl,
+ } => {
+ buffer.write_qname(&domain)?;
+ buffer.write_u16(QueryType::TXT.to_num())?;
+ buffer.write_u16(1)?;
+ buffer.write_u32(ttl)?;
+
+ let pos = buffer.pos();
+ buffer.write_u16(0)?;
+
+ if text.is_empty() {
+ return Ok(buffer.pos() - start_pos);
+ }
+
+ for s in text {
+ buffer.write_u8(s.len() as u8)?;
+ buffer.write_string(&s)?;
+ }
+ let size = buffer.pos() - (pos + 2);
+ buffer.set_u16(pos, size as u16)?;
+ }
+ Self::AAAA {
+ ref domain,
+ ref addr,
+ ttl,
+ } => {
+ buffer.write_qname(domain)?;
+ buffer.write_u16(QueryType::AAAA.to_num())?;
+ buffer.write_u16(1)?;
+ buffer.write_u32(ttl)?;
+ buffer.write_u16(16)?;
+
+ for octet in &addr.segments() {
+ buffer.write_u16(*octet)?;
+ }
+ }
+ Self::SRV {
+ ref domain,
+ priority,
+ weight,
+ port,
+ ref target,
+ ttl,
+ } => {
+ buffer.write_qname(domain)?;
+ buffer.write_u16(QueryType::SRV.to_num())?;
+ buffer.write_u16(1)?;
+ buffer.write_u32(ttl)?;
+
+ let pos = buffer.pos();
+ buffer.write_u16(0)?;
+
+ buffer.write_u16(priority)?;
+ buffer.write_u16(weight)?;
+ buffer.write_u16(port)?;
+ buffer.write_qname(target)?;
+
+ let size = buffer.pos() - (pos + 2);
+ buffer.set_u16(pos, size as u16)?;
+ }
+ Self::CAA {
+ ref domain,
+ flags,
+ length,
+ ref tag,
+ ref value,
+ ttl,
+ } => {
+ buffer.write_qname(domain)?;
+ buffer.write_u16(QueryType::CAA.to_num())?;
+ buffer.write_u16(1)?;
+ buffer.write_u32(ttl)?;
+
+ let pos = buffer.pos();
+ buffer.write_u16(0)?;
+
+ buffer.write_u8(flags)?;
+ buffer.write_u8(length)?;
+ buffer.write_string(tag)?;
+ buffer.write_string(value)?;
+
+ let size = buffer.pos() - (pos + 2);
+ buffer.set_u16(pos, size as u16)?;
+ }
+ Self::AR { ref domain, ttl } => {
+ buffer.write_qname(domain)?;
+ buffer.write_u16(QueryType::A.to_num())?;
+ buffer.write_u16(1)?;
+ buffer.write_u32(ttl)?;
+ buffer.write_u16(4)?;
+
+ let mut rand = rand::thread_rng();
+ buffer.write_u32(rand.next_u32())?;
+ }
+ Self::AAAAR { ref domain, ttl } => {
+ buffer.write_qname(domain)?;
+ buffer.write_u16(QueryType::A.to_num())?;
+ buffer.write_u16(1)?;
+ buffer.write_u32(ttl)?;
+ buffer.write_u16(4)?;
+
+ let mut rand = rand::thread_rng();
+ buffer.write_u32(rand.next_u32())?;
+ buffer.write_u32(rand.next_u32())?;
+ buffer.write_u32(rand.next_u32())?;
+ buffer.write_u32(rand.next_u32())?;
+ }
+ Self::UNKNOWN { .. } => {
+ warn!("Skipping record: {self:?}");
+ }
+ }
+
+ Ok(buffer.pos() - start_pos)
+ }
+
+ pub fn get_domain(&self) -> String {
+ self.get_shared_domain().0
+ }
+
+ pub fn get_qtype(&self) -> QueryType {
+ self.get_shared_domain().1
+ }
+
+ pub fn get_ttl(&self) -> u32 {
+ self.get_shared_domain().2
+ }
+
+ fn get_shared_domain(&self) -> (String, QueryType, u32) {
+ match self {
+ DnsRecord::UNKNOWN {
+ domain, ttl, qtype, ..
+ } => (domain.clone(), QueryType::UNKNOWN(*qtype), *ttl),
+ DnsRecord::AAAA { domain, ttl, .. } => (domain.clone(), QueryType::AAAA, *ttl),
+ DnsRecord::A { domain, ttl, .. } => (domain.clone(), QueryType::A, *ttl),
+ DnsRecord::NS { domain, ttl, .. } => (domain.clone(), QueryType::NS, *ttl),
+ DnsRecord::CNAME { domain, ttl, .. } => (domain.clone(), QueryType::CNAME, *ttl),
+ DnsRecord::SOA { domain, ttl, .. } => (domain.clone(), QueryType::SOA, *ttl),
+ DnsRecord::PTR { domain, ttl, .. } => (domain.clone(), QueryType::PTR, *ttl),
+ DnsRecord::MX { domain, ttl, .. } => (domain.clone(), QueryType::MX, *ttl),
+ DnsRecord::TXT { domain, ttl, .. } => (domain.clone(), QueryType::TXT, *ttl),
+ DnsRecord::SRV { domain, ttl, .. } => (domain.clone(), QueryType::SRV, *ttl),
+ DnsRecord::CAA { domain, ttl, .. } => (domain.clone(), QueryType::CAA, *ttl),
+ DnsRecord::AR { domain, ttl, .. } => (domain.clone(), QueryType::AR, *ttl),
+ DnsRecord::AAAAR { domain, ttl, .. } => (domain.clone(), QueryType::AAAAR, *ttl),
+ }
+ }
+}
diff --git a/src/dns/packet/result.rs b/src/dns/packet/result.rs
new file mode 100644
index 0000000..41c8ba9
--- /dev/null
+++ b/src/dns/packet/result.rs
@@ -0,0 +1,22 @@
+#[derive(Copy, Clone, Debug, PartialEq, Eq)]
+pub enum ResultCode {
+ NOERROR = 0,
+ FORMERR = 1,
+ SERVFAIL = 2,
+ NXDOMAIN = 3,
+ NOTIMP = 4,
+ REFUSED = 5,
+}
+
+impl ResultCode {
+ pub fn from_num(num: u8) -> Self {
+ match num {
+ 1 => Self::FORMERR,
+ 2 => Self::SERVFAIL,
+ 3 => Self::NXDOMAIN,
+ 4 => Self::NOTIMP,
+ 5 => Self::REFUSED,
+ 0 | _ => Self::NOERROR,
+ }
+ }
+}
diff --git a/src/dns/resolver.rs b/src/dns/resolver.rs
new file mode 100644
index 0000000..18b5bba
--- /dev/null
+++ b/src/dns/resolver.rs
@@ -0,0 +1,230 @@
+use super::binding::Connection;
+use super::packet::{query::QueryType, question::DnsQuestion, result::ResultCode, Packet};
+use crate::Result;
+use crate::{config::Config, database::Database, get_time};
+use async_recursion::async_recursion;
+use moka::future::Cache;
+use std::{net::IpAddr, sync::Arc, time::Duration};
+use tracing::{error, trace};
+
+pub struct Resolver {
+ request_id: u16,
+ connection: Connection,
+ config: Arc<Config>,
+ database: Arc<Database>,
+ cache: Cache<DnsQuestion, (Packet, u64)>,
+}
+
+impl Resolver {
+ pub fn new(
+ request_id: u16,
+ connection: Connection,
+ config: Arc<Config>,
+ database: Arc<Database>,
+ cache: Cache<DnsQuestion, (Packet, u64)>,
+ ) -> Self {
+ Self {
+ request_id,
+ connection,
+ config,
+ database,
+ cache,
+ }
+ }
+
+ async fn lookup_database(&self, question: &DnsQuestion) -> Option<Packet> {
+ let records = match self
+ .database
+ .get_records(&question.name, question.qtype)
+ .await
+ {
+ Ok(record) => record,
+ Err(err) => {
+ error!("{err}");
+ return None;
+ }
+ };
+
+ if records.is_empty() {
+ return None;
+ }
+
+ let mut packet = Packet::new();
+
+ packet.header.id = self.request_id;
+ packet.header.questions = 1;
+ packet.header.answers = records.len() as u16;
+ packet.header.recursion_desired = true;
+ packet
+ .questions
+ .push(DnsQuestion::new(question.name.to_string(), question.qtype));
+
+ for record in records {
+ packet.answers.push(record);
+ }
+
+ trace!(
+ "Found stored value for {:?} {}",
+ question.qtype,
+ question.name
+ );
+
+ Some(packet)
+ }
+
+ async fn lookup_cache(&self, question: &DnsQuestion) -> Option<Packet> {
+ let Some((packet, date)) = self.cache.get(&question) else {
+ return None
+ };
+
+ let now = get_time();
+ let diff = Duration::from_millis(now - date).as_secs() as u32;
+
+ for answer in &packet.answers {
+ let ttl = answer.get_ttl();
+ if diff > ttl {
+ self.cache.invalidate(&question).await;
+ return None;
+ }
+ }
+
+ trace!(
+ "Found cached value for {:?} {}",
+ question.qtype,
+ question.name
+ );
+
+ Some(packet)
+ }
+
+ async fn lookup_fallback(&self, question: &DnsQuestion, server: (IpAddr, u16)) -> Packet {
+ let mut packet = Packet::new();
+
+ packet.header.id = self.request_id;
+ packet.header.questions = 1;
+ packet.header.recursion_desired = true;
+ packet
+ .questions
+ .push(DnsQuestion::new(question.name.to_string(), question.qtype));
+
+ let packet = match self.connection.request_packet(packet, server).await {
+ Ok(packet) => packet,
+ Err(e) => {
+ error!("Failed to complete nameserver request: {e}");
+ let mut packet = Packet::new();
+ packet.header.rescode = ResultCode::SERVFAIL;
+ packet
+ }
+ };
+
+ packet
+ }
+
+ async fn lookup(&self, question: &DnsQuestion, server: (IpAddr, u16)) -> Packet {
+ if let Some(packet) = self.lookup_cache(question).await {
+ return packet;
+ };
+
+ if let Some(packet) = self.lookup_database(question).await {
+ return packet;
+ };
+
+ trace!(
+ "Attempting lookup of {:?} {} with ns {}",
+ question.qtype,
+ question.name,
+ server.0
+ );
+
+ self.lookup_fallback(question, server).await
+ }
+
+ #[async_recursion]
+ async fn recursive_lookup(&mut self, qname: &str, qtype: QueryType) -> Packet {
+ let question = DnsQuestion::new(qname.to_string(), qtype);
+ let mut ns = self.config.dns_fallback.clone();
+
+ loop {
+ let ns_copy = ns;
+
+ let server = (ns_copy, 53);
+ let response = self.lookup(&question, server).await;
+
+ if !response.answers.is_empty() && response.header.rescode == ResultCode::NOERROR {
+ self.cache
+ .insert(question, (response.clone(), get_time()))
+ .await;
+ return response;
+ }
+
+ if response.header.rescode == ResultCode::NXDOMAIN {
+ self.cache
+ .insert(question, (response.clone(), get_time()))
+ .await;
+ return response;
+ }
+
+ if let Some(new_ns) = response.get_resolved_ns(qname) {
+ ns = new_ns;
+ continue;
+ }
+
+ let new_ns_name = match response.get_unresolved_ns(qname) {
+ Some(x) => x,
+ None => {
+ self.cache
+ .insert(question, (response.clone(), get_time()))
+ .await;
+ return response;
+ }
+ };
+
+ let recursive_response = self.recursive_lookup(new_ns_name, QueryType::A).await;
+
+ if let Some(new_ns) = recursive_response.get_random_a() {
+ ns = new_ns;
+ } else {
+ self.cache
+ .insert(question, (response.clone(), get_time()))
+ .await;
+ return response;
+ }
+ }
+ }
+
+ pub async fn handle_query(mut self) -> Result<()> {
+ let mut request = self.connection.read_packet().await?;
+
+ let mut packet = Packet::new();
+ packet.header.id = request.header.id;
+ packet.header.recursion_desired = true;
+ packet.header.recursion_available = true;
+ packet.header.response = true;
+
+ if let Some(question) = request.questions.pop() {
+ trace!("Received query: {question:?}");
+
+ let result = self.recursive_lookup(&question.name, question.qtype).await;
+ packet.questions.push(question.clone());
+ packet.header.rescode = result.header.rescode;
+
+ for rec in result.answers {
+ trace!("Answer: {rec:?}");
+ packet.answers.push(rec);
+ }
+ for rec in result.authorities {
+ trace!("Authority: {rec:?}");
+ packet.authorities.push(rec);
+ }
+ for rec in result.resources {
+ trace!("Resource: {rec:?}");
+ packet.resources.push(rec);
+ }
+ } else {
+ packet.header.rescode = ResultCode::FORMERR;
+ }
+
+ self.connection.write_packet(packet).await?;
+ Ok(())
+ }
+}
diff --git a/src/dns/server.rs b/src/dns/server.rs
new file mode 100644
index 0000000..65d15df
--- /dev/null
+++ b/src/dns/server.rs
@@ -0,0 +1,85 @@
+use super::{
+ binding::Binding,
+ packet::{question::DnsQuestion, Packet},
+ resolver::Resolver,
+};
+use crate::{config::Config, database::Database, Result};
+use moka::future::Cache;
+use std::{net::SocketAddr, sync::Arc, time::Duration};
+use tokio::task::JoinHandle;
+use tracing::{error, info};
+
+pub struct DnsServer {
+ addr: SocketAddr,
+ config: Arc<Config>,
+ database: Arc<Database>,
+ cache: Cache<DnsQuestion, (Packet, u64)>,
+}
+
+impl DnsServer {
+ pub async fn new(config: Config, database: Database) -> Result<Self> {
+ let addr = format!("[::]:{}", config.dns_port).parse::<SocketAddr>()?;
+ let cache = Cache::builder()
+ .time_to_live(Duration::from_secs(60 * 60))
+ .max_capacity(config.dns_cache_size)
+ .build();
+
+ info!("Created DNS cache with size of {}", config.dns_cache_size);
+
+ Ok(Self {
+ addr,
+ config: Arc::new(config),
+ database: Arc::new(database),
+ cache,
+ })
+ }
+
+ pub async fn run(&self) -> Result<(JoinHandle<()>, JoinHandle<()>)> {
+ let tcp = Binding::tcp(self.addr).await?;
+ let tcp_handle = self.listen(tcp);
+
+ let udp = Binding::udp(self.addr).await?;
+ let udp_handle = self.listen(udp);
+
+ info!(
+ "Fallback DNS Server is set to: {:?}",
+ self.config.dns_fallback
+ );
+ info!(
+ "Listening for TCP and UDP traffic on [::]:{}",
+ self.config.dns_port
+ );
+
+ Ok((udp_handle, tcp_handle))
+ }
+
+ fn listen(&self, mut binding: Binding) -> JoinHandle<()> {
+ let config = self.config.clone();
+ let database = self.database.clone();
+ let cache = self.cache.clone();
+ tokio::spawn(async move {
+ let mut id = 0;
+ loop {
+ let Ok(connection) = binding.connect().await else { continue };
+ info!("Received request on {}", binding.name());
+
+ let resolver = Resolver::new(
+ id,
+ connection,
+ config.clone(),
+ database.clone(),
+ cache.clone(),
+ );
+
+ let name = binding.name().to_string();
+ tokio::spawn(async move {
+ if let Err(err) = resolver.handle_query().await {
+ error!("{} request {} failed: {:?}", name, id, err);
+ };
+ });
+
+ id += 1;
+ }
+ })
+ }
+}