summaryrefslogtreecommitdiff
path: root/src/server
diff options
context:
space:
mode:
Diffstat (limited to 'src/server')
-rw-r--r--src/server/binding.rs150
-rw-r--r--src/server/mod.rs3
-rw-r--r--src/server/resolver.rs165
-rw-r--r--src/server/server.rs73
4 files changed, 391 insertions, 0 deletions
diff --git a/src/server/binding.rs b/src/server/binding.rs
new file mode 100644
index 0000000..1c69651
--- /dev/null
+++ b/src/server/binding.rs
@@ -0,0 +1,150 @@
+use std::{
+ net::{IpAddr, SocketAddr},
+ sync::Arc,
+};
+
+use crate::packet::{buffer::PacketBuffer, Packet, Result};
+use tokio::{
+ io::{AsyncReadExt, AsyncWriteExt},
+ net::{TcpListener, TcpStream, UdpSocket},
+};
+use tracing::trace;
+
+pub enum Binding {
+ UDP(Arc<UdpSocket>),
+ TCP(TcpListener),
+}
+
+impl Binding {
+ pub async fn udp(addr: SocketAddr) -> Result<Self> {
+ let socket = UdpSocket::bind(addr).await?;
+ Ok(Self::UDP(Arc::new(socket)))
+ }
+
+ pub async fn tcp(addr: SocketAddr) -> Result<Self> {
+ let socket = TcpListener::bind(addr).await?;
+ Ok(Self::TCP(socket))
+ }
+
+ pub fn name(&self) -> &str {
+ match self {
+ Binding::UDP(_) => "UDP",
+ Binding::TCP(_) => "TCP",
+ }
+ }
+
+ pub async fn connect(&mut self) -> Result<Connection> {
+ match self {
+ Self::UDP(socket) => {
+ let mut buf = [0; 512];
+ let (_, addr) = socket.recv_from(&mut buf).await?;
+ Ok(Connection::UDP(socket.clone(), addr, buf))
+ }
+ Self::TCP(socket) => {
+ let (stream, _) = socket.accept().await?;
+ Ok(Connection::TCP(stream))
+ }
+ }
+ }
+}
+
+pub enum Connection {
+ UDP(Arc<UdpSocket>, SocketAddr, [u8; 512]),
+ TCP(TcpStream),
+}
+
+impl Connection {
+ pub async fn read_packet(&mut self) -> Result<Packet> {
+ let data = self.read().await?;
+ let mut packet_buffer = PacketBuffer::new(data);
+
+ let packet = Packet::from_buffer(&mut packet_buffer)?;
+ Ok(packet)
+ }
+
+ pub async fn write_packet(self, mut packet: Packet) -> Result<()> {
+ let mut packet_buffer = PacketBuffer::new(Vec::new());
+ packet.write(&mut packet_buffer)?;
+
+ self.write(packet_buffer.buf).await?;
+ Ok(())
+ }
+
+ pub async fn request_packet(&self, mut packet: Packet, dest: (IpAddr, u16)) -> Result<Packet> {
+ let mut packet_buffer = PacketBuffer::new(Vec::new());
+ packet.write(&mut packet_buffer)?;
+
+ let data = self.request(packet_buffer.buf, dest).await?;
+ let mut packet_buffer = PacketBuffer::new(data);
+
+ let packet = Packet::from_buffer(&mut packet_buffer)?;
+ Ok(packet)
+ }
+
+ async fn read(&mut self) -> Result<Vec<u8>> {
+ trace!("Reading DNS packet");
+ match self {
+ Self::UDP(_, _, src) => Ok(Vec::from(*src)),
+ Self::TCP(stream) => {
+ let size = stream.read_u16().await?;
+ let mut buf = Vec::with_capacity(size as usize);
+ stream.read_buf(&mut buf).await?;
+ Ok(buf)
+ }
+ }
+ }
+
+ async fn write(self, mut buf: Vec<u8>) -> Result<()> {
+ trace!("Returning DNS packet");
+ match self {
+ Self::UDP(socket, addr, _) => {
+ if buf.len() > 512 {
+ buf[2] = buf[2] | 0x03;
+ socket.send_to(&buf[0..512], addr).await?;
+ } else {
+ socket.send_to(&buf, addr).await?;
+ }
+ Ok(())
+ }
+ Self::TCP(mut stream) => {
+ stream.write_u16(buf.len() as u16).await?;
+ stream.write(&buf[0..buf.len()]).await?;
+ Ok(())
+ }
+ }
+ }
+
+ async fn request(&self, buf: Vec<u8>, dest: (IpAddr, u16)) -> Result<Vec<u8>> {
+ match self {
+ Self::UDP(_socket, _addr, _src) => {
+ let local_addr = "[::]:0".parse::<SocketAddr>()?;
+ let socket = UdpSocket::bind(local_addr).await?;
+ socket.send_to(&buf, dest).await?;
+
+ let mut buf = [0; 512];
+ socket.recv_from(&mut buf).await?;
+
+ Ok(Vec::from(buf))
+ }
+ Self::TCP(_stream) => {
+ let mut stream = TcpStream::connect(dest).await?;
+ stream.write_u16((buf.len()) as u16).await?;
+ stream.write_all(&buf[0..buf.len()]).await?;
+
+ stream.readable().await?;
+ let size = stream.read_u16().await?;
+ let mut buf = Vec::with_capacity(size as usize);
+ stream.read_buf(&mut buf).await?;
+
+ Ok(buf)
+ }
+ }
+ }
+
+ // fn pb(buf: &[u8]) {
+ // for i in 0..buf.len() {
+ // print!("{:02X?} ", buf[i]);
+ // }
+ // println!("");
+ // }
+}
diff --git a/src/server/mod.rs b/src/server/mod.rs
new file mode 100644
index 0000000..25076ef
--- /dev/null
+++ b/src/server/mod.rs
@@ -0,0 +1,3 @@
+mod binding;
+mod resolver;
+pub mod server;
diff --git a/src/server/resolver.rs b/src/server/resolver.rs
new file mode 100644
index 0000000..464620c
--- /dev/null
+++ b/src/server/resolver.rs
@@ -0,0 +1,165 @@
+use super::binding::Connection;
+use crate::{
+ config::Config,
+ packet::{
+ query::QueryType, question::DnsQuestion, result::ResultCode, Packet,
+ Result,
+ }, get_time,
+};
+use async_recursion::async_recursion;
+use moka::future::Cache;
+use std::{net::IpAddr, sync::Arc, time::Duration};
+use tracing::{error, trace};
+
+pub struct Resolver {
+ request_id: u16,
+ connection: Connection,
+ config: Arc<Config>,
+ cache: Cache<DnsQuestion, (Packet, u64)>,
+}
+
+impl Resolver {
+ pub fn new(
+ request_id: u16,
+ connection: Connection,
+ config: Arc<Config>,
+ cache: Cache<DnsQuestion, (Packet, u64)>,
+ ) -> Self {
+ Self {
+ request_id,
+ connection,
+ config,
+ cache,
+ }
+ }
+
+ async fn lookup_cache(&mut self, qname: &str, qtype: QueryType) -> Option<Packet> {
+ let question = DnsQuestion::new(qname.to_string(), qtype);
+ let Some((packet, date)) = self.cache.get(&question) else {
+ return None
+ };
+
+ let now = get_time();
+ let diff = Duration::from_millis(now - date).as_secs() as u32;
+
+ for answer in &packet.answers {
+ let ttl = answer.get_ttl();
+ if diff > ttl {
+ self.cache.invalidate(&question).await;
+ return None
+ }
+ }
+
+ trace!("Found cached value for {qtype:?} {qname}");
+
+ Some(packet)
+ }
+
+ async fn lookup(&mut self, qname: &str, qtype: QueryType, server: (IpAddr, u16)) -> Packet {
+ let mut packet = Packet::new();
+
+ packet.header.id = self.request_id;
+ packet.header.questions = 1;
+ packet.header.recursion_desired = true;
+ packet
+ .questions
+ .push(DnsQuestion::new(qname.to_string(), qtype));
+
+ let packet = match self.connection.request_packet(packet, server).await {
+ Ok(packet) => packet,
+ Err(e) => {
+ error!("Failed to complete nameserver request: {e}");
+ let mut packet = Packet::new();
+ packet.header.rescode = ResultCode::SERVFAIL;
+ packet
+ }
+ };
+
+ packet
+ }
+
+ #[async_recursion]
+ async fn recursive_lookup(&mut self, qname: &str, qtype: QueryType) -> Packet {
+ let question = DnsQuestion::new(qname.to_string(), qtype);
+ let mut ns = self.config.get_fallback_ns().clone();
+
+ if let Some(packet) = self.lookup_cache(qname, qtype).await { return packet }
+
+ loop {
+ trace!("Attempting lookup of {qtype:?} {qname} with ns {ns}");
+
+ let ns_copy = ns;
+
+ let server = (ns_copy, 53);
+ let response = self.lookup(qname, qtype, server).await;
+
+ if !response.answers.is_empty() && response.header.rescode == ResultCode::NOERROR {
+ self.cache.insert(question, (response.clone(), get_time())).await;
+ return response;
+ }
+
+ if response.header.rescode == ResultCode::NXDOMAIN {
+ self.cache.insert(question, (response.clone(), get_time())).await;
+ return response;
+ }
+
+ if let Some(new_ns) = response.get_resolved_ns(qname) {
+ ns = new_ns;
+ continue;
+ }
+
+ let new_ns_name = match response.get_unresolved_ns(qname) {
+ Some(x) => x,
+ None => {
+ self.cache.insert(question, (response.clone(), get_time())).await;
+ return response
+ },
+ };
+
+ let recursive_response = self.recursive_lookup(new_ns_name, QueryType::A).await;
+
+ if let Some(new_ns) = recursive_response.get_random_a() {
+ ns = new_ns;
+ } else {
+ self.cache.insert(question, (response.clone(), get_time())).await;
+ return response;
+ }
+ }
+ }
+
+ pub async fn handle_query(mut self) -> Result<()> {
+ let mut request = self.connection.read_packet().await?;
+
+ let mut packet = Packet::new();
+ packet.header.id = request.header.id;
+ packet.header.recursion_desired = true;
+ packet.header.recursion_available = true;
+ packet.header.response = true;
+
+ if let Some(question) = request.questions.pop() {
+ trace!("Received query: {question:?}");
+
+ let result = self.recursive_lookup(&question.name, question.qtype).await;
+ packet.questions.push(question.clone());
+ packet.header.rescode = result.header.rescode;
+
+ for rec in result.answers {
+ trace!("Answer: {rec:?}");
+ packet.answers.push(rec);
+ }
+ for rec in result.authorities {
+ trace!("Authority: {rec:?}");
+ packet.authorities.push(rec);
+ }
+ for rec in result.resources {
+ trace!("Resource: {rec:?}");
+ packet.resources.push(rec);
+ }
+ } else {
+ packet.header.rescode = ResultCode::FORMERR;
+ }
+
+ self.connection.write_packet(packet).await?;
+ Ok(())
+ }
+}
diff --git a/src/server/server.rs b/src/server/server.rs
new file mode 100644
index 0000000..e006bb1
--- /dev/null
+++ b/src/server/server.rs
@@ -0,0 +1,73 @@
+use moka::future::Cache;
+use std::net::SocketAddr;
+use std::sync::Arc;
+use std::time::Duration;
+use tokio::task::JoinHandle;
+use tracing::{error, info};
+
+use crate::config::Config;
+use crate::packet::question::DnsQuestion;
+use crate::packet::{Result, Packet};
+
+use super::binding::Binding;
+use super::resolver::Resolver;
+
+pub struct Server {
+ addr: SocketAddr,
+ config: Arc<Config>,
+ cache: Cache<DnsQuestion, (Packet, u64)>,
+}
+
+impl Server {
+ pub async fn new(config: Config) -> Result<Self> {
+ let addr = format!("[::]:{}", config.get_port()).parse::<SocketAddr>()?;
+ let cache = Cache::builder()
+ .time_to_live(Duration::from_secs(60 * 60))
+ .max_capacity(1_000)
+ .build();
+ Ok(Self {
+ addr,
+ config: Arc::new(config),
+ cache,
+ })
+ }
+
+ pub async fn run(&self) -> Result<()> {
+ let tcp = Binding::tcp(self.addr).await?;
+ let tcp_handle = self.listen(tcp);
+
+ let udp = Binding::udp(self.addr).await?;
+ let udp_handle = self.listen(udp);
+
+ info!("Fallback DNS Server is set to: {:?}", self.config.get_fallback_ns());
+ info!("Listening for TCP and UDP traffic on [::]:{}", self.config.get_port());
+
+ tokio::join!(tcp_handle)
+ .0
+ .expect("Failed to join tcp thread");
+ tokio::join!(udp_handle)
+ .0
+ .expect("Failed to join udp thread");
+ Ok(())
+ }
+
+ fn listen(&self, mut binding: Binding) -> JoinHandle<()> {
+ let config = self.config.clone();
+ let cache = self.cache.clone();
+ tokio::spawn(async move {
+ let mut id = 0;
+ loop {
+ let Ok(connection) = binding.connect().await else { continue };
+ info!("Received request on {}", binding.name());
+
+ let resolver = Resolver::new(id, connection, config.clone(), cache.clone());
+
+ if let Err(err) = resolver.handle_query().await {
+ error!("{} request {} failed: {:?}", binding.name(), id, err);
+ };
+
+ id += 1;
+ }
+ })
+ }
+}