From b1fb410affb7bcd2e714abac01d22c4a5332c344 Mon Sep 17 00:00:00 2001 From: Tyler Murphy Date: Mon, 6 Mar 2023 18:50:08 -0500 Subject: finish dns and start webserver --- src/database/mod.rs | 146 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 146 insertions(+) create mode 100644 src/database/mod.rs (limited to 'src/database/mod.rs') diff --git a/src/database/mod.rs b/src/database/mod.rs new file mode 100644 index 0000000..0d81dc3 --- /dev/null +++ b/src/database/mod.rs @@ -0,0 +1,146 @@ +use futures::TryStreamExt; +use mongodb::{ + bson::doc, + options::{ClientOptions, Credential, ServerAddress}, + Client, +}; +use serde::{Deserialize, Serialize}; +use tracing::info; + +use crate::{ + config::Config, + dns::packet::{query::QueryType, record::DnsRecord}, +}; + +use crate::Result; + +#[derive(Clone)] +pub struct Database { + client: Client, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct StoredRecord { + record: DnsRecord, + domain: String, + prefix: String, +} + +impl StoredRecord { + fn get_domain_parts(domain: &str) -> (String, String) { + let parts: Vec<&str> = domain.split(".").collect(); + let len = parts.len(); + if len == 1 { + (String::new(), String::from(parts[0])) + } else if len == 2 { + (String::new(), String::from(parts.join("."))) + } else { + ( + String::from(parts[0..len - 2].join(".")), + String::from(parts[len - 2..len].join(".")), + ) + } + } +} + +impl From for StoredRecord { + fn from(record: DnsRecord) -> Self { + let (prefix, domain) = Self::get_domain_parts(&record.get_domain()); + Self { + record, + domain, + prefix, + } + } +} + +impl Into for StoredRecord { + fn into(self) -> DnsRecord { + self.record + } +} + +impl Database { + pub async fn new(config: Config) -> Result { + let options = ClientOptions::builder() + .hosts(vec![ServerAddress::Tcp { + host: config.db_host, + port: Some(config.db_port), + }]) + .credential( + Credential::builder() + .username(config.db_user) + .password(config.db_pass) + .build(), + ) + .max_pool_size(100) + .app_name(String::from("wrapper")) + .build(); + + let client = Client::with_options(options)?; + + client + .database("wrapper") + .run_command(doc! {"ping": 1}, None) + .await?; + + info!("Connection to mongodb successfully"); + + Ok(Database { client }) + } + + pub async fn get_records(&self, domain: &str, qtype: QueryType) -> Result> { + let (prefix, domain) = StoredRecord::get_domain_parts(domain); + Ok(self + .get_domain(&domain) + .await? + .into_iter() + .filter(|r| r.prefix == prefix) + .filter(|r| { + let rqtype = r.record.get_qtype(); + if qtype == QueryType::A { + return rqtype == QueryType::A || rqtype == QueryType::AR; + } else if qtype == QueryType::AAAA { + return rqtype == QueryType::AAAA || rqtype == QueryType::AAAAR; + } else { + r.record.get_qtype() == qtype + } + }) + .map(|r| r.into()) + .collect()) + } + + pub async fn get_domain(&self, domain: &str) -> Result> { + let db = self.client.database("wrapper"); + let col = db.collection::(domain); + + let filter = doc! { "domain": domain }; + let mut cursor = col.find(filter, None).await?; + + let mut records = Vec::new(); + while let Some(record) = cursor.try_next().await? { + records.push(record); + } + + Ok(records) + } + + pub async fn add_record(&self, record: DnsRecord) -> Result<()> { + let record = StoredRecord::from(record); + let db = self.client.database("wrapper"); + let col = db.collection::(&record.domain); + col.insert_one(record, None).await?; + Ok(()) + } + + pub async fn get_domains(&self) -> Result> { + let db = self.client.database("wrapper"); + Ok(db.list_collection_names(None).await?) + } + + pub async fn delete_domain(&self, domain: String) -> Result<()> { + let db = self.client.database("wrapper"); + let col = db.collection::(&domain); + Ok(col.drop(None).await?) + } +} -- cgit v1.2.3-freya