summaryrefslogtreecommitdiff
path: root/src/database/mod.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/database/mod.rs')
-rw-r--r--src/database/mod.rs146
1 files changed, 0 insertions, 146 deletions
diff --git a/src/database/mod.rs b/src/database/mod.rs
deleted file mode 100644
index 0d81dc3..0000000
--- a/src/database/mod.rs
+++ /dev/null
@@ -1,146 +0,0 @@
-use futures::TryStreamExt;
-use mongodb::{
- bson::doc,
- options::{ClientOptions, Credential, ServerAddress},
- Client,
-};
-use serde::{Deserialize, Serialize};
-use tracing::info;
-
-use crate::{
- config::Config,
- dns::packet::{query::QueryType, record::DnsRecord},
-};
-
-use crate::Result;
-
-#[derive(Clone)]
-pub struct Database {
- client: Client,
-}
-
-#[derive(Debug, Serialize, Deserialize)]
-pub struct StoredRecord {
- record: DnsRecord,
- domain: String,
- prefix: String,
-}
-
-impl StoredRecord {
- fn get_domain_parts(domain: &str) -> (String, String) {
- let parts: Vec<&str> = domain.split(".").collect();
- let len = parts.len();
- if len == 1 {
- (String::new(), String::from(parts[0]))
- } else if len == 2 {
- (String::new(), String::from(parts.join(".")))
- } else {
- (
- String::from(parts[0..len - 2].join(".")),
- String::from(parts[len - 2..len].join(".")),
- )
- }
- }
-}
-
-impl From<DnsRecord> for StoredRecord {
- fn from(record: DnsRecord) -> Self {
- let (prefix, domain) = Self::get_domain_parts(&record.get_domain());
- Self {
- record,
- domain,
- prefix,
- }
- }
-}
-
-impl Into<DnsRecord> for StoredRecord {
- fn into(self) -> DnsRecord {
- self.record
- }
-}
-
-impl Database {
- pub async fn new(config: Config) -> Result<Self> {
- let options = ClientOptions::builder()
- .hosts(vec![ServerAddress::Tcp {
- host: config.db_host,
- port: Some(config.db_port),
- }])
- .credential(
- Credential::builder()
- .username(config.db_user)
- .password(config.db_pass)
- .build(),
- )
- .max_pool_size(100)
- .app_name(String::from("wrapper"))
- .build();
-
- let client = Client::with_options(options)?;
-
- client
- .database("wrapper")
- .run_command(doc! {"ping": 1}, None)
- .await?;
-
- info!("Connection to mongodb successfully");
-
- Ok(Database { client })
- }
-
- pub async fn get_records(&self, domain: &str, qtype: QueryType) -> Result<Vec<DnsRecord>> {
- let (prefix, domain) = StoredRecord::get_domain_parts(domain);
- Ok(self
- .get_domain(&domain)
- .await?
- .into_iter()
- .filter(|r| r.prefix == prefix)
- .filter(|r| {
- let rqtype = r.record.get_qtype();
- if qtype == QueryType::A {
- return rqtype == QueryType::A || rqtype == QueryType::AR;
- } else if qtype == QueryType::AAAA {
- return rqtype == QueryType::AAAA || rqtype == QueryType::AAAAR;
- } else {
- r.record.get_qtype() == qtype
- }
- })
- .map(|r| r.into())
- .collect())
- }
-
- pub async fn get_domain(&self, domain: &str) -> Result<Vec<StoredRecord>> {
- let db = self.client.database("wrapper");
- let col = db.collection::<StoredRecord>(domain);
-
- let filter = doc! { "domain": domain };
- let mut cursor = col.find(filter, None).await?;
-
- let mut records = Vec::new();
- while let Some(record) = cursor.try_next().await? {
- records.push(record);
- }
-
- Ok(records)
- }
-
- pub async fn add_record(&self, record: DnsRecord) -> Result<()> {
- let record = StoredRecord::from(record);
- let db = self.client.database("wrapper");
- let col = db.collection::<StoredRecord>(&record.domain);
- col.insert_one(record, None).await?;
- Ok(())
- }
-
- pub async fn get_domains(&self) -> Result<Vec<String>> {
- let db = self.client.database("wrapper");
- Ok(db.list_collection_names(None).await?)
- }
-
- pub async fn delete_domain(&self, domain: String) -> Result<()> {
- let db = self.client.database("wrapper");
- let col = db.collection::<StoredRecord>(&domain);
- Ok(col.drop(None).await?)
- }
-}