summaryrefslogtreecommitdiff
path: root/src/dns
diff options
context:
space:
mode:
Diffstat (limited to 'src/dns')
-rw-r--r--src/dns/binding.rs144
-rw-r--r--src/dns/mod.rs4
-rw-r--r--src/dns/packet/buffer.rs227
-rw-r--r--src/dns/packet/header.rs102
-rw-r--r--src/dns/packet/mod.rs128
-rw-r--r--src/dns/packet/query.rs78
-rw-r--r--src/dns/packet/question.rs31
-rw-r--r--src/dns/packet/record.rs544
-rw-r--r--src/dns/packet/result.rs22
-rw-r--r--src/dns/resolver.rs230
-rw-r--r--src/dns/server.rs85
11 files changed, 0 insertions, 1595 deletions
diff --git a/src/dns/binding.rs b/src/dns/binding.rs
deleted file mode 100644
index 4c7e15f..0000000
--- a/src/dns/binding.rs
+++ /dev/null
@@ -1,144 +0,0 @@
-use std::{
- net::{IpAddr, SocketAddr},
- sync::Arc,
-};
-
-use super::packet::{buffer::PacketBuffer, Packet};
-use crate::Result;
-use tokio::{
- io::{AsyncReadExt, AsyncWriteExt},
- net::{TcpListener, TcpStream, UdpSocket},
-};
-use tracing::trace;
-
-pub enum Binding {
- UDP(Arc<UdpSocket>),
- TCP(TcpListener),
-}
-
-impl Binding {
- pub async fn udp(addr: SocketAddr) -> Result<Self> {
- let socket = UdpSocket::bind(addr).await?;
- Ok(Self::UDP(Arc::new(socket)))
- }
-
- pub async fn tcp(addr: SocketAddr) -> Result<Self> {
- let socket = TcpListener::bind(addr).await?;
- Ok(Self::TCP(socket))
- }
-
- pub fn name(&self) -> &str {
- match self {
- Binding::UDP(_) => "UDP",
- Binding::TCP(_) => "TCP",
- }
- }
-
- pub async fn connect(&mut self) -> Result<Connection> {
- match self {
- Self::UDP(socket) => {
- let mut buf = [0; 512];
- let (_, addr) = socket.recv_from(&mut buf).await?;
- Ok(Connection::UDP(socket.clone(), addr, buf))
- }
- Self::TCP(socket) => {
- let (stream, _) = socket.accept().await?;
- Ok(Connection::TCP(stream))
- }
- }
- }
-}
-
-pub enum Connection {
- UDP(Arc<UdpSocket>, SocketAddr, [u8; 512]),
- TCP(TcpStream),
-}
-
-impl Connection {
- pub async fn read_packet(&mut self) -> Result<Packet> {
- let data = self.read().await?;
- let mut packet_buffer = PacketBuffer::new(data);
-
- let packet = Packet::from_buffer(&mut packet_buffer)?;
- Ok(packet)
- }
-
- pub async fn write_packet(self, mut packet: Packet) -> Result<()> {
- let mut packet_buffer = PacketBuffer::new(Vec::new());
- packet.write(&mut packet_buffer)?;
-
- self.write(packet_buffer.buf).await?;
- Ok(())
- }
-
- pub async fn request_packet(&self, mut packet: Packet, dest: (IpAddr, u16)) -> Result<Packet> {
- let mut packet_buffer = PacketBuffer::new(Vec::new());
- packet.write(&mut packet_buffer)?;
-
- let data = self.request(packet_buffer.buf, dest).await?;
- let mut packet_buffer = PacketBuffer::new(data);
-
- let packet = Packet::from_buffer(&mut packet_buffer)?;
- Ok(packet)
- }
-
- async fn read(&mut self) -> Result<Vec<u8>> {
- trace!("Reading DNS packet");
- match self {
- Self::UDP(_, _, src) => Ok(Vec::from(*src)),
- Self::TCP(stream) => {
- let size = stream.read_u16().await?;
- let mut buf = Vec::with_capacity(size as usize);
- stream.read_buf(&mut buf).await?;
- Ok(buf)
- }
- }
- }
-
- async fn write(self, mut buf: Vec<u8>) -> Result<()> {
- trace!("Returning DNS packet");
- match self {
- Self::UDP(socket, addr, _) => {
- if buf.len() > 512 {
- buf[2] = buf[2] | 0x03;
- socket.send_to(&buf[0..512], addr).await?;
- } else {
- socket.send_to(&buf, addr).await?;
- }
- Ok(())
- }
- Self::TCP(mut stream) => {
- stream.write_u16(buf.len() as u16).await?;
- stream.write(&buf[0..buf.len()]).await?;
- Ok(())
- }
- }
- }
-
- async fn request(&self, buf: Vec<u8>, dest: (IpAddr, u16)) -> Result<Vec<u8>> {
- match self {
- Self::UDP(_socket, _addr, _src) => {
- let local_addr = "[::]:0".parse::<SocketAddr>()?;
- let socket = UdpSocket::bind(local_addr).await?;
- socket.send_to(&buf, dest).await?;
-
- let mut buf = [0; 512];
- socket.recv_from(&mut buf).await?;
-
- Ok(Vec::from(buf))
- }
- Self::TCP(_stream) => {
- let mut stream = TcpStream::connect(dest).await?;
- stream.write_u16((buf.len()) as u16).await?;
- stream.write_all(&buf[0..buf.len()]).await?;
-
- stream.readable().await?;
- let size = stream.read_u16().await?;
- let mut buf = Vec::with_capacity(size as usize);
- stream.read_buf(&mut buf).await?;
-
- Ok(buf)
- }
- }
- }
-}
diff --git a/src/dns/mod.rs b/src/dns/mod.rs
deleted file mode 100644
index 6f1e59e..0000000
--- a/src/dns/mod.rs
+++ /dev/null
@@ -1,4 +0,0 @@
-mod binding;
-pub mod packet;
-mod resolver;
-pub mod server;
diff --git a/src/dns/packet/buffer.rs b/src/dns/packet/buffer.rs
deleted file mode 100644
index 058156e..0000000
--- a/src/dns/packet/buffer.rs
+++ /dev/null
@@ -1,227 +0,0 @@
-use crate::Result;
-
-pub struct PacketBuffer {
- pub buf: Vec<u8>,
- pub pos: usize,
- pub size: usize,
-}
-
-impl PacketBuffer {
- pub fn new(buf: Vec<u8>) -> Self {
- Self {
- size: buf.len(),
- buf,
- pos: 0,
- }
- }
-
- pub fn pos(&self) -> usize {
- self.pos
- }
-
- pub fn step(&mut self, steps: usize) -> Result<()> {
- self.pos += steps;
-
- Ok(())
- }
-
- pub fn seek(&mut self, pos: usize) -> Result<()> {
- self.pos = pos;
-
- Ok(())
- }
-
- pub fn read(&mut self) -> Result<u8> {
- if self.pos >= self.size {
- return Err("Tried to read past end of buffer".into());
- }
- let res = self.buf[self.pos];
- self.pos += 1;
- Ok(res)
- }
-
- pub fn get(&mut self, pos: usize) -> Result<u8> {
- if pos >= self.size {
- return Err("Tried to read past end of buffer".into());
- }
- Ok(self.buf[pos])
- }
-
- pub fn get_range(&mut self, start: usize, len: usize) -> Result<&[u8]> {
- if start + len >= self.size {
- return Err("Tried to read past end of buffer".into());
- }
- Ok(&self.buf[start..start + len])
- }
-
- pub fn read_u16(&mut self) -> Result<u16> {
- let res = ((self.read()? as u16) << 8) | (self.read()? as u16);
-
- Ok(res)
- }
-
- pub fn read_u32(&mut self) -> Result<u32> {
- let res = ((self.read()? as u32) << 24)
- | ((self.read()? as u32) << 16)
- | ((self.read()? as u32) << 8)
- | (self.read()? as u32);
-
- Ok(res)
- }
-
- pub fn read_qname(&mut self, outstr: &mut String) -> Result<()> {
- let mut pos = self.pos();
- let mut jumped = false;
-
- let mut delim = "";
- let max_jumps = 5;
- let mut jumps_performed = 0;
- loop {
- // Dns Packets are untrusted data, so we need to be paranoid. Someone
- // can craft a packet with a cycle in the jump instructions. This guards
- // against such packets.
- if jumps_performed > max_jumps {
- return Err(format!("Limit of {max_jumps} jumps exceeded").into());
- }
-
- let len = self.get(pos)?;
-
- if (len & 0xC0) == 0xC0 {
- if !jumped {
- self.seek(pos + 2)?;
- }
-
- let b2 = self.get(pos + 1)? as u16;
- let offset = (((len as u16) ^ 0xC0) << 8) | b2;
- pos = offset as usize;
- jumped = true;
- jumps_performed += 1;
- continue;
- }
-
- pos += 1;
-
- if len == 0 {
- break;
- }
-
- outstr.push_str(delim);
-
- let str_buffer = self.get_range(pos, len as usize)?;
- outstr.push_str(&String::from_utf8_lossy(str_buffer).to_lowercase());
-
- delim = ".";
-
- pos += len as usize;
- }
-
- if !jumped {
- self.seek(pos)?;
- }
-
- Ok(())
- }
-
- pub fn read_string(&mut self, outstr: &mut String) -> Result<()> {
- let len = self.read()?;
-
- self.read_string_n(outstr, len)?;
-
- Ok(())
- }
-
- pub fn read_string_n(&mut self, outstr: &mut String, len: u8) -> Result<()> {
- let mut pos = self.pos;
-
- let str_buffer = self.get_range(pos, len as usize)?;
-
- let mut i = 0;
- for b in str_buffer {
- let c = *b as char;
- if c == '\0' {
- break;
- }
- outstr.push(c);
- i += 1;
- }
-
- pos += i;
- self.seek(pos)?;
-
- Ok(())
- }
-
- pub fn write(&mut self, val: u8) -> Result<()> {
- if self.size < self.pos {
- self.size = self.pos;
- }
-
- if self.buf.len() <= self.size {
- self.buf.resize(self.size + 1, 0x00);
- }
-
- self.buf[self.pos] = val;
- self.pos += 1;
- Ok(())
- }
-
- pub fn write_u8(&mut self, val: u8) -> Result<()> {
- self.write(val)?;
-
- Ok(())
- }
-
- pub fn write_u16(&mut self, val: u16) -> Result<()> {
- self.write((val >> 8) as u8)?;
- self.write((val & 0xFF) as u8)?;
-
- Ok(())
- }
-
- pub fn write_u32(&mut self, val: u32) -> Result<()> {
- self.write(((val >> 24) & 0xFF) as u8)?;
- self.write(((val >> 16) & 0xFF) as u8)?;
- self.write(((val >> 8) & 0xFF) as u8)?;
- self.write((val & 0xFF) as u8)?;
-
- Ok(())
- }
-
- pub fn write_qname(&mut self, qname: &str) -> Result<()> {
- for label in qname.split('.') {
- let len = label.len();
-
- self.write_u8(len as u8)?;
- for b in label.as_bytes() {
- self.write_u8(*b)?;
- }
- }
-
- if !qname.is_empty() {
- self.write_u8(0)?;
- }
-
- Ok(())
- }
-
- pub fn write_string(&mut self, text: &str) -> Result<()> {
- for b in text.as_bytes() {
- self.write_u8(*b)?;
- }
-
- Ok(())
- }
-
- pub fn set(&mut self, pos: usize, val: u8) -> Result<()> {
- self.buf[pos] = val;
-
- Ok(())
- }
-
- pub fn set_u16(&mut self, pos: usize, val: u16) -> Result<()> {
- self.set(pos, (val >> 8) as u8)?;
- self.set(pos + 1, (val & 0xFF) as u8)?;
-
- Ok(())
- }
-}
diff --git a/src/dns/packet/header.rs b/src/dns/packet/header.rs
deleted file mode 100644
index 2355ecb..0000000
--- a/src/dns/packet/header.rs
+++ /dev/null
@@ -1,102 +0,0 @@
-use super::{buffer::PacketBuffer, result::ResultCode};
-use crate::Result;
-
-#[derive(Clone, Debug)]
-pub struct DnsHeader {
- pub id: u16, // 16 bits
-
- pub recursion_desired: bool, // 1 bit
- pub truncated_message: bool, // 1 bit
- pub authoritative_answer: bool, // 1 bit
- pub opcode: u8, // 4 bits
- pub response: bool, // 1 bit
-
- pub rescode: ResultCode, // 4 bits
- pub checking_disabled: bool, // 1 bit
- pub authed_data: bool, // 1 bit
- pub z: bool, // 1 bit
- pub recursion_available: bool, // 1 bit
-
- pub questions: u16, // 16 bits
- pub answers: u16, // 16 bits
- pub authoritative_entries: u16, // 16 bits
- pub resource_entries: u16, // 16 bits
-}
-
-impl DnsHeader {
- pub fn new() -> Self {
- Self {
- id: 0,
-
- recursion_desired: false,
- truncated_message: false,
- authoritative_answer: false,
- opcode: 0,
- response: false,
-
- rescode: ResultCode::NOERROR,
- checking_disabled: false,
- authed_data: false,
- z: false,
- recursion_available: false,
-
- questions: 0,
- answers: 0,
- authoritative_entries: 0,
- resource_entries: 0,
- }
- }
-
- pub fn read(&mut self, buffer: &mut PacketBuffer) -> Result<()> {
- self.id = buffer.read_u16()?;
- let flags = buffer.read_u16()?;
- let a = (flags >> 8) as u8;
- let b = (flags & 0xFF) as u8;
- self.recursion_desired = (a & (1 << 0)) > 0;
- self.truncated_message = (a & (1 << 1)) > 0;
- self.authoritative_answer = (a & (1 << 2)) > 0;
- self.opcode = (a >> 3) & 0x0F;
- self.response = (a & (1 << 7)) > 0;
-
- self.rescode = ResultCode::from_num(b & 0x0F);
- self.checking_disabled = (b & (1 << 4)) > 0;
- self.authed_data = (b & (1 << 5)) > 0;
- self.z = (b & (1 << 6)) > 0;
- self.recursion_available = (b & (1 << 7)) > 0;
-
- self.questions = buffer.read_u16()?;
- self.answers = buffer.read_u16()?;
- self.authoritative_entries = buffer.read_u16()?;
- self.resource_entries = buffer.read_u16()?;
-
- // Return the constant header size
- Ok(())
- }
-
- pub fn write(&self, buffer: &mut PacketBuffer) -> Result<()> {
- buffer.write_u16(self.id)?;
-
- buffer.write_u8(
- (self.recursion_desired as u8)
- | ((self.truncated_message as u8) << 1)
- | ((self.authoritative_answer as u8) << 2)
- | (self.opcode << 3)
- | ((self.response as u8) << 7),
- )?;
-
- buffer.write_u8(
- (self.rescode as u8)
- | ((self.checking_disabled as u8) << 4)
- | ((self.authed_data as u8) << 5)
- | ((self.z as u8) << 6)
- | ((self.recursion_available as u8) << 7),
- )?;
-
- buffer.write_u16(self.questions)?;
- buffer.write_u16(self.answers)?;
- buffer.write_u16(self.authoritative_entries)?;
- buffer.write_u16(self.resource_entries)?;
-
- Ok(())
- }
-}
diff --git a/src/dns/packet/mod.rs b/src/dns/packet/mod.rs
deleted file mode 100644
index 9873b94..0000000
--- a/src/dns/packet/mod.rs
+++ /dev/null
@@ -1,128 +0,0 @@
-use std::net::IpAddr;
-
-use self::{
- buffer::PacketBuffer, header::DnsHeader, query::QueryType, question::DnsQuestion,
- record::DnsRecord,
-};
-use crate::Result;
-
-pub mod buffer;
-pub mod header;
-pub mod query;
-pub mod question;
-pub mod record;
-pub mod result;
-
-#[derive(Clone, Debug)]
-pub struct Packet {
- pub header: DnsHeader,
- pub questions: Vec<DnsQuestion>,
- pub answers: Vec<DnsRecord>,
- pub authorities: Vec<DnsRecord>,
- pub resources: Vec<DnsRecord>,
-}
-
-impl Packet {
- pub fn new() -> Self {
- Self {
- header: DnsHeader::new(),
- questions: Vec::new(),
- answers: Vec::new(),
- authorities: Vec::new(),
- resources: Vec::new(),
- }
- }
-
- pub fn from_buffer(buffer: &mut PacketBuffer) -> Result<Self> {
- let mut result = Self::new();
- result.header.read(buffer)?;
-
- for _ in 0..result.header.questions {
- let mut question = DnsQuestion::new("".to_string(), QueryType::UNKNOWN(0));
- question.read(buffer)?;
- result.questions.push(question);
- }
-
- for _ in 0..result.header.answers {
- let rec = DnsRecord::read(buffer)?;
- result.answers.push(rec);
- }
- for _ in 0..result.header.authoritative_entries {
- let rec = DnsRecord::read(buffer)?;
- result.authorities.push(rec);
- }
- for _ in 0..result.header.resource_entries {
- let rec = DnsRecord::read(buffer)?;
- result.resources.push(rec);
- }
-
- Ok(result)
- }
-
- pub fn write(&mut self, buffer: &mut PacketBuffer) -> Result<()> {
- self.header.questions = self.questions.len() as u16;
- self.header.answers = self.answers.len() as u16;
- self.header.authoritative_entries = self.authorities.len() as u16;
- self.header.resource_entries = self.resources.len() as u16;
-
- self.header.write(buffer)?;
-
- for question in &self.questions {
- question.write(buffer)?;
- }
- for rec in &self.answers {
- rec.write(buffer)?;
- }
- for rec in &self.authorities {
- rec.write(buffer)?;
- }
- for rec in &self.resources {
- rec.write(buffer)?;
- }
-
- Ok(())
- }
-
- pub fn get_random_a(&self) -> Option<IpAddr> {
- self.answers
- .iter()
- .filter_map(|record| match record {
- DnsRecord::A { addr, .. } => Some(IpAddr::V4(*addr)),
- DnsRecord::AAAA { addr, .. } => Some(IpAddr::V6(*addr)),
- _ => None,
- })
- .next()
- }
-
- fn get_ns<'a>(&'a self, qname: &'a str) -> impl Iterator<Item = (&'a str, &'a str)> {
- self.authorities
- .iter()
- .filter_map(|record| match record {
- DnsRecord::NS { domain, host, .. } => Some((domain.as_str(), host.as_str())),
- _ => None,
- })
- .filter(move |(domain, _)| qname.ends_with(*domain))
- }
-
- pub fn get_resolved_ns(&self, qname: &str) -> Option<IpAddr> {
- self.get_ns(qname)
- .flat_map(|(_, host)| {
- self.resources
- .iter()
- .filter_map(move |record| match record {
- DnsRecord::A { domain, addr, .. } if domain == host => {
- Some(IpAddr::V4(*addr))
- }
- DnsRecord::AAAA { domain, addr, .. } if domain == host => {
- Some(IpAddr::V6(*addr))
- }
- _ => None,
- })
- })
- .next()
- }
-
- pub fn get_unresolved_ns<'a>(&'a self, qname: &'a str) -> Option<&'a str> {
- self.get_ns(qname).map(|(_, host)| host).next()
- }
-}
diff --git a/src/dns/packet/query.rs b/src/dns/packet/query.rs
deleted file mode 100644
index 732b9b2..0000000
--- a/src/dns/packet/query.rs
+++ /dev/null
@@ -1,78 +0,0 @@
-#[derive(PartialEq, Eq, Debug, Clone, Hash, Copy)]
-pub enum QueryType {
- UNKNOWN(u16),
- A, // 1
- NS, // 2
- CNAME, // 5
- SOA, // 6
- PTR, // 12
- MX, // 15
- TXT, // 16
- AAAA, // 28
- SRV, // 33
- OPT, // 41
- CAA, // 257
- AR, // 1000
- AAAAR, // 1001
-}
-
-impl QueryType {
- pub fn to_num(&self) -> u16 {
- match *self {
- Self::UNKNOWN(x) => x,
- Self::A => 1,
- Self::NS => 2,
- Self::CNAME => 5,
- Self::SOA => 6,
- Self::PTR => 12,
- Self::MX => 15,
- Self::TXT => 16,
- Self::AAAA => 28,
- Self::SRV => 33,
- Self::OPT => 41,
- Self::CAA => 257,
- Self::AR => 1000,
- Self::AAAAR => 1001,
- }
- }
-
- pub fn from_num(num: u16) -> Self {
- match num {
- 1 => Self::A,
- 2 => Self::NS,
- 5 => Self::CNAME,
- 6 => Self::SOA,
- 12 => Self::PTR,
- 15 => Self::MX,
- 16 => Self::TXT,
- 28 => Self::AAAA,
- 33 => Self::SRV,
- 41 => Self::OPT,
- 257 => Self::CAA,
- 1000 => Self::AR,
- 1001 => Self::AAAAR,
- _ => Self::UNKNOWN(num),
- }
- }
-
- pub fn allowed_actions(&self) -> (bool, bool) {
- // 0. duplicates allowed
- // 1. allowed to be created by database
- match self {
- QueryType::UNKNOWN(_) => (false, false),
- QueryType::A => (true, true),
- QueryType::NS => (false, true),
- QueryType::CNAME => (false, true),
- QueryType::SOA => (false, false),
- QueryType::PTR => (false, true),
- QueryType::MX => (false, true),
- QueryType::TXT => (true, true),
- QueryType::AAAA => (true, true),
- QueryType::SRV => (false, true),
- QueryType::OPT => (false, false),
- QueryType::CAA => (false, true),
- QueryType::AR => (false, true),
- QueryType::AAAAR => (false, true),
- }
- }
-}
diff --git a/src/dns/packet/question.rs b/src/dns/packet/question.rs
deleted file mode 100644
index 9042e1c..0000000
--- a/src/dns/packet/question.rs
+++ /dev/null
@@ -1,31 +0,0 @@
-use super::{buffer::PacketBuffer, query::QueryType, Result};
-
-#[derive(Debug, Clone, PartialEq, Eq, Hash)]
-pub struct DnsQuestion {
- pub name: String,
- pub qtype: QueryType,
-}
-
-impl DnsQuestion {
- pub fn new(name: String, qtype: QueryType) -> Self {
- Self { name, qtype }
- }
-
- pub fn read(&mut self, buffer: &mut PacketBuffer) -> Result<()> {
- buffer.read_qname(&mut self.name)?;
- self.qtype = QueryType::from_num(buffer.read_u16()?); // qtype
- let _ = buffer.read_u16()?; // class
-
- Ok(())
- }
-
- pub fn write(&self, buffer: &mut PacketBuffer) -> Result<()> {
- buffer.write_qname(&self.name)?;
-
- let typenum = self.qtype.to_num();
- buffer.write_u16(typenum)?;
- buffer.write_u16(1)?;
-
- Ok(())
- }
-}
diff --git a/src/dns/packet/record.rs b/src/dns/packet/record.rs
deleted file mode 100644
index 88008f0..0000000
--- a/src/dns/packet/record.rs
+++ /dev/null
@@ -1,544 +0,0 @@
-use std::net::{Ipv4Addr, Ipv6Addr};
-
-use rand::RngCore;
-use serde::{Deserialize, Serialize};
-use tracing::{trace, warn};
-
-use super::{buffer::PacketBuffer, query::QueryType, Result};
-
-#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)]
-pub enum DnsRecord {
- UNKNOWN {
- domain: String,
- qtype: u16,
- data_len: u16,
- ttl: u32,
- }, // 0
- A {
- domain: String,
- addr: Ipv4Addr,
- ttl: u32,
- }, // 1
- NS {
- domain: String,
- host: String,
- ttl: u32,
- }, // 2
- CNAME {
- domain: String,
- host: String,
- ttl: u32,
- }, // 5
- SOA {
- domain: String,
- mname: String,
- nname: String,
- serial: u32,
- refresh: u32,
- retry: u32,
- expire: u32,
- minimum: u32,
- ttl: u32,
- }, // 6
- PTR {
- domain: String,
- pointer: String,
- ttl: u32,
- }, // 12
- MX {
- domain: String,
- priority: u16,
- host: String,
- ttl: u32,
- }, // 15
- TXT {
- domain: String,
- text: Vec<String>,
- ttl: u32,
- }, //16
- AAAA {
- domain: String,
- addr: Ipv6Addr,
- ttl: u32,
- }, // 28
- SRV {
- domain: String,
- priority: u16,
- weight: u16,
- port: u16,
- target: String,
- ttl: u32,
- }, // 33
- CAA {
- domain: String,
- flags: u8,
- length: u8,
- tag: String,
- value: String,
- ttl: u32,
- }, // 257
- AR {
- domain: String,
- ttl: u32,
- },
- AAAAR {
- domain: String,
- ttl: u32,
- },
-}
-
-impl DnsRecord {
- pub fn read(buffer: &mut PacketBuffer) -> Result<Self> {
- let mut domain = String::new();
- buffer.read_qname(&mut domain)?;
-
- let qtype_num = buffer.read_u16()?;
- let qtype = QueryType::from_num(qtype_num);
- let _ = buffer.read_u16()?;
- let ttl = buffer.read_u32()?;
- let data_len = buffer.read_u16()?;
-
- trace!("Reading DNS Record TYPE: {:?}", qtype);
-
- let header_pos = buffer.pos();
-
- match qtype {
- QueryType::A => {
- let raw_addr = buffer.read_u32()?;
- let addr = Ipv4Addr::new(
- ((raw_addr >> 24) & 0xFF) as u8,
- ((raw_addr >> 16) & 0xFF) as u8,
- ((raw_addr >> 8) & 0xFF) as u8,
- (raw_addr & 0xFF) as u8,
- );
-
- Ok(Self::A { domain, addr, ttl })
- }
- QueryType::AAAA => {
- let raw_addr1 = buffer.read_u32()?;
- let raw_addr2 = buffer.read_u32()?;
- let raw_addr3 = buffer.read_u32()?;
- let raw_addr4 = buffer.read_u32()?;
- let addr = Ipv6Addr::new(
- ((raw_addr1 >> 16) & 0xFFFF) as u16,
- (raw_addr1 & 0xFFFF) as u16,
- ((raw_addr2 >> 16) & 0xFFFF) as u16,
- (raw_addr2 & 0xFFFF) as u16,
- ((raw_addr3 >> 16) & 0xFFFF) as u16,
- (raw_addr3 & 0xFFFF) as u16,
- ((raw_addr4 >> 16) & 0xFFFF) as u16,
- (raw_addr4 & 0xFFFF) as u16,
- );
-
- Ok(Self::AAAA { domain, addr, ttl })
- }
- QueryType::NS => {
- let mut ns = String::new();
- buffer.read_qname(&mut ns)?;
-
- Ok(Self::NS {
- domain,
- host: ns,
- ttl,
- })
- }
- QueryType::CNAME => {
- let mut cname = String::new();
- buffer.read_qname(&mut cname)?;
-
- Ok(Self::CNAME {
- domain,
- host: cname,
- ttl,
- })
- }
- QueryType::SOA => {
- let mut mname = String::new();
- buffer.read_qname(&mut mname)?;
-
- let mut nname = String::new();
- buffer.read_qname(&mut nname)?;
-
- let serial = buffer.read_u32()?;
- let refresh = buffer.read_u32()?;
- let retry = buffer.read_u32()?;
- let expire = buffer.read_u32()?;
- let minimum = buffer.read_u32()?;
-
- Ok(Self::SOA {
- domain,
- mname,
- nname,
- serial,
- refresh,
- retry,
- expire,
- minimum,
- ttl,
- })
- }
- QueryType::PTR => {
- let mut pointer = String::new();
- buffer.read_qname(&mut pointer)?;
-
- Ok(Self::PTR {
- domain,
- pointer,
- ttl,
- })
- }
- QueryType::MX => {
- let priority = buffer.read_u16()?;
- let mut mx = String::new();
- buffer.read_qname(&mut mx)?;
-
- Ok(Self::MX {
- domain,
- priority,
- host: mx,
- ttl,
- })
- }
- QueryType::TXT => {
- let mut text = Vec::new();
-
- loop {
- let mut s = String::new();
- buffer.read_string(&mut s)?;
-
- if s.len() == 0 {
- break;
- } else {
- text.push(s);
- }
- }
-
- Ok(Self::TXT { domain, text, ttl })
- }
- QueryType::SRV => {
- let priority = buffer.read_u16()?;
- let weight = buffer.read_u16()?;
- let port = buffer.read_u16()?;
-
- let mut target = String::new();
- buffer.read_qname(&mut target)?;
-
- Ok(Self::SRV {
- domain,
- priority,
- weight,
- port,
- target,
- ttl,
- })
- }
- QueryType::CAA => {
- let flags = buffer.read()?;
- let length = buffer.read()?;
-
- let mut tag = String::new();
- buffer.read_string_n(&mut tag, length)?;
-
- let value_len = (data_len as usize) + header_pos - buffer.pos;
- let mut value = String::new();
- buffer.read_string_n(&mut value, value_len as u8)?;
-
- Ok(Self::CAA {
- domain,
- flags,
- length,
- tag,
- value,
- ttl,
- })
- }
- QueryType::UNKNOWN(_) | _ => {
- buffer.step(data_len as usize)?;
-
- Ok(Self::UNKNOWN {
- domain,
- qtype: qtype_num,
- data_len,
- ttl,
- })
- }
- }
- }
-
- pub fn write(&self, buffer: &mut PacketBuffer) -> Result<usize> {
- let start_pos = buffer.pos();
-
- trace!("Writing DNS Record {:?}", self);
-
- match *self {
- Self::A {
- ref domain,
- ref addr,
- ttl,
- } => {
- buffer.write_qname(domain)?;
- buffer.write_u16(QueryType::A.to_num())?;
- buffer.write_u16(1)?;
- buffer.write_u32(ttl)?;
- buffer.write_u16(4)?;
-
- let octets = addr.octets();
- buffer.write_u8(octets[0])?;
- buffer.write_u8(octets[1])?;
- buffer.write_u8(octets[2])?;
- buffer.write_u8(octets[3])?;
- }
- Self::NS {
- ref domain,
- ref host,
- ttl,
- } => {
- buffer.write_qname(domain)?;
- buffer.write_u16(QueryType::NS.to_num())?;
- buffer.write_u16(1)?;
- buffer.write_u32(ttl)?;
-
- let pos = buffer.pos();
- buffer.write_u16(0)?;
-
- buffer.write_qname(host)?;
-
- let size = buffer.pos() - (pos + 2);
- buffer.set_u16(pos, size as u16)?;
- }
- Self::CNAME {
- ref domain,
- ref host,
- ttl,
- } => {
- buffer.write_qname(domain)?;
- buffer.write_u16(QueryType::CNAME.to_num())?;
- buffer.write_u16(1)?;
- buffer.write_u32(ttl)?;
-
- let pos = buffer.pos();
- buffer.write_u16(0)?;
-
- buffer.write_qname(host)?;
-
- let size = buffer.pos() - (pos + 2);
- buffer.set_u16(pos, size as u16)?;
- }
- Self::SOA {
- ref domain,
- ref mname,
- ref nname,
- serial,
- refresh,
- retry,
- expire,
- minimum,
- ttl,
- } => {
- buffer.write_qname(domain)?;
- buffer.write_u16(QueryType::SOA.to_num())?;
- buffer.write_u16(1)?;
- buffer.write_u32(ttl)?;
-
- let pos = buffer.pos();
- buffer.write_u16(0)?;
-
- buffer.write_qname(mname)?;
- buffer.write_qname(nname)?;
- buffer.write_u32(serial)?;
- buffer.write_u32(refresh)?;
- buffer.write_u32(retry)?;
- buffer.write_u32(expire)?;
- buffer.write_u32(minimum)?;
-
- let size = buffer.pos() - (pos + 2);
- buffer.set_u16(pos, size as u16)?;
- }
- Self::PTR {
- ref domain,
- ref pointer,
- ttl,
- } => {
- buffer.write_qname(domain)?;
- buffer.write_u16(QueryType::NS.to_num())?;
- buffer.write_u16(1)?;
- buffer.write_u32(ttl)?;
-
- let pos = buffer.pos();
- buffer.write_u16(0)?;
-
- buffer.write_qname(&pointer)?;
-
- let size = buffer.pos() - (pos + 2);
- buffer.set_u16(pos, size as u16)?;
- }
- Self::MX {
- ref domain,
- priority,
- ref host,
- ttl,
- } => {
- buffer.write_qname(domain)?;
- buffer.write_u16(QueryType::MX.to_num())?;
- buffer.write_u16(1)?;
- buffer.write_u32(ttl)?;
-
- let pos = buffer.pos();
- buffer.write_u16(0)?;
-
- buffer.write_u16(priority)?;
- buffer.write_qname(host)?;
-
- let size = buffer.pos() - (pos + 2);
- buffer.set_u16(pos, size as u16)?;
- }
- Self::TXT {
- ref domain,
- ref text,
- ttl,
- } => {
- buffer.write_qname(&domain)?;
- buffer.write_u16(QueryType::TXT.to_num())?;
- buffer.write_u16(1)?;
- buffer.write_u32(ttl)?;
-
- let pos = buffer.pos();
- buffer.write_u16(0)?;
-
- if text.is_empty() {
- return Ok(buffer.pos() - start_pos);
- }
-
- for s in text {
- buffer.write_u8(s.len() as u8)?;
- buffer.write_string(&s)?;
- }
- let size = buffer.pos() - (pos + 2);
- buffer.set_u16(pos, size as u16)?;
- }
- Self::AAAA {
- ref domain,
- ref addr,
- ttl,
- } => {
- buffer.write_qname(domain)?;
- buffer.write_u16(QueryType::AAAA.to_num())?;
- buffer.write_u16(1)?;
- buffer.write_u32(ttl)?;
- buffer.write_u16(16)?;
-
- for octet in &addr.segments() {
- buffer.write_u16(*octet)?;
- }
- }
- Self::SRV {
- ref domain,
- priority,
- weight,
- port,
- ref target,
- ttl,
- } => {
- buffer.write_qname(domain)?;
- buffer.write_u16(QueryType::SRV.to_num())?;
- buffer.write_u16(1)?;
- buffer.write_u32(ttl)?;
-
- let pos = buffer.pos();
- buffer.write_u16(0)?;
-
- buffer.write_u16(priority)?;
- buffer.write_u16(weight)?;
- buffer.write_u16(port)?;
- buffer.write_qname(target)?;
-
- let size = buffer.pos() - (pos + 2);
- buffer.set_u16(pos, size as u16)?;
- }
- Self::CAA {
- ref domain,
- flags,
- length,
- ref tag,
- ref value,
- ttl,
- } => {
- buffer.write_qname(domain)?;
- buffer.write_u16(QueryType::CAA.to_num())?;
- buffer.write_u16(1)?;
- buffer.write_u32(ttl)?;
-
- let pos = buffer.pos();
- buffer.write_u16(0)?;
-
- buffer.write_u8(flags)?;
- buffer.write_u8(length)?;
- buffer.write_string(tag)?;
- buffer.write_string(value)?;
-
- let size = buffer.pos() - (pos + 2);
- buffer.set_u16(pos, size as u16)?;
- }
- Self::AR { ref domain, ttl } => {
- buffer.write_qname(domain)?;
- buffer.write_u16(QueryType::A.to_num())?;
- buffer.write_u16(1)?;
- buffer.write_u32(ttl)?;
- buffer.write_u16(4)?;
-
- let mut rand = rand::thread_rng();
- buffer.write_u32(rand.next_u32())?;
- }
- Self::AAAAR { ref domain, ttl } => {
- buffer.write_qname(domain)?;
- buffer.write_u16(QueryType::A.to_num())?;
- buffer.write_u16(1)?;
- buffer.write_u32(ttl)?;
- buffer.write_u16(4)?;
-
- let mut rand = rand::thread_rng();
- buffer.write_u32(rand.next_u32())?;
- buffer.write_u32(rand.next_u32())?;
- buffer.write_u32(rand.next_u32())?;
- buffer.write_u32(rand.next_u32())?;
- }
- Self::UNKNOWN { .. } => {
- warn!("Skipping record: {self:?}");
- }
- }
-
- Ok(buffer.pos() - start_pos)
- }
-
- pub fn get_domain(&self) -> String {
- self.get_shared_domain().0
- }
-
- pub fn get_qtype(&self) -> QueryType {
- self.get_shared_domain().1
- }
-
- pub fn get_ttl(&self) -> u32 {
- self.get_shared_domain().2
- }
-
- fn get_shared_domain(&self) -> (String, QueryType, u32) {
- match self {
- DnsRecord::UNKNOWN {
- domain, ttl, qtype, ..
- } => (domain.clone(), QueryType::UNKNOWN(*qtype), *ttl),
- DnsRecord::AAAA { domain, ttl, .. } => (domain.clone(), QueryType::AAAA, *ttl),
- DnsRecord::A { domain, ttl, .. } => (domain.clone(), QueryType::A, *ttl),
- DnsRecord::NS { domain, ttl, .. } => (domain.clone(), QueryType::NS, *ttl),
- DnsRecord::CNAME { domain, ttl, .. } => (domain.clone(), QueryType::CNAME, *ttl),
- DnsRecord::SOA { domain, ttl, .. } => (domain.clone(), QueryType::SOA, *ttl),
- DnsRecord::PTR { domain, ttl, .. } => (domain.clone(), QueryType::PTR, *ttl),
- DnsRecord::MX { domain, ttl, .. } => (domain.clone(), QueryType::MX, *ttl),
- DnsRecord::TXT { domain, ttl, .. } => (domain.clone(), QueryType::TXT, *ttl),
- DnsRecord::SRV { domain, ttl, .. } => (domain.clone(), QueryType::SRV, *ttl),
- DnsRecord::CAA { domain, ttl, .. } => (domain.clone(), QueryType::CAA, *ttl),
- DnsRecord::AR { domain, ttl, .. } => (domain.clone(), QueryType::AR, *ttl),
- DnsRecord::AAAAR { domain, ttl, .. } => (domain.clone(), QueryType::AAAAR, *ttl),
- }
- }
-}
diff --git a/src/dns/packet/result.rs b/src/dns/packet/result.rs
deleted file mode 100644
index 41c8ba9..0000000
--- a/src/dns/packet/result.rs
+++ /dev/null
@@ -1,22 +0,0 @@
-#[derive(Copy, Clone, Debug, PartialEq, Eq)]
-pub enum ResultCode {
- NOERROR = 0,
- FORMERR = 1,
- SERVFAIL = 2,
- NXDOMAIN = 3,
- NOTIMP = 4,
- REFUSED = 5,
-}
-
-impl ResultCode {
- pub fn from_num(num: u8) -> Self {
- match num {
- 1 => Self::FORMERR,
- 2 => Self::SERVFAIL,
- 3 => Self::NXDOMAIN,
- 4 => Self::NOTIMP,
- 5 => Self::REFUSED,
- 0 | _ => Self::NOERROR,
- }
- }
-}
diff --git a/src/dns/resolver.rs b/src/dns/resolver.rs
deleted file mode 100644
index 18b5bba..0000000
--- a/src/dns/resolver.rs
+++ /dev/null
@@ -1,230 +0,0 @@
-use super::binding::Connection;
-use super::packet::{query::QueryType, question::DnsQuestion, result::ResultCode, Packet};
-use crate::Result;
-use crate::{config::Config, database::Database, get_time};
-use async_recursion::async_recursion;
-use moka::future::Cache;
-use std::{net::IpAddr, sync::Arc, time::Duration};
-use tracing::{error, trace};
-
-pub struct Resolver {
- request_id: u16,
- connection: Connection,
- config: Arc<Config>,
- database: Arc<Database>,
- cache: Cache<DnsQuestion, (Packet, u64)>,
-}
-
-impl Resolver {
- pub fn new(
- request_id: u16,
- connection: Connection,
- config: Arc<Config>,
- database: Arc<Database>,
- cache: Cache<DnsQuestion, (Packet, u64)>,
- ) -> Self {
- Self {
- request_id,
- connection,
- config,
- database,
- cache,
- }
- }
-
- async fn lookup_database(&self, question: &DnsQuestion) -> Option<Packet> {
- let records = match self
- .database
- .get_records(&question.name, question.qtype)
- .await
- {
- Ok(record) => record,
- Err(err) => {
- error!("{err}");
- return None;
- }
- };
-
- if records.is_empty() {
- return None;
- }
-
- let mut packet = Packet::new();
-
- packet.header.id = self.request_id;
- packet.header.questions = 1;
- packet.header.answers = records.len() as u16;
- packet.header.recursion_desired = true;
- packet
- .questions
- .push(DnsQuestion::new(question.name.to_string(), question.qtype));
-
- for record in records {
- packet.answers.push(record);
- }
-
- trace!(
- "Found stored value for {:?} {}",
- question.qtype,
- question.name
- );
-
- Some(packet)
- }
-
- async fn lookup_cache(&self, question: &DnsQuestion) -> Option<Packet> {
- let Some((packet, date)) = self.cache.get(&question) else {
- return None
- };
-
- let now = get_time();
- let diff = Duration::from_millis(now - date).as_secs() as u32;
-
- for answer in &packet.answers {
- let ttl = answer.get_ttl();
- if diff > ttl {
- self.cache.invalidate(&question).await;
- return None;
- }
- }
-
- trace!(
- "Found cached value for {:?} {}",
- question.qtype,
- question.name
- );
-
- Some(packet)
- }
-
- async fn lookup_fallback(&self, question: &DnsQuestion, server: (IpAddr, u16)) -> Packet {
- let mut packet = Packet::new();
-
- packet.header.id = self.request_id;
- packet.header.questions = 1;
- packet.header.recursion_desired = true;
- packet
- .questions
- .push(DnsQuestion::new(question.name.to_string(), question.qtype));
-
- let packet = match self.connection.request_packet(packet, server).await {
- Ok(packet) => packet,
- Err(e) => {
- error!("Failed to complete nameserver request: {e}");
- let mut packet = Packet::new();
- packet.header.rescode = ResultCode::SERVFAIL;
- packet
- }
- };
-
- packet
- }
-
- async fn lookup(&self, question: &DnsQuestion, server: (IpAddr, u16)) -> Packet {
- if let Some(packet) = self.lookup_cache(question).await {
- return packet;
- };
-
- if let Some(packet) = self.lookup_database(question).await {
- return packet;
- };
-
- trace!(
- "Attempting lookup of {:?} {} with ns {}",
- question.qtype,
- question.name,
- server.0
- );
-
- self.lookup_fallback(question, server).await
- }
-
- #[async_recursion]
- async fn recursive_lookup(&mut self, qname: &str, qtype: QueryType) -> Packet {
- let question = DnsQuestion::new(qname.to_string(), qtype);
- let mut ns = self.config.dns_fallback.clone();
-
- loop {
- let ns_copy = ns;
-
- let server = (ns_copy, 53);
- let response = self.lookup(&question, server).await;
-
- if !response.answers.is_empty() && response.header.rescode == ResultCode::NOERROR {
- self.cache
- .insert(question, (response.clone(), get_time()))
- .await;
- return response;
- }
-
- if response.header.rescode == ResultCode::NXDOMAIN {
- self.cache
- .insert(question, (response.clone(), get_time()))
- .await;
- return response;
- }
-
- if let Some(new_ns) = response.get_resolved_ns(qname) {
- ns = new_ns;
- continue;
- }
-
- let new_ns_name = match response.get_unresolved_ns(qname) {
- Some(x) => x,
- None => {
- self.cache
- .insert(question, (response.clone(), get_time()))
- .await;
- return response;
- }
- };
-
- let recursive_response = self.recursive_lookup(new_ns_name, QueryType::A).await;
-
- if let Some(new_ns) = recursive_response.get_random_a() {
- ns = new_ns;
- } else {
- self.cache
- .insert(question, (response.clone(), get_time()))
- .await;
- return response;
- }
- }
- }
-
- pub async fn handle_query(mut self) -> Result<()> {
- let mut request = self.connection.read_packet().await?;
-
- let mut packet = Packet::new();
- packet.header.id = request.header.id;
- packet.header.recursion_desired = true;
- packet.header.recursion_available = true;
- packet.header.response = true;
-
- if let Some(question) = request.questions.pop() {
- trace!("Received query: {question:?}");
-
- let result = self.recursive_lookup(&question.name, question.qtype).await;
- packet.questions.push(question.clone());
- packet.header.rescode = result.header.rescode;
-
- for rec in result.answers {
- trace!("Answer: {rec:?}");
- packet.answers.push(rec);
- }
- for rec in result.authorities {
- trace!("Authority: {rec:?}");
- packet.authorities.push(rec);
- }
- for rec in result.resources {
- trace!("Resource: {rec:?}");
- packet.resources.push(rec);
- }
- } else {
- packet.header.rescode = ResultCode::FORMERR;
- }
-
- self.connection.write_packet(packet).await?;
- Ok(())
- }
-}
diff --git a/src/dns/server.rs b/src/dns/server.rs
deleted file mode 100644
index 65d15df..0000000
--- a/src/dns/server.rs
+++ /dev/null
@@ -1,85 +0,0 @@
-use super::{
- binding::Binding,
- packet::{question::DnsQuestion, Packet},
- resolver::Resolver,
-};
-use crate::{config::Config, database::Database, Result};
-use moka::future::Cache;
-use std::{net::SocketAddr, sync::Arc, time::Duration};
-use tokio::task::JoinHandle;
-use tracing::{error, info};
-
-pub struct DnsServer {
- addr: SocketAddr,
- config: Arc<Config>,
- database: Arc<Database>,
- cache: Cache<DnsQuestion, (Packet, u64)>,
-}
-
-impl DnsServer {
- pub async fn new(config: Config, database: Database) -> Result<Self> {
- let addr = format!("[::]:{}", config.dns_port).parse::<SocketAddr>()?;
- let cache = Cache::builder()
- .time_to_live(Duration::from_secs(60 * 60))
- .max_capacity(config.dns_cache_size)
- .build();
-
- info!("Created DNS cache with size of {}", config.dns_cache_size);
-
- Ok(Self {
- addr,
- config: Arc::new(config),
- database: Arc::new(database),
- cache,
- })
- }
-
- pub async fn run(&self) -> Result<(JoinHandle<()>, JoinHandle<()>)> {
- let tcp = Binding::tcp(self.addr).await?;
- let tcp_handle = self.listen(tcp);
-
- let udp = Binding::udp(self.addr).await?;
- let udp_handle = self.listen(udp);
-
- info!(
- "Fallback DNS Server is set to: {:?}",
- self.config.dns_fallback
- );
- info!(
- "Listening for TCP and UDP traffic on [::]:{}",
- self.config.dns_port
- );
-
- Ok((udp_handle, tcp_handle))
- }
-
- fn listen(&self, mut binding: Binding) -> JoinHandle<()> {
- let config = self.config.clone();
- let database = self.database.clone();
- let cache = self.cache.clone();
- tokio::spawn(async move {
- let mut id = 0;
- loop {
- let Ok(connection) = binding.connect().await else { continue };
- info!("Received request on {}", binding.name());
-
- let resolver = Resolver::new(
- id,
- connection,
- config.clone(),
- database.clone(),
- cache.clone(),
- );
-
- let name = binding.name().to_string();
- tokio::spawn(async move {
- if let Err(err) = resolver.handle_query().await {
- error!("{} request {} failed: {:?}", name, id, err);
- };
- });
-
- id += 1;
- }
- })
- }
-}