summaryrefslogtreecommitdiff
path: root/src/dns/server.rs
blob: 65d15df2dc8b830015201db7386e8bf97bf4158c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
use super::{
    binding::Binding,
    packet::{question::DnsQuestion, Packet},
    resolver::Resolver,
};
use crate::{config::Config, database::Database, Result};
use moka::future::Cache;
use std::{net::SocketAddr, sync::Arc, time::Duration};
use tokio::task::JoinHandle;
use tracing::{error, info};

pub struct DnsServer {
    addr: SocketAddr,
    config: Arc<Config>,
    database: Arc<Database>,
    cache: Cache<DnsQuestion, (Packet, u64)>,
}

impl DnsServer {
    pub async fn new(config: Config, database: Database) -> Result<Self> {
        let addr = format!("[::]:{}", config.dns_port).parse::<SocketAddr>()?;
        let cache = Cache::builder()
            .time_to_live(Duration::from_secs(60 * 60))
            .max_capacity(config.dns_cache_size)
            .build();

        info!("Created DNS cache with size of {}", config.dns_cache_size);

        Ok(Self {
            addr,
            config: Arc::new(config),
            database: Arc::new(database),
            cache,
        })
    }

    pub async fn run(&self) -> Result<(JoinHandle<()>, JoinHandle<()>)> {
        let tcp = Binding::tcp(self.addr).await?;
        let tcp_handle = self.listen(tcp);

        let udp = Binding::udp(self.addr).await?;
        let udp_handle = self.listen(udp);

        info!(
            "Fallback DNS Server is set to: {:?}",
            self.config.dns_fallback
        );
        info!(
            "Listening for TCP and UDP traffic on [::]:{}",
            self.config.dns_port
        );

        Ok((udp_handle, tcp_handle))
    }

    fn listen(&self, mut binding: Binding) -> JoinHandle<()> {
        let config = self.config.clone();
        let database = self.database.clone();
        let cache = self.cache.clone();
        tokio::spawn(async move {
            let mut id = 0;
            loop {
                let Ok(connection) = binding.connect().await else { continue };
                info!("Received request on {}", binding.name());

                let resolver = Resolver::new(
                    id,
                    connection,
                    config.clone(),
                    database.clone(),
                    cache.clone(),
                );

                let name = binding.name().to_string();
                tokio::spawn(async move {
                    if let Err(err) = resolver.handle_query().await {
                        error!("{} request {} failed: {:?}", name, id, err);
                    };
                });

                id += 1;
            }
        })
    }
}