use futures::TryStreamExt; use mongodb::{ bson::doc, options::{ClientOptions, Credential, ServerAddress}, Client, }; use serde::{Deserialize, Serialize}; use tracing::info; use crate::{ config::Config, dns::packet::{query::QueryType, record::DnsRecord}, }; use crate::Result; #[derive(Clone)] pub struct Database { client: Client, } #[derive(Debug, Serialize, Deserialize)] pub struct StoredRecord { record: DnsRecord, domain: String, prefix: String, } impl StoredRecord { fn get_domain_parts(domain: &str) -> (String, String) { let parts: Vec<&str> = domain.split(".").collect(); let len = parts.len(); if len == 1 { (String::new(), String::from(parts[0])) } else if len == 2 { (String::new(), String::from(parts.join("."))) } else { ( String::from(parts[0..len - 2].join(".")), String::from(parts[len - 2..len].join(".")), ) } } } impl From for StoredRecord { fn from(record: DnsRecord) -> Self { let (prefix, domain) = Self::get_domain_parts(&record.get_domain()); Self { record, domain, prefix, } } } impl Into for StoredRecord { fn into(self) -> DnsRecord { self.record } } impl Database { pub async fn new(config: Config) -> Result { let options = ClientOptions::builder() .hosts(vec![ServerAddress::Tcp { host: config.db_host, port: Some(config.db_port), }]) .credential( Credential::builder() .username(config.db_user) .password(config.db_pass) .build(), ) .max_pool_size(100) .app_name(String::from("wrapper")) .build(); let client = Client::with_options(options)?; client .database("wrapper") .run_command(doc! {"ping": 1}, None) .await?; info!("Connection to mongodb successfully"); Ok(Database { client }) } pub async fn get_records(&self, domain: &str, qtype: QueryType) -> Result> { let (prefix, domain) = StoredRecord::get_domain_parts(domain); Ok(self .get_domain(&domain) .await? .into_iter() .filter(|r| r.prefix == prefix) .filter(|r| { let rqtype = r.record.get_qtype(); if qtype == QueryType::A { return rqtype == QueryType::A || rqtype == QueryType::AR; } else if qtype == QueryType::AAAA { return rqtype == QueryType::AAAA || rqtype == QueryType::AAAAR; } else { r.record.get_qtype() == qtype } }) .map(|r| r.into()) .collect()) } pub async fn get_domain(&self, domain: &str) -> Result> { let db = self.client.database("wrapper"); let col = db.collection::(domain); let filter = doc! { "domain": domain }; let mut cursor = col.find(filter, None).await?; let mut records = Vec::new(); while let Some(record) = cursor.try_next().await? { records.push(record); } Ok(records) } pub async fn add_record(&self, record: DnsRecord) -> Result<()> { let record = StoredRecord::from(record); let db = self.client.database("wrapper"); let col = db.collection::(&record.domain); col.insert_one(record, None).await?; Ok(()) } pub async fn get_domains(&self) -> Result> { let db = self.client.database("wrapper"); Ok(db.list_collection_names(None).await?) } pub async fn delete_domain(&self, domain: String) -> Result<()> { let db = self.client.database("wrapper"); let col = db.collection::(&domain); Ok(col.drop(None).await?) } }