summaryrefslogtreecommitdiff
path: root/src/server/binding.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/server/binding.rs')
-rw-r--r--src/server/binding.rs150
1 files changed, 150 insertions, 0 deletions
diff --git a/src/server/binding.rs b/src/server/binding.rs
new file mode 100644
index 0000000..1c69651
--- /dev/null
+++ b/src/server/binding.rs
@@ -0,0 +1,150 @@
+use std::{
+ net::{IpAddr, SocketAddr},
+ sync::Arc,
+};
+
+use crate::packet::{buffer::PacketBuffer, Packet, Result};
+use tokio::{
+ io::{AsyncReadExt, AsyncWriteExt},
+ net::{TcpListener, TcpStream, UdpSocket},
+};
+use tracing::trace;
+
+pub enum Binding {
+ UDP(Arc<UdpSocket>),
+ TCP(TcpListener),
+}
+
+impl Binding {
+ pub async fn udp(addr: SocketAddr) -> Result<Self> {
+ let socket = UdpSocket::bind(addr).await?;
+ Ok(Self::UDP(Arc::new(socket)))
+ }
+
+ pub async fn tcp(addr: SocketAddr) -> Result<Self> {
+ let socket = TcpListener::bind(addr).await?;
+ Ok(Self::TCP(socket))
+ }
+
+ pub fn name(&self) -> &str {
+ match self {
+ Binding::UDP(_) => "UDP",
+ Binding::TCP(_) => "TCP",
+ }
+ }
+
+ pub async fn connect(&mut self) -> Result<Connection> {
+ match self {
+ Self::UDP(socket) => {
+ let mut buf = [0; 512];
+ let (_, addr) = socket.recv_from(&mut buf).await?;
+ Ok(Connection::UDP(socket.clone(), addr, buf))
+ }
+ Self::TCP(socket) => {
+ let (stream, _) = socket.accept().await?;
+ Ok(Connection::TCP(stream))
+ }
+ }
+ }
+}
+
+pub enum Connection {
+ UDP(Arc<UdpSocket>, SocketAddr, [u8; 512]),
+ TCP(TcpStream),
+}
+
+impl Connection {
+ pub async fn read_packet(&mut self) -> Result<Packet> {
+ let data = self.read().await?;
+ let mut packet_buffer = PacketBuffer::new(data);
+
+ let packet = Packet::from_buffer(&mut packet_buffer)?;
+ Ok(packet)
+ }
+
+ pub async fn write_packet(self, mut packet: Packet) -> Result<()> {
+ let mut packet_buffer = PacketBuffer::new(Vec::new());
+ packet.write(&mut packet_buffer)?;
+
+ self.write(packet_buffer.buf).await?;
+ Ok(())
+ }
+
+ pub async fn request_packet(&self, mut packet: Packet, dest: (IpAddr, u16)) -> Result<Packet> {
+ let mut packet_buffer = PacketBuffer::new(Vec::new());
+ packet.write(&mut packet_buffer)?;
+
+ let data = self.request(packet_buffer.buf, dest).await?;
+ let mut packet_buffer = PacketBuffer::new(data);
+
+ let packet = Packet::from_buffer(&mut packet_buffer)?;
+ Ok(packet)
+ }
+
+ async fn read(&mut self) -> Result<Vec<u8>> {
+ trace!("Reading DNS packet");
+ match self {
+ Self::UDP(_, _, src) => Ok(Vec::from(*src)),
+ Self::TCP(stream) => {
+ let size = stream.read_u16().await?;
+ let mut buf = Vec::with_capacity(size as usize);
+ stream.read_buf(&mut buf).await?;
+ Ok(buf)
+ }
+ }
+ }
+
+ async fn write(self, mut buf: Vec<u8>) -> Result<()> {
+ trace!("Returning DNS packet");
+ match self {
+ Self::UDP(socket, addr, _) => {
+ if buf.len() > 512 {
+ buf[2] = buf[2] | 0x03;
+ socket.send_to(&buf[0..512], addr).await?;
+ } else {
+ socket.send_to(&buf, addr).await?;
+ }
+ Ok(())
+ }
+ Self::TCP(mut stream) => {
+ stream.write_u16(buf.len() as u16).await?;
+ stream.write(&buf[0..buf.len()]).await?;
+ Ok(())
+ }
+ }
+ }
+
+ async fn request(&self, buf: Vec<u8>, dest: (IpAddr, u16)) -> Result<Vec<u8>> {
+ match self {
+ Self::UDP(_socket, _addr, _src) => {
+ let local_addr = "[::]:0".parse::<SocketAddr>()?;
+ let socket = UdpSocket::bind(local_addr).await?;
+ socket.send_to(&buf, dest).await?;
+
+ let mut buf = [0; 512];
+ socket.recv_from(&mut buf).await?;
+
+ Ok(Vec::from(buf))
+ }
+ Self::TCP(_stream) => {
+ let mut stream = TcpStream::connect(dest).await?;
+ stream.write_u16((buf.len()) as u16).await?;
+ stream.write_all(&buf[0..buf.len()]).await?;
+
+ stream.readable().await?;
+ let size = stream.read_u16().await?;
+ let mut buf = Vec::with_capacity(size as usize);
+ stream.read_buf(&mut buf).await?;
+
+ Ok(buf)
+ }
+ }
+ }
+
+ // fn pb(buf: &[u8]) {
+ // for i in 0..buf.len() {
+ // print!("{:02X?} ", buf[i]);
+ // }
+ // println!("");
+ // }
+}