use std::net::IpAddr; use async_recursion::async_recursion; use packet::{PacketType, Packet, Result, PacketQuestion, PacketBuffer, ResultCode}; use tokio::net::UdpSocket; use crate::config::Config; async fn lookup(qname: &str, qtype: PacketType, server: (IpAddr, u16)) -> Result { let socket = UdpSocket::bind("0.0.0.0:43210").await?; let mut packet = Packet::new(); packet.header.id = 6666; packet.header.questions = 1; packet.header.recursion_desired = true; packet .questions .push(PacketQuestion::new(qname.to_string(), qtype)); let mut req_buffer = PacketBuffer::new(); packet.write(&mut req_buffer)?; socket.send_to(&req_buffer.buf[0..req_buffer.pos], server).await?; let mut res_buffer = PacketBuffer::new(); socket.recv_from(&mut res_buffer.buf).await?; Packet::from_buffer(&mut res_buffer) } #[async_recursion] async fn recursive_lookup(qname: &str, qtype: PacketType, config: &Config) -> Result { let mut ns = config.get_fallback_ns().clone(); loop { println!("attempting lookup of {:?} {} with ns {}", qtype, qname, ns); let ns_copy = ns; let server = (ns_copy, 53); let response = lookup(qname, qtype, server).await?; if !response.answers.is_empty() && response.header.rescode == ResultCode::NOERROR { return Ok(response); } if response.header.rescode == ResultCode::NXDOMAIN { return Ok(response); } if let Some(new_ns) = response.get_resolved_ns(qname) { ns = new_ns; continue; } let new_ns_name = match response.get_unresolved_ns(qname) { Some(x) => x, None => return Ok(response), }; let recursive_response = recursive_lookup(&new_ns_name, PacketType::A, config).await?; if let Some(new_ns) = recursive_response.get_random_a() { ns = new_ns; } else { return Ok(response); } } } pub async fn handle_query(socket: &UdpSocket, config: &Config) -> Result<()> { let mut req_buffer = PacketBuffer::new(); let (_, src) = socket.recv_from(&mut req_buffer.buf).await?; let mut request = Packet::from_buffer(&mut req_buffer)?; let mut packet = Packet::new(); packet.header.id = request.header.id; packet.header.recursion_desired = true; packet.header.recursion_available = true; packet.header.response = true; if let Some(question) = request.questions.pop() { println!("Received query: {:?}", question); if let Ok(result) = recursive_lookup(&question.name, question.qtype, config).await { packet.questions.push(question.clone()); packet.header.rescode = result.header.rescode; for rec in result.answers { println!("Answer: {:?}", rec); packet.answers.push(rec); } for rec in result.authorities { println!("Authority: {:?}", rec); packet.authorities.push(rec); } for rec in result.resources { println!("Resource: {:?}", rec); packet.resources.push(rec); } } else { packet.header.rescode = ResultCode::SERVFAIL; } } else { packet.header.rescode = ResultCode::FORMERR; } let mut res_buffer = PacketBuffer::new(); packet.write(&mut res_buffer)?; let len = res_buffer.pos(); let data = res_buffer.get_range(0, len)?; socket.send_to(data, src).await?; Ok(()) }