diff options
Diffstat (limited to 'src/server/binding.rs')
-rw-r--r-- | src/server/binding.rs | 150 |
1 files changed, 150 insertions, 0 deletions
diff --git a/src/server/binding.rs b/src/server/binding.rs new file mode 100644 index 0000000..1c69651 --- /dev/null +++ b/src/server/binding.rs @@ -0,0 +1,150 @@ +use std::{ + net::{IpAddr, SocketAddr}, + sync::Arc, +}; + +use crate::packet::{buffer::PacketBuffer, Packet, Result}; +use tokio::{ + io::{AsyncReadExt, AsyncWriteExt}, + net::{TcpListener, TcpStream, UdpSocket}, +}; +use tracing::trace; + +pub enum Binding { + UDP(Arc<UdpSocket>), + TCP(TcpListener), +} + +impl Binding { + pub async fn udp(addr: SocketAddr) -> Result<Self> { + let socket = UdpSocket::bind(addr).await?; + Ok(Self::UDP(Arc::new(socket))) + } + + pub async fn tcp(addr: SocketAddr) -> Result<Self> { + let socket = TcpListener::bind(addr).await?; + Ok(Self::TCP(socket)) + } + + pub fn name(&self) -> &str { + match self { + Binding::UDP(_) => "UDP", + Binding::TCP(_) => "TCP", + } + } + + pub async fn connect(&mut self) -> Result<Connection> { + match self { + Self::UDP(socket) => { + let mut buf = [0; 512]; + let (_, addr) = socket.recv_from(&mut buf).await?; + Ok(Connection::UDP(socket.clone(), addr, buf)) + } + Self::TCP(socket) => { + let (stream, _) = socket.accept().await?; + Ok(Connection::TCP(stream)) + } + } + } +} + +pub enum Connection { + UDP(Arc<UdpSocket>, SocketAddr, [u8; 512]), + TCP(TcpStream), +} + +impl Connection { + pub async fn read_packet(&mut self) -> Result<Packet> { + let data = self.read().await?; + let mut packet_buffer = PacketBuffer::new(data); + + let packet = Packet::from_buffer(&mut packet_buffer)?; + Ok(packet) + } + + pub async fn write_packet(self, mut packet: Packet) -> Result<()> { + let mut packet_buffer = PacketBuffer::new(Vec::new()); + packet.write(&mut packet_buffer)?; + + self.write(packet_buffer.buf).await?; + Ok(()) + } + + pub async fn request_packet(&self, mut packet: Packet, dest: (IpAddr, u16)) -> Result<Packet> { + let mut packet_buffer = PacketBuffer::new(Vec::new()); + packet.write(&mut packet_buffer)?; + + let data = self.request(packet_buffer.buf, dest).await?; + let mut packet_buffer = PacketBuffer::new(data); + + let packet = Packet::from_buffer(&mut packet_buffer)?; + Ok(packet) + } + + async fn read(&mut self) -> Result<Vec<u8>> { + trace!("Reading DNS packet"); + match self { + Self::UDP(_, _, src) => Ok(Vec::from(*src)), + Self::TCP(stream) => { + let size = stream.read_u16().await?; + let mut buf = Vec::with_capacity(size as usize); + stream.read_buf(&mut buf).await?; + Ok(buf) + } + } + } + + async fn write(self, mut buf: Vec<u8>) -> Result<()> { + trace!("Returning DNS packet"); + match self { + Self::UDP(socket, addr, _) => { + if buf.len() > 512 { + buf[2] = buf[2] | 0x03; + socket.send_to(&buf[0..512], addr).await?; + } else { + socket.send_to(&buf, addr).await?; + } + Ok(()) + } + Self::TCP(mut stream) => { + stream.write_u16(buf.len() as u16).await?; + stream.write(&buf[0..buf.len()]).await?; + Ok(()) + } + } + } + + async fn request(&self, buf: Vec<u8>, dest: (IpAddr, u16)) -> Result<Vec<u8>> { + match self { + Self::UDP(_socket, _addr, _src) => { + let local_addr = "[::]:0".parse::<SocketAddr>()?; + let socket = UdpSocket::bind(local_addr).await?; + socket.send_to(&buf, dest).await?; + + let mut buf = [0; 512]; + socket.recv_from(&mut buf).await?; + + Ok(Vec::from(buf)) + } + Self::TCP(_stream) => { + let mut stream = TcpStream::connect(dest).await?; + stream.write_u16((buf.len()) as u16).await?; + stream.write_all(&buf[0..buf.len()]).await?; + + stream.readable().await?; + let size = stream.read_u16().await?; + let mut buf = Vec::with_capacity(size as usize); + stream.read_buf(&mut buf).await?; + + Ok(buf) + } + } + } + + // fn pb(buf: &[u8]) { + // for i in 0..buf.len() { + // print!("{:02X?} ", buf[i]); + // } + // println!(""); + // } +} |