summaryrefslogtreecommitdiff
path: root/src/database
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--src/database/mod.rs146
1 files changed, 146 insertions, 0 deletions
diff --git a/src/database/mod.rs b/src/database/mod.rs
new file mode 100644
index 0000000..0d81dc3
--- /dev/null
+++ b/src/database/mod.rs
@@ -0,0 +1,146 @@
+use futures::TryStreamExt;
+use mongodb::{
+ bson::doc,
+ options::{ClientOptions, Credential, ServerAddress},
+ Client,
+};
+use serde::{Deserialize, Serialize};
+use tracing::info;
+
+use crate::{
+ config::Config,
+ dns::packet::{query::QueryType, record::DnsRecord},
+};
+
+use crate::Result;
+
+#[derive(Clone)]
+pub struct Database {
+ client: Client,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+pub struct StoredRecord {
+ record: DnsRecord,
+ domain: String,
+ prefix: String,
+}
+
+impl StoredRecord {
+ fn get_domain_parts(domain: &str) -> (String, String) {
+ let parts: Vec<&str> = domain.split(".").collect();
+ let len = parts.len();
+ if len == 1 {
+ (String::new(), String::from(parts[0]))
+ } else if len == 2 {
+ (String::new(), String::from(parts.join(".")))
+ } else {
+ (
+ String::from(parts[0..len - 2].join(".")),
+ String::from(parts[len - 2..len].join(".")),
+ )
+ }
+ }
+}
+
+impl From<DnsRecord> for StoredRecord {
+ fn from(record: DnsRecord) -> Self {
+ let (prefix, domain) = Self::get_domain_parts(&record.get_domain());
+ Self {
+ record,
+ domain,
+ prefix,
+ }
+ }
+}
+
+impl Into<DnsRecord> for StoredRecord {
+ fn into(self) -> DnsRecord {
+ self.record
+ }
+}
+
+impl Database {
+ pub async fn new(config: Config) -> Result<Self> {
+ let options = ClientOptions::builder()
+ .hosts(vec![ServerAddress::Tcp {
+ host: config.db_host,
+ port: Some(config.db_port),
+ }])
+ .credential(
+ Credential::builder()
+ .username(config.db_user)
+ .password(config.db_pass)
+ .build(),
+ )
+ .max_pool_size(100)
+ .app_name(String::from("wrapper"))
+ .build();
+
+ let client = Client::with_options(options)?;
+
+ client
+ .database("wrapper")
+ .run_command(doc! {"ping": 1}, None)
+ .await?;
+
+ info!("Connection to mongodb successfully");
+
+ Ok(Database { client })
+ }
+
+ pub async fn get_records(&self, domain: &str, qtype: QueryType) -> Result<Vec<DnsRecord>> {
+ let (prefix, domain) = StoredRecord::get_domain_parts(domain);
+ Ok(self
+ .get_domain(&domain)
+ .await?
+ .into_iter()
+ .filter(|r| r.prefix == prefix)
+ .filter(|r| {
+ let rqtype = r.record.get_qtype();
+ if qtype == QueryType::A {
+ return rqtype == QueryType::A || rqtype == QueryType::AR;
+ } else if qtype == QueryType::AAAA {
+ return rqtype == QueryType::AAAA || rqtype == QueryType::AAAAR;
+ } else {
+ r.record.get_qtype() == qtype
+ }
+ })
+ .map(|r| r.into())
+ .collect())
+ }
+
+ pub async fn get_domain(&self, domain: &str) -> Result<Vec<StoredRecord>> {
+ let db = self.client.database("wrapper");
+ let col = db.collection::<StoredRecord>(domain);
+
+ let filter = doc! { "domain": domain };
+ let mut cursor = col.find(filter, None).await?;
+
+ let mut records = Vec::new();
+ while let Some(record) = cursor.try_next().await? {
+ records.push(record);
+ }
+
+ Ok(records)
+ }
+
+ pub async fn add_record(&self, record: DnsRecord) -> Result<()> {
+ let record = StoredRecord::from(record);
+ let db = self.client.database("wrapper");
+ let col = db.collection::<StoredRecord>(&record.domain);
+ col.insert_one(record, None).await?;
+ Ok(())
+ }
+
+ pub async fn get_domains(&self) -> Result<Vec<String>> {
+ let db = self.client.database("wrapper");
+ Ok(db.list_collection_names(None).await?)
+ }
+
+ pub async fn delete_domain(&self, domain: String) -> Result<()> {
+ let db = self.client.database("wrapper");
+ let col = db.collection::<StoredRecord>(&domain);
+ Ok(col.drop(None).await?)
+ }
+}