use std::{ net::{IpAddr, SocketAddr}, sync::Arc, }; use super::packet::{buffer::PacketBuffer, Packet}; use crate::Result; use tokio::{ io::{AsyncReadExt, AsyncWriteExt}, net::{TcpListener, TcpStream, UdpSocket}, }; use tracing::trace; pub enum Binding { UDP(Arc), TCP(TcpListener), } impl Binding { pub async fn udp(addr: SocketAddr) -> Result { let socket = UdpSocket::bind(addr).await?; Ok(Self::UDP(Arc::new(socket))) } pub async fn tcp(addr: SocketAddr) -> Result { let socket = TcpListener::bind(addr).await?; Ok(Self::TCP(socket)) } pub fn name(&self) -> &str { match self { Binding::UDP(_) => "UDP", Binding::TCP(_) => "TCP", } } pub async fn connect(&mut self) -> Result { match self { Self::UDP(socket) => { let mut buf = [0; 512]; let (_, addr) = socket.recv_from(&mut buf).await?; Ok(Connection::UDP(socket.clone(), addr, buf)) } Self::TCP(socket) => { let (stream, _) = socket.accept().await?; Ok(Connection::TCP(stream)) } } } } pub enum Connection { UDP(Arc, SocketAddr, [u8; 512]), TCP(TcpStream), } impl Connection { pub async fn read_packet(&mut self) -> Result { let data = self.read().await?; let mut packet_buffer = PacketBuffer::new(data); let packet = Packet::from_buffer(&mut packet_buffer)?; Ok(packet) } pub async fn write_packet(self, mut packet: Packet) -> Result<()> { let mut packet_buffer = PacketBuffer::new(Vec::new()); packet.write(&mut packet_buffer)?; self.write(packet_buffer.buf).await?; Ok(()) } pub async fn request_packet(&self, mut packet: Packet, dest: (IpAddr, u16)) -> Result { let mut packet_buffer = PacketBuffer::new(Vec::new()); packet.write(&mut packet_buffer)?; let data = self.request(packet_buffer.buf, dest).await?; let mut packet_buffer = PacketBuffer::new(data); let packet = Packet::from_buffer(&mut packet_buffer)?; Ok(packet) } async fn read(&mut self) -> Result> { trace!("Reading DNS packet"); match self { Self::UDP(_, _, src) => Ok(Vec::from(*src)), Self::TCP(stream) => { let size = stream.read_u16().await?; let mut buf = Vec::with_capacity(size as usize); stream.read_buf(&mut buf).await?; Ok(buf) } } } async fn write(self, mut buf: Vec) -> Result<()> { trace!("Returning DNS packet"); match self { Self::UDP(socket, addr, _) => { if buf.len() > 512 { buf[2] = buf[2] | 0x03; socket.send_to(&buf[0..512], addr).await?; } else { socket.send_to(&buf, addr).await?; } Ok(()) } Self::TCP(mut stream) => { stream.write_u16(buf.len() as u16).await?; stream.write(&buf[0..buf.len()]).await?; Ok(()) } } } async fn request(&self, buf: Vec, dest: (IpAddr, u16)) -> Result> { match self { Self::UDP(_socket, _addr, _src) => { let local_addr = "[::]:0".parse::()?; let socket = UdpSocket::bind(local_addr).await?; socket.send_to(&buf, dest).await?; let mut buf = [0; 512]; socket.recv_from(&mut buf).await?; Ok(Vec::from(buf)) } Self::TCP(_stream) => { let mut stream = TcpStream::connect(dest).await?; stream.write_u16((buf.len()) as u16).await?; stream.write_all(&buf[0..buf.len()]).await?; stream.readable().await?; let size = stream.read_u16().await?; let mut buf = Vec::with_capacity(size as usize); stream.read_buf(&mut buf).await?; Ok(buf) } } } }