summaryrefslogtreecommitdiff
path: root/src/database/mod.rs
blob: 0d81dc33b81a8a3c95fc3b1fc9d99d15e4adf0f3 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
use futures::TryStreamExt;
use mongodb::{
    bson::doc,
    options::{ClientOptions, Credential, ServerAddress},
    Client,
};
use serde::{Deserialize, Serialize};
use tracing::info;

use crate::{
    config::Config,
    dns::packet::{query::QueryType, record::DnsRecord},
};

use crate::Result;

#[derive(Clone)]
pub struct Database {
    client: Client,
}

#[derive(Debug, Serialize, Deserialize)]
pub struct StoredRecord {
    record: DnsRecord,
    domain: String,
    prefix: String,
}

impl StoredRecord {
    fn get_domain_parts(domain: &str) -> (String, String) {
        let parts: Vec<&str> = domain.split(".").collect();
        let len = parts.len();
        if len == 1 {
            (String::new(), String::from(parts[0]))
        } else if len == 2 {
            (String::new(), String::from(parts.join(".")))
        } else {
            (
                String::from(parts[0..len - 2].join(".")),
                String::from(parts[len - 2..len].join(".")),
            )
        }
    }
}

impl From<DnsRecord> for StoredRecord {
    fn from(record: DnsRecord) -> Self {
        let (prefix, domain) = Self::get_domain_parts(&record.get_domain());
        Self {
            record,
            domain,
            prefix,
        }
    }
}

impl Into<DnsRecord> for StoredRecord {
    fn into(self) -> DnsRecord {
        self.record
    }
}

impl Database {
    pub async fn new(config: Config) -> Result<Self> {
        let options = ClientOptions::builder()
            .hosts(vec![ServerAddress::Tcp {
                host: config.db_host,
                port: Some(config.db_port),
            }])
            .credential(
                Credential::builder()
                    .username(config.db_user)
                    .password(config.db_pass)
                    .build(),
            )
            .max_pool_size(100)
            .app_name(String::from("wrapper"))
            .build();

        let client = Client::with_options(options)?;

        client
            .database("wrapper")
            .run_command(doc! {"ping": 1}, None)
            .await?;

        info!("Connection to mongodb successfully");

        Ok(Database { client })
    }

    pub async fn get_records(&self, domain: &str, qtype: QueryType) -> Result<Vec<DnsRecord>> {
        let (prefix, domain) = StoredRecord::get_domain_parts(domain);
        Ok(self
            .get_domain(&domain)
            .await?
            .into_iter()
            .filter(|r| r.prefix == prefix)
            .filter(|r| {
                let rqtype = r.record.get_qtype();
                if qtype == QueryType::A {
                    return rqtype == QueryType::A || rqtype == QueryType::AR;
                } else if qtype == QueryType::AAAA {
                    return rqtype == QueryType::AAAA || rqtype == QueryType::AAAAR;
                } else {
                    r.record.get_qtype() == qtype
                }
            })
            .map(|r| r.into())
            .collect())
    }

    pub async fn get_domain(&self, domain: &str) -> Result<Vec<StoredRecord>> {
        let db = self.client.database("wrapper");
        let col = db.collection::<StoredRecord>(domain);

        let filter = doc! { "domain": domain };
        let mut cursor = col.find(filter, None).await?;

        let mut records = Vec::new();
        while let Some(record) = cursor.try_next().await? {
            records.push(record);
        }

        Ok(records)
    }

    pub async fn add_record(&self, record: DnsRecord) -> Result<()> {
        let record = StoredRecord::from(record);
        let db = self.client.database("wrapper");
        let col = db.collection::<StoredRecord>(&record.domain);
        col.insert_one(record, None).await?;
        Ok(())
    }

    pub async fn get_domains(&self) -> Result<Vec<String>> {
        let db = self.client.database("wrapper");
        Ok(db.list_collection_names(None).await?)
    }

    pub async fn delete_domain(&self, domain: String) -> Result<()> {
        let db = self.client.database("wrapper");
        let col = db.collection::<StoredRecord>(&domain);
        Ok(col.drop(None).await?)
    }
}