diff options
author | Tyler Murphy <tylermurphy534@gmail.com> | 2023-02-15 00:01:44 -0500 |
---|---|---|
committer | Tyler Murphy <tylermurphy534@gmail.com> | 2023-02-15 00:01:44 -0500 |
commit | aec4fdecc10be35cde5dc42308960f10bc452187 (patch) | |
tree | 67233229c6839c78d1bd3db0147467da30843f44 /src/types | |
parent | bug fixes (diff) | |
download | xssbook-aec4fdecc10be35cde5dc42308960f10bc452187.tar.gz xssbook-aec4fdecc10be35cde5dc42308960f10bc452187.tar.bz2 xssbook-aec4fdecc10be35cde5dc42308960f10bc452187.zip |
make database calls 1 conn
Diffstat (limited to 'src/types')
-rw-r--r-- | src/types/comment.rs | 20 | ||||
-rw-r--r-- | src/types/extract.rs | 32 | ||||
-rw-r--r-- | src/types/like.rs | 20 | ||||
-rw-r--r-- | src/types/post.rs | 43 | ||||
-rw-r--r-- | src/types/session.rs | 26 | ||||
-rw-r--r-- | src/types/user.rs | 68 |
6 files changed, 123 insertions, 86 deletions
diff --git a/src/types/comment.rs b/src/types/comment.rs index cf94bd3..0836950 100644 --- a/src/types/comment.rs +++ b/src/types/comment.rs @@ -2,7 +2,7 @@ use serde::Serialize; use tracing::instrument; use crate::{ - database::{self, comments}, + database::Database, types::http::{ResponseCode, Result}, }; @@ -16,9 +16,9 @@ pub struct Comment { } impl Comment { - #[instrument()] - pub fn new(user_id: u64, post_id: u64, content: &str) -> Result<Self> { - let Ok(comment) = comments::add_comment(user_id, post_id, content) else { + #[instrument(skip(db))] + pub fn new(db: &Database, user_id: u64, post_id: u64, content: &str) -> Result<Self> { + let Ok(comment) = db.add_comment(user_id, post_id, content) else { tracing::error!("Failed to create comment"); return Err(ResponseCode::InternalServerError.text("Failed to create post")) }; @@ -26,17 +26,17 @@ impl Comment { Ok(comment) } - #[instrument()] - pub fn from_comment_page(page: u64, post_id: u64) -> Result<Vec<Self>> { - let Ok(posts) = database::comments::get_comments_page(page, post_id) else { + #[instrument(skip(db))] + pub fn from_comment_page(db: &Database, page: u64, post_id: u64) -> Result<Vec<Self>> { + let Ok(posts) = db.get_comments_page(page, post_id) else { return Err(ResponseCode::BadRequest.text("Failed to fetch comments")) }; Ok(posts) } - #[instrument()] - pub fn reterieve_all() -> Result<Vec<Self>> { - let Ok(posts) = database::comments::get_all_comments() else { + #[instrument(skip(db))] + pub fn reterieve_all(db: &Database) -> Result<Vec<Self>> { + let Ok(posts) = db.get_all_comments() else { return Err(ResponseCode::InternalServerError.text("Failed to fetch comments")) }; Ok(posts) diff --git a/src/types/extract.rs b/src/types/extract.rs index 65d9f1a..f05215f 100644 --- a/src/types/extract.rs +++ b/src/types/extract.rs @@ -14,9 +14,11 @@ use axum::{ use bytes::Bytes; use image::{io::Reader, DynamicImage, ImageFormat}; use serde::de::DeserializeOwned; +use tokio::sync::Mutex; use tower_cookies::Cookies; use crate::{ + database, public::admin, public::console, types::{ @@ -97,11 +99,17 @@ where return Err(ResponseCode::Forbidden.text("No auth token provided")) }; - let Ok(session) = Session::from_token(token.value()) else { + let Some(db) = parts.extensions.get::<DatabaseExtention>() else { + return Err(ResponseCode::Forbidden.text("Could not connect to database")) + }; + + let db = db.0.lock().await; + + let Ok(session) = Session::from_token(&db, token.value()) else { return Err(ResponseCode::Unauthorized.text("Auth token invalid")) }; - let Ok(user) = User::from_user_id(session.user_id, true) else { + let Ok(user) = User::from_user_id(&db, session.user_id, true) else { tracing::error!("Valid token but no valid user"); return Err(ResponseCode::InternalServerError.text("Valid token but no valid user")) }; @@ -260,6 +268,26 @@ where } } +pub struct DatabaseExtention(pub Mutex<database::Database>); +pub struct Database(pub database::Database); + +#[async_trait] +impl<S> FromRequestParts<S> for Database +where + S: Send + Sync, +{ + type Rejection = Response; + + async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self> { + let db = parts.extensions.remove::<DatabaseExtention>(); + let Some(db) = db else { + return Err(ResponseCode::InternalServerError.text("Database is not loaded")) + }; + + Ok(Self(db.0.into_inner())) + } +} + async fn read_body<S, B>(mut req: Request<B>, state: &S) -> Result<Vec<u8>> where B: HttpBody + Sync + Send + 'static, diff --git a/src/types/like.rs b/src/types/like.rs index 1c113c1..8eec941 100644 --- a/src/types/like.rs +++ b/src/types/like.rs @@ -1,7 +1,7 @@ use serde::Serialize; use tracing::instrument; -use crate::database; +use crate::database::Database; use crate::types::http::{ResponseCode, Result}; #[derive(Serialize)] @@ -11,9 +11,9 @@ pub struct Like { } impl Like { - #[instrument()] - pub fn add_liked(user_id: u64, post_id: u64) -> Result<()> { - let Ok(liked) = database::likes::add_liked(user_id, post_id) else { + #[instrument(skip(db))] + pub fn add_liked(db: &Database, user_id: u64, post_id: u64) -> Result<()> { + let Ok(liked) = db.add_liked(user_id, post_id) else { return Err(ResponseCode::BadRequest.text("Failed to add like status")) }; @@ -24,9 +24,9 @@ impl Like { Ok(()) } - #[instrument()] - pub fn remove_liked(user_id: u64, post_id: u64) -> Result<()> { - let Ok(liked) = database::likes::remove_liked(user_id, post_id) else { + #[instrument(skip(db))] + pub fn remove_liked(db: &Database, user_id: u64, post_id: u64) -> Result<()> { + let Ok(liked) = db.remove_liked(user_id, post_id) else { return Err(ResponseCode::BadRequest.text("Failed to remove like status")) }; @@ -37,9 +37,9 @@ impl Like { Ok(()) } - #[instrument()] - pub fn reterieve_all() -> Result<Vec<Self>> { - let Ok(likes) = database::likes::get_all_likes() else { + #[instrument(skip(db))] + pub fn reterieve_all(db: &Database) -> Result<Vec<Self>> { + let Ok(likes) = db.get_all_likes() else { return Err(ResponseCode::InternalServerError.text("Failed to fetch likes")) }; Ok(likes) diff --git a/src/types/post.rs b/src/types/post.rs index 09f2f50..bff68c7 100644 --- a/src/types/post.rs +++ b/src/types/post.rs @@ -2,7 +2,7 @@ use core::fmt; use serde::Serialize; use tracing::instrument; -use crate::database; +use crate::database::Database; use crate::types::http::{ResponseCode, Result}; use super::comment::Comment; @@ -27,57 +27,62 @@ impl fmt::Debug for Post { } impl Post { - #[instrument()] - pub fn from_post_id(self_id: u64, post_id: u64) -> Result<Self> { - let Ok(Some(mut post)) = database::posts::get_post(post_id) else { + #[instrument(skip(db))] + pub fn from_post_id(db: &Database, self_id: u64, post_id: u64) -> Result<Self> { + let Ok(Some(mut post)) = db.get_post(post_id) else { return Err(ResponseCode::BadRequest.text("Post does not exist")) }; - let liked = database::likes::get_liked(self_id, post.post_id).unwrap_or(false); + let liked = db.get_liked(self_id, post.post_id).unwrap_or(false); post.liked = liked; Ok(post) } - #[instrument()] - pub fn from_post_page(self_id: u64, page: u64) -> Result<Vec<Self>> { - let Ok(mut posts) = database::posts::get_post_page(page) else { + #[instrument(skip(db))] + pub fn from_post_page(db: &Database, self_id: u64, page: u64) -> Result<Vec<Self>> { + let Ok(mut posts) = db.get_post_page(page) else { return Err(ResponseCode::BadRequest.text("Failed to fetch posts")) }; for post in &mut posts { - let liked = database::likes::get_liked(self_id, post.post_id).unwrap_or(false); + let liked = db.get_liked(self_id, post.post_id).unwrap_or(false); post.liked = liked; } Ok(posts) } - #[instrument()] - pub fn from_user_post_page(self_id: u64, user_id: u64, page: u64) -> Result<Vec<Self>> { - let Ok(mut posts) = database::posts::get_users_post_page(user_id, page) else { + #[instrument(skip(db))] + pub fn from_user_post_page( + db: &Database, + self_id: u64, + user_id: u64, + page: u64, + ) -> Result<Vec<Self>> { + let Ok(mut posts) = db.get_users_post_page(user_id, page) else { return Err(ResponseCode::BadRequest.text("Failed to fetch posts")) }; for post in &mut posts { - let liked = database::likes::get_liked(self_id, post.post_id).unwrap_or(false); + let liked = db.get_liked(self_id, post.post_id).unwrap_or(false); post.liked = liked; } Ok(posts) } - #[instrument()] - pub fn reterieve_all() -> Result<Vec<Self>> { - let Ok(posts) = database::posts::get_all_posts() else { + #[instrument(skip(db))] + pub fn reterieve_all(db: &Database) -> Result<Vec<Self>> { + let Ok(posts) = db.get_all_posts() else { return Err(ResponseCode::InternalServerError.text("Failed to fetch posts")) }; Ok(posts) } - #[instrument()] - pub fn new(user_id: u64, content: String) -> Result<Self> { - let Ok(post) = database::posts::add_post(user_id, &content) else { + #[instrument(skip(db))] + pub fn new(db: &Database, user_id: u64, content: String) -> Result<Self> { + let Ok(post) = db.add_post(user_id, &content) else { tracing::error!("Failed to create post"); return Err(ResponseCode::InternalServerError.text("Failed to create post")) }; diff --git a/src/types/session.rs b/src/types/session.rs index a9073aa..27c5c66 100644 --- a/src/types/session.rs +++ b/src/types/session.rs @@ -2,7 +2,7 @@ use rand::{distributions::Alphanumeric, Rng}; use serde::Serialize; use tracing::instrument; -use crate::database; +use crate::database::Database; use crate::types::http::{ResponseCode, Result}; #[derive(Serialize)] @@ -12,39 +12,39 @@ pub struct Session { } impl Session { - #[instrument()] - pub fn from_token(token: &str) -> Result<Self> { - let Ok(Some(session)) = database::sessions::get_session(token) else { + #[instrument(skip(db))] + pub fn from_token(db: &Database, token: &str) -> Result<Self> { + let Ok(Some(session)) = db.get_session(token) else { return Err(ResponseCode::BadRequest.text("Invalid auth token")); }; Ok(session) } - #[instrument()] - pub fn reterieve_all() -> Result<Vec<Self>> { - let Ok(sessions) = database::sessions::get_all_sessions() else { + #[instrument(skip(db))] + pub fn reterieve_all(db: &Database) -> Result<Vec<Self>> { + let Ok(sessions) = db.get_all_sessions() else { return Err(ResponseCode::InternalServerError.text("Failed to fetch sessions")) }; Ok(sessions) } - #[instrument()] - pub fn new(user_id: u64) -> Result<Self> { + #[instrument(skip(db))] + pub fn new(db: &Database, user_id: u64) -> Result<Self> { let token: String = rand::thread_rng() .sample_iter(&Alphanumeric) .take(32) .map(char::from) .collect(); - match database::sessions::set_session(user_id, &token) { + match db.set_session(user_id, &token) { Err(_) => Err(ResponseCode::BadRequest.text("Failed to create session")), Ok(_) => Ok(Self { user_id, token }), } } - #[instrument()] - pub fn delete(user_id: u64) -> Result<()> { - if database::sessions::delete_session(user_id).is_err() { + #[instrument(skip(db))] + pub fn delete(db: &Database, user_id: u64) -> Result<()> { + if db.delete_session(user_id).is_err() { tracing::error!("Failed to logout user"); return Err(ResponseCode::InternalServerError.text("Failed to logout")); }; diff --git a/src/types/user.rs b/src/types/user.rs index 245e9b7..3c4cd6a 100644 --- a/src/types/user.rs +++ b/src/types/user.rs @@ -2,7 +2,7 @@ use serde::{Deserialize, Serialize}; use tracing::instrument; use crate::api::RegistrationRequet; -use crate::database; +use crate::database::Database; use crate::types::http::{ResponseCode, Result}; #[derive(Serialize, Deserialize, Debug)] @@ -24,21 +24,21 @@ pub const FOLLOWING: u8 = 1; pub const FOLLOWED: u8 = 2; impl User { - #[instrument()] - pub fn from_user_id(user_id: u64, hide_password: bool) -> Result<Self> { - let Ok(Some(user)) = database::users::get_user_by_id(user_id, hide_password) else { + #[instrument(skip(db))] + pub fn from_user_id(db: &Database, user_id: u64, hide_password: bool) -> Result<Self> { + let Ok(Some(user)) = db.get_user_by_id(user_id, hide_password) else { return Err(ResponseCode::BadRequest.text("User does not exist")) }; Ok(user) } - #[instrument()] - pub fn from_user_ids(user_ids: Vec<u64>) -> Vec<Self> { + #[instrument(skip(db))] + pub fn from_user_ids(db: &Database, user_ids: Vec<u64>) -> Vec<Self> { user_ids .iter() .filter_map(|user_id| { - let Ok(Some(user)) = database::users::get_user_by_id(*user_id, true) else { + let Ok(Some(user)) = db.get_user_by_id(*user_id, true) else { return None; }; Some(user) @@ -46,53 +46,53 @@ impl User { .collect() } - #[instrument()] - pub fn from_user_page(page: u64) -> Result<Vec<Self>> { - let Ok(users) = database::users::get_user_page(page, true) else { + #[instrument(skip(db))] + pub fn from_user_page(db: &Database, page: u64) -> Result<Vec<Self>> { + let Ok(users) = db.get_user_page(page, true) else { return Err(ResponseCode::BadRequest.text("Failed to fetch users")) }; Ok(users) } - #[instrument()] - pub fn from_email(email: &str) -> Result<Self> { - let Ok(Some(user)) = database::users::get_user_by_email(email, false) else { + #[instrument(skip(db))] + pub fn from_email(db: &Database, email: &str) -> Result<Self> { + let Ok(Some(user)) = db.get_user_by_email(email, false) else { return Err(ResponseCode::BadRequest.text("User does not exist")) }; Ok(user) } - #[instrument()] - pub fn from_password(password: &str) -> Result<Self> { - let Ok(Some(user)) = database::users::get_user_by_password(password, true) else { + #[instrument(skip(db))] + pub fn from_password(db: &Database, password: &str) -> Result<Self> { + let Ok(Some(user)) = db.get_user_by_password(password, true) else { return Err(ResponseCode::BadRequest.text("User does not exist")) }; Ok(user) } - #[instrument()] - pub fn reterieve_all() -> Result<Vec<Self>> { - let Ok(users) = database::users::get_all_users() else { + #[instrument(skip(db))] + pub fn reterieve_all(db: &Database) -> Result<Vec<Self>> { + let Ok(users) = db.get_all_users() else { return Err(ResponseCode::InternalServerError.text("Failed to fetch users")) }; Ok(users) } - #[instrument()] - pub fn new(request: RegistrationRequet) -> Result<Self> { - if Self::from_email(&request.email).is_ok() { + #[instrument(skip(db))] + pub fn new(db: &Database, request: RegistrationRequet) -> Result<Self> { + if Self::from_email(db, &request.email).is_ok() { return Err(ResponseCode::BadRequest .text(&format!("Email is already in use by {}", &request.email))); } - if let Ok(user) = Self::from_password(&request.password) { + if let Ok(user) = Self::from_password(db, &request.password) { return Err(ResponseCode::BadRequest .text(&format!("Password is already in use by {}", user.email))); } - let Ok(user) = database::users::add_user(request) else { + let Ok(user) = db.add_user(request) else { tracing::error!("Failed to create new user"); return Err(ResponseCode::InternalServerError.text("Failed to create new uesr")) }; @@ -100,8 +100,9 @@ impl User { Ok(user) } - pub fn add_following(user_id_1: u64, user_id_2: u64) -> Result<()> { - let Ok(followed) = database::friends::set_following(user_id_1, user_id_2) else { + #[instrument(skip(db))] + pub fn add_following(db: &Database, user_id_1: u64, user_id_2: u64) -> Result<()> { + let Ok(followed) = db.set_following(user_id_1, user_id_2) else { return Err(ResponseCode::BadRequest.text("Failed to add follow status")) }; @@ -112,8 +113,9 @@ impl User { Ok(()) } - pub fn remove_following(user_id_1: u64, user_id_2: u64) -> Result<()> { - let Ok(followed) = database::friends::remove_following(user_id_1, user_id_2) else { + #[instrument(skip(db))] + pub fn remove_following(db: &Database, user_id_1: u64, user_id_2: u64) -> Result<()> { + let Ok(followed) = db.remove_following(user_id_1, user_id_2) else { return Err(ResponseCode::BadRequest.text("Failed to remove follow status")) }; @@ -124,15 +126,17 @@ impl User { Ok(()) } - pub fn get_following(user_id_1: u64, user_id_2: u64) -> Result<u8> { - let Ok(followed) = database::friends::get_friend_status(user_id_1, user_id_2) else { + #[instrument(skip(db))] + pub fn get_following(db: &Database, user_id_1: u64, user_id_2: u64) -> Result<u8> { + let Ok(followed) = db.get_friend_status(user_id_1, user_id_2) else { return Err(ResponseCode::InternalServerError.text("Failed to get follow status")) }; Ok(followed) } - pub fn get_friends(user_id: u64) -> Result<Vec<Self>> { - let Ok(users) = database::friends::get_friends(user_id) else { + #[instrument(skip(db))] + pub fn get_friends(db: &Database, user_id: u64) -> Result<Vec<Self>> { + let Ok(users) = db.get_friends(user_id) else { return Err(ResponseCode::InternalServerError.text("Failed to fetch friends")) }; Ok(users) |