From bb85374b79086cd8efde24d23a1bffeb97cae26b Mon Sep 17 00:00:00 2001 From: Tyler Murphy Date: Wed, 5 Apr 2023 23:08:09 -0400 Subject: new c version --- src/config.rs | 57 ----- src/database/mod.rs | 146 ------------ src/dns/binding.rs | 144 ------------ src/dns/mod.rs | 4 - src/dns/packet/buffer.rs | 227 ------------------- src/dns/packet/header.rs | 102 --------- src/dns/packet/mod.rs | 128 ----------- src/dns/packet/query.rs | 78 ------- src/dns/packet/question.rs | 31 --- src/dns/packet/record.rs | 544 --------------------------------------------- src/dns/packet/result.rs | 22 -- src/dns/resolver.rs | 230 ------------------- src/dns/server.rs | 85 ------- src/io/log.c | 49 ++++ src/io/log.h | 45 ++++ src/main.c | 32 +++ src/main.rs | 64 ------ src/packet/buffer.c | 240 ++++++++++++++++++++ src/packet/buffer.h | 51 +++++ src/packet/header.c | 93 ++++++++ src/packet/header.h | 41 ++++ src/packet/packet.c | 171 ++++++++++++++ src/packet/packet.h | 23 ++ src/packet/question.c | 94 ++++++++ src/packet/question.h | 15 ++ src/packet/record.c | 540 ++++++++++++++++++++++++++++++++++++++++++++ src/packet/record.h | 101 +++++++++ src/server/addr.c | 233 +++++++++++++++++++ src/server/addr.h | 69 ++++++ src/server/binding.c | 245 ++++++++++++++++++++ src/server/binding.h | 42 ++++ src/server/resolver.c | 166 ++++++++++++++ src/server/resolver.h | 6 + src/server/server.c | 100 +++++++++ src/server/server.h | 12 + src/web/api.rs | 156 ------------- src/web/extract.rs | 139 ------------ src/web/file.rs | 31 --- src/web/http.rs | 50 ----- src/web/mod.rs | 82 ------- src/web/pages.rs | 31 --- 41 files changed, 2368 insertions(+), 2351 deletions(-) delete mode 100644 src/config.rs delete mode 100644 src/database/mod.rs delete mode 100644 src/dns/binding.rs delete mode 100644 src/dns/mod.rs delete mode 100644 src/dns/packet/buffer.rs delete mode 100644 src/dns/packet/header.rs delete mode 100644 src/dns/packet/mod.rs delete mode 100644 src/dns/packet/query.rs delete mode 100644 src/dns/packet/question.rs delete mode 100644 src/dns/packet/record.rs delete mode 100644 src/dns/packet/result.rs delete mode 100644 src/dns/resolver.rs delete mode 100644 src/dns/server.rs create mode 100644 src/io/log.c create mode 100644 src/io/log.h create mode 100644 src/main.c delete mode 100644 src/main.rs create mode 100644 src/packet/buffer.c create mode 100644 src/packet/buffer.h create mode 100644 src/packet/header.c create mode 100644 src/packet/header.h create mode 100644 src/packet/packet.c create mode 100644 src/packet/packet.h create mode 100644 src/packet/question.c create mode 100644 src/packet/question.h create mode 100644 src/packet/record.c create mode 100644 src/packet/record.h create mode 100644 src/server/addr.c create mode 100644 src/server/addr.h create mode 100644 src/server/binding.c create mode 100644 src/server/binding.h create mode 100644 src/server/resolver.c create mode 100644 src/server/resolver.h create mode 100644 src/server/server.c create mode 100644 src/server/server.h delete mode 100644 src/web/api.rs delete mode 100644 src/web/extract.rs delete mode 100644 src/web/file.rs delete mode 100644 src/web/http.rs delete mode 100644 src/web/mod.rs delete mode 100644 src/web/pages.rs (limited to 'src') diff --git a/src/config.rs b/src/config.rs deleted file mode 100644 index 547e853..0000000 --- a/src/config.rs +++ /dev/null @@ -1,57 +0,0 @@ -use std::{env, net::IpAddr, str::FromStr, fmt::Display}; - -#[derive(Clone)] -pub struct Config { - pub dns_fallback: IpAddr, - pub dns_port: u16, - pub dns_cache_size: u64, - - pub db_host: String, - pub db_port: u16, - pub db_user: String, - pub db_pass: String, - - pub web_user: String, - pub web_pass: String, - pub web_port: u16, -} - -impl Config { - pub fn new() -> Self { - let dns_port = Self::get_var::("WRAPPER_DNS_PORT", 53); - let dns_fallback = Self::get_var::("WRAPPER_FALLBACK_DNS", [9, 9, 9, 9].into()); - let dns_cache_size = Self::get_var::("WRAPPER_CACHE_SIZE", 1000); - - let db_host = Self::get_var::("WRAPPER_DB_HOST", String::from("localhost")); - let db_port = Self::get_var::("WRAPPER_DB_PORT", 27017); - let db_user = Self::get_var::("WRAPPER_DB_USER", String::from("root")); - let db_pass = Self::get_var::("WRAPPER_DB_PASS", String::from("")); - - let web_user = Self::get_var::("WRAPPER_WEB_USER", String::from("admin")); - let web_pass = Self::get_var::("WRAPPER_WEB_PASS", String::from("wrapper")); - let web_port = Self::get_var::("WRAPPER_WEB_PORT", 80); - - Self { - dns_fallback, - dns_port, - dns_cache_size, - - db_host, - db_port, - db_user, - db_pass, - - web_user, - web_pass, - web_port, - } - } - - fn get_var(name: &str, default: T) -> T - where - T: FromStr + Display, - { - let env = env::var(name).unwrap_or(format!("{default}")); - env.parse::().unwrap_or(default) - } -} diff --git a/src/database/mod.rs b/src/database/mod.rs deleted file mode 100644 index 0d81dc3..0000000 --- a/src/database/mod.rs +++ /dev/null @@ -1,146 +0,0 @@ -use futures::TryStreamExt; -use mongodb::{ - bson::doc, - options::{ClientOptions, Credential, ServerAddress}, - Client, -}; -use serde::{Deserialize, Serialize}; -use tracing::info; - -use crate::{ - config::Config, - dns::packet::{query::QueryType, record::DnsRecord}, -}; - -use crate::Result; - -#[derive(Clone)] -pub struct Database { - client: Client, -} - -#[derive(Debug, Serialize, Deserialize)] -pub struct StoredRecord { - record: DnsRecord, - domain: String, - prefix: String, -} - -impl StoredRecord { - fn get_domain_parts(domain: &str) -> (String, String) { - let parts: Vec<&str> = domain.split(".").collect(); - let len = parts.len(); - if len == 1 { - (String::new(), String::from(parts[0])) - } else if len == 2 { - (String::new(), String::from(parts.join("."))) - } else { - ( - String::from(parts[0..len - 2].join(".")), - String::from(parts[len - 2..len].join(".")), - ) - } - } -} - -impl From for StoredRecord { - fn from(record: DnsRecord) -> Self { - let (prefix, domain) = Self::get_domain_parts(&record.get_domain()); - Self { - record, - domain, - prefix, - } - } -} - -impl Into for StoredRecord { - fn into(self) -> DnsRecord { - self.record - } -} - -impl Database { - pub async fn new(config: Config) -> Result { - let options = ClientOptions::builder() - .hosts(vec![ServerAddress::Tcp { - host: config.db_host, - port: Some(config.db_port), - }]) - .credential( - Credential::builder() - .username(config.db_user) - .password(config.db_pass) - .build(), - ) - .max_pool_size(100) - .app_name(String::from("wrapper")) - .build(); - - let client = Client::with_options(options)?; - - client - .database("wrapper") - .run_command(doc! {"ping": 1}, None) - .await?; - - info!("Connection to mongodb successfully"); - - Ok(Database { client }) - } - - pub async fn get_records(&self, domain: &str, qtype: QueryType) -> Result> { - let (prefix, domain) = StoredRecord::get_domain_parts(domain); - Ok(self - .get_domain(&domain) - .await? - .into_iter() - .filter(|r| r.prefix == prefix) - .filter(|r| { - let rqtype = r.record.get_qtype(); - if qtype == QueryType::A { - return rqtype == QueryType::A || rqtype == QueryType::AR; - } else if qtype == QueryType::AAAA { - return rqtype == QueryType::AAAA || rqtype == QueryType::AAAAR; - } else { - r.record.get_qtype() == qtype - } - }) - .map(|r| r.into()) - .collect()) - } - - pub async fn get_domain(&self, domain: &str) -> Result> { - let db = self.client.database("wrapper"); - let col = db.collection::(domain); - - let filter = doc! { "domain": domain }; - let mut cursor = col.find(filter, None).await?; - - let mut records = Vec::new(); - while let Some(record) = cursor.try_next().await? { - records.push(record); - } - - Ok(records) - } - - pub async fn add_record(&self, record: DnsRecord) -> Result<()> { - let record = StoredRecord::from(record); - let db = self.client.database("wrapper"); - let col = db.collection::(&record.domain); - col.insert_one(record, None).await?; - Ok(()) - } - - pub async fn get_domains(&self) -> Result> { - let db = self.client.database("wrapper"); - Ok(db.list_collection_names(None).await?) - } - - pub async fn delete_domain(&self, domain: String) -> Result<()> { - let db = self.client.database("wrapper"); - let col = db.collection::(&domain); - Ok(col.drop(None).await?) - } -} diff --git a/src/dns/binding.rs b/src/dns/binding.rs deleted file mode 100644 index 4c7e15f..0000000 --- a/src/dns/binding.rs +++ /dev/null @@ -1,144 +0,0 @@ -use std::{ - net::{IpAddr, SocketAddr}, - sync::Arc, -}; - -use super::packet::{buffer::PacketBuffer, Packet}; -use crate::Result; -use tokio::{ - io::{AsyncReadExt, AsyncWriteExt}, - net::{TcpListener, TcpStream, UdpSocket}, -}; -use tracing::trace; - -pub enum Binding { - UDP(Arc), - TCP(TcpListener), -} - -impl Binding { - pub async fn udp(addr: SocketAddr) -> Result { - let socket = UdpSocket::bind(addr).await?; - Ok(Self::UDP(Arc::new(socket))) - } - - pub async fn tcp(addr: SocketAddr) -> Result { - let socket = TcpListener::bind(addr).await?; - Ok(Self::TCP(socket)) - } - - pub fn name(&self) -> &str { - match self { - Binding::UDP(_) => "UDP", - Binding::TCP(_) => "TCP", - } - } - - pub async fn connect(&mut self) -> Result { - match self { - Self::UDP(socket) => { - let mut buf = [0; 512]; - let (_, addr) = socket.recv_from(&mut buf).await?; - Ok(Connection::UDP(socket.clone(), addr, buf)) - } - Self::TCP(socket) => { - let (stream, _) = socket.accept().await?; - Ok(Connection::TCP(stream)) - } - } - } -} - -pub enum Connection { - UDP(Arc, SocketAddr, [u8; 512]), - TCP(TcpStream), -} - -impl Connection { - pub async fn read_packet(&mut self) -> Result { - let data = self.read().await?; - let mut packet_buffer = PacketBuffer::new(data); - - let packet = Packet::from_buffer(&mut packet_buffer)?; - Ok(packet) - } - - pub async fn write_packet(self, mut packet: Packet) -> Result<()> { - let mut packet_buffer = PacketBuffer::new(Vec::new()); - packet.write(&mut packet_buffer)?; - - self.write(packet_buffer.buf).await?; - Ok(()) - } - - pub async fn request_packet(&self, mut packet: Packet, dest: (IpAddr, u16)) -> Result { - let mut packet_buffer = PacketBuffer::new(Vec::new()); - packet.write(&mut packet_buffer)?; - - let data = self.request(packet_buffer.buf, dest).await?; - let mut packet_buffer = PacketBuffer::new(data); - - let packet = Packet::from_buffer(&mut packet_buffer)?; - Ok(packet) - } - - async fn read(&mut self) -> Result> { - trace!("Reading DNS packet"); - match self { - Self::UDP(_, _, src) => Ok(Vec::from(*src)), - Self::TCP(stream) => { - let size = stream.read_u16().await?; - let mut buf = Vec::with_capacity(size as usize); - stream.read_buf(&mut buf).await?; - Ok(buf) - } - } - } - - async fn write(self, mut buf: Vec) -> Result<()> { - trace!("Returning DNS packet"); - match self { - Self::UDP(socket, addr, _) => { - if buf.len() > 512 { - buf[2] = buf[2] | 0x03; - socket.send_to(&buf[0..512], addr).await?; - } else { - socket.send_to(&buf, addr).await?; - } - Ok(()) - } - Self::TCP(mut stream) => { - stream.write_u16(buf.len() as u16).await?; - stream.write(&buf[0..buf.len()]).await?; - Ok(()) - } - } - } - - async fn request(&self, buf: Vec, dest: (IpAddr, u16)) -> Result> { - match self { - Self::UDP(_socket, _addr, _src) => { - let local_addr = "[::]:0".parse::()?; - let socket = UdpSocket::bind(local_addr).await?; - socket.send_to(&buf, dest).await?; - - let mut buf = [0; 512]; - socket.recv_from(&mut buf).await?; - - Ok(Vec::from(buf)) - } - Self::TCP(_stream) => { - let mut stream = TcpStream::connect(dest).await?; - stream.write_u16((buf.len()) as u16).await?; - stream.write_all(&buf[0..buf.len()]).await?; - - stream.readable().await?; - let size = stream.read_u16().await?; - let mut buf = Vec::with_capacity(size as usize); - stream.read_buf(&mut buf).await?; - - Ok(buf) - } - } - } -} diff --git a/src/dns/mod.rs b/src/dns/mod.rs deleted file mode 100644 index 6f1e59e..0000000 --- a/src/dns/mod.rs +++ /dev/null @@ -1,4 +0,0 @@ -mod binding; -pub mod packet; -mod resolver; -pub mod server; diff --git a/src/dns/packet/buffer.rs b/src/dns/packet/buffer.rs deleted file mode 100644 index 058156e..0000000 --- a/src/dns/packet/buffer.rs +++ /dev/null @@ -1,227 +0,0 @@ -use crate::Result; - -pub struct PacketBuffer { - pub buf: Vec, - pub pos: usize, - pub size: usize, -} - -impl PacketBuffer { - pub fn new(buf: Vec) -> Self { - Self { - size: buf.len(), - buf, - pos: 0, - } - } - - pub fn pos(&self) -> usize { - self.pos - } - - pub fn step(&mut self, steps: usize) -> Result<()> { - self.pos += steps; - - Ok(()) - } - - pub fn seek(&mut self, pos: usize) -> Result<()> { - self.pos = pos; - - Ok(()) - } - - pub fn read(&mut self) -> Result { - if self.pos >= self.size { - return Err("Tried to read past end of buffer".into()); - } - let res = self.buf[self.pos]; - self.pos += 1; - Ok(res) - } - - pub fn get(&mut self, pos: usize) -> Result { - if pos >= self.size { - return Err("Tried to read past end of buffer".into()); - } - Ok(self.buf[pos]) - } - - pub fn get_range(&mut self, start: usize, len: usize) -> Result<&[u8]> { - if start + len >= self.size { - return Err("Tried to read past end of buffer".into()); - } - Ok(&self.buf[start..start + len]) - } - - pub fn read_u16(&mut self) -> Result { - let res = ((self.read()? as u16) << 8) | (self.read()? as u16); - - Ok(res) - } - - pub fn read_u32(&mut self) -> Result { - let res = ((self.read()? as u32) << 24) - | ((self.read()? as u32) << 16) - | ((self.read()? as u32) << 8) - | (self.read()? as u32); - - Ok(res) - } - - pub fn read_qname(&mut self, outstr: &mut String) -> Result<()> { - let mut pos = self.pos(); - let mut jumped = false; - - let mut delim = ""; - let max_jumps = 5; - let mut jumps_performed = 0; - loop { - // Dns Packets are untrusted data, so we need to be paranoid. Someone - // can craft a packet with a cycle in the jump instructions. This guards - // against such packets. - if jumps_performed > max_jumps { - return Err(format!("Limit of {max_jumps} jumps exceeded").into()); - } - - let len = self.get(pos)?; - - if (len & 0xC0) == 0xC0 { - if !jumped { - self.seek(pos + 2)?; - } - - let b2 = self.get(pos + 1)? as u16; - let offset = (((len as u16) ^ 0xC0) << 8) | b2; - pos = offset as usize; - jumped = true; - jumps_performed += 1; - continue; - } - - pos += 1; - - if len == 0 { - break; - } - - outstr.push_str(delim); - - let str_buffer = self.get_range(pos, len as usize)?; - outstr.push_str(&String::from_utf8_lossy(str_buffer).to_lowercase()); - - delim = "."; - - pos += len as usize; - } - - if !jumped { - self.seek(pos)?; - } - - Ok(()) - } - - pub fn read_string(&mut self, outstr: &mut String) -> Result<()> { - let len = self.read()?; - - self.read_string_n(outstr, len)?; - - Ok(()) - } - - pub fn read_string_n(&mut self, outstr: &mut String, len: u8) -> Result<()> { - let mut pos = self.pos; - - let str_buffer = self.get_range(pos, len as usize)?; - - let mut i = 0; - for b in str_buffer { - let c = *b as char; - if c == '\0' { - break; - } - outstr.push(c); - i += 1; - } - - pos += i; - self.seek(pos)?; - - Ok(()) - } - - pub fn write(&mut self, val: u8) -> Result<()> { - if self.size < self.pos { - self.size = self.pos; - } - - if self.buf.len() <= self.size { - self.buf.resize(self.size + 1, 0x00); - } - - self.buf[self.pos] = val; - self.pos += 1; - Ok(()) - } - - pub fn write_u8(&mut self, val: u8) -> Result<()> { - self.write(val)?; - - Ok(()) - } - - pub fn write_u16(&mut self, val: u16) -> Result<()> { - self.write((val >> 8) as u8)?; - self.write((val & 0xFF) as u8)?; - - Ok(()) - } - - pub fn write_u32(&mut self, val: u32) -> Result<()> { - self.write(((val >> 24) & 0xFF) as u8)?; - self.write(((val >> 16) & 0xFF) as u8)?; - self.write(((val >> 8) & 0xFF) as u8)?; - self.write((val & 0xFF) as u8)?; - - Ok(()) - } - - pub fn write_qname(&mut self, qname: &str) -> Result<()> { - for label in qname.split('.') { - let len = label.len(); - - self.write_u8(len as u8)?; - for b in label.as_bytes() { - self.write_u8(*b)?; - } - } - - if !qname.is_empty() { - self.write_u8(0)?; - } - - Ok(()) - } - - pub fn write_string(&mut self, text: &str) -> Result<()> { - for b in text.as_bytes() { - self.write_u8(*b)?; - } - - Ok(()) - } - - pub fn set(&mut self, pos: usize, val: u8) -> Result<()> { - self.buf[pos] = val; - - Ok(()) - } - - pub fn set_u16(&mut self, pos: usize, val: u16) -> Result<()> { - self.set(pos, (val >> 8) as u8)?; - self.set(pos + 1, (val & 0xFF) as u8)?; - - Ok(()) - } -} diff --git a/src/dns/packet/header.rs b/src/dns/packet/header.rs deleted file mode 100644 index 2355ecb..0000000 --- a/src/dns/packet/header.rs +++ /dev/null @@ -1,102 +0,0 @@ -use super::{buffer::PacketBuffer, result::ResultCode}; -use crate::Result; - -#[derive(Clone, Debug)] -pub struct DnsHeader { - pub id: u16, // 16 bits - - pub recursion_desired: bool, // 1 bit - pub truncated_message: bool, // 1 bit - pub authoritative_answer: bool, // 1 bit - pub opcode: u8, // 4 bits - pub response: bool, // 1 bit - - pub rescode: ResultCode, // 4 bits - pub checking_disabled: bool, // 1 bit - pub authed_data: bool, // 1 bit - pub z: bool, // 1 bit - pub recursion_available: bool, // 1 bit - - pub questions: u16, // 16 bits - pub answers: u16, // 16 bits - pub authoritative_entries: u16, // 16 bits - pub resource_entries: u16, // 16 bits -} - -impl DnsHeader { - pub fn new() -> Self { - Self { - id: 0, - - recursion_desired: false, - truncated_message: false, - authoritative_answer: false, - opcode: 0, - response: false, - - rescode: ResultCode::NOERROR, - checking_disabled: false, - authed_data: false, - z: false, - recursion_available: false, - - questions: 0, - answers: 0, - authoritative_entries: 0, - resource_entries: 0, - } - } - - pub fn read(&mut self, buffer: &mut PacketBuffer) -> Result<()> { - self.id = buffer.read_u16()?; - let flags = buffer.read_u16()?; - let a = (flags >> 8) as u8; - let b = (flags & 0xFF) as u8; - self.recursion_desired = (a & (1 << 0)) > 0; - self.truncated_message = (a & (1 << 1)) > 0; - self.authoritative_answer = (a & (1 << 2)) > 0; - self.opcode = (a >> 3) & 0x0F; - self.response = (a & (1 << 7)) > 0; - - self.rescode = ResultCode::from_num(b & 0x0F); - self.checking_disabled = (b & (1 << 4)) > 0; - self.authed_data = (b & (1 << 5)) > 0; - self.z = (b & (1 << 6)) > 0; - self.recursion_available = (b & (1 << 7)) > 0; - - self.questions = buffer.read_u16()?; - self.answers = buffer.read_u16()?; - self.authoritative_entries = buffer.read_u16()?; - self.resource_entries = buffer.read_u16()?; - - // Return the constant header size - Ok(()) - } - - pub fn write(&self, buffer: &mut PacketBuffer) -> Result<()> { - buffer.write_u16(self.id)?; - - buffer.write_u8( - (self.recursion_desired as u8) - | ((self.truncated_message as u8) << 1) - | ((self.authoritative_answer as u8) << 2) - | (self.opcode << 3) - | ((self.response as u8) << 7), - )?; - - buffer.write_u8( - (self.rescode as u8) - | ((self.checking_disabled as u8) << 4) - | ((self.authed_data as u8) << 5) - | ((self.z as u8) << 6) - | ((self.recursion_available as u8) << 7), - )?; - - buffer.write_u16(self.questions)?; - buffer.write_u16(self.answers)?; - buffer.write_u16(self.authoritative_entries)?; - buffer.write_u16(self.resource_entries)?; - - Ok(()) - } -} diff --git a/src/dns/packet/mod.rs b/src/dns/packet/mod.rs deleted file mode 100644 index 9873b94..0000000 --- a/src/dns/packet/mod.rs +++ /dev/null @@ -1,128 +0,0 @@ -use std::net::IpAddr; - -use self::{ - buffer::PacketBuffer, header::DnsHeader, query::QueryType, question::DnsQuestion, - record::DnsRecord, -}; -use crate::Result; - -pub mod buffer; -pub mod header; -pub mod query; -pub mod question; -pub mod record; -pub mod result; - -#[derive(Clone, Debug)] -pub struct Packet { - pub header: DnsHeader, - pub questions: Vec, - pub answers: Vec, - pub authorities: Vec, - pub resources: Vec, -} - -impl Packet { - pub fn new() -> Self { - Self { - header: DnsHeader::new(), - questions: Vec::new(), - answers: Vec::new(), - authorities: Vec::new(), - resources: Vec::new(), - } - } - - pub fn from_buffer(buffer: &mut PacketBuffer) -> Result { - let mut result = Self::new(); - result.header.read(buffer)?; - - for _ in 0..result.header.questions { - let mut question = DnsQuestion::new("".to_string(), QueryType::UNKNOWN(0)); - question.read(buffer)?; - result.questions.push(question); - } - - for _ in 0..result.header.answers { - let rec = DnsRecord::read(buffer)?; - result.answers.push(rec); - } - for _ in 0..result.header.authoritative_entries { - let rec = DnsRecord::read(buffer)?; - result.authorities.push(rec); - } - for _ in 0..result.header.resource_entries { - let rec = DnsRecord::read(buffer)?; - result.resources.push(rec); - } - - Ok(result) - } - - pub fn write(&mut self, buffer: &mut PacketBuffer) -> Result<()> { - self.header.questions = self.questions.len() as u16; - self.header.answers = self.answers.len() as u16; - self.header.authoritative_entries = self.authorities.len() as u16; - self.header.resource_entries = self.resources.len() as u16; - - self.header.write(buffer)?; - - for question in &self.questions { - question.write(buffer)?; - } - for rec in &self.answers { - rec.write(buffer)?; - } - for rec in &self.authorities { - rec.write(buffer)?; - } - for rec in &self.resources { - rec.write(buffer)?; - } - - Ok(()) - } - - pub fn get_random_a(&self) -> Option { - self.answers - .iter() - .filter_map(|record| match record { - DnsRecord::A { addr, .. } => Some(IpAddr::V4(*addr)), - DnsRecord::AAAA { addr, .. } => Some(IpAddr::V6(*addr)), - _ => None, - }) - .next() - } - - fn get_ns<'a>(&'a self, qname: &'a str) -> impl Iterator { - self.authorities - .iter() - .filter_map(|record| match record { - DnsRecord::NS { domain, host, .. } => Some((domain.as_str(), host.as_str())), - _ => None, - }) - .filter(move |(domain, _)| qname.ends_with(*domain)) - } - - pub fn get_resolved_ns(&self, qname: &str) -> Option { - self.get_ns(qname) - .flat_map(|(_, host)| { - self.resources - .iter() - .filter_map(move |record| match record { - DnsRecord::A { domain, addr, .. } if domain == host => { - Some(IpAddr::V4(*addr)) - } - DnsRecord::AAAA { domain, addr, .. } if domain == host => { - Some(IpAddr::V6(*addr)) - } - _ => None, - }) - }) - .next() - } - - pub fn get_unresolved_ns<'a>(&'a self, qname: &'a str) -> Option<&'a str> { - self.get_ns(qname).map(|(_, host)| host).next() - } -} diff --git a/src/dns/packet/query.rs b/src/dns/packet/query.rs deleted file mode 100644 index 732b9b2..0000000 --- a/src/dns/packet/query.rs +++ /dev/null @@ -1,78 +0,0 @@ -#[derive(PartialEq, Eq, Debug, Clone, Hash, Copy)] -pub enum QueryType { - UNKNOWN(u16), - A, // 1 - NS, // 2 - CNAME, // 5 - SOA, // 6 - PTR, // 12 - MX, // 15 - TXT, // 16 - AAAA, // 28 - SRV, // 33 - OPT, // 41 - CAA, // 257 - AR, // 1000 - AAAAR, // 1001 -} - -impl QueryType { - pub fn to_num(&self) -> u16 { - match *self { - Self::UNKNOWN(x) => x, - Self::A => 1, - Self::NS => 2, - Self::CNAME => 5, - Self::SOA => 6, - Self::PTR => 12, - Self::MX => 15, - Self::TXT => 16, - Self::AAAA => 28, - Self::SRV => 33, - Self::OPT => 41, - Self::CAA => 257, - Self::AR => 1000, - Self::AAAAR => 1001, - } - } - - pub fn from_num(num: u16) -> Self { - match num { - 1 => Self::A, - 2 => Self::NS, - 5 => Self::CNAME, - 6 => Self::SOA, - 12 => Self::PTR, - 15 => Self::MX, - 16 => Self::TXT, - 28 => Self::AAAA, - 33 => Self::SRV, - 41 => Self::OPT, - 257 => Self::CAA, - 1000 => Self::AR, - 1001 => Self::AAAAR, - _ => Self::UNKNOWN(num), - } - } - - pub fn allowed_actions(&self) -> (bool, bool) { - // 0. duplicates allowed - // 1. allowed to be created by database - match self { - QueryType::UNKNOWN(_) => (false, false), - QueryType::A => (true, true), - QueryType::NS => (false, true), - QueryType::CNAME => (false, true), - QueryType::SOA => (false, false), - QueryType::PTR => (false, true), - QueryType::MX => (false, true), - QueryType::TXT => (true, true), - QueryType::AAAA => (true, true), - QueryType::SRV => (false, true), - QueryType::OPT => (false, false), - QueryType::CAA => (false, true), - QueryType::AR => (false, true), - QueryType::AAAAR => (false, true), - } - } -} diff --git a/src/dns/packet/question.rs b/src/dns/packet/question.rs deleted file mode 100644 index 9042e1c..0000000 --- a/src/dns/packet/question.rs +++ /dev/null @@ -1,31 +0,0 @@ -use super::{buffer::PacketBuffer, query::QueryType, Result}; - -#[derive(Debug, Clone, PartialEq, Eq, Hash)] -pub struct DnsQuestion { - pub name: String, - pub qtype: QueryType, -} - -impl DnsQuestion { - pub fn new(name: String, qtype: QueryType) -> Self { - Self { name, qtype } - } - - pub fn read(&mut self, buffer: &mut PacketBuffer) -> Result<()> { - buffer.read_qname(&mut self.name)?; - self.qtype = QueryType::from_num(buffer.read_u16()?); // qtype - let _ = buffer.read_u16()?; // class - - Ok(()) - } - - pub fn write(&self, buffer: &mut PacketBuffer) -> Result<()> { - buffer.write_qname(&self.name)?; - - let typenum = self.qtype.to_num(); - buffer.write_u16(typenum)?; - buffer.write_u16(1)?; - - Ok(()) - } -} diff --git a/src/dns/packet/record.rs b/src/dns/packet/record.rs deleted file mode 100644 index 88008f0..0000000 --- a/src/dns/packet/record.rs +++ /dev/null @@ -1,544 +0,0 @@ -use std::net::{Ipv4Addr, Ipv6Addr}; - -use rand::RngCore; -use serde::{Deserialize, Serialize}; -use tracing::{trace, warn}; - -use super::{buffer::PacketBuffer, query::QueryType, Result}; - -#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)] -pub enum DnsRecord { - UNKNOWN { - domain: String, - qtype: u16, - data_len: u16, - ttl: u32, - }, // 0 - A { - domain: String, - addr: Ipv4Addr, - ttl: u32, - }, // 1 - NS { - domain: String, - host: String, - ttl: u32, - }, // 2 - CNAME { - domain: String, - host: String, - ttl: u32, - }, // 5 - SOA { - domain: String, - mname: String, - nname: String, - serial: u32, - refresh: u32, - retry: u32, - expire: u32, - minimum: u32, - ttl: u32, - }, // 6 - PTR { - domain: String, - pointer: String, - ttl: u32, - }, // 12 - MX { - domain: String, - priority: u16, - host: String, - ttl: u32, - }, // 15 - TXT { - domain: String, - text: Vec, - ttl: u32, - }, //16 - AAAA { - domain: String, - addr: Ipv6Addr, - ttl: u32, - }, // 28 - SRV { - domain: String, - priority: u16, - weight: u16, - port: u16, - target: String, - ttl: u32, - }, // 33 - CAA { - domain: String, - flags: u8, - length: u8, - tag: String, - value: String, - ttl: u32, - }, // 257 - AR { - domain: String, - ttl: u32, - }, - AAAAR { - domain: String, - ttl: u32, - }, -} - -impl DnsRecord { - pub fn read(buffer: &mut PacketBuffer) -> Result { - let mut domain = String::new(); - buffer.read_qname(&mut domain)?; - - let qtype_num = buffer.read_u16()?; - let qtype = QueryType::from_num(qtype_num); - let _ = buffer.read_u16()?; - let ttl = buffer.read_u32()?; - let data_len = buffer.read_u16()?; - - trace!("Reading DNS Record TYPE: {:?}", qtype); - - let header_pos = buffer.pos(); - - match qtype { - QueryType::A => { - let raw_addr = buffer.read_u32()?; - let addr = Ipv4Addr::new( - ((raw_addr >> 24) & 0xFF) as u8, - ((raw_addr >> 16) & 0xFF) as u8, - ((raw_addr >> 8) & 0xFF) as u8, - (raw_addr & 0xFF) as u8, - ); - - Ok(Self::A { domain, addr, ttl }) - } - QueryType::AAAA => { - let raw_addr1 = buffer.read_u32()?; - let raw_addr2 = buffer.read_u32()?; - let raw_addr3 = buffer.read_u32()?; - let raw_addr4 = buffer.read_u32()?; - let addr = Ipv6Addr::new( - ((raw_addr1 >> 16) & 0xFFFF) as u16, - (raw_addr1 & 0xFFFF) as u16, - ((raw_addr2 >> 16) & 0xFFFF) as u16, - (raw_addr2 & 0xFFFF) as u16, - ((raw_addr3 >> 16) & 0xFFFF) as u16, - (raw_addr3 & 0xFFFF) as u16, - ((raw_addr4 >> 16) & 0xFFFF) as u16, - (raw_addr4 & 0xFFFF) as u16, - ); - - Ok(Self::AAAA { domain, addr, ttl }) - } - QueryType::NS => { - let mut ns = String::new(); - buffer.read_qname(&mut ns)?; - - Ok(Self::NS { - domain, - host: ns, - ttl, - }) - } - QueryType::CNAME => { - let mut cname = String::new(); - buffer.read_qname(&mut cname)?; - - Ok(Self::CNAME { - domain, - host: cname, - ttl, - }) - } - QueryType::SOA => { - let mut mname = String::new(); - buffer.read_qname(&mut mname)?; - - let mut nname = String::new(); - buffer.read_qname(&mut nname)?; - - let serial = buffer.read_u32()?; - let refresh = buffer.read_u32()?; - let retry = buffer.read_u32()?; - let expire = buffer.read_u32()?; - let minimum = buffer.read_u32()?; - - Ok(Self::SOA { - domain, - mname, - nname, - serial, - refresh, - retry, - expire, - minimum, - ttl, - }) - } - QueryType::PTR => { - let mut pointer = String::new(); - buffer.read_qname(&mut pointer)?; - - Ok(Self::PTR { - domain, - pointer, - ttl, - }) - } - QueryType::MX => { - let priority = buffer.read_u16()?; - let mut mx = String::new(); - buffer.read_qname(&mut mx)?; - - Ok(Self::MX { - domain, - priority, - host: mx, - ttl, - }) - } - QueryType::TXT => { - let mut text = Vec::new(); - - loop { - let mut s = String::new(); - buffer.read_string(&mut s)?; - - if s.len() == 0 { - break; - } else { - text.push(s); - } - } - - Ok(Self::TXT { domain, text, ttl }) - } - QueryType::SRV => { - let priority = buffer.read_u16()?; - let weight = buffer.read_u16()?; - let port = buffer.read_u16()?; - - let mut target = String::new(); - buffer.read_qname(&mut target)?; - - Ok(Self::SRV { - domain, - priority, - weight, - port, - target, - ttl, - }) - } - QueryType::CAA => { - let flags = buffer.read()?; - let length = buffer.read()?; - - let mut tag = String::new(); - buffer.read_string_n(&mut tag, length)?; - - let value_len = (data_len as usize) + header_pos - buffer.pos; - let mut value = String::new(); - buffer.read_string_n(&mut value, value_len as u8)?; - - Ok(Self::CAA { - domain, - flags, - length, - tag, - value, - ttl, - }) - } - QueryType::UNKNOWN(_) | _ => { - buffer.step(data_len as usize)?; - - Ok(Self::UNKNOWN { - domain, - qtype: qtype_num, - data_len, - ttl, - }) - } - } - } - - pub fn write(&self, buffer: &mut PacketBuffer) -> Result { - let start_pos = buffer.pos(); - - trace!("Writing DNS Record {:?}", self); - - match *self { - Self::A { - ref domain, - ref addr, - ttl, - } => { - buffer.write_qname(domain)?; - buffer.write_u16(QueryType::A.to_num())?; - buffer.write_u16(1)?; - buffer.write_u32(ttl)?; - buffer.write_u16(4)?; - - let octets = addr.octets(); - buffer.write_u8(octets[0])?; - buffer.write_u8(octets[1])?; - buffer.write_u8(octets[2])?; - buffer.write_u8(octets[3])?; - } - Self::NS { - ref domain, - ref host, - ttl, - } => { - buffer.write_qname(domain)?; - buffer.write_u16(QueryType::NS.to_num())?; - buffer.write_u16(1)?; - buffer.write_u32(ttl)?; - - let pos = buffer.pos(); - buffer.write_u16(0)?; - - buffer.write_qname(host)?; - - let size = buffer.pos() - (pos + 2); - buffer.set_u16(pos, size as u16)?; - } - Self::CNAME { - ref domain, - ref host, - ttl, - } => { - buffer.write_qname(domain)?; - buffer.write_u16(QueryType::CNAME.to_num())?; - buffer.write_u16(1)?; - buffer.write_u32(ttl)?; - - let pos = buffer.pos(); - buffer.write_u16(0)?; - - buffer.write_qname(host)?; - - let size = buffer.pos() - (pos + 2); - buffer.set_u16(pos, size as u16)?; - } - Self::SOA { - ref domain, - ref mname, - ref nname, - serial, - refresh, - retry, - expire, - minimum, - ttl, - } => { - buffer.write_qname(domain)?; - buffer.write_u16(QueryType::SOA.to_num())?; - buffer.write_u16(1)?; - buffer.write_u32(ttl)?; - - let pos = buffer.pos(); - buffer.write_u16(0)?; - - buffer.write_qname(mname)?; - buffer.write_qname(nname)?; - buffer.write_u32(serial)?; - buffer.write_u32(refresh)?; - buffer.write_u32(retry)?; - buffer.write_u32(expire)?; - buffer.write_u32(minimum)?; - - let size = buffer.pos() - (pos + 2); - buffer.set_u16(pos, size as u16)?; - } - Self::PTR { - ref domain, - ref pointer, - ttl, - } => { - buffer.write_qname(domain)?; - buffer.write_u16(QueryType::NS.to_num())?; - buffer.write_u16(1)?; - buffer.write_u32(ttl)?; - - let pos = buffer.pos(); - buffer.write_u16(0)?; - - buffer.write_qname(&pointer)?; - - let size = buffer.pos() - (pos + 2); - buffer.set_u16(pos, size as u16)?; - } - Self::MX { - ref domain, - priority, - ref host, - ttl, - } => { - buffer.write_qname(domain)?; - buffer.write_u16(QueryType::MX.to_num())?; - buffer.write_u16(1)?; - buffer.write_u32(ttl)?; - - let pos = buffer.pos(); - buffer.write_u16(0)?; - - buffer.write_u16(priority)?; - buffer.write_qname(host)?; - - let size = buffer.pos() - (pos + 2); - buffer.set_u16(pos, size as u16)?; - } - Self::TXT { - ref domain, - ref text, - ttl, - } => { - buffer.write_qname(&domain)?; - buffer.write_u16(QueryType::TXT.to_num())?; - buffer.write_u16(1)?; - buffer.write_u32(ttl)?; - - let pos = buffer.pos(); - buffer.write_u16(0)?; - - if text.is_empty() { - return Ok(buffer.pos() - start_pos); - } - - for s in text { - buffer.write_u8(s.len() as u8)?; - buffer.write_string(&s)?; - } - let size = buffer.pos() - (pos + 2); - buffer.set_u16(pos, size as u16)?; - } - Self::AAAA { - ref domain, - ref addr, - ttl, - } => { - buffer.write_qname(domain)?; - buffer.write_u16(QueryType::AAAA.to_num())?; - buffer.write_u16(1)?; - buffer.write_u32(ttl)?; - buffer.write_u16(16)?; - - for octet in &addr.segments() { - buffer.write_u16(*octet)?; - } - } - Self::SRV { - ref domain, - priority, - weight, - port, - ref target, - ttl, - } => { - buffer.write_qname(domain)?; - buffer.write_u16(QueryType::SRV.to_num())?; - buffer.write_u16(1)?; - buffer.write_u32(ttl)?; - - let pos = buffer.pos(); - buffer.write_u16(0)?; - - buffer.write_u16(priority)?; - buffer.write_u16(weight)?; - buffer.write_u16(port)?; - buffer.write_qname(target)?; - - let size = buffer.pos() - (pos + 2); - buffer.set_u16(pos, size as u16)?; - } - Self::CAA { - ref domain, - flags, - length, - ref tag, - ref value, - ttl, - } => { - buffer.write_qname(domain)?; - buffer.write_u16(QueryType::CAA.to_num())?; - buffer.write_u16(1)?; - buffer.write_u32(ttl)?; - - let pos = buffer.pos(); - buffer.write_u16(0)?; - - buffer.write_u8(flags)?; - buffer.write_u8(length)?; - buffer.write_string(tag)?; - buffer.write_string(value)?; - - let size = buffer.pos() - (pos + 2); - buffer.set_u16(pos, size as u16)?; - } - Self::AR { ref domain, ttl } => { - buffer.write_qname(domain)?; - buffer.write_u16(QueryType::A.to_num())?; - buffer.write_u16(1)?; - buffer.write_u32(ttl)?; - buffer.write_u16(4)?; - - let mut rand = rand::thread_rng(); - buffer.write_u32(rand.next_u32())?; - } - Self::AAAAR { ref domain, ttl } => { - buffer.write_qname(domain)?; - buffer.write_u16(QueryType::A.to_num())?; - buffer.write_u16(1)?; - buffer.write_u32(ttl)?; - buffer.write_u16(4)?; - - let mut rand = rand::thread_rng(); - buffer.write_u32(rand.next_u32())?; - buffer.write_u32(rand.next_u32())?; - buffer.write_u32(rand.next_u32())?; - buffer.write_u32(rand.next_u32())?; - } - Self::UNKNOWN { .. } => { - warn!("Skipping record: {self:?}"); - } - } - - Ok(buffer.pos() - start_pos) - } - - pub fn get_domain(&self) -> String { - self.get_shared_domain().0 - } - - pub fn get_qtype(&self) -> QueryType { - self.get_shared_domain().1 - } - - pub fn get_ttl(&self) -> u32 { - self.get_shared_domain().2 - } - - fn get_shared_domain(&self) -> (String, QueryType, u32) { - match self { - DnsRecord::UNKNOWN { - domain, ttl, qtype, .. - } => (domain.clone(), QueryType::UNKNOWN(*qtype), *ttl), - DnsRecord::AAAA { domain, ttl, .. } => (domain.clone(), QueryType::AAAA, *ttl), - DnsRecord::A { domain, ttl, .. } => (domain.clone(), QueryType::A, *ttl), - DnsRecord::NS { domain, ttl, .. } => (domain.clone(), QueryType::NS, *ttl), - DnsRecord::CNAME { domain, ttl, .. } => (domain.clone(), QueryType::CNAME, *ttl), - DnsRecord::SOA { domain, ttl, .. } => (domain.clone(), QueryType::SOA, *ttl), - DnsRecord::PTR { domain, ttl, .. } => (domain.clone(), QueryType::PTR, *ttl), - DnsRecord::MX { domain, ttl, .. } => (domain.clone(), QueryType::MX, *ttl), - DnsRecord::TXT { domain, ttl, .. } => (domain.clone(), QueryType::TXT, *ttl), - DnsRecord::SRV { domain, ttl, .. } => (domain.clone(), QueryType::SRV, *ttl), - DnsRecord::CAA { domain, ttl, .. } => (domain.clone(), QueryType::CAA, *ttl), - DnsRecord::AR { domain, ttl, .. } => (domain.clone(), QueryType::AR, *ttl), - DnsRecord::AAAAR { domain, ttl, .. } => (domain.clone(), QueryType::AAAAR, *ttl), - } - } -} diff --git a/src/dns/packet/result.rs b/src/dns/packet/result.rs deleted file mode 100644 index 41c8ba9..0000000 --- a/src/dns/packet/result.rs +++ /dev/null @@ -1,22 +0,0 @@ -#[derive(Copy, Clone, Debug, PartialEq, Eq)] -pub enum ResultCode { - NOERROR = 0, - FORMERR = 1, - SERVFAIL = 2, - NXDOMAIN = 3, - NOTIMP = 4, - REFUSED = 5, -} - -impl ResultCode { - pub fn from_num(num: u8) -> Self { - match num { - 1 => Self::FORMERR, - 2 => Self::SERVFAIL, - 3 => Self::NXDOMAIN, - 4 => Self::NOTIMP, - 5 => Self::REFUSED, - 0 | _ => Self::NOERROR, - } - } -} diff --git a/src/dns/resolver.rs b/src/dns/resolver.rs deleted file mode 100644 index 18b5bba..0000000 --- a/src/dns/resolver.rs +++ /dev/null @@ -1,230 +0,0 @@ -use super::binding::Connection; -use super::packet::{query::QueryType, question::DnsQuestion, result::ResultCode, Packet}; -use crate::Result; -use crate::{config::Config, database::Database, get_time}; -use async_recursion::async_recursion; -use moka::future::Cache; -use std::{net::IpAddr, sync::Arc, time::Duration}; -use tracing::{error, trace}; - -pub struct Resolver { - request_id: u16, - connection: Connection, - config: Arc, - database: Arc, - cache: Cache, -} - -impl Resolver { - pub fn new( - request_id: u16, - connection: Connection, - config: Arc, - database: Arc, - cache: Cache, - ) -> Self { - Self { - request_id, - connection, - config, - database, - cache, - } - } - - async fn lookup_database(&self, question: &DnsQuestion) -> Option { - let records = match self - .database - .get_records(&question.name, question.qtype) - .await - { - Ok(record) => record, - Err(err) => { - error!("{err}"); - return None; - } - }; - - if records.is_empty() { - return None; - } - - let mut packet = Packet::new(); - - packet.header.id = self.request_id; - packet.header.questions = 1; - packet.header.answers = records.len() as u16; - packet.header.recursion_desired = true; - packet - .questions - .push(DnsQuestion::new(question.name.to_string(), question.qtype)); - - for record in records { - packet.answers.push(record); - } - - trace!( - "Found stored value for {:?} {}", - question.qtype, - question.name - ); - - Some(packet) - } - - async fn lookup_cache(&self, question: &DnsQuestion) -> Option { - let Some((packet, date)) = self.cache.get(&question) else { - return None - }; - - let now = get_time(); - let diff = Duration::from_millis(now - date).as_secs() as u32; - - for answer in &packet.answers { - let ttl = answer.get_ttl(); - if diff > ttl { - self.cache.invalidate(&question).await; - return None; - } - } - - trace!( - "Found cached value for {:?} {}", - question.qtype, - question.name - ); - - Some(packet) - } - - async fn lookup_fallback(&self, question: &DnsQuestion, server: (IpAddr, u16)) -> Packet { - let mut packet = Packet::new(); - - packet.header.id = self.request_id; - packet.header.questions = 1; - packet.header.recursion_desired = true; - packet - .questions - .push(DnsQuestion::new(question.name.to_string(), question.qtype)); - - let packet = match self.connection.request_packet(packet, server).await { - Ok(packet) => packet, - Err(e) => { - error!("Failed to complete nameserver request: {e}"); - let mut packet = Packet::new(); - packet.header.rescode = ResultCode::SERVFAIL; - packet - } - }; - - packet - } - - async fn lookup(&self, question: &DnsQuestion, server: (IpAddr, u16)) -> Packet { - if let Some(packet) = self.lookup_cache(question).await { - return packet; - }; - - if let Some(packet) = self.lookup_database(question).await { - return packet; - }; - - trace!( - "Attempting lookup of {:?} {} with ns {}", - question.qtype, - question.name, - server.0 - ); - - self.lookup_fallback(question, server).await - } - - #[async_recursion] - async fn recursive_lookup(&mut self, qname: &str, qtype: QueryType) -> Packet { - let question = DnsQuestion::new(qname.to_string(), qtype); - let mut ns = self.config.dns_fallback.clone(); - - loop { - let ns_copy = ns; - - let server = (ns_copy, 53); - let response = self.lookup(&question, server).await; - - if !response.answers.is_empty() && response.header.rescode == ResultCode::NOERROR { - self.cache - .insert(question, (response.clone(), get_time())) - .await; - return response; - } - - if response.header.rescode == ResultCode::NXDOMAIN { - self.cache - .insert(question, (response.clone(), get_time())) - .await; - return response; - } - - if let Some(new_ns) = response.get_resolved_ns(qname) { - ns = new_ns; - continue; - } - - let new_ns_name = match response.get_unresolved_ns(qname) { - Some(x) => x, - None => { - self.cache - .insert(question, (response.clone(), get_time())) - .await; - return response; - } - }; - - let recursive_response = self.recursive_lookup(new_ns_name, QueryType::A).await; - - if let Some(new_ns) = recursive_response.get_random_a() { - ns = new_ns; - } else { - self.cache - .insert(question, (response.clone(), get_time())) - .await; - return response; - } - } - } - - pub async fn handle_query(mut self) -> Result<()> { - let mut request = self.connection.read_packet().await?; - - let mut packet = Packet::new(); - packet.header.id = request.header.id; - packet.header.recursion_desired = true; - packet.header.recursion_available = true; - packet.header.response = true; - - if let Some(question) = request.questions.pop() { - trace!("Received query: {question:?}"); - - let result = self.recursive_lookup(&question.name, question.qtype).await; - packet.questions.push(question.clone()); - packet.header.rescode = result.header.rescode; - - for rec in result.answers { - trace!("Answer: {rec:?}"); - packet.answers.push(rec); - } - for rec in result.authorities { - trace!("Authority: {rec:?}"); - packet.authorities.push(rec); - } - for rec in result.resources { - trace!("Resource: {rec:?}"); - packet.resources.push(rec); - } - } else { - packet.header.rescode = ResultCode::FORMERR; - } - - self.connection.write_packet(packet).await?; - Ok(()) - } -} diff --git a/src/dns/server.rs b/src/dns/server.rs deleted file mode 100644 index 65d15df..0000000 --- a/src/dns/server.rs +++ /dev/null @@ -1,85 +0,0 @@ -use super::{ - binding::Binding, - packet::{question::DnsQuestion, Packet}, - resolver::Resolver, -}; -use crate::{config::Config, database::Database, Result}; -use moka::future::Cache; -use std::{net::SocketAddr, sync::Arc, time::Duration}; -use tokio::task::JoinHandle; -use tracing::{error, info}; - -pub struct DnsServer { - addr: SocketAddr, - config: Arc, - database: Arc, - cache: Cache, -} - -impl DnsServer { - pub async fn new(config: Config, database: Database) -> Result { - let addr = format!("[::]:{}", config.dns_port).parse::()?; - let cache = Cache::builder() - .time_to_live(Duration::from_secs(60 * 60)) - .max_capacity(config.dns_cache_size) - .build(); - - info!("Created DNS cache with size of {}", config.dns_cache_size); - - Ok(Self { - addr, - config: Arc::new(config), - database: Arc::new(database), - cache, - }) - } - - pub async fn run(&self) -> Result<(JoinHandle<()>, JoinHandle<()>)> { - let tcp = Binding::tcp(self.addr).await?; - let tcp_handle = self.listen(tcp); - - let udp = Binding::udp(self.addr).await?; - let udp_handle = self.listen(udp); - - info!( - "Fallback DNS Server is set to: {:?}", - self.config.dns_fallback - ); - info!( - "Listening for TCP and UDP traffic on [::]:{}", - self.config.dns_port - ); - - Ok((udp_handle, tcp_handle)) - } - - fn listen(&self, mut binding: Binding) -> JoinHandle<()> { - let config = self.config.clone(); - let database = self.database.clone(); - let cache = self.cache.clone(); - tokio::spawn(async move { - let mut id = 0; - loop { - let Ok(connection) = binding.connect().await else { continue }; - info!("Received request on {}", binding.name()); - - let resolver = Resolver::new( - id, - connection, - config.clone(), - database.clone(), - cache.clone(), - ); - - let name = binding.name().to_string(); - tokio::spawn(async move { - if let Err(err) = resolver.handle_query().await { - error!("{} request {} failed: {:?}", name, id, err); - }; - }); - - id += 1; - } - }) - } -} diff --git a/src/io/log.c b/src/io/log.c new file mode 100644 index 0000000..ddf56ff --- /dev/null +++ b/src/io/log.c @@ -0,0 +1,49 @@ +#include +#include +#include +#include +#include + +#include "log.h" + +#ifdef LOG + +void logmsg(LogLevel level, const char* msg, ...) { + + INIT_LOG_BOUNDS + INIT_LOG_BUFFER(buffer) + + time_t now = time(NULL); + struct tm *tm = localtime(&now); + APPEND(buffer, "\x1b[97m%02d:%02d:%02d ", tm->tm_hour, tm->tm_min, tm->tm_sec); + + switch (level) { + case DEBUG: + APPEND(buffer, "\x1b[%dm%s\x1b[97m ", 95, "DEBUG"); + break; + case TRACE: + APPEND(buffer, "\x1b[%dm%s\x1b[97m ", 96, "TRACE"); + break; + case INFO: + APPEND(buffer, "\x1b[%dm%s\x1b[97m ", 92, "INFO"); + break; + case WARN: + APPEND(buffer, "\x1b[%dm%s\x1b[97m ", 93, "WARN"); + break; + case ERROR: + APPEND(buffer, "\x1b[%dm%s\x1b[97m ", 91, "ERROR"); + break; + break; + } + + va_list valist; + va_start(valist, msg); + t += vsnprintf(buffer + t, BUF_LENGTH - t, msg, valist); + va_end(valist); + + APPEND(buffer, "\n"); + + fwrite(&buffer, t, 1, stdout); +} + +#endif \ No newline at end of file diff --git a/src/io/log.h b/src/io/log.h new file mode 100644 index 0000000..c2fbd90 --- /dev/null +++ b/src/io/log.h @@ -0,0 +1,45 @@ +#pragma once + +#include + +#define LOG +#ifdef LOG + +typedef enum { + DEBUG, + TRACE, + INFO, + WARN, + ERROR, +} LogLevel; + +#define BUF_LENGTH 256 +#define INIT_LOG_BUFFER(name) char name[BUF_LENGTH]; +#define INIT_LOG_BOUNDS int t = 0; +#define APPEND(buffer, msg, ...) t += snprintf(buffer + t, BUF_LENGTH - t, msg, ##__VA_ARGS__); +#define LOGONLY(code) code + +void logmsg(LogLevel level, const char* msg, ...) + __attribute__ ((__format__(printf, 2, 3))); + +#define DEBUG(msg, ...) logmsg(DEBUG, msg, ##__VA_ARGS__) +#define TRACE(msg, ...) logmsg(TRACE, msg, ##__VA_ARGS__) +#define INFO(msg, ...) logmsg(INFO, msg, ##__VA_ARGS__) +#define WARN(msg, ...) logmsg(WARN, msg, ##__VA_ARGS__) +#define ERROR(msg, ...) logmsg(ERROR, msg, ##__VA_ARGS__) + +#else + +#define BUF_LENGTH +#define INIT_LOG_BUFFER(name) +#define INIT_LOG_BOUNDS +#define APPEND(buffer, msg, ...) +#define LOGONLY(code) + +#define DEBUG(msg, ...) +#define TRACE(msg, ...) +#define INFO(msg, ...) +#define WARN(msg, ...) +#define ERROR(msg, ...) + +#endif \ No newline at end of file diff --git a/src/main.c b/src/main.c new file mode 100644 index 0000000..13dae57 --- /dev/null +++ b/src/main.c @@ -0,0 +1,32 @@ +#include "server/server.h" + +#include +#include + +#define DEFAULT_PORT 53 + +static uint16_t get_port(const char* port_str) { + if (port_str == NULL) { + return DEFAULT_PORT; + } + + uint16_t port; + if ((port = strtoul(port_str, NULL, 10)) == 0) { + return DEFAULT_PORT; + } + + return port; +} + +int main(void) { + + const char* port_str = getenv("PORT"); + uint16_t port = get_port(port_str); + + Server server; + server_init(port, &server); + server_run(&server); + server_free(&server); + + return EXIT_SUCCESS; +} diff --git a/src/main.rs b/src/main.rs deleted file mode 100644 index 679e87b..0000000 --- a/src/main.rs +++ /dev/null @@ -1,64 +0,0 @@ -use std::time::{SystemTime, UNIX_EPOCH}; - -use config::Config; - -use database::Database; -use dotenv::dotenv; -use dns::server::DnsServer; -use tracing::{error, metadata::LevelFilter}; -use tracing_subscriber::{ - filter::filter_fn, prelude::__tracing_subscriber_SubscriberExt, util::SubscriberInitExt, Layer, -}; -use web::WebServer; - -mod config; -mod database; -mod dns; -mod web; - -type Error = Box; -pub type Result = std::result::Result; - -#[tokio::main] -async fn main() { - if let Err(err) = run().await { - error!("{err}") - }; -} - -async fn run() -> Result<()> { - dotenv().ok(); - - tracing_subscriber::registry() - .with( - tracing_subscriber::fmt::layer() - .with_filter(LevelFilter::TRACE) - .with_filter(filter_fn(|metadata| { - metadata.target().starts_with("wrapper") - })), - ) - .init(); - - let config = Config::new(); - let database = Database::new(config.clone()).await?; - - let dns_server = DnsServer::new(config.clone(), database.clone()).await?; - let (udp, tcp) = dns_server.run().await?; - - let web_server = WebServer::new(config, database).await?; - let web = web_server.run().await?; - - tokio::join!(udp).0?; - tokio::join!(tcp).0?; - tokio::join!(web).0?; - - Ok(()) -} - -pub fn get_time() -> u64 { - let start = SystemTime::now(); - let since_the_epoch = start - .duration_since(UNIX_EPOCH) - .expect("Time went backwards"); - since_the_epoch.as_millis() as u64 -} diff --git a/src/packet/buffer.c b/src/packet/buffer.c new file mode 100644 index 0000000..28dd73b --- /dev/null +++ b/src/packet/buffer.c @@ -0,0 +1,240 @@ +#include "buffer.h" + +#include +#include +#include + +struct PacketBuffer { + uint8_t* arr; + int capacity; + int index; + int size; +}; + +PacketBuffer* buffer_create(int capacity) { + PacketBuffer* buffer = malloc(sizeof(PacketBuffer)); + buffer->arr = malloc(capacity); + buffer->capacity = capacity; + buffer->size = 0; + buffer->index = 0; + return buffer; +} + +void buffer_free(PacketBuffer* buffer) { + free(buffer->arr); + free(buffer); +} + +void buffer_seek(PacketBuffer* buffer, int index) { + buffer->index = index; +} + +uint8_t buffer_read(PacketBuffer* buffer) { + if (buffer->index > buffer->size) { + return 0; + } + uint8_t data = buffer->arr[buffer->index]; + buffer->index++; + return data; +} + +uint16_t buffer_read_short(PacketBuffer* buffer) { + return + (uint16_t) buffer_read(buffer) << 8 | + (uint16_t) buffer_read(buffer); +} + +uint32_t buffer_read_int(PacketBuffer* buffer) { + return + (uint32_t) buffer_read(buffer) << 24 | + (uint32_t) buffer_read(buffer) << 16 | + (uint32_t) buffer_read(buffer) << 8 | + (uint32_t) buffer_read(buffer); +} + +uint8_t buffer_get(PacketBuffer* buffer, int index) { + if (index > buffer->size) { + return 0; + } + uint8_t data = buffer->arr[index]; + return data; +} + +uint8_t* buffer_get_range(PacketBuffer* buffer, int start, int len) { + uint8_t* arr = malloc(len); + for (int i = 0; i < len; i++) { + arr[i] = buffer_get(buffer, start + i); + } + return arr; +} + +uint16_t buffer_get_size(PacketBuffer* buffer) { + return (uint16_t) buffer->size + 1; +} + +static void write(uint8_t** buffer, uint8_t* size, uint8_t* capacity, uint8_t data) { + if (*size == *capacity) { + *capacity *= 2; + *buffer = realloc(*buffer, *capacity); + } + (*buffer)[*size] = data; + (*size)++; +} + +void buffer_read_qname(PacketBuffer* buffer, uint8_t** out) { + int index = buffer->index; + int jumped = 0; + + int max_jumps = 5; + int jumps_performed = 0; + + uint8_t length = 0; + uint8_t capacity = 8; + *out = malloc(capacity); + write(out, &length, &capacity, 0); + + while(1) { + if (jumps_performed > max_jumps) { + break; + } + + uint8_t len = buffer_get(buffer, index); + + if ((len & 0xC0) == 0xC0) { + if (jumped == 0) { + buffer_seek(buffer, index + 2); + } + + uint16_t b2 = (uint16_t) buffer_get(buffer, index + 1); + uint16_t offset = ((((uint16_t) len) ^ 0xC0) << 8) | b2; + index = (int) offset; + jumped = 1; + jumps_performed++; + continue; + } + + index++; + + if (len == 0) { + break; + } + + if (length > 1) { + write(out, &length, &capacity, '.'); + } + + uint8_t* range = buffer_get_range(buffer, index, len); + for (uint8_t i = 0; i < len; i++) { + write(out, &length, &capacity, range[i]); + } + free(range); + + index += (int) len; + } + + if (jumped == 0) { + buffer_seek(buffer, index); + } + + (*out)[0] = length - 1; +} + +void buffer_read_string(PacketBuffer* buffer, uint8_t** out) { + uint8_t len = buffer_read(buffer); + buffer_read_n(buffer, out, len); +} + +void buffer_read_n(PacketBuffer* buffer, uint8_t** out, uint8_t len) { + *out = malloc(len + 1); + *out[0] = len; + memcpy(*out + 1, buffer->arr + buffer->index, len); + buffer->index += len; +} + +void buffer_write(PacketBuffer* buffer, uint8_t data) { + if(buffer->index == buffer->capacity) { + buffer->capacity *= 2; + buffer->arr = realloc(buffer->arr, buffer->capacity); + } + if (buffer->size < buffer->index) { + buffer->size = buffer->index; + } + buffer->arr[buffer->index] = data; + buffer->index++; +} + +void buffer_write_short(PacketBuffer* buffer, uint16_t data) { + buffer_write(buffer, (uint8_t)(data >> 8)); + buffer_write(buffer, (uint8_t)(data & 0xFF)); +} + +void buffer_write_int(PacketBuffer* buffer, uint32_t data) { + buffer_write(buffer, (uint8_t)(data >> 24)); + buffer_write(buffer, (uint8_t)(data >> 16)); + buffer_write(buffer, (uint8_t)(data >> 8)); + buffer_write(buffer, (uint8_t)(data & 0xFF)); +} + +void buffer_write_qname(PacketBuffer* buffer, uint8_t* in) { + uint8_t part = 0; + uint8_t len = in[0]; + + buffer_write(buffer, 0); + + if (len == 0) { + return; + } + + for(uint8_t i = 0; i < len; i ++) { + if (in[i+1] == '.') { + buffer_set(buffer, part, buffer->index - (int)part - 1); + buffer_write(buffer, 0); + part = 0; + } else { + buffer_write(buffer, in[i+1]); + part++; + } + } + buffer_set(buffer, part, buffer->index - (int)part - 1); + buffer_write(buffer, 0); +} + +void buffer_write_string(PacketBuffer* buffer, uint8_t* in) { + buffer_write(buffer, in[0]); + buffer_write_n(buffer, in + 1, in[0]); +} + +void buffer_write_n(PacketBuffer* buffer, uint8_t* in, int len) { + if (buffer->size + len >= buffer->capacity) { + buffer->capacity *= 2; + buffer->capacity += len; + buffer->arr = realloc(buffer->arr, buffer->capacity); + } + memcpy(buffer->arr + buffer->index, in, len); + buffer->size += len; + buffer->index += len; +} + +void buffer_set(PacketBuffer* buffer, uint8_t data, int index) { + if (index > buffer->size) { + return; + } + buffer->arr[index] = data; +} + +void buffer_set_uint16_t(PacketBuffer* buffer, uint16_t data, int index) { + buffer_set(buffer, (uint8_t)(data >> 8), index); + buffer_set(buffer, (uint8_t)(data & 0xFF), index + 1); +} + +int buffer_get_index(PacketBuffer* buffer) { + return buffer->index; +} + +void buffer_step(PacketBuffer* buffer, int len) { + buffer->index += len; +} + +uint8_t* buffer_get_ptr(PacketBuffer* buffer) { + return buffer->arr; +} diff --git a/src/packet/buffer.h b/src/packet/buffer.h new file mode 100644 index 0000000..ad3145d --- /dev/null +++ b/src/packet/buffer.h @@ -0,0 +1,51 @@ +#pragma once + +#include + +typedef struct PacketBuffer PacketBuffer; + +PacketBuffer* buffer_create(int capacity); + +void buffer_free(PacketBuffer* buffer); + +void buffer_seek(PacketBuffer* buffer, int index); + +uint8_t buffer_read(PacketBuffer* buffer); + +uint16_t buffer_read_short(PacketBuffer* buffer); + +uint32_t buffer_read_int(PacketBuffer* buffer); + +uint8_t buffer_get(PacketBuffer* buffer, int index); + +uint8_t* buffer_get_range(PacketBuffer* buffer, int start, int len); + +uint16_t buffer_get_size(PacketBuffer* buffer); + +void buffer_read_qname(PacketBuffer* buffer, uint8_t** out); + +void buffer_read_string(PacketBuffer* buffer, uint8_t** out); + +void buffer_read_n(PacketBuffer* buffer, uint8_t** out, uint8_t len); + +void buffer_write(PacketBuffer* buffer, uint8_t data); + +void buffer_write_short(PacketBuffer* buffer, uint16_t data); + +void buffer_write_int(PacketBuffer* buffer, uint32_t data); + +void buffer_write_qname(PacketBuffer* buffer, uint8_t* in); + +void buffer_write_string(PacketBuffer* buffer, uint8_t* in); + +void buffer_write_n(PacketBuffer* buffer, uint8_t* in, int len); + +void buffer_set(PacketBuffer* buffer, uint8_t data, int index); + +void buffer_set_uint16_t(PacketBuffer* buffer, uint16_t data, int index); + +int buffer_get_index(PacketBuffer* buffer); + +void buffer_step(PacketBuffer* buffer, int len); + +uint8_t* buffer_get_ptr(PacketBuffer* buffer); diff --git a/src/packet/header.c b/src/packet/header.c new file mode 100644 index 0000000..fd601ce --- /dev/null +++ b/src/packet/header.c @@ -0,0 +1,93 @@ +#include "header.h" +#include "buffer.h" + +#include +#include +#include + +uint8_t rescode_to_id(ResultCode code) { + switch(code) { + case NOERROR: + return 0; + case FORMERR: + return 1; + case SERVFAIL: + return 2; + case NXDOMAIN: + return 3; + case NOTIMP: + return 4; + case REFUSED: + return 5; + default: + return 2; + } +} + +ResultCode rescode_from_id(uint8_t id) { + switch(id) { + case 0: + return NOERROR; + case 1: + return FORMERR; + case 2: + return SERVFAIL; + case 3: + return NXDOMAIN; + case 4: + return NOTIMP; + case 5: + return REFUSED; + default: + return FORMERR; + } +} + +void read_header(PacketBuffer* buffer, Header* header) { + // memset(header, 0, sizeof(Header)); + header->id = buffer_read_short(buffer); + + uint8_t a = buffer_read(buffer); + header->recursion_desired = (a & (1 << 0)) > 0; + header->truncated_message = (a & (1 << 1)) > 0; + header->authorative_answer = (a & (1 << 2)) > 0; + header->opcode = (a >> 3) & 0x0F; + header->response = (a & (1 << 7)) > 0; + + uint8_t b = buffer_read(buffer); + header->rescode = rescode_from_id(b & 0x0F); + header->checking_disabled = (b & (1 << 4)) > 0; + header->authed_data = (b& (1 << 4)) > 0; + header->z = (b & (1 << 6)) > 0; + header->recursion_available = (b & (1 << 7)) > 0; + + header->questions = buffer_read_short(buffer); + header->answers = buffer_read_short(buffer); + header->authoritative_entries = buffer_read_short(buffer); + header->resource_entries = buffer_read_short(buffer); +} + +void write_header(PacketBuffer* buffer, Header* header) { + buffer_write_short(buffer, header->id); + + buffer_write(buffer, + ((uint8_t) header->recursion_desired) | + ((uint8_t) header->truncated_message << 1) | + ((uint8_t) header->authorative_answer << 2) | + (header->opcode << 3) | + ((uint8_t) header->response << 7) + ); + + buffer_write(buffer, + (rescode_to_id(header->rescode)) | + ((uint8_t) header->checking_disabled << 4) | + ((uint8_t) header->authed_data << 5) | + ((uint8_t) header->z << 6) | + ((uint8_t) header->recursion_available << 7) + ); + + buffer_write_short(buffer, header->questions); + buffer_write_short(buffer, header->answers); + buffer_write_short(buffer, header->authoritative_entries); + buffer_write_short(buffer, header->resource_entries); +} diff --git a/src/packet/header.h b/src/packet/header.h new file mode 100644 index 0000000..d9a8cea --- /dev/null +++ b/src/packet/header.h @@ -0,0 +1,41 @@ +#pragma once + +#include "buffer.h" + +#include + +typedef enum { + NOERROR, // 0 + FORMERR, // 1 + SERVFAIL, // 2 + NXDOMAIN, // 3, + NOTIMP, // 4 + REFUSED, // 5 +} ResultCode; + +uint8_t rescode_to_id(ResultCode code); +ResultCode rescode_from_id(uint8_t id); + +typedef struct { + uint16_t id; + + bool recursion_desired; // 1 bit + bool truncated_message; // 1 bit + bool authorative_answer; // 1 bit + uint8_t opcode; // 4 bits + bool response; // 1 bit + + ResultCode rescode; // 4 bits + bool checking_disabled; // 1 bit + bool authed_data; // 1 bit + bool z; // 1 bit + bool recursion_available; // 1 bit + + uint16_t questions; // 16 bits + uint16_t answers; // 16 bits + uint16_t authoritative_entries; // 16 bits + uint16_t resource_entries; // 16 bits +} Header; + +void read_header(PacketBuffer* buffer, Header* header); +void write_header(PacketBuffer* buffer, Header* header); diff --git a/src/packet/packet.c b/src/packet/packet.c new file mode 100644 index 0000000..9b1159d --- /dev/null +++ b/src/packet/packet.c @@ -0,0 +1,171 @@ +#include "packet.h" +#include "buffer.h" +#include "header.h" +#include "question.h" +#include "record.h" + +#include +#include +#include + +void read_packet(PacketBuffer* buffer, Packet* packet) { + read_header(buffer, &packet->header); + + packet->questions = malloc(sizeof(Question) * packet->header.questions); + for(uint16_t i = 0; i < packet->header.questions; i++) { + read_question(buffer, &packet->questions[i]); + } + + packet->answers = malloc(sizeof(Record) * packet->header.answers); + for(uint16_t i = 0; i < packet->header.answers; i++) { + read_record(buffer, &packet->answers[i]); + } + + packet->authorities = malloc(sizeof(Record) * packet->header.authoritative_entries); + for(uint16_t i = 0; i < packet->header.authoritative_entries; i++) { + read_record(buffer, &packet->authorities[i]); + } + + packet->resources = malloc(sizeof(Record) * packet->header.resource_entries); + for(uint16_t i = 0; i < packet->header.resource_entries; i++) { + read_record(buffer, &packet->resources[i]); + } +} + +void write_packet(PacketBuffer* buffer, Packet* packet) { + write_header(buffer, &packet->header); + + for(uint16_t i = 0; i < packet->header.questions; i++) { + write_question(buffer, &packet->questions[i]); + } + + for(uint16_t i = 0; i < packet->header.answers; i++) { + write_record(buffer, &packet->answers[i]); + } + + for(uint16_t i = 0; i < packet->header.authoritative_entries; i++) { + write_record(buffer, &packet->authorities[i]); + } + + for(uint16_t i = 0; i < packet->header.resource_entries; i++) { + write_record(buffer, &packet->resources[i]); + } +} + +void free_packet(Packet* packet) { + + for(uint16_t i = 0; i < packet->header.questions; i++) { + free_question(&packet->questions[i]); + } + free(packet->questions); + + for(uint16_t i = 0; i < packet->header.answers; i++) { + free_record(&packet->answers[i]); + } + free(packet->answers); + + for(uint16_t i = 0; i < packet->header.authoritative_entries; i++) { + free_record(&packet->authorities[i]); + } + free(packet->authorities); + + for(uint16_t i = 0; i < packet->header.resource_entries; i++) { + free_record(&packet->resources[i]); + } + free(packet->resources); +} + +bool get_random_a(Packet* packet, IpAddr* addr) { + for (uint16_t i = 0; i < packet->header.answers; i++) { + Record record = packet->answers[i]; + if (record.type == A) { + create_ip_addr((char*) &record.data.a.addr, addr); + return true; + } else if (record.type == AAAA) { + create_ip_addr6((char*) &record.data.aaaa.addr, addr); + return true; + } + } + return false; +} + +static bool ends_with(uint8_t* full, uint8_t* end) { + uint8_t check = end[0]; + uint8_t len = full[0]; + + if (check > len) { + return false; + } + + for(uint8_t i = 0; i < check; i++) { + if (end[check - 1 - i] != full[len - 1 - i]) { + return false; + } + } + + return true; +} + +static bool equals(uint8_t* a, uint8_t* b) { + if (a[0] != b[0]) { + return false; + } + + for(uint8_t i = 1; i < a[0] + 1; i++) { + if(a[i] != b[i]) { + return false; + } + } + + return true; +} + +bool get_resolved_ns(Packet* packet, uint8_t* qname, IpAddr* addr) { + for (uint16_t i = 0; i < packet->header.authoritative_entries; i++) { + Record record = packet->authorities[i]; + if (record.type != NS) { + continue; + } + + if(!ends_with(qname, record.domain)) { + continue; + } + + for (uint16_t i = 0; i < packet->header.resource_entries; i++) { + Record resource = packet->resources[i]; + if (!equals(record.data.ns.host, resource.domain)) { + continue; + } + + if (resource.type == A) { + create_ip_addr((char*) &record.data.a.addr, addr); + return true; + } else if (resource.type == AAAA) { + create_ip_addr6((char*) &record.data.aaaa.addr, addr); + return true; + } + } + } + return false; +} + +bool get_unresoled_ns(Packet* packet, uint8_t* qname, Question* question) { + for (uint16_t i = 0; i < packet->header.authoritative_entries; i++) { + Record record = packet->authorities[i]; + if (record.type != NS) { + continue; + } + + if(!ends_with(qname, record.domain)) { + continue; + } + + uint8_t* host = record.data.ns.host; + + question->qtype = NS; + question->domain = malloc(host[0] + 1); + memcpy(question->domain, host, host[0] + 1); + return true; + } + return false; +} \ No newline at end of file diff --git a/src/packet/packet.h b/src/packet/packet.h new file mode 100644 index 0000000..aa1c35e --- /dev/null +++ b/src/packet/packet.h @@ -0,0 +1,23 @@ +#pragma once + +#include "buffer.h" +#include "question.h" +#include "header.h" +#include "record.h" +#include "../server/addr.h" + +typedef struct { + Header header; + Question* questions; + Record* answers; + Record* authorities; + Record* resources; +} Packet; + +void read_packet(PacketBuffer* buffer, Packet* packet); +void write_packet(PacketBuffer* buffer, Packet* packet); +void free_packet(Packet* packet); + +bool get_random_a(Packet* packet, IpAddr* addr); +bool get_resolved_ns(Packet* packet, uint8_t* qname, IpAddr* addr); +bool get_unresoled_ns(Packet* packet, uint8_t* qname, Question* question); diff --git a/src/packet/question.c b/src/packet/question.c new file mode 100644 index 0000000..c2807d0 --- /dev/null +++ b/src/packet/question.c @@ -0,0 +1,94 @@ +#include +#include + +#include "question.h" +#include "buffer.h" +#include "record.h" +#include "../io/log.h" + +void read_question(PacketBuffer* buffer, Question* question) { + buffer_read_qname(buffer, &question->domain); + + uint16_t qtype_num = buffer_read_short(buffer); + record_from_id(qtype_num, &question->qtype); + question->cls = buffer_read_short(buffer); + + INIT_LOG_BUFFER(log) + LOGONLY(print_question(question, log);) + TRACE("Reading question: %s", log); +} + +void write_question(PacketBuffer* buffer, Question* question) { + buffer_write_qname(buffer, question->domain); + + uint16_t id = record_to_id(question->qtype); + buffer_write_short(buffer, id); + + buffer_write_short(buffer, question->cls); + + INIT_LOG_BUFFER(log) + LOGONLY(print_question(question, log);) + TRACE("Writing question: %s", log); +} + +void free_question(Question* question) { + free(question->domain); +} + +void print_question(Question* question, char* buffer) { + INIT_LOG_BOUNDS + switch (question->cls) { + case 1: + APPEND(buffer, "IN ");; + break; + case 3: + APPEND(buffer, "CH "); + break; + case 4: + APPEND(buffer, "HS "); + break; + default: + APPEND(buffer, "?? "); + break; + } + switch(question->qtype) { + case UNKOWN: + APPEND(buffer, "UNKOWN "); + break; + case A: + APPEND(buffer, "A "); + break; + case NS: + APPEND(buffer, "NS "); + break; + case CNAME: + APPEND(buffer, "CNAME "); + break; + case SOA: + APPEND(buffer, "SOA "); + break; + case PTR: + APPEND(buffer, "PTR "); + break; + case MX: + APPEND(buffer, "MX "); + break; + case TXT: + APPEND(buffer, "TXT "); + break; + case AAAA: + APPEND(buffer, "AAAA "); + break; + case SRV: + APPEND(buffer, "SRV "); + break; + case CAA: + APPEND(buffer, "CAA "); + break; + break; + } + APPEND(buffer, "%.*s", + question->domain[0], + question->domain + 1 + ); +} \ No newline at end of file diff --git a/src/packet/question.h b/src/packet/question.h new file mode 100644 index 0000000..e8c385a --- /dev/null +++ b/src/packet/question.h @@ -0,0 +1,15 @@ +#pragma once + +#include "buffer.h" +#include "record.h" + +typedef struct { + uint8_t* domain; + RecordType qtype; + uint16_t cls; +} Question; + +void read_question(PacketBuffer* buffer, Question* question); +void write_question(PacketBuffer* buffer, Question* question); +void free_question(Question* question); +void print_question(Question* question, char* buffer); diff --git a/src/packet/record.c b/src/packet/record.c new file mode 100644 index 0000000..29c3bf0 --- /dev/null +++ b/src/packet/record.c @@ -0,0 +1,540 @@ +#include +#include +#include + +#include "record.h" +#include "buffer.h" +#include "../io/log.h" + +uint16_t record_to_id(RecordType type) { + switch (type) { + case A: + return 1; + case NS: + return 2; + case CNAME: + return 5; + case SOA: + return 6; + case PTR: + return 12; + case MX: + return 15; + case TXT: + return 16; + case AAAA: + return 28; + case SRV: + return 33; + case CAA: + return 257; + default: + return 0; + } +} + +void record_from_id(uint16_t i, RecordType* type) { + switch (i) { + case 1: + *type = A; + break; + case 2: + *type = NS; + break; + case 5: + *type = CNAME; + break; + case 6: + *type = SOA; + break; + case 12: + *type = PTR; + break; + case 15: + *type = MX; + break; + case 16: + *type = TXT; + break; + case 28: + *type = AAAA; + break; + case 33: + *type = SRV; + break; + case 257: + *type = CAA; + break; + default: + *type = UNKOWN; + } +} + +static void read_a_record(PacketBuffer* buffer, Record* record) { + ARecord data; + data.addr[0] = buffer_read(buffer); + data.addr[1] = buffer_read(buffer); + data.addr[2] = buffer_read(buffer); + data.addr[3] = buffer_read(buffer); + + record->data.a = data; +} + +static void read_ns_record(PacketBuffer* buffer, Record* record) { + NSRecord data; + buffer_read_qname(buffer, &data.host); + + record->data.ns = data; +} + +static void read_cname_record(PacketBuffer* buffer, Record* record) { + CNAMERecord data; + buffer_read_qname(buffer, &data.host); + + record->data.cname = data; +} + +static void read_soa_record(PacketBuffer* buffer, Record* record) { + SOARecord data; + buffer_read_qname(buffer, &data.mname); + buffer_read_qname(buffer, &data.nname); + data.serial = buffer_read_int(buffer); + data.refresh = buffer_read_int(buffer); + data.retry = buffer_read_int(buffer); + data.expire = buffer_read_int(buffer); + data.minimum = buffer_read_int(buffer); + + record->data.soa = data; +} + +static void read_ptr_record(PacketBuffer* buffer, Record* record) { + PTRRecord data; + buffer_read_qname(buffer, &data.pointer); + + record->data.ptr = data; +} + +static void read_mx_record(PacketBuffer* buffer, Record* record) { + MXRecord data; + data.priority = buffer_read_short(buffer); + buffer_read_qname(buffer, &data.host); + + record->data.mx = data; +} + +static void read_txt_record(PacketBuffer* buffer, Record* record) { + TXTRecord data; + data.len = 0; + data.text = malloc(sizeof(uint8_t*) * 2); + + uint8_t capacity = 2; + while (1) { + if (data.len >= capacity) { + capacity *= 2; + data.text = realloc(data.text, sizeof(uint8_t*) * capacity); + } + + buffer_read_string(buffer, &data.text[data.len]); + if(data.text[data.len][0] == 0) break; + data.len++; + } + + record->data.txt = data; +} + +static void read_aaaa_record(PacketBuffer* buffer, Record* record) { + AAAARecord data; + for (int i = 0; i < 16; i++) { + data.addr[i] = buffer_read(buffer); + } + + record->data.aaaa = data; +} + +static void read_srv_record(PacketBuffer* buffer, Record* record) { + SRVRecord data; + data.priority = buffer_read_short(buffer); + data.weight = buffer_read_short(buffer); + data.port = buffer_read_short(buffer); + buffer_read_qname(buffer, &data.target); + + record->data.srv = data; +} + +static void read_caa_record(PacketBuffer* buffer, Record* record, int header_pos) { + CAARecord data; + data.flags = buffer_read(buffer); + data.length = buffer_read(buffer); + buffer_read_n(buffer, &data.tag, data.length); + int value_len = ((int)record->len) + header_pos - buffer_get_index(buffer); + buffer_read_n(buffer, &data.value, (uint8_t)value_len); + + record->data.caa = data; +} + +void read_record(PacketBuffer* buffer, Record* record) { + buffer_read_qname(buffer, &record->domain); + + uint16_t qtype_num = buffer_read_short(buffer); + record_from_id(qtype_num, &record->type); + + record->cls = buffer_read_short(buffer); + record->ttl = buffer_read_int(buffer); + record->len = buffer_read_short(buffer); + + int header_pos = buffer_get_index(buffer); + + switch (record->type) { + case A: + read_a_record(buffer, record); + break; + case NS: + read_ns_record(buffer, record); + break; + case CNAME: + read_cname_record(buffer, record); + break; + case SOA: + read_soa_record(buffer, record); + break; + case PTR: + read_ptr_record(buffer, record); + break; + case MX: + read_mx_record(buffer, record); + break; + case TXT: + read_txt_record(buffer, record); + break; + case AAAA: + read_aaaa_record(buffer, record); + break; + case SRV: + read_srv_record(buffer, record); + break; + case CAA: + read_caa_record(buffer, record, header_pos); + break; + default: + buffer_step(buffer, record->len); + return; + } + + INIT_LOG_BUFFER(log) + LOGONLY(print_record(record, log);) + TRACE("Reading record: %s", log); +} + +static void write_a_record(PacketBuffer* buffer, Record* record) { + ARecord data = record->data.a; + buffer_write_short(buffer, 4); + buffer_write(buffer, record->data.a.addr[0]); + buffer_write(buffer, data.addr[1]); + buffer_write(buffer, data.addr[2]); + buffer_write(buffer, data.addr[3]); +} + +static void write_ns_record(PacketBuffer* buffer, Record* record) { + NSRecord data = record->data.ns; + int pos = buffer_get_index(buffer); + buffer_write_short(buffer, 0); + + buffer_write_qname(buffer, data.host); + + int size = buffer_get_index(buffer) - pos - 2; + buffer_set_uint16_t(buffer, (uint16_t)size, pos); +} + +static void write_cname_record(PacketBuffer* buffer, Record* record) { + CNAMERecord data = record->data.cname; + int pos = buffer_get_index(buffer); + buffer_write_short(buffer, 0); + + buffer_write_qname(buffer, data.host); + + int size = buffer_get_index(buffer) - pos - 2; + buffer_set_uint16_t(buffer, (uint16_t)size, pos); +} + +static void write_soa_record(PacketBuffer* buffer, Record* record) { + SOARecord data = record->data.soa; + int pos = buffer_get_index(buffer); + buffer_write_short(buffer, 0); + + buffer_write_qname(buffer, data.mname); + buffer_write_qname(buffer, data.nname); + buffer_write_int(buffer, data.serial); + buffer_write_int(buffer, data.refresh); + buffer_write_int(buffer, data.retry); + buffer_write_int(buffer, data.expire); + buffer_write_int(buffer, data.minimum); + + int size = buffer_get_index(buffer) - pos - 2; + buffer_set_uint16_t(buffer, (uint16_t)size, pos); +} + +static void write_ptr_record(PacketBuffer* buffer, Record* record) { + PTRRecord data = record->data.ptr; + int pos = buffer_get_index(buffer); + buffer_write_short(buffer, 0); + + buffer_write_qname(buffer, data.pointer); + + int size = buffer_get_index(buffer) - pos - 2; + buffer_set_uint16_t(buffer, (uint16_t)size, pos); +} + +static void write_mx_record(PacketBuffer* buffer, Record* record) { + MXRecord data = record->data.mx; + int pos = buffer_get_index(buffer); + buffer_write_short(buffer, 0); + + buffer_write_short(buffer, data.priority); + buffer_write_qname(buffer, data.host); + + int size = buffer_get_index(buffer) - pos - 2; + buffer_set_uint16_t(buffer, (uint16_t)size, pos); +} + +static void write_txt_record(PacketBuffer* buffer, Record* record) { + TXTRecord data = record->data.txt; + int pos = buffer_get_index(buffer); + buffer_write_short(buffer, 0); + + if(data.len == 0) { + return; + } + + for(uint8_t i = 0; i < data.len; i++) { + buffer_write_string(buffer, data.text[i]); + } + + int size = buffer_get_index(buffer) - pos - 2; + buffer_set_uint16_t(buffer, (uint16_t)size, pos); +} + +static void write_aaaa_record(PacketBuffer* buffer, Record* record) { + AAAARecord data = record->data.aaaa; + + buffer_write_short(buffer, 16); + + for (int i = 0; i < 16; i++) { + buffer_write(buffer, data.addr[i]); + } +} + +static void write_srv_record(PacketBuffer* buffer, Record* record) { + SRVRecord data = record->data.srv; + int pos = buffer_get_index(buffer); + buffer_write_short(buffer, 0); + + buffer_write_short(buffer, data.priority); + buffer_write_short(buffer, data.weight); + buffer_write_short(buffer, data.port); + buffer_write_qname(buffer, data.target); + + int size = buffer_get_index(buffer) - pos - 2; + buffer_set_uint16_t(buffer, (uint16_t)size, pos); +} + +static void write_caa_record(PacketBuffer* buffer, Record* record) { + CAARecord data = record->data.caa; + int pos = buffer_get_index(buffer); + buffer_write_short(buffer, 0); + buffer_write(buffer, data.flags); + buffer_write(buffer, data.length); + buffer_write_n(buffer, data.tag + 1, data.tag[0]); + buffer_write_n(buffer, data.value + 1, data.value[0]); + + int size = buffer_get_index(buffer) - pos - 2; + buffer_set_uint16_t(buffer, (uint16_t)size, pos); +} + +static void write_record_header(PacketBuffer* buffer, Record* record) { + buffer_write_qname(buffer, record->domain); + uint16_t id = record_to_id(record->type); + buffer_write_short(buffer, id); + buffer_write_short(buffer, record->cls); + buffer_write_int(buffer, record->ttl); +} + +void write_record(PacketBuffer* buffer, Record* record) { + switch(record->type) { + case A: + write_record_header(buffer, record); + write_a_record(buffer, record); + break; + case NS: + write_record_header(buffer, record); + write_ns_record(buffer, record); + break; + case CNAME: + write_record_header(buffer, record); + write_cname_record(buffer, record); + break; + case SOA: + write_record_header(buffer, record); + write_soa_record(buffer, record); + break; + case PTR: + write_record_header(buffer, record); + write_ptr_record(buffer, record); + break; + case MX: + write_record_header(buffer, record); + write_mx_record(buffer, record); + break; + case TXT: + write_record_header(buffer, record); + write_txt_record(buffer, record); + break; + case AAAA: + write_record_header(buffer, record); + write_aaaa_record(buffer, record); + break; + case SRV: + write_record_header(buffer, record); + write_srv_record(buffer, record); + break; + case CAA: + write_record_header(buffer, record); + write_caa_record(buffer, record); + break; + default: + break; + } + + INIT_LOG_BUFFER(log) + LOGONLY(print_record(record, log);) + TRACE("Writing record: %s", log); +} + +void free_record(Record* record) { + free(record->domain); + switch (record->type) { + case NS: + free(record->data.ns.host); + break; + case CNAME: + free(record->data.cname.host); + break; + case SOA: + free(record->data.soa.mname); + free(record->data.soa.nname); + break; + case PTR: + free(record->data.ptr.pointer); + break; + case MX: + free(record->data.mx.host); + break; + case TXT: + for (uint8_t i = 0; i < record->data.txt.len; i++) { + free(record->data.txt.text[i]); + } + free(record->data.txt.text); + break; + case SRV: + free(record->data.srv.target); + break; + case CAA: + free(record->data.caa.value); + free(record->data.caa.tag); + break; + default: + break; + } +} + +void print_record(Record* record, char* buffer) { + INIT_LOG_BOUNDS + switch(record->type) { + case UNKOWN: + APPEND(buffer, "UNKOWN"); + break; + case A: + APPEND(buffer, "A (%hhu.%hhu.%hhu.%hhu)", + record->data.a.addr[0], + record->data.a.addr[1], + record->data.a.addr[2], + record->data.a.addr[3] + ); + break; + case NS: + APPEND(buffer, "NS (%.*s)", + record->data.ns.host[0], + record->data.ns.host + 1 + ); + break; + case CNAME: + APPEND(buffer, "CNAME (%.*s)", + record->data.cname.host[0], + record->data.cname.host + 1 + ); + break; + case SOA: + APPEND(buffer, "SOA (%.*s %.*s %u %u %u %u %u)", + record->data.soa.mname[0], + record->data.soa.mname + 1, + record->data.soa.nname[0], + record->data.soa.nname + 1, + record->data.soa.serial, + record->data.soa.refresh, + record->data.soa.retry, + record->data.soa.expire, + record->data.soa.minimum + ); + break; + case PTR: + APPEND(buffer, "PTR (%.*s)", + record->data.ptr.pointer[0], + record->data.ptr.pointer + 1 + ); + break; + case MX: + APPEND(buffer, "MX (%.*s %hu)", + record->data.mx.host[0], + record->data.mx.host + 1, + record->data.mx.priority + ); + break; + case TXT: + APPEND(buffer, "TXT ("); + for(uint8_t i = 0; i < record->data.txt.len; i++) { + APPEND(buffer, "\"%.*s\"", + record->data.txt.text[i][0], + record->data.txt.text[i] + 1 + ); + } + APPEND(buffer, ")"); + break; + case AAAA: + APPEND(buffer, "AAAA ("); + for(int i = 0; i < 8; i++) { + APPEND(buffer, "%02hhx%02hhx:", + record->data.a.addr[i*2 + 0], + record->data.a.addr[i*2 + 1] + ); + } + APPEND(buffer, ":)"); + break; + case SRV: + APPEND(buffer, "SRV (%hu %hu %hu %.*s)", + record->data.srv.priority, + record->data.srv.weight, + record->data.srv.port, + record->data.srv.target[0], + record->data.srv.target + 1 + ); + break; + case CAA: + APPEND(buffer, "CAA (%hhu %.*s %.*s)", + record->data.caa.flags, + record->data.caa.tag[0], + record->data.caa.tag + 1, + record->data.caa.value[0], + record->data.caa.value + 1 + ); + break; + } +} \ No newline at end of file diff --git a/src/packet/record.h b/src/packet/record.h new file mode 100644 index 0000000..95bbbbe --- /dev/null +++ b/src/packet/record.h @@ -0,0 +1,101 @@ +#pragma once + +#include "buffer.h" + +#include + +typedef enum { + UNKOWN, + A, // 1 + NS, // 2 + CNAME, // 5 + SOA, // 6 + PTR, // 12 + MX, // 15 + TXT, // 16 + AAAA, // 28 + SRV, // 33 + CAA // 257 +} RecordType; + +uint16_t record_to_id(RecordType type); +void record_from_id(uint16_t i, RecordType* type); + +typedef struct { + uint8_t addr[4]; +} ARecord; + +typedef struct { + uint8_t* host; +} NSRecord; + +typedef struct { + uint8_t* host; +} CNAMERecord; + +typedef struct { + uint8_t* mname; + uint8_t* nname; + uint32_t serial; + uint32_t refresh; + uint32_t retry; + uint32_t expire; + uint32_t minimum; +} SOARecord; + +typedef struct { + uint8_t* pointer; +} PTRRecord; + +typedef struct { + uint16_t priority; + uint8_t* host; +} MXRecord; + +typedef struct TXTRecord { + uint8_t** text; + uint8_t len; +} TXTRecord; + +typedef struct { + uint8_t addr[16]; +} AAAARecord; + +typedef struct { + uint16_t priority; + uint16_t weight; + uint16_t port; + uint8_t* target; +} SRVRecord; + +typedef struct { + uint8_t flags; + uint8_t length; + uint8_t* tag; + uint8_t* value; +} CAARecord; + +typedef struct { + uint32_t ttl; + uint16_t cls; + uint16_t len; + uint8_t* domain; + RecordType type; + union data { + ARecord a; + NSRecord ns; + CNAMERecord cname; + SOARecord soa; + PTRRecord ptr; + MXRecord mx; + TXTRecord txt; + AAAARecord aaaa; + SRVRecord srv; + CAARecord caa; + } data; +} Record; + +void read_record(PacketBuffer* buffer, Record* record); +void write_record(PacketBuffer* buffer, Record* record); +void free_record(Record* record); +void print_record(Record* record, char* buffer); \ No newline at end of file diff --git a/src/server/addr.c b/src/server/addr.c new file mode 100644 index 0000000..982da13 --- /dev/null +++ b/src/server/addr.c @@ -0,0 +1,233 @@ +#include +#include +#include +#include +#include + +#include "addr.h" +#include "../io/log.h" + +void create_ip_addr(char* domain, IpAddr* addr) { + addr->type = V4; + memcpy(&addr->data.v4.s_addr, domain, 4); +} + +void create_ip_addr6(char* domain, IpAddr* addr) { + addr->type = V6; + memcpy(&addr->data.v6.__in6_u.__u6_addr8, domain, 16); +} + +void ip_addr_any(IpAddr* addr) { + addr->type = V4; + addr->data.v4.s_addr = htonl(INADDR_ANY); +} + +void ip_addr_any6(IpAddr* addr) { + addr->type = V6; + addr->data.v6 = in6addr_any; +} + +static struct sockaddr_in create_socket_addr_v4(IpAddr addr, uint16_t port) { + struct sockaddr_in socketaddr; + memset(&socketaddr, 0, sizeof(socketaddr)); + socketaddr.sin_family = AF_INET; + socketaddr.sin_port = htons(port); + socketaddr.sin_addr = addr.data.v4; + return socketaddr; +} + +static struct sockaddr_in6 create_socket_addr_v6(IpAddr addr, uint16_t port) { + struct sockaddr_in6 socketaddr; + memset(&socketaddr, 0, sizeof(socketaddr)); + socketaddr.sin6_family = AF_INET6; + socketaddr.sin6_port = htons(port); + socketaddr.sin6_addr = addr.data.v6; + return socketaddr; +} + +static size_t get_addr_len(AddrType type) { + if (type == V4) { + return sizeof(struct sockaddr_in); + } else if (type == V6) { + return sizeof(struct sockaddr_in6); + } else { + return 0; + } +} + +void create_socket_addr(uint16_t port, IpAddr addr, SocketAddr* socket) { + socket->type = addr.type; + if (addr.type == V4) { + socket->data.v4 = create_socket_addr_v4(addr, port); + } else if(addr.type == V6) { + socket->data.v6 = create_socket_addr_v6(addr, port); + } else { + ERROR("Tried to create socketaddr with invalid protocol type"); + exit(EXIT_FAILURE); + } + socket->len = get_addr_len(addr.type); +} + +void print_socket_addr(SocketAddr* addr, char* buffer) { + INIT_LOG_BOUNDS + if(addr->type == V4) { + APPEND(buffer, "%hhu.%hhu.%hhu.%hhu:%hu", + (uint8_t) ((uint32_t)addr->data.v4.sin_addr.s_addr >> 24), + (uint8_t) ((uint32_t)addr->data.v4.sin_addr.s_addr >> 16), + (uint8_t) ((uint32_t)addr->data.v4.sin_addr.s_addr >> 8), + (uint8_t) ((uint32_t)addr->data.v4.sin_addr.s_addr), + addr->data.v4.sin_port + ); + } else { + for(int i = 0; i < 8; i++) { + APPEND(buffer, "%02hhx%02hhx:", + addr->data.v6.sin6_addr.__in6_u.__u6_addr8[i*2 + 0], + addr->data.v6.sin6_addr.__in6_u.__u6_addr8[i*2 + 1] + ); + } + APPEND(buffer, ":[%hu]", addr->data.v6.sin6_port); + } +} + +#define ADDR_DOMAIN(addr, var) \ + struct sockaddr* var; \ + if (addr->type == V4) { \ + var = (struct sockaddr*) &addr->data.v4; \ + } else if (addr->type == V6) { \ + var = (struct sockaddr*) &addr->data.v6; \ + } else { \ + return -1; \ + } + +#define ADDR_AFNET(type, var) \ + int var; \ + if (type == V4) { \ + var = AF_INET; \ + } else if (type == V6) { \ + var = AF_INET6; \ + } else { \ + return -1; \ + } + +int32_t create_udp_socket(AddrType type, UdpSocket* sock) { + ADDR_AFNET(type, __domain) + sock->type = type; + sock->sockfd = socket(__domain, SOCK_DGRAM, 0); + return sock->sockfd; +} + +int32_t bind_udp_socket(SocketAddr* addr, UdpSocket* sock) { + if (addr->type == V6) { + int v6OnlyEnabled = 0; + int32_t res = setsockopt( + sock->sockfd, + IPPROTO_IPV6, + IPV6_V6ONLY, + &v6OnlyEnabled, + sizeof(v6OnlyEnabled) + ); + if (res < 0) return res; + } + ADDR_DOMAIN(addr, __addr) + return bind(sock->sockfd, __addr, addr->len); +} + +int32_t read_udp_socket(UdpSocket* socket, void* buffer, uint16_t len, SocketAddr* clientaddr) { + clientaddr->type = socket->type; + clientaddr->len = get_addr_len(socket->type); + ADDR_DOMAIN(clientaddr, __addr) + return recvfrom( + socket->sockfd, + buffer, + (size_t) len, + MSG_WAITALL, + __addr, + (uint32_t*) &clientaddr->len + ); +} + +int32_t write_udp_socket(UdpSocket* socket, void* buffer, uint16_t len, SocketAddr* clientaddr) { + ADDR_DOMAIN(clientaddr, __addr) + return sendto( + socket->sockfd, + buffer, + (size_t) len, + MSG_CONFIRM, + __addr, + (uint32_t) clientaddr->len + ); +} + +int32_t close_udp_socket(UdpSocket* socket) { + return close(socket->sockfd); +} + +int32_t create_tcp_socket(AddrType type, TcpSocket* sock) { + ADDR_AFNET(type, __domain) + sock->type = type; + sock->sockfd = socket(__domain, SOCK_STREAM, 0); + return sock->sockfd; +} + +int32_t bind_tcp_socket(SocketAddr* addr, TcpSocket* sock) { + if (addr->type == V6) { + int v6OnlyEnabled = 0; + int32_t res = setsockopt( + sock->sockfd, + IPPROTO_IPV6, + IPV6_V6ONLY, + &v6OnlyEnabled, + sizeof(v6OnlyEnabled) + ); + if (res < 0) return res; + } + ADDR_DOMAIN(addr, __addr) + return bind(sock->sockfd, __addr, addr->len); +} + +int32_t listen_tcp_socket(TcpSocket* socket, uint32_t max) { + return listen(socket->sockfd, max); +} + +int32_t accept_tcp_socket(TcpSocket* socket, TcpStream* stream) { + stream->clientaddr.type = socket->type; + memset(&stream->clientaddr, 0, sizeof(SocketAddr)); + SocketAddr* addr = &stream->clientaddr; + ADDR_DOMAIN(addr, __addr) + stream->streamfd = accept( + socket->sockfd, + __addr, + (uint32_t*) &stream->clientaddr.len + ); + return stream->streamfd; +} + +int32_t close_tcp_socket(TcpSocket* socket) { + return close(socket->sockfd); +} + +int32_t connect_tcp_stream(SocketAddr* servaddr, TcpStream* stream) { + TcpSocket socket; + int32_t res = create_tcp_socket(servaddr->type, &socket); + if (res < 0) return res; + stream->clientaddr = *servaddr; + stream->streamfd = socket.sockfd; + ADDR_DOMAIN(servaddr, __addr) + return connect( + socket.sockfd, + __addr, + servaddr->len + ); +} + +int32_t read_tcp_stream(TcpStream* stream, void* buffer, uint16_t len) { + return recv(stream->streamfd, buffer, len, 0); +} + +int32_t write_tcp_stream(TcpStream* stream, void* buffer, uint16_t len) { + return send(stream->streamfd, buffer, len, MSG_NOSIGNAL); +} + +int32_t close_tcp_stream(TcpStream* stream) { + return close(stream->streamfd); +} diff --git a/src/server/addr.h b/src/server/addr.h new file mode 100644 index 0000000..173c7fd --- /dev/null +++ b/src/server/addr.h @@ -0,0 +1,69 @@ +#pragma once + +#include "../packet/record.h" + +#include +#include +#include +#include + +typedef enum { + V4, + V6 +} AddrType; + +typedef struct { + AddrType type; + union { + struct in_addr v4; + struct in6_addr v6; + } data; +} IpAddr; + +void create_ip_addr(char* domain, IpAddr* addr); +void create_ip_addr6(char* domain, IpAddr* addr); +void ip_addr_any(IpAddr* addr); +void ip_addr_any6(IpAddr* addr); + +typedef struct { + AddrType type; + union { + struct sockaddr_in v4; + struct sockaddr_in6 v6; + } data; + size_t len; +} SocketAddr; + +void create_socket_addr(uint16_t port, IpAddr addr, SocketAddr* socket); +void print_socket_addr(SocketAddr* addr, char* buffer); + +typedef struct { + AddrType type; + uint32_t sockfd; +} UdpSocket; + +int32_t create_udp_socket(AddrType type, UdpSocket* socket); +int32_t bind_udp_socket(SocketAddr* addr, UdpSocket* socket); +int32_t read_udp_socket(UdpSocket* socket, void* buffer, uint16_t len, SocketAddr* clientaddr); +int32_t write_udp_socket(UdpSocket* socket, void* buffer, uint16_t len, SocketAddr* clientaddr); +int32_t close_udp_socket(UdpSocket* socket); + +typedef struct { + AddrType type; + uint32_t sockfd; +} TcpSocket; + +typedef struct { + SocketAddr clientaddr; + uint32_t streamfd; +} TcpStream; + +int32_t create_tcp_socket(AddrType type, TcpSocket* socket); +int32_t bind_tcp_socket(SocketAddr* addr, TcpSocket* socket); +int32_t listen_tcp_socket(TcpSocket* socket, uint32_t max); +int32_t accept_tcp_socket(TcpSocket* socket, TcpStream* stream); +int32_t close_tcp_socket(TcpSocket* socket); +int32_t connect_tcp_stream(SocketAddr* servaddr, TcpStream* stream); +int32_t read_tcp_stream(TcpStream* stream, void* buffer, uint16_t len); +int32_t write_tcp_stream(TcpStream* stream, void* buffer, uint16_t len); +int32_t close_tcp_stream(TcpStream* stream); diff --git a/src/server/binding.c b/src/server/binding.c new file mode 100644 index 0000000..47c62c6 --- /dev/null +++ b/src/server/binding.c @@ -0,0 +1,245 @@ +#include +#include +#include +#include +#include +#include + +#include "addr.h" +#include "binding.h" +#include "../io/log.h" + +static void create_udp_binding(UdpSocket* socket, uint16_t port) { + if (create_udp_socket(V6, socket) < 0) { + ERROR("Failed to create UDP socket: %s", strerror(errno)); + exit(EXIT_FAILURE); + } + + IpAddr addr; + ip_addr_any6(&addr); + + SocketAddr socketaddr; + create_socket_addr(port, addr, &socketaddr); + + if (bind_udp_socket(&socketaddr, socket) < 0) { + ERROR("Failed to bind UDP socket on port %hu: %s", port, strerror(errno)); + exit(EXIT_FAILURE); + } +} + +static void create_tcp_binding(TcpSocket* socket, uint16_t port) { + if (create_tcp_socket(V6, socket) < 0) { + ERROR("Failed to create TCP socket: %s", strerror(errno)); + exit(EXIT_FAILURE); + } + + IpAddr addr; + ip_addr_any6(&addr); + + SocketAddr socketaddr; + create_socket_addr(port, addr, &socketaddr); + + if (bind_tcp_socket(&socketaddr, socket) < 0) { + ERROR("Failed to bind TCP socket on port %hu: %s", port, strerror(errno)); + exit(EXIT_FAILURE); + } + + if (listen_tcp_socket(socket, 5) < 0) { + ERROR("Failed to listen on TCP socket: %s", strerror(errno)); + exit(EXIT_FAILURE); + } +} + +void create_binding(BindingType type, uint16_t port, Binding* binding) { + binding->type = type; + if (type == UDP) { + create_udp_binding(&binding->sock.udp, port); + } else if(type == TCP) { + create_tcp_binding(&binding->sock.tcp, port); + } else { + exit(EXIT_FAILURE); + } +} + +void free_binding(Binding* binding) { + if (binding->type == UDP) { + close_udp_socket(&binding->sock.udp); + } else if(binding->type == TCP) { + close_tcp_socket(&binding->sock.tcp); + } +} + +bool accept_connection(Binding* binding, Connection* connection) { + connection->type = binding->type; + + if(binding->type == UDP) { + connection->sock.udp.udp = binding->sock.udp; + memset(&connection->sock.udp.clientaddr, 0, sizeof(SocketAddr)); + return true; + } + + if (accept_tcp_socket(&binding->sock.tcp, &connection->sock.tcp) < 0) { + ERROR("Failed to accept TCP connection: %s", strerror(errno)); + return false; + } + + return true; +} + +static void read_to_packet(uint8_t* buf, uint16_t len, Packet* packet) { + PacketBuffer* pkbuffer = buffer_create(len); + for (int i = 0; i < len; i++) { + buffer_write(pkbuffer, buf[i]); + } + buffer_seek(pkbuffer, 0); + read_packet(pkbuffer, packet); + buffer_free(pkbuffer); +} + +static bool read_udp(Connection* connection, Packet* packet) { + uint8_t buffer[512]; + int32_t n = read_udp_socket( + &connection->sock.udp.udp, + buffer, + 512, + &connection->sock.udp.clientaddr + ); + if (n < 0) { + return false; + } + read_to_packet(buffer, n, packet); + return true; +} + +static bool read_tcp(Connection* connection, Packet* packet) { + uint16_t len; + if ( read_tcp_stream( + &connection->sock.tcp, + &len, + sizeof(uint16_t) + ) < 0) { + return false; + } + + uint8_t buffer[len]; + if ( read_tcp_stream( + &connection->sock.tcp, + buffer, + len + ) < 0) { + return false; + } + + read_to_packet(buffer, len, packet); + return true; +} + +bool read_connection(Connection* connection, Packet* packet) { + if (connection->type == UDP) { + return read_udp(connection, packet); + } else if (connection->type == TCP) { + return read_tcp(connection, packet); + } + return false; +} + +static bool write_udp(Connection* connection, uint8_t* buf, uint16_t len) { + //if (len > 512) { + buf[2] = buf[2] | 0x03; + // len = 512; + // } + return write_udp_socket( + &connection->sock.udp.udp, + buf, + len, + &connection->sock.udp.clientaddr + ) == len; +} + +static bool write_tcp(Connection* connection, uint8_t* buf, uint16_t len) { + len = htons(len); + if (write_tcp_stream( + &connection->sock.tcp, + &len, + sizeof(uint16_t) + ) < 0) { + return false; + } + + if (write_tcp_stream( + &connection->sock.tcp, + buf, + len + ) < 0) { + return false; + } + + return true; +} + +bool write_connection(Connection* connection, Packet* packet) { + PacketBuffer* pkbuffer = buffer_create(64); + write_packet(pkbuffer, packet); + uint16_t len = buffer_get_size(pkbuffer); + uint8_t* buffer = buffer_get_ptr(pkbuffer); + bool success = false; + if(connection->type == UDP) { + success = write_udp(connection, buffer, len); + } else if(connection->type == TCP) { + success = write_tcp(connection, buffer, len); + }; + buffer_free(pkbuffer); + return success; +} + +void free_connection(Connection* connection) { + if (connection->type == TCP) { + close_tcp_stream(&connection->sock.tcp); + } +} + +static bool create_udp_request(SocketAddr* addr, Connection* request) { + if ( create_udp_socket(addr->type, &request->sock.udp.udp) < 0) { + ERROR("Failed to connect to UDP socket: %s", strerror(errno)); + return false; + } + request->sock.udp.clientaddr = *addr; + return true; +} + +static bool create_tcp_request(SocketAddr* addr, Connection* request) { + if( connect_tcp_stream(addr, &request->sock.tcp) < 0) { + ERROR("Failed to connect to TCP socket: %s", strerror(errno)); + return false; + } + return true; +} + +bool create_request(BindingType type, SocketAddr* addr, Connection* request) { + request->type = type; + if (type == UDP) { + return create_udp_request(addr, request); + } else if (type == TCP) { + return create_tcp_request(addr, request); + } else { + return true; + } +} + +bool request_packet(Connection* request, Packet* in, Packet* out) { + if (!write_connection(request, in)) { + return false; + } + if (!read_connection(request, out)) { + return false; + } + return true; +} + +void free_request(Connection* connection) { + if (connection->type == UDP) { + close_udp_socket(&connection->sock.udp.udp); + } else if (connection->type == TCP) { + close_tcp_stream(&connection->sock.tcp); + } +} diff --git a/src/server/binding.h b/src/server/binding.h new file mode 100644 index 0000000..e2e6160 --- /dev/null +++ b/src/server/binding.h @@ -0,0 +1,42 @@ +#pragma once + +#include "../packet/packet.h" +#include "addr.h" + +#include + +typedef enum { + UDP, + TCP +} BindingType; + +typedef struct { + BindingType type; + union { + UdpSocket udp; + TcpSocket tcp; + } sock; +} Binding; + +void create_binding(BindingType type, uint16_t port, Binding* binding); +void free_binding(Binding* binding); + +typedef struct { + BindingType type; + union { + struct { + UdpSocket udp; + SocketAddr clientaddr; + } udp; + TcpStream tcp; + } sock; +} Connection; + +bool accept_connection(Binding* binding, Connection* connection); +bool read_connection(Connection* connection, Packet* packet); +bool write_connection(Connection* connection, Packet* packet); +void free_connection(Connection* connection); + +bool create_request(BindingType type, SocketAddr* addr, Connection* request); +bool request_packet(Connection* request, Packet* in, Packet* out); +void free_request(Connection* connection); diff --git a/src/server/resolver.c b/src/server/resolver.c new file mode 100644 index 0000000..e05f365 --- /dev/null +++ b/src/server/resolver.c @@ -0,0 +1,166 @@ +#include +#include +#include + +#include "resolver.h" +#include "addr.h" +#include "binding.h" +#include "../io/log.h" + +static bool lookup( + Question* question, + Packet* response, + BindingType type, + SocketAddr addr +) { + INIT_LOG_BUFFER(log) + LOGONLY(print_socket_addr(&addr, log);) + TRACE("Attempting lookup on fallback dns %s", log); + + Connection request; + if (!create_request(type, &addr, &request)) { + return false; + } + + Packet req; + memset(&req, 0, sizeof(Packet)); + req.header.id = response->header.id; + req.header.opcode = response->header.opcode; + req.header.questions = 1; + req.header.recursion_desired = true; + req.questions = malloc(sizeof(Question)); + req.questions[0] = *question; + + if (!request_packet(&request, &req, response)) { + free_request(&request); + free(req.questions); + ERROR("Failed to request fallback dns: %s", strerror(errno)); + return false; + } + + free_request(&request); + free(req.questions); + return true; +} + +static bool search(Question* question, Packet* result, BindingType type) { + IpAddr addr; + char ip[4] = {1, 1, 1, 1}; + create_ip_addr(ip, &addr); + + uint16_t port = 53; + SocketAddr saddr; + create_socket_addr(port, addr, &saddr); + + while(1) { + if (!lookup(question, result, type, saddr)) { + return false; + } + + if (result->header.answers > 0 && result->header.rescode == NOERROR) { + return true; + } + + if (result->header.rescode == NXDOMAIN) { + return true; + } + + if (get_resolved_ns(result, question->domain, &addr)) { + continue; + } + + Question new_question; + if (!get_unresoled_ns(result, question->domain, &new_question)) { + return true; + } + + Packet recurse; + if (!search(&new_question, &recurse, type)) { + return false; + } + + free_question(&new_question); + + IpAddr random; + if (!get_random_a(&recurse, &random)) { + free_packet(&recurse); + return true; + } else { + free_packet(&recurse); + addr = random; + } + } +} + +static void push_records(Record* from, uint8_t from_len, Record** to, uint8_t to_len) { + if(from_len < 1) return; + *to = realloc(*to, sizeof(Record) * (from_len + to_len)); + memcpy(*to + to_len, from, from_len * sizeof(Record)); +} + +static void push_questions(Question* from, uint8_t from_len, Question** to, uint8_t to_len) { + if(from_len < 1) return; + *to = realloc(*to, sizeof(Question) * (from_len + to_len)); + memcpy(*to + to_len, from, from_len * sizeof(Question)); +} + +void handle_query(Packet* request, Packet* response, BindingType type) { + memset(response, 0, sizeof(Packet)); + response->header.id = request->header.id; + response->header.opcode = request->header.opcode; + response->header.recursion_desired = true; + response->header.recursion_available = true; + response->header.response = true; + + if (request->header.questions < 1) { + response->header.response = FORMERR; + return; + } + + for (uint16_t i = 0; i < request->header.questions; i++) { + Packet result; + memset(&result, 0, sizeof(Packet)); + result.header.id = response->header.id; + if (!search(&request->questions[i], &result, type)) { + response->header.response = SERVFAIL; + break; + } + + push_questions( + result.questions, + result.header.questions, + &response->questions, + response->header.questions + ); + response->header.questions += result.header.questions; + + push_records( + result.answers, + result.header.answers, + &response->answers, + response->header.answers + ); + response->header.answers += result.header.answers; + + push_records( + result.authorities, + result.header.authoritative_entries, + &response->authorities, + response->header.authoritative_entries + ); + response->header.authoritative_entries += result.header.authoritative_entries; + + push_records( + result.resources, + result.header.resource_entries, + &response->resources, + response->header.resource_entries + ); + response->header.resource_entries += result.header.resource_entries; + + free(result.questions); + free(result.answers); + free(result.authorities); + free(result.resources); + } +} \ No newline at end of file diff --git a/src/server/resolver.h b/src/server/resolver.h new file mode 100644 index 0000000..79b4825 --- /dev/null +++ b/src/server/resolver.h @@ -0,0 +1,6 @@ +#pragma once + +#include "../packet/packet.h" +#include "binding.h" + +void handle_query(Packet* request, Packet* response, BindingType type); \ No newline at end of file diff --git a/src/server/server.c b/src/server/server.c new file mode 100644 index 0000000..c8975ee --- /dev/null +++ b/src/server/server.c @@ -0,0 +1,100 @@ +#define _POSIX_SOURCE +#include +#include +#include +#include +#include +#include + +#include "addr.h" +#include "server.h" +#include "resolver.h" +#include "../io/log.h" + +static pid_t udp, tcp; + +void server_init(uint16_t port, Server* server) { + INFO("Server port set to %hu", port); + create_binding(UDP, port, &server->udp); + create_binding(TCP, port, &server->tcp); +} + +static void server_listen(Binding* binding) { + while(1) { + + Connection connection; + if (!accept_connection(binding, &connection)) { + ERROR("Failed to accept connection"); + continue; + } + + Packet request; + if (!read_connection(&connection, &request)) { + ERROR("Failed to read connection"); + free_connection(&connection); + continue; + } + + if(fork() != 0) { + free_packet(&request); + free_connection(&connection); + continue; + } + + INFO("Recieved packet request ID %hu", request.header.id); + + Packet response; + handle_query(&request, &response, connection.type); + + if (!write_connection(&connection, &response)) { + ERROR("Failed to respond to connection ID %hu: %s", request.header.id, strerror(errno)); + } + + free_packet(&request); + free_packet(&response); + free_connection(&connection); + exit(EXIT_SUCCESS); + } +} + +static void signal_handler() { + printf("\n"); + kill(udp, SIGTERM); + kill(tcp, SIGTERM); +} + +void server_run(Server* server) { + if ((udp = fork()) == 0) { + INFO("Listening for connections on UDP"); + server_listen(&server->udp); + exit(EXIT_SUCCESS); + } + + if ((tcp = fork()) == 0) { + INFO("Listening for connections on TCP"); + server_listen(&server->tcp); + exit(EXIT_SUCCESS); + } + + signal(SIGINT, signal_handler); + + int status; + waitpid(udp, &status, 0); + if (status == 0) { + INFO("UDP process closed successfully"); + } else { + ERROR("UDP process failed with error code %d", status); + } + + waitpid(tcp, &status, 0); + if (status == 0) { + INFO("TCP process closed successfully"); + } else { + ERROR("TCP process failed with error code %d", status); + } +} + +void server_free(Server* server) { + free_binding(&server->udp); + free_binding(&server->tcp); +} diff --git a/src/server/server.h b/src/server/server.h new file mode 100644 index 0000000..c9509f2 --- /dev/null +++ b/src/server/server.h @@ -0,0 +1,12 @@ +#pragma once + +#include "binding.h" + +typedef struct { + Binding udp; + Binding tcp; +} Server; + +void server_init(uint16_t port, Server* server); +void server_run(Server* server); +void server_free(Server* server); \ No newline at end of file diff --git a/src/web/api.rs b/src/web/api.rs deleted file mode 100644 index 1fddb5f..0000000 --- a/src/web/api.rs +++ /dev/null @@ -1,156 +0,0 @@ -use std::net::IpAddr; - -use axum::{ - extract::Query, - response::Response, - routing::{get, post, put, delete}, - Extension, Router, -}; -use moka::future::Cache; -use rand::distributions::{Alphanumeric, DistString}; -use serde::Deserialize; -use tower_cookies::{Cookie, Cookies}; - -use crate::{config::Config, database::Database, dns::packet::record::DnsRecord}; - -use super::{ - extract::{Authorized, Body, RequestIp}, - http::{json, text}, -}; - -pub fn router() -> Router { - Router::new() - .route("/login", post(login)) - .route("/domains", get(list_domains)) - .route("/domains", delete(delete_domain)) - .route("/records", get(get_domain)) - .route("/records", put(add_record)) -} - -async fn list_domains(_: Authorized, Extension(database): Extension) -> Response { - let domains = match database.get_domains().await { - Ok(domains) => domains, - Err(err) => return text(500, &format!("{err}")), - }; - - let Ok(domains) = serde_json::to_string(&domains) else { - return text(500, "Failed to fetch domains") - }; - - json(200, &domains) -} - -#[derive(Deserialize)] -struct DomainRequest { - domain: String, -} - -async fn get_domain( - _: Authorized, - Extension(database): Extension, - Query(query): Query, -) -> Response { - let records = match database.get_domain(&query.domain).await { - Ok(records) => records, - Err(err) => return text(500, &format!("{err}")), - }; - - let Ok(records) = serde_json::to_string(&records) else { - return text(500, "Failed to fetch records") - }; - - json(200, &records) -} - -async fn delete_domain( - _: Authorized, - Extension(database): Extension, - Body(body): Body, -) -> Response { - - let Ok(request) = serde_json::from_str::(&body) else { - return text(400, "Missing request parameters") - }; - - let Ok(domains) = database.get_domains().await else { - return text(500, "Failed to delete domain") - }; - - if !domains.contains(&request.domain) { - return text(400, "Domain does not exist") - } - - if database.delete_domain(request.domain).await.is_err() { - return text(500, "Failed to delete domain") - }; - - return text(204, "Successfully deleted domain") -} - -async fn add_record( - _: Authorized, - Extension(database): Extension, - Body(body): Body, -) -> Response { - let Ok(record) = serde_json::from_str::(&body) else { - return text(400, "Invalid DNS record") - }; - - let allowed = record.get_qtype().allowed_actions(); - if !allowed.1 { - return text(400, "Not allowed to create record") - } - - let Ok(records) = database.get_records(&record.get_domain(), record.get_qtype()).await else { - return text(500, "Failed to complete record check"); - }; - - if !records.is_empty() && !allowed.0 { - return text(400, "Not allowed to create duplicate record") - }; - - if records.contains(&record) { - return text(400, "Not allowed to create duplicate record") - } - - if let Err(err) = database.add_record(record).await { - return text(500, &format!("{err}")); - } - - return text(201, "Added record to database successfully"); -} - -#[derive(Deserialize)] -struct LoginRequest { - user: String, - pass: String, -} - -async fn login( - Extension(config): Extension, - Extension(cache): Extension>, - RequestIp(ip): RequestIp, - cookies: Cookies, - Body(body): Body, -) -> Response { - let Ok(request) = serde_json::from_str::(&body) else { - return text(400, "Missing request parameters") - }; - - if request.user != config.web_user || request.pass != config.web_pass { - return text(400, "Invalid credentials"); - }; - - let token = Alphanumeric.sample_string(&mut rand::thread_rng(), 128); - - cache.insert(token.clone(), ip).await; - - let mut cookie = Cookie::new("auth", token); - cookie.set_secure(true); - cookie.set_http_only(true); - cookie.set_path("/"); - - cookies.add(cookie); - - text(200, "Successfully logged in") -} diff --git a/src/web/extract.rs b/src/web/extract.rs deleted file mode 100644 index 4b6cd7c..0000000 --- a/src/web/extract.rs +++ /dev/null @@ -1,139 +0,0 @@ -use std::{ - io::Read, - net::{IpAddr, SocketAddr}, -}; - -use axum::{ - async_trait, - body::HttpBody, - extract::{ConnectInfo, FromRequest, FromRequestParts}, - http::{request::Parts, Request}, - response::Response, - BoxError, -}; -use bytes::Bytes; -use moka::future::Cache; -use tower_cookies::Cookies; - -use super::http::text; - -pub struct Authorized; - -#[async_trait] -impl FromRequestParts for Authorized -where - S: Send + Sync, -{ - type Rejection = Response; - - async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { - let Ok(Some(cookies)) = Option::::from_request_parts(parts, state).await else { - return Err(text(403, "No cookies provided")) - }; - - let Some(token) = cookies.get("auth") else { - return Err(text(403, "No auth token provided")) - }; - - let auth_ip: IpAddr; - { - let Some(cache) = parts.extensions.get::>() else { - return Err(text(500, "Failed to load auth store")) - }; - - let Some(ip) = cache.get(token.value()) else { - return Err(text(401, "Unauthorized")) - }; - - auth_ip = ip - } - - let Ok(Some(RequestIp(ip))) = Option::::from_request_parts(parts, state).await else { - return Err(text(403, "You have no ip")) - }; - - if auth_ip != ip { - return Err(text(403, "Auth token does not match current ip")); - } - - Ok(Self) - } -} - -pub struct RequestIp(pub IpAddr); - -#[async_trait] -impl FromRequestParts for RequestIp -where - S: Send + Sync, -{ - type Rejection = Response; - - async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result { - let headers = &parts.headers; - - let forwardedfor = headers - .get("x-forwarded-for") - .and_then(|h| h.to_str().ok()) - .and_then(|h| { - h.split(',') - .rev() - .find_map(|s| s.trim().parse::().ok()) - }); - - if let Some(forwardedfor) = forwardedfor { - return Ok(Self(forwardedfor)); - } - - let realip = headers - .get("x-real-ip") - .and_then(|hv| hv.to_str().ok()) - .and_then(|s| s.parse::().ok()); - - if let Some(realip) = realip { - return Ok(Self(realip)); - } - - let realip = headers - .get("x-real-ip") - .and_then(|hv| hv.to_str().ok()) - .and_then(|s| s.parse::().ok()); - - if let Some(realip) = realip { - return Ok(Self(realip)); - } - - let info = parts.extensions.get::>(); - - if let Some(info) = info { - return Ok(Self(info.0.ip())); - } - - Err(text(403, "You have no ip")) - } -} - -pub struct Body(pub String); - -#[async_trait] -impl FromRequest for Body -where - B: HttpBody + Sync + Send + 'static, - B::Data: Send, - B::Error: Into, - S: Send + Sync, -{ - type Rejection = Response; - - async fn from_request(req: Request, state: &S) -> Result { - let Ok(bytes) = Bytes::from_request(req, state).await else { - return Err(text(413, "Payload too large")); - }; - - let Ok(body) = String::from_utf8(bytes.bytes().flatten().collect()) else { - return Err(text(400, "Invalid utf8 body")) - }; - - Ok(Self(body)) - } -} diff --git a/src/web/file.rs b/src/web/file.rs deleted file mode 100644 index 73ecdc9..0000000 --- a/src/web/file.rs +++ /dev/null @@ -1,31 +0,0 @@ -use axum::{extract::Path, response::Response}; - -use super::http::serve; - -pub async fn js(Path(path): Path) -> Response { - let path = format!("/js/{path}"); - serve(&path).await -} - -pub async fn css(Path(path): Path) -> Response { - let path = format!("/css/{path}"); - serve(&path).await -} - -pub async fn fonts(Path(path): Path) -> Response { - let path = format!("/fonts/{path}"); - serve(&path).await -} - -pub async fn image(Path(path): Path) -> Response { - let path = format!("/image/{path}"); - serve(&path).await -} - -pub async fn favicon() -> Response { - serve("/favicon.ico").await -} - -pub async fn robots() -> Response { - serve("/robots.txt").await -} diff --git a/src/web/http.rs b/src/web/http.rs deleted file mode 100644 index 7ab1b11..0000000 --- a/src/web/http.rs +++ /dev/null @@ -1,50 +0,0 @@ -use axum::{ - body::Body, - http::{header::HeaderName, HeaderValue, Request, StatusCode}, - response::{IntoResponse, Response}, -}; -use std::str; -use tower::ServiceExt; -use tower_http::services::ServeFile; - -pub fn text(code: u16, msg: &str) -> Response { - (status_code(code), msg.to_owned()).into_response() -} - -pub fn json(code: u16, json: &str) -> Response { - let mut res = (status_code(code), json.to_owned()).into_response(); - res.headers_mut().insert( - HeaderName::from_static("content-type"), - HeaderValue::from_static("application/json"), - ); - res -} - -pub async fn serve(path: &str) -> Response { - if !path.chars().any(|c| c == '.') { - return text(403, "Invalid file path"); - } - - let path = format!("public{path}"); - let file = ServeFile::new(path); - - let Ok(mut res) = file.oneshot(Request::new(Body::empty())).await else { - tracing::error!("Error while fetching file"); - return text(500, "Error when fetching file") - }; - - if res.status() != StatusCode::OK { - return text(404, "File not found"); - } - - res.headers_mut().insert( - HeaderName::from_static("cache-control"), - HeaderValue::from_static("max-age=300"), - ); - - res.into_response() -} - -fn status_code(code: u16) -> StatusCode { - StatusCode::from_u16(code).map_or(StatusCode::OK, |code| code) -} diff --git a/src/web/mod.rs b/src/web/mod.rs deleted file mode 100644 index 530a3f9..0000000 --- a/src/web/mod.rs +++ /dev/null @@ -1,82 +0,0 @@ -use std::net::{IpAddr, SocketAddr, TcpListener}; -use std::time::Duration; - -use axum::routing::get; -use axum::{Extension, Router}; -use moka::future::Cache; -use tokio::task::JoinHandle; -use tower_cookies::CookieManagerLayer; -use tracing::{error, info}; - -use crate::config::Config; -use crate::database::Database; -use crate::Result; - -mod api; -mod extract; -mod file; -mod http; -mod pages; - -pub struct WebServer { - config: Config, - database: Database, - addr: SocketAddr, -} - -impl WebServer { - pub async fn new(config: Config, database: Database) -> Result { - let addr = format!("[::]:{}", config.web_port).parse::()?; - Ok(Self { - config, - database, - addr, - }) - } - - pub async fn run(&self) -> Result> { - let config = self.config.clone(); - let database = self.database.clone(); - let listener = TcpListener::bind(self.addr)?; - - info!( - "Listening for HTTP traffic on [::]:{}", - self.config.web_port - ); - - let app = Self::router(config, database); - let server = axum::Server::from_tcp(listener)?; - - let web_handle = tokio::spawn(async move { - if let Err(err) = server - .serve(app.into_make_service_with_connect_info::()) - .await - { - error!("{err}"); - } - }); - - Ok(web_handle) - } - - fn router(config: Config, database: Database) -> Router { - let cache: Cache = Cache::builder() - .time_to_live(Duration::from_secs(60 * 15)) - .max_capacity(config.dns_cache_size) - .build(); - - Router::new() - .nest("/", pages::router()) - .nest("/api", api::router()) - .layer(Extension(config)) - .layer(Extension(cache)) - .layer(Extension(database)) - .layer(CookieManagerLayer::new()) - .route("/js/*path", get(file::js)) - .route("/css/*path", get(file::css)) - .route("/fonts/*path", get(file::fonts)) - .route("/image/*path", get(file::image)) - .route("/favicon.ico", get(file::favicon)) - .route("/robots.txt", get(file::robots)) - } -} diff --git a/src/web/pages.rs b/src/web/pages.rs deleted file mode 100644 index a8605ef..0000000 --- a/src/web/pages.rs +++ /dev/null @@ -1,31 +0,0 @@ -use axum::{response::Response, routing::get, Router}; - -use super::{extract::Authorized, http::serve}; - -pub fn router() -> Router { - Router::new() - .route("/", get(root)) - .route("/login", get(login)) - .route("/home", get(home)) - .route("/domain", get(domain)) -} - -async fn root(user: Option) -> Response { - if user.is_some() { - home().await - } else { - login().await - } -} - -async fn login() -> Response { - serve("/login.html").await -} - -async fn home() -> Response { - serve("/home.html").await -} - -async fn domain() -> Response { - serve("/domain.html").await -} -- cgit v1.2.3-freya