From aec4fdecc10be35cde5dc42308960f10bc452187 Mon Sep 17 00:00:00 2001 From: Tyler Murphy Date: Wed, 15 Feb 2023 00:01:44 -0500 Subject: [PATCH] make database calls 1 conn --- src/api/admin.rs | 27 ++-- src/api/auth.rs | 31 ++-- src/api/mod.rs | 27 +++- src/api/posts.rs | 27 ++-- src/api/users.rs | 26 ++-- src/database/comments.rs | 175 +++++++++++---------- src/database/friends.rs | 177 ++++++++++----------- src/database/likes.rs | 136 +++++++++-------- src/database/mod.rs | 37 +++-- src/database/posts.rs | 205 +++++++++++++------------ src/database/sessions.rs | 105 +++++++------ src/database/users.rs | 321 ++++++++++++++++++++------------------- src/public/admin.rs | 21 +-- src/types/comment.rs | 20 +-- src/types/extract.rs | 32 +++- src/types/like.rs | 20 +-- src/types/post.rs | 43 +++--- src/types/session.rs | 26 ++-- src/types/user.rs | 68 +++++---- 19 files changed, 829 insertions(+), 695 deletions(-) diff --git a/src/api/admin.rs b/src/api/admin.rs index 6030315..f412d75 100644 --- a/src/api/admin.rs +++ b/src/api/admin.rs @@ -5,13 +5,12 @@ use serde::Deserialize; use tower_cookies::{Cookie, Cookies}; use crate::{ - database, public::{ admin, docs::{EndpointDocumentation, EndpointMethod}, }, types::{ - extract::{AdminUser, Check, CheckResult, Json}, + extract::{AdminUser, Check, CheckResult, Database, Json}, http::ResponseCode, }, }; @@ -92,8 +91,8 @@ impl Check for QueryRequest { } } -async fn query(_: AdminUser, Json(body): Json) -> Response { - match database::query(body.query) { +async fn query(_: AdminUser, Database(db): Database, Json(body): Json) -> Response { + match db.query(body.query) { Ok(changes) => ResponseCode::Success.text(&format!( "Query executed successfully. {changes} lines changed." )), @@ -114,8 +113,8 @@ pub const ADMIN_POSTS: EndpointDocumentation = EndpointDocumentation { cookie: Some("admin"), }; -async fn posts(_: AdminUser) -> Response { - admin::generate_posts() +async fn posts(_: AdminUser, Database(db): Database) -> Response { + admin::generate_posts(&db) } pub const ADMIN_USERS: EndpointDocumentation = EndpointDocumentation { @@ -131,8 +130,8 @@ pub const ADMIN_USERS: EndpointDocumentation = EndpointDocumentation { cookie: Some("admin"), }; -async fn users(_: AdminUser) -> Response { - admin::generate_users() +async fn users(_: AdminUser, Database(db): Database) -> Response { + admin::generate_users(&db) } pub const ADMIN_SESSIONS: EndpointDocumentation = EndpointDocumentation { @@ -148,8 +147,8 @@ pub const ADMIN_SESSIONS: EndpointDocumentation = EndpointDocumentation { cookie: Some("admin"), }; -async fn sessions(_: AdminUser) -> Response { - admin::generate_sessions() +async fn sessions(_: AdminUser, Database(db): Database) -> Response { + admin::generate_sessions(&db) } pub const ADMIN_COMMENTS: EndpointDocumentation = EndpointDocumentation { @@ -165,8 +164,8 @@ pub const ADMIN_COMMENTS: EndpointDocumentation = EndpointDocumentation { cookie: Some("admin"), }; -async fn comments(_: AdminUser) -> Response { - admin::generate_comments() +async fn comments(_: AdminUser, Database(db): Database) -> Response { + admin::generate_comments(&db) } pub const ADMIN_LIKES: EndpointDocumentation = EndpointDocumentation { @@ -182,8 +181,8 @@ pub const ADMIN_LIKES: EndpointDocumentation = EndpointDocumentation { cookie: Some("admin"), }; -async fn likes(_: AdminUser) -> Response { - admin::generate_likes() +async fn likes(_: AdminUser, Database(db): Database) -> Response { + admin::generate_likes(&db) } async fn check(check: Option) -> Response { diff --git a/src/api/auth.rs b/src/api/auth.rs index 60ddc80..c48b583 100644 --- a/src/api/auth.rs +++ b/src/api/auth.rs @@ -6,7 +6,7 @@ use tower_cookies::{Cookie, Cookies}; use crate::{ public::docs::{EndpointDocumentation, EndpointMethod}, types::{ - extract::{AuthorizedUser, Check, CheckResult, Json, Log}, + extract::{AuthorizedUser, Check, CheckResult, Database, Json, Log}, http::ResponseCode, session::Session, user::User, @@ -99,13 +99,17 @@ impl Check for RegistrationRequet { } } -async fn register(cookies: Cookies, Json(body): Json) -> Response { - let user = match User::new(body) { +async fn register( + cookies: Cookies, + Database(db): Database, + Json(body): Json, +) -> Response { + let user = match User::new(&db, body) { Ok(user) => user, Err(err) => return err, }; - let session = match Session::new(user.user_id) { + let session = match Session::new(&db, user.user_id) { Ok(session) => session, Err(err) => return err, }; @@ -158,8 +162,12 @@ impl Check for LoginRequest { } } -async fn login(cookies: Cookies, Json(body): Json) -> Response { - let Ok(user) = User::from_email(&body.email) else { +async fn login( + cookies: Cookies, + Database(db): Database, + Json(body): Json, +) -> Response { + let Ok(user) = User::from_email(&db, &body.email) else { return ResponseCode::BadRequest.text("Email is not registered") }; @@ -167,7 +175,7 @@ async fn login(cookies: Cookies, Json(body): Json) -> Response { return ResponseCode::BadRequest.text("Password is not correct"); } - let session = match Session::new(user.user_id) { + let session = match Session::new(&db, user.user_id) { Ok(session) => session, Err(err) => return err, }; @@ -199,10 +207,15 @@ pub const AUTH_LOGOUT: EndpointDocumentation = EndpointDocumentation { cookie: None, }; -async fn logout(cookies: Cookies, AuthorizedUser(user): AuthorizedUser, _: Log) -> Response { +async fn logout( + cookies: Cookies, + AuthorizedUser(user): AuthorizedUser, + Database(db): Database, + _: Log, +) -> Response { cookies.remove(Cookie::new("auth", "")); - if let Err(err) = Session::delete(user.user_id) { + if let Err(err) = Session::delete(&db, user.user_id) { return err; } diff --git a/src/api/mod.rs b/src/api/mod.rs index cd2190c..eeaaa0a 100644 --- a/src/api/mod.rs +++ b/src/api/mod.rs @@ -1,5 +1,15 @@ -use crate::types::extract::RouterURI; -use axum::{error_handling::HandleErrorLayer, BoxError, Extension, Router}; +use crate::{ + database, + types::extract::{DatabaseExtention, RouterURI}, +}; +use axum::{ + error_handling::HandleErrorLayer, + http::Request, + middleware::{self, Next}, + response::Response, + BoxError, Extension, Router, +}; +use tokio::sync::Mutex; use tower::ServiceBuilder; use tower_governor::{ errors::display_error, governor::GovernorConfigBuilder, key_extractor::SmartIpKeyExtractor, @@ -13,6 +23,18 @@ pub mod users; pub use auth::RegistrationRequet; +async fn connect(mut req: Request, next: Next) -> Response +where + B: Send, +{ + if let Ok(db) = database::Database::connect() { + let ex = DatabaseExtention(Mutex::new(db)); + req.extensions_mut().insert(ex); + } + + next.run(req).await +} + pub fn router() -> Router { let governor_conf = Box::new( GovernorConfigBuilder::default() @@ -49,4 +71,5 @@ pub fn router() -> Router { config: Box::leak(governor_conf), }), ) + .layer(middleware::from_fn(connect)) } diff --git a/src/api/posts.rs b/src/api/posts.rs index 57b2ca8..bd7e665 100644 --- a/src/api/posts.rs +++ b/src/api/posts.rs @@ -9,7 +9,7 @@ use crate::{ public::docs::{EndpointDocumentation, EndpointMethod}, types::{ comment::Comment, - extract::{AuthorizedUser, Check, CheckResult, Json}, + extract::{AuthorizedUser, Check, CheckResult, Database, Json}, http::ResponseCode, like::Like, post::Post, @@ -55,9 +55,10 @@ impl Check for PostCreateRequest { async fn create( AuthorizedUser(user): AuthorizedUser, + Database(db): Database, Json(body): Json, ) -> Response { - let Ok(post) = Post::new(user.user_id, body.content) else { + let Ok(post) = Post::new(&db, user.user_id, body.content) else { return ResponseCode::InternalServerError.text("Failed to create post") }; @@ -101,9 +102,10 @@ impl Check for PostPageRequest { async fn page( AuthorizedUser(user): AuthorizedUser, + Database(db): Database, Json(body): Json, ) -> Response { - let Ok(posts) = Post::from_post_page(user.user_id, body.page) else { + let Ok(posts) = Post::from_post_page(&db, user.user_id, body.page) else { return ResponseCode::InternalServerError.text("Failed to fetch posts") }; @@ -149,9 +151,10 @@ impl Check for CommentsPageRequest { async fn comments( AuthorizedUser(_user): AuthorizedUser, + Database(db): Database, Json(body): Json, ) -> Response { - let Ok(comments) = Comment::from_comment_page(body.page, body.post_id) else { + let Ok(comments) = Comment::from_comment_page(&db, body.page, body.post_id) else { return ResponseCode::InternalServerError.text("Failed to fetch comments") }; @@ -197,9 +200,10 @@ impl Check for UsersPostsRequest { async fn user( AuthorizedUser(user): AuthorizedUser, + Database(db): Database, Json(body): Json, ) -> Response { - let Ok(posts) = Post::from_user_post_page(user.user_id, body.user_id, body.page) else { + let Ok(posts) = Post::from_user_post_page(&db, user.user_id, body.user_id, body.page) else { return ResponseCode::InternalServerError.text("Failed to fetch posts") }; @@ -251,9 +255,10 @@ impl Check for PostCommentRequest { async fn comment( AuthorizedUser(user): AuthorizedUser, + Database(db): Database, Json(body): Json, ) -> Response { - if let Err(err) = Comment::new(user.user_id, body.post_id, &body.content) { + if let Err(err) = Comment::new(&db, user.user_id, body.post_id, &body.content) { return err; } @@ -293,12 +298,16 @@ impl Check for PostLikeRequest { } } -async fn like(AuthorizedUser(user): AuthorizedUser, Json(body): Json) -> Response { +async fn like( + AuthorizedUser(user): AuthorizedUser, + Database(db): Database, + Json(body): Json, +) -> Response { if body.state { - if let Err(err) = Like::add_liked(user.user_id, body.post_id) { + if let Err(err) = Like::add_liked(&db, user.user_id, body.post_id) { return err; } - } else if let Err(err) = Like::remove_liked(user.user_id, body.post_id) { + } else if let Err(err) = Like::remove_liked(&db, user.user_id, body.post_id) { return err; } diff --git a/src/api/users.rs b/src/api/users.rs index 082926e..71305c5 100644 --- a/src/api/users.rs +++ b/src/api/users.rs @@ -1,7 +1,7 @@ use crate::{ public::docs::{EndpointDocumentation, EndpointMethod}, types::{ - extract::{AuthorizedUser, Check, CheckResult, Json, Log, Png}, + extract::{AuthorizedUser, Check, CheckResult, Database, Json, Log, Png}, http::ResponseCode, user::User, }, @@ -46,9 +46,10 @@ impl Check for UserLoadRequest { async fn load_batch( AuthorizedUser(_user): AuthorizedUser, + Database(db): Database, Json(body): Json, ) -> Response { - let users = User::from_user_ids(body.ids); + let users = User::from_user_ids(&db, body.ids); let Ok(json) = serde_json::to_string(&users) else { return ResponseCode::InternalServerError.text("Failed to fetch users") }; @@ -90,9 +91,10 @@ impl Check for UserPageReqiest { async fn load_page( AuthorizedUser(_user): AuthorizedUser, + Database(db): Database, Json(body): Json, ) -> Response { - let Ok(users) = User::from_user_page(body.page) else { + let Ok(users) = User::from_user_page(&db, body.page) else { return ResponseCode::InternalServerError.text("Failed to fetch users") }; @@ -207,17 +209,18 @@ impl Check for UserFollowRequest { async fn follow( AuthorizedUser(user): AuthorizedUser, + Database(db): Database, Json(body): Json, ) -> Response { if body.state { - if let Err(err) = User::add_following(user.user_id, body.user_id) { + if let Err(err) = User::add_following(&db, user.user_id, body.user_id) { return err; } - } else if let Err(err) = User::remove_following(user.user_id, body.user_id) { + } else if let Err(err) = User::remove_following(&db, user.user_id, body.user_id) { return err; } - match User::get_following(user.user_id, body.user_id) { + match User::get_following(&db, user.user_id, body.user_id) { Ok(status) => ResponseCode::Success.text(&format!("{status}")), Err(err) => err, } @@ -259,9 +262,10 @@ impl Check for UserFollowStatusRequest { async fn follow_status( AuthorizedUser(user): AuthorizedUser, + Database(db): Database, Json(body): Json, ) -> Response { - match User::get_following(user.user_id, body.user_id) { + match User::get_following(&db, user.user_id, body.user_id) { Ok(status) => ResponseCode::Success.text(&format!("{status}")), Err(err) => err, } @@ -297,8 +301,12 @@ impl Check for UserFriendsRequest { } } -async fn friends(AuthorizedUser(_user): AuthorizedUser, Json(body): Json) -> Response { - let Ok(users) = User::get_friends(body.user_id) else { +async fn friends( + AuthorizedUser(_user): AuthorizedUser, + Database(db): Database, + Json(body): Json, +) -> Response { + let Ok(users) = User::get_friends(&db, body.user_id) else { return ResponseCode::InternalServerError.text("Failed to fetch user") }; diff --git a/src/database/comments.rs b/src/database/comments.rs index 9e0eaf9..5a1b39d 100644 --- a/src/database/comments.rs +++ b/src/database/comments.rs @@ -3,89 +3,100 @@ use std::time::{Duration, SystemTime, UNIX_EPOCH}; use rusqlite::Row; use tracing::instrument; -use crate::{database, types::comment::Comment}; +use crate::types::comment::Comment; -pub fn init() -> Result<(), rusqlite::Error> { - let sql = " - CREATE TABLE IF NOT EXISTS comments ( - comment_id INTEGER PRIMARY KEY AUTOINCREMENT, - user_id INTEGER NOT NULL, - post_id INTEGER NOT NULL, - date INTEGER NOT NULL, - content VARCHAR(255) NOT NULL, - FOREIGN KEY(user_id) REFERENCES users(user_id), - FOREIGN KEY(post_id) REFERENCES posts(post_id) - ); - "; - let conn = database::connect()?; - conn.execute(sql, ())?; +use super::Database; - let sql2 = "CREATE INDEX IF NOT EXISTS post_ids on comments (post_id);"; - conn.execute(sql2, ())?; +impl Database { + pub fn init_comments(&self) -> Result<(), rusqlite::Error> { + let sql = " + CREATE TABLE IF NOT EXISTS comments ( + comment_id INTEGER PRIMARY KEY AUTOINCREMENT, + user_id INTEGER NOT NULL, + post_id INTEGER NOT NULL, + date INTEGER NOT NULL, + content VARCHAR(255) NOT NULL, + FOREIGN KEY(user_id) REFERENCES users(user_id), + FOREIGN KEY(post_id) REFERENCES posts(post_id) + ); + "; + self.0.execute(sql, ())?; - Ok(()) -} - -fn comment_from_row(row: &Row) -> Result { - let comment_id = row.get(0)?; - let user_id = row.get(1)?; - let post_id = row.get(2)?; - let date = row.get(3)?; - let content = row.get(4)?; - - Ok(Comment { - comment_id, - user_id, - post_id, - date, - content, - }) -} - -#[instrument()] -pub fn get_comments_page(page: u64, post_id: u64) -> Result, rusqlite::Error> { - tracing::trace!("Retrieving comments page"); - let page_size = 5; - let conn = database::connect()?; - let mut stmt = conn.prepare( - "SELECT * FROM comments WHERE post_id = ? ORDER BY comment_id ASC LIMIT ? OFFSET ?", - )?; - let row = stmt.query_map([post_id, page_size, page_size * page], |row| { - let row = comment_from_row(row)?; - Ok(row) - })?; - Ok(row.into_iter().flatten().collect()) -} - -#[instrument()] -pub fn get_all_comments() -> Result, rusqlite::Error> { - tracing::trace!("Retrieving comments page"); - let conn = database::connect()?; - let mut stmt = conn.prepare("SELECT * FROM comments ORDER BY comment_id DESC")?; - let row = stmt.query_map([], |row| { - let row = comment_from_row(row)?; - Ok(row) - })?; - Ok(row.into_iter().flatten().collect()) -} - -#[instrument()] -pub fn add_comment(user_id: u64, post_id: u64, content: &str) -> Result { - tracing::trace!("Adding comment"); - let date = u64::try_from( - SystemTime::now() - .duration_since(UNIX_EPOCH) - .unwrap_or(Duration::ZERO) - .as_millis(), - ) - .unwrap_or(0); - let conn = database::connect()?; - let mut stmt = conn.prepare( - "INSERT INTO comments (user_id, post_id, date, content) VALUES(?,?,?,?) RETURNING *;", - )?; - let post = stmt.query_row((user_id, post_id, date, content), |row| { - let row = comment_from_row(row)?; - Ok(row) - })?; - Ok(post) + let sql2 = "CREATE INDEX IF NOT EXISTS post_ids on comments (post_id);"; + self.0.execute(sql2, ())?; + + Ok(()) + } + + fn comment_from_row(row: &Row) -> Result { + let comment_id = row.get(0)?; + let user_id = row.get(1)?; + let post_id = row.get(2)?; + let date = row.get(3)?; + let content = row.get(4)?; + + Ok(Comment { + comment_id, + user_id, + post_id, + date, + content, + }) + } + + #[instrument(skip(self))] + pub fn get_comments_page( + &self, + page: u64, + post_id: u64, + ) -> Result, rusqlite::Error> { + tracing::trace!("Retrieving comments page"); + let page_size = 5; + let mut stmt = self.0.prepare( + "SELECT * FROM comments WHERE post_id = ? ORDER BY comment_id ASC LIMIT ? OFFSET ?", + )?; + let row = stmt.query_map([post_id, page_size, page_size * page], |row| { + let row = Self::comment_from_row(row)?; + Ok(row) + })?; + Ok(row.into_iter().flatten().collect()) + } + + #[instrument(skip(self))] + pub fn get_all_comments(&self) -> Result, rusqlite::Error> { + tracing::trace!("Retrieving comments page"); + let mut stmt = self + .0 + .prepare("SELECT * FROM comments ORDER BY comment_id DESC")?; + let row = stmt.query_map([], |row| { + let row = Self::comment_from_row(row)?; + Ok(row) + })?; + Ok(row.into_iter().flatten().collect()) + } + + #[instrument(skip(self))] + pub fn add_comment( + &self, + user_id: u64, + post_id: u64, + content: &str, + ) -> Result { + tracing::trace!("Adding comment"); + let date = u64::try_from( + SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap_or(Duration::ZERO) + .as_millis(), + ) + .unwrap_or(0); + let mut stmt = self.0.prepare( + "INSERT INTO comments (user_id, post_id, date, content) VALUES(?,?,?,?) RETURNING *;", + )?; + let post = stmt.query_row((user_id, post_id, date, content), |row| { + let row = Self::comment_from_row(row)?; + Ok(row) + })?; + Ok(post) + } } diff --git a/src/database/friends.rs b/src/database/friends.rs index 0b78488..31434d4 100644 --- a/src/database/friends.rs +++ b/src/database/friends.rs @@ -1,97 +1,100 @@ use tracing::instrument; -use crate::{ - database::{self, users::user_from_row}, - types::user::{User, FOLLOWED, FOLLOWING, NO_RELATION}, -}; +use crate::types::user::{User, FOLLOWED, FOLLOWING, NO_RELATION}; -pub fn init() -> Result<(), rusqlite::Error> { - let sql = " - CREATE TABLE IF NOT EXISTS friends ( - follower_id INTEGER NOT NULL, - followee_id INTEGER NOT NULL, - FOREIGN KEY(follower_id) REFERENCES users(user_id), - FOREIGN KEY(followee_id) REFERENCES users(user_id), - PRIMARY KEY (follower_id, followee_id) - ); - "; - let conn = database::connect()?; - conn.execute(sql, ())?; - Ok(()) -} +use super::Database; -#[instrument()] -pub fn get_friend_status(user_id_1: u64, user_id_2: u64) -> Result { - tracing::trace!("Retrieving friend status"); - let conn = database::connect()?; - let mut stmt = conn.prepare("SELECT * FROM friends WHERE (follower_id = ? AND followee_id = ?) OR (follower_id = ? AND followee_id = ?);")?; - let mut status = NO_RELATION; - let rows: Vec = stmt - .query_map([user_id_1, user_id_2, user_id_2, user_id_1], |row| { - let id: u64 = row.get(0)?; - Ok(id) - })? - .into_iter() - .flatten() - .collect(); - - for follower in rows { - if follower == user_id_1 { - status |= FOLLOWING; - } - - if follower == user_id_2 { - status |= FOLLOWED; - } +impl Database { + pub fn init_friends(&self) -> Result<(), rusqlite::Error> { + let sql = " + CREATE TABLE IF NOT EXISTS friends ( + follower_id INTEGER NOT NULL, + followee_id INTEGER NOT NULL, + FOREIGN KEY(follower_id) REFERENCES users(user_id), + FOREIGN KEY(followee_id) REFERENCES users(user_id), + PRIMARY KEY (follower_id, followee_id) + ); + "; + self.0.execute(sql, ())?; + Ok(()) } - Ok(status) -} + #[instrument(skip(self))] + pub fn get_friend_status(&self, user_id_1: u64, user_id_2: u64) -> Result { + tracing::trace!("Retrieving friend status"); + let mut stmt = self.0.prepare("SELECT * FROM friends WHERE (follower_id = ? AND followee_id = ?) OR (follower_id = ? AND followee_id = ?);")?; + let mut status = NO_RELATION; + let rows: Vec = stmt + .query_map([user_id_1, user_id_2, user_id_2, user_id_1], |row| { + let id: u64 = row.get(0)?; + Ok(id) + })? + .into_iter() + .flatten() + .collect(); -#[instrument()] -pub fn get_friends(user_id: u64) -> Result, rusqlite::Error> { - tracing::trace!("Retrieving friends"); - let conn = database::connect()?; - let mut stmt = conn.prepare( - " - SELECT * - FROM users u - WHERE EXISTS ( - SELECT NULL - FROM friends f - WHERE u.user_id = f.follower_id - AND f.followee_id = ? - ) - AND EXISTS ( - SELECT NULL - FROM friends f - WHERE u.user_id = f.followee_id - AND f.follower_id = ? - ) - ", - )?; - let row = stmt.query_map([user_id, user_id], |row| { - let row = user_from_row(row, true)?; - Ok(row) - })?; - Ok(row.into_iter().flatten().collect()) -} + for follower in rows { + if follower == user_id_1 { + status |= FOLLOWING; + } -#[instrument()] -pub fn set_following(user_id_1: u64, user_id_2: u64) -> Result { - tracing::trace!("Setting following"); - let conn = database::connect()?; - let mut stmt = - conn.prepare("INSERT OR REPLACE INTO friends (follower_id, followee_id) VALUES (?,?)")?; - let changes = stmt.execute([user_id_1, user_id_2])?; - Ok(changes == 1) -} + if follower == user_id_2 { + status |= FOLLOWED; + } + } -#[instrument()] -pub fn remove_following(user_id_1: u64, user_id_2: u64) -> Result { - tracing::trace!("Removing following"); - let conn = database::connect()?; - let mut stmt = conn.prepare("DELETE FROM friends WHERE follower_id = ? AND followee_id = ?")?; - let changes = stmt.execute([user_id_1, user_id_2])?; - Ok(changes == 1) + Ok(status) + } + + #[instrument(skip(self))] + pub fn get_friends(&self, user_id: u64) -> Result, rusqlite::Error> { + tracing::trace!("Retrieving friends"); + let mut stmt = self.0.prepare( + " + SELECT * + FROM users u + WHERE EXISTS ( + SELECT NULL + FROM friends f + WHERE u.user_id = f.follower_id + AND f.followee_id = ? + ) + AND EXISTS ( + SELECT NULL + FROM friends f + WHERE u.user_id = f.followee_id + AND f.follower_id = ? + ) + ", + )?; + let row = stmt.query_map([user_id, user_id], |row| { + let row = Self::user_from_row(row, true)?; + Ok(row) + })?; + Ok(row.into_iter().flatten().collect()) + } + + #[instrument(skip(self))] + pub fn set_following(&self, user_id_1: u64, user_id_2: u64) -> Result { + tracing::trace!("Setting following"); + let mut stmt = self + .0 + .prepare("INSERT OR REPLACE INTO friends (follower_id, followee_id) VALUES (?,?)")?; + let changes = stmt.execute([user_id_1, user_id_2])?; + Ok(changes == 1) + } + + #[instrument(skip(self))] + pub fn remove_following( + &self, + user_id_1: u64, + user_id_2: u64, + ) -> Result { + tracing::trace!("Removing following"); + let mut stmt = self + .0 + .prepare("DELETE FROM friends WHERE follower_id = ? AND followee_id = ?")?; + let changes = stmt.execute([user_id_1, user_id_2])?; + Ok(changes == 1) + } } diff --git a/src/database/likes.rs b/src/database/likes.rs index f6a130b..b313c97 100644 --- a/src/database/likes.rs +++ b/src/database/likes.rs @@ -1,75 +1,81 @@ use rusqlite::OptionalExtension; use tracing::instrument; -use crate::{database, types::like::Like}; +use crate::types::like::Like; -pub fn init() -> Result<(), rusqlite::Error> { - let sql = " - CREATE TABLE IF NOT EXISTS likes ( - user_id INTEGER NOT NULL, - post_id INTEGER NOT NULL, - FOREIGN KEY(user_id) REFERENCES users(user_id), - FOREIGN KEY(post_id) REFERENCES posts(post_id), - PRIMARY KEY (user_id, post_id) - ); - "; - let conn = database::connect()?; - conn.execute(sql, ())?; - Ok(()) -} +use super::Database; -#[instrument()] -pub fn get_like_count(post_id: u64) -> Result, rusqlite::Error> { - tracing::trace!("Retrieving like count"); - let conn = database::connect()?; - let mut stmt = conn.prepare("SELECT COUNT(post_id) FROM likes WHERE post_id = ?")?; - let row = stmt - .query_row([post_id], |row| { - let row = row.get(0)?; - Ok(row) - }) - .optional()?; - Ok(row) -} +impl Database { + pub fn init_likes(&self) -> Result<(), rusqlite::Error> { + let sql = " + CREATE TABLE IF NOT EXISTS likes ( + user_id INTEGER NOT NULL, + post_id INTEGER NOT NULL, + FOREIGN KEY(user_id) REFERENCES users(user_id), + FOREIGN KEY(post_id) REFERENCES posts(post_id), + PRIMARY KEY (user_id, post_id) + ); + "; + self.0.execute(sql, ())?; + Ok(()) + } -#[instrument()] -pub fn get_liked(user_id: u64, post_id: u64) -> Result { - tracing::trace!("Retrieving if liked"); - let conn = database::connect()?; - let mut stmt = conn.prepare("SELECT * FROM likes WHERE user_id = ? AND post_id = ?")?; - let liked = stmt.query_row([user_id, post_id], |_| Ok(())).optional()?; - Ok(liked.is_some()) -} + #[instrument(skip(self))] + pub fn get_like_count(&self, post_id: u64) -> Result, rusqlite::Error> { + tracing::trace!("Retrieving like count"); + let mut stmt = self + .0 + .prepare("SELECT COUNT(post_id) FROM likes WHERE post_id = ?")?; + let row = stmt + .query_row([post_id], |row| { + let row = row.get(0)?; + Ok(row) + }) + .optional()?; + Ok(row) + } -#[instrument()] -pub fn add_liked(user_id: u64, post_id: u64) -> Result { - tracing::trace!("Adding like"); - let conn = database::connect()?; - let mut stmt = conn.prepare("INSERT OR REPLACE INTO likes (user_id, post_id) VALUES (?,?)")?; - let changes = stmt.execute([user_id, post_id])?; - Ok(changes == 1) -} + #[instrument(skip(self))] + pub fn get_liked(&self, user_id: u64, post_id: u64) -> Result { + tracing::trace!("Retrieving if liked"); + let mut stmt = self + .0 + .prepare("SELECT * FROM likes WHERE user_id = ? AND post_id = ?")?; + let liked = stmt.query_row([user_id, post_id], |_| Ok(())).optional()?; + Ok(liked.is_some()) + } -#[instrument()] -pub fn remove_liked(user_id: u64, post_id: u64) -> Result { - tracing::trace!("Removing like"); - let conn = database::connect()?; - let mut stmt = conn.prepare("DELETE FROM likes WHERE user_id = ? AND post_id = ?;")?; - let changes = stmt.execute((user_id, post_id))?; - Ok(changes == 1) -} + #[instrument(skip(self))] + pub fn add_liked(&self, user_id: u64, post_id: u64) -> Result { + tracing::trace!("Adding like"); + let mut stmt = self + .0 + .prepare("INSERT OR REPLACE INTO likes (user_id, post_id) VALUES (?,?)")?; + let changes = stmt.execute([user_id, post_id])?; + Ok(changes == 1) + } -#[instrument()] -pub fn get_all_likes() -> Result, rusqlite::Error> { - tracing::trace!("Retrieving comments page"); - let conn = database::connect()?; - let mut stmt = conn.prepare("SELECT * FROM likes")?; - let row = stmt.query_map([], |row| { - let like = Like { - user_id: row.get(0)?, - post_id: row.get(1)?, - }; - Ok(like) - })?; - Ok(row.into_iter().flatten().collect()) + #[instrument(skip(self))] + pub fn remove_liked(&self, user_id: u64, post_id: u64) -> Result { + tracing::trace!("Removing like"); + let mut stmt = self + .0 + .prepare("DELETE FROM likes WHERE user_id = ? AND post_id = ?;")?; + let changes = stmt.execute((user_id, post_id))?; + Ok(changes == 1) + } + + #[instrument(skip(self))] + pub fn get_all_likes(&self) -> Result, rusqlite::Error> { + tracing::trace!("Retrieving comments page"); + let mut stmt = self.0.prepare("SELECT * FROM likes")?; + let row = stmt.query_map([], |row| { + let like = Like { + user_id: row.get(0)?, + post_id: row.get(1)?, + }; + Ok(like) + })?; + Ok(row.into_iter().flatten().collect()) + } } diff --git a/src/database/mod.rs b/src/database/mod.rs index d22a350..67e05c6 100644 --- a/src/database/mod.rs +++ b/src/database/mod.rs @@ -1,3 +1,4 @@ +use rusqlite::Connection; use tracing::instrument; pub mod comments; @@ -7,23 +8,29 @@ pub mod posts; pub mod sessions; pub mod users; -pub fn connect() -> Result { - rusqlite::Connection::open("xssbook.db") +#[derive(Debug)] +pub struct Database(Connection); + +impl Database { + pub fn connect() -> Result { + let conn = rusqlite::Connection::open("xssbook.db")?; + Ok(Self(conn)) + } + + #[instrument(skip(self))] + pub fn query(&self, query: String) -> Result { + tracing::trace!("Running custom query"); + self.0.execute(&query, []) + } } pub fn init() -> Result<(), rusqlite::Error> { - users::init()?; - posts::init()?; - sessions::init()?; - likes::init()?; - comments::init()?; - friends::init()?; + let db = Database::connect()?; + db.init_users()?; + db.init_posts()?; + db.init_sessions()?; + db.init_likes()?; + db.init_comments()?; + db.init_friends()?; Ok(()) } - -#[instrument()] -pub fn query(query: String) -> Result { - tracing::trace!("Running custom query"); - let conn = connect()?; - conn.execute(&query, []) -} diff --git a/src/database/posts.rs b/src/database/posts.rs index c33e7e7..fa0fd3c 100644 --- a/src/database/posts.rs +++ b/src/database/posts.rs @@ -3,115 +3,122 @@ use std::time::{Duration, SystemTime, UNIX_EPOCH}; use rusqlite::{OptionalExtension, Row}; use tracing::instrument; -use crate::database; use crate::types::post::Post; -use super::{comments, likes}; +use super::Database; -pub fn init() -> Result<(), rusqlite::Error> { - let sql = " - CREATE TABLE IF NOT EXISTS posts ( - post_id INTEGER PRIMARY KEY AUTOINCREMENT, - user_id INTEGER NOT NULL, - content VARCHAR(500) NOT NULL, - date INTEGER NOT NULL, - FOREIGN KEY(user_id) REFERENCES users(user_id) - ); - "; - let conn = database::connect()?; - conn.execute(sql, ())?; - Ok(()) -} +impl Database { + pub fn init_posts(&self) -> Result<(), rusqlite::Error> { + let sql = " + CREATE TABLE IF NOT EXISTS posts ( + post_id INTEGER PRIMARY KEY AUTOINCREMENT, + user_id INTEGER NOT NULL, + content VARCHAR(500) NOT NULL, + date INTEGER NOT NULL, + FOREIGN KEY(user_id) REFERENCES users(user_id) + ); + "; + self.0.execute(sql, ())?; + Ok(()) + } -fn post_from_row(row: &Row) -> Result { - let post_id = row.get(0)?; - let user_id = row.get(1)?; - let content = row.get(2)?; - let date = row.get(3)?; + fn post_from_row(&self, row: &Row) -> Result { + let post_id = row.get(0)?; + let user_id = row.get(1)?; + let content = row.get(2)?; + let date = row.get(3)?; - let comments = comments::get_comments_page(0, post_id).unwrap_or_else(|_| Vec::new()); - let likes = likes::get_like_count(post_id).unwrap_or(None).unwrap_or(0); + let comments = self + .get_comments_page(0, post_id) + .unwrap_or_else(|_| Vec::new()); + let likes = self.get_like_count(post_id).unwrap_or(None).unwrap_or(0); - Ok(Post { - post_id, - user_id, - content, - date, - likes, - liked: false, - comments, - }) -} - -#[instrument()] -pub fn get_post(post_id: u64) -> Result, rusqlite::Error> { - tracing::trace!("Retrieving post"); - let conn = database::connect()?; - let mut stmt = conn.prepare("SELECT * FROM posts WHERE post_id = ?")?; - let row = stmt - .query_row([post_id], |row| { - let row = post_from_row(row)?; - Ok(row) + Ok(Post { + post_id, + user_id, + content, + date, + likes, + liked: false, + comments, }) - .optional()?; - Ok(row) -} + } -#[instrument()] -pub fn get_post_page(page: u64) -> Result, rusqlite::Error> { - tracing::trace!("Retrieving posts page"); - let page_size = 10; - let conn = database::connect()?; - let mut stmt = conn.prepare("SELECT * FROM posts ORDER BY post_id DESC LIMIT ? OFFSET ?")?; - let row = stmt.query_map([page_size, page_size * page], |row| { - let row = post_from_row(row)?; + #[instrument(skip(self))] + pub fn get_post(&self, post_id: u64) -> Result, rusqlite::Error> { + tracing::trace!("Retrieving post"); + let mut stmt = self.0.prepare("SELECT * FROM posts WHERE post_id = ?")?; + let row = stmt + .query_row([post_id], |row| { + let row = self.post_from_row(row)?; + Ok(row) + }) + .optional()?; Ok(row) - })?; - Ok(row.into_iter().flatten().collect()) -} + } -#[instrument()] -pub fn get_all_posts() -> Result, rusqlite::Error> { - tracing::trace!("Retrieving posts page"); - let conn = database::connect()?; - let mut stmt = conn.prepare("SELECT * FROM posts ORDER BY post_id DESC")?; - let row = stmt.query_map([], |row| { - let row = post_from_row(row)?; - Ok(row) - })?; - Ok(row.into_iter().flatten().collect()) -} + #[instrument(skip(self))] + pub fn get_post_page(&self, page: u64) -> Result, rusqlite::Error> { + tracing::trace!("Retrieving posts page"); + let page_size = 10; + let mut stmt = self + .0 + .prepare("SELECT * FROM posts ORDER BY post_id DESC LIMIT ? OFFSET ?")?; + let row = stmt.query_map([page_size, page_size * page], |row| { + let row = self.post_from_row(row)?; + Ok(row) + })?; + Ok(row.into_iter().flatten().collect()) + } -#[instrument()] -pub fn get_users_post_page(user_id: u64, page: u64) -> Result, rusqlite::Error> { - tracing::trace!("Retrieving users posts"); - let page_size = 10; - let conn = database::connect()?; - let mut stmt = conn - .prepare("SELECT * FROM posts WHERE user_id = ? ORDER BY post_id DESC LIMIT ? OFFSET ?")?; - let row = stmt.query_map([user_id, page_size, page_size * page], |row| { - let row = post_from_row(row)?; - Ok(row) - })?; - Ok(row.into_iter().flatten().collect()) -} + #[instrument(skip(self))] + pub fn get_all_posts(&self) -> Result, rusqlite::Error> { + tracing::trace!("Retrieving posts page"); + let mut stmt = self + .0 + .prepare("SELECT * FROM posts ORDER BY post_id DESC")?; + let row = stmt.query_map([], |row| { + let row = self.post_from_row(row)?; + Ok(row) + })?; + Ok(row.into_iter().flatten().collect()) + } -#[instrument()] -pub fn add_post(user_id: u64, content: &str) -> Result { - tracing::trace!("Adding post"); - let date = u64::try_from( - SystemTime::now() - .duration_since(UNIX_EPOCH) - .unwrap_or(Duration::ZERO) - .as_millis(), - ) - .unwrap_or(0); - let conn = database::connect()?; - let mut stmt = - conn.prepare("INSERT INTO posts (user_id, content, date) VALUES(?,?,?) RETURNING *;")?; - let post = stmt.query_row((user_id, content, date), |row| { - let row = post_from_row(row)?; - Ok(row) - })?; - Ok(post) + #[instrument(skip(self))] + pub fn get_users_post_page( + &self, + user_id: u64, + page: u64, + ) -> Result, rusqlite::Error> { + tracing::trace!("Retrieving users posts"); + let page_size = 10; + let mut stmt = self.0.prepare( + "SELECT * FROM posts WHERE user_id = ? ORDER BY post_id DESC LIMIT ? OFFSET ?", + )?; + let row = stmt.query_map([user_id, page_size, page_size * page], |row| { + let row = self.post_from_row(row)?; + Ok(row) + })?; + Ok(row.into_iter().flatten().collect()) + } + + #[instrument(skip(self))] + pub fn add_post(&self, user_id: u64, content: &str) -> Result { + tracing::trace!("Adding post"); + let date = u64::try_from( + SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap_or(Duration::ZERO) + .as_millis(), + ) + .unwrap_or(0); + let mut stmt = self + .0 + .prepare("INSERT INTO posts (user_id, content, date) VALUES(?,?,?) RETURNING *;")?; + let post = stmt.query_row((user_id, content, date), |row| { + let row = self.post_from_row(row)?; + Ok(row) + })?; + Ok(post) + } } diff --git a/src/database/sessions.rs b/src/database/sessions.rs index 9adccd4..a50bb51 100644 --- a/src/database/sessions.rs +++ b/src/database/sessions.rs @@ -1,65 +1,64 @@ use rusqlite::OptionalExtension; use tracing::instrument; -use crate::{database, types::session::Session}; +use crate::types::session::Session; -pub fn init() -> Result<(), rusqlite::Error> { - let sql = " - CREATE TABLE IF NOT EXISTS sessions ( - user_id INTEGER PRIMARY KEY NOT NULL, - token TEXT NOT NULL, - FOREIGN KEY(user_id) REFERENCES users(user_id) - ); - "; - let conn = database::connect()?; - conn.execute(sql, ())?; - Ok(()) -} +use super::Database; -#[instrument()] -pub fn get_session(token: &str) -> Result, rusqlite::Error> { - tracing::trace!("Retrieving session"); - let conn = database::connect()?; - let mut stmt = conn.prepare("SELECT * FROM sessions WHERE token = ?")?; - let row = stmt - .query_row([token], |row| { +impl Database { + pub fn init_sessions(&self) -> Result<(), rusqlite::Error> { + let sql = " + CREATE TABLE IF NOT EXISTS sessions ( + user_id INTEGER PRIMARY KEY NOT NULL, + token TEXT NOT NULL, + FOREIGN KEY(user_id) REFERENCES users(user_id) + ); + "; + self.0.execute(sql, ())?; + Ok(()) + } + + #[instrument(skip(self))] + pub fn get_session(&self, token: &str) -> Result, rusqlite::Error> { + tracing::trace!("Retrieving session"); + let mut stmt = self.0.prepare("SELECT * FROM sessions WHERE token = ?")?; + let row = stmt + .query_row([token], |row| { + Ok(Session { + user_id: row.get(0)?, + token: row.get(1)?, + }) + }) + .optional()?; + Ok(row) + } + + #[instrument(skip(self))] + pub fn get_all_sessions(&self) -> Result, rusqlite::Error> { + tracing::trace!("Retrieving session"); + let mut stmt = self.0.prepare("SELECT * FROM sessions")?; + let row = stmt.query_map([], |row| { Ok(Session { user_id: row.get(0)?, token: row.get(1)?, }) - }) - .optional()?; - Ok(row) -} + })?; + Ok(row.into_iter().flatten().collect()) + } -#[instrument()] -pub fn get_all_sessions() -> Result, rusqlite::Error> { - tracing::trace!("Retrieving session"); - let conn = database::connect()?; - let mut stmt = conn.prepare("SELECT * FROM sessions")?; - let row = stmt.query_map([], |row| { - Ok(Session { - user_id: row.get(0)?, - token: row.get(1)?, - }) - })?; - Ok(row.into_iter().flatten().collect()) -} + #[instrument(skip(self))] + pub fn set_session(&self, user_id: u64, token: &str) -> Result<(), Box> { + tracing::trace!("Setting new session"); + let sql = "INSERT OR REPLACE INTO sessions (user_id, token) VALUES (?, ?);"; + self.0.execute(sql, (user_id, token))?; + Ok(()) + } -#[instrument()] -pub fn set_session(user_id: u64, token: &str) -> Result<(), Box> { - tracing::trace!("Setting new session"); - let conn = database::connect()?; - let sql = "INSERT OR REPLACE INTO sessions (user_id, token) VALUES (?, ?);"; - conn.execute(sql, (user_id, token))?; - Ok(()) -} - -#[instrument()] -pub fn delete_session(user_id: u64) -> Result<(), Box> { - tracing::trace!("Deleting session"); - let conn = database::connect()?; - let sql = "DELETE FROM sessions WHERE user_id = ?;"; - conn.execute(sql, [user_id])?; - Ok(()) + #[instrument(skip(self))] + pub fn delete_session(&self, user_id: u64) -> Result<(), Box> { + tracing::trace!("Deleting session"); + let sql = "DELETE FROM sessions WHERE user_id = ?;"; + self.0.execute(sql, [user_id])?; + Ok(()) + } } diff --git a/src/database/users.rs b/src/database/users.rs index 6062ea8..9df69ee 100644 --- a/src/database/users.rs +++ b/src/database/users.rs @@ -2,169 +2,180 @@ use rusqlite::{OptionalExtension, Row}; use std::time::{Duration, SystemTime, UNIX_EPOCH}; use tracing::instrument; -use crate::{api::RegistrationRequet, database, types::user::User}; +use crate::{api::RegistrationRequet, types::user::User}; -pub fn init() -> Result<(), rusqlite::Error> { - let sql = " - CREATE TABLE IF NOT EXISTS users ( - user_id INTEGER PRIMARY KEY AUTOINCREMENT, - firstname VARCHAR(20) NOT NULL, - lastname VARCHAR(20) NOT NULL, - email VARCHAR(50) NOT NULL, - password VARCHAR(50) NOT NULL, - gender VARCHAR(100) NOT NULL, - date BIGINT NOT NULL, - day TINYINT NOT NULL, - month TINYINT NOT NULL, - year INTEGER NOT NULL - ); - "; - let conn = database::connect()?; - conn.execute(sql, ())?; +use super::Database; - let sql2 = "CREATE UNIQUE INDEX IF NOT EXISTS emails on users (email);"; - conn.execute(sql2, ())?; +impl Database { + pub fn init_users(&self) -> Result<(), rusqlite::Error> { + let sql = " + CREATE TABLE IF NOT EXISTS users ( + user_id INTEGER PRIMARY KEY AUTOINCREMENT, + firstname VARCHAR(20) NOT NULL, + lastname VARCHAR(20) NOT NULL, + email VARCHAR(50) NOT NULL, + password VARCHAR(50) NOT NULL, + gender VARCHAR(100) NOT NULL, + date BIGINT NOT NULL, + day TINYINT NOT NULL, + month TINYINT NOT NULL, + year INTEGER NOT NULL + ); + "; + self.0.execute(sql, ())?; - let sql3 = "CREATE UNIQUE INDEX IF NOT EXISTS passwords on users (password);"; - conn.execute(sql3, ())?; + let sql2 = "CREATE UNIQUE INDEX IF NOT EXISTS emails on users (email);"; + self.0.execute(sql2, ())?; - Ok(()) -} + let sql3 = "CREATE UNIQUE INDEX IF NOT EXISTS passwords on users (password);"; + self.0.execute(sql3, ())?; -pub fn user_from_row(row: &Row, hide_password: bool) -> Result { - let user_id = row.get(0)?; - let firstname = row.get(1)?; - let lastname = row.get(2)?; - let email = row.get(3)?; - let password = row.get(4)?; - let gender = row.get(5)?; - let date = row.get(6)?; - let day = row.get(7)?; - let month = row.get(8)?; - let year = row.get(9)?; + Ok(()) + } - let password = if hide_password { - String::new() - } else { - password - }; + pub fn user_from_row(row: &Row, hide_password: bool) -> Result { + let user_id = row.get(0)?; + let firstname = row.get(1)?; + let lastname = row.get(2)?; + let email = row.get(3)?; + let password = row.get(4)?; + let gender = row.get(5)?; + let date = row.get(6)?; + let day = row.get(7)?; + let month = row.get(8)?; + let year = row.get(9)?; - Ok(User { - user_id, - firstname, - lastname, - email, - password, - gender, - date, - day, - month, - year, - }) -} + let password = if hide_password { + String::new() + } else { + password + }; -#[instrument()] -pub fn get_user_by_id(user_id: u64, hide_password: bool) -> Result, rusqlite::Error> { - tracing::trace!("Retrieving user by id"); - let conn = database::connect()?; - let mut stmt = conn.prepare("SELECT * FROM users WHERE user_id = ?")?; - let row = stmt - .query_row([user_id], |row| { - let row = user_from_row(row, hide_password)?; - Ok(row) - }) - .optional()?; - Ok(row) -} - -#[instrument()] -pub fn get_user_by_email( - email: &str, - hide_password: bool, -) -> Result, rusqlite::Error> { - tracing::trace!("Retrieving user by email"); - let conn = database::connect()?; - let mut stmt = conn.prepare("SELECT * FROM users WHERE email = ?")?; - let row = stmt - .query_row([email], |row| { - let row = user_from_row(row, hide_password)?; - Ok(row) - }) - .optional()?; - Ok(row) -} - -#[instrument()] -pub fn get_user_by_password( - password: &str, - hide_password: bool, -) -> Result, rusqlite::Error> { - tracing::trace!("Retrieving user by password"); - let conn = database::connect()?; - let mut stmt = conn.prepare("SELECT * FROM users WHERE password = ?")?; - let row = stmt - .query_row([password], |row| { - let row = user_from_row(row, hide_password)?; - Ok(row) - }) - .optional()?; - Ok(row) -} - -#[instrument()] -pub fn get_user_page(page: u64, hide_password: bool) -> Result, rusqlite::Error> { - tracing::trace!("Retrieving user page"); - let page_size = 5; - let conn = database::connect()?; - let mut stmt = conn.prepare("SELECT * FROM users ORDER BY user_id DESC LIMIT ? OFFSET ?")?; - let row = stmt.query_map([page_size, page_size * page], |row| { - let row = user_from_row(row, hide_password)?; - Ok(row) - })?; - Ok(row.into_iter().flatten().collect()) -} - -#[instrument()] -pub fn get_all_users() -> Result, rusqlite::Error> { - tracing::trace!("Retrieving user page"); - let conn = database::connect()?; - let mut stmt = conn.prepare("SELECT * FROM users ORDER BY user_id DESC")?; - let row = stmt.query_map([], |row| { - let row = user_from_row(row, false)?; - Ok(row) - })?; - Ok(row.into_iter().flatten().collect()) -} - -#[instrument()] -pub fn add_user(request: RegistrationRequet) -> Result { - tracing::trace!("Adding new user"); - let date = u64::try_from( - SystemTime::now() - .duration_since(UNIX_EPOCH) - .unwrap_or(Duration::ZERO) - .as_millis(), - ) - .unwrap_or(0); - - let conn = database::connect()?; - let mut stmt = conn.prepare("INSERT INTO users (firstname, lastname, email, password, gender, date, day, month, year) VALUES(?,?,?,?,?,?,?,?,?) RETURNING *;")?; - let user = stmt.query_row( - ( - request.firstname, - request.lastname, - request.email, - request.password, - request.gender, + Ok(User { + user_id, + firstname, + lastname, + email, + password, + gender, date, - request.day, - request.month, - request.year, - ), - |row| { - let row = user_from_row(row, false)?; + day, + month, + year, + }) + } + + #[instrument(skip(self))] + pub fn get_user_by_id( + &self, + user_id: u64, + hide_password: bool, + ) -> Result, rusqlite::Error> { + tracing::trace!("Retrieving user by id"); + let mut stmt = self.0.prepare("SELECT * FROM users WHERE user_id = ?")?; + let row = stmt + .query_row([user_id], |row| { + let row = Self::user_from_row(row, hide_password)?; + Ok(row) + }) + .optional()?; + Ok(row) + } + + #[instrument(skip(self))] + pub fn get_user_by_email( + &self, + email: &str, + hide_password: bool, + ) -> Result, rusqlite::Error> { + tracing::trace!("Retrieving user by email"); + let mut stmt = self.0.prepare("SELECT * FROM users WHERE email = ?")?; + let row = stmt + .query_row([email], |row| { + let row = Self::user_from_row(row, hide_password)?; + Ok(row) + }) + .optional()?; + Ok(row) + } + + #[instrument(skip(self))] + pub fn get_user_by_password( + &self, + password: &str, + hide_password: bool, + ) -> Result, rusqlite::Error> { + tracing::trace!("Retrieving user by password"); + let mut stmt = self.0.prepare("SELECT * FROM users WHERE password = ?")?; + let row = stmt + .query_row([password], |row| { + let row = Self::user_from_row(row, hide_password)?; + Ok(row) + }) + .optional()?; + Ok(row) + } + + #[instrument(skip(self))] + pub fn get_user_page( + &self, + page: u64, + hide_password: bool, + ) -> Result, rusqlite::Error> { + tracing::trace!("Retrieving user page"); + let page_size = 5; + let mut stmt = self + .0 + .prepare("SELECT * FROM users ORDER BY user_id DESC LIMIT ? OFFSET ?")?; + let row = stmt.query_map([page_size, page_size * page], |row| { + let row = Self::user_from_row(row, hide_password)?; Ok(row) - }, - )?; - Ok(user) + })?; + Ok(row.into_iter().flatten().collect()) + } + + #[instrument(skip(self))] + pub fn get_all_users(&self) -> Result, rusqlite::Error> { + tracing::trace!("Retrieving user page"); + let mut stmt = self + .0 + .prepare("SELECT * FROM users ORDER BY user_id DESC")?; + let row = stmt.query_map([], |row| { + let row = Self::user_from_row(row, false)?; + Ok(row) + })?; + Ok(row.into_iter().flatten().collect()) + } + + #[instrument(skip(self))] + pub fn add_user(&self, request: RegistrationRequet) -> Result { + tracing::trace!("Adding new user"); + let date = u64::try_from( + SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap_or(Duration::ZERO) + .as_millis(), + ) + .unwrap_or(0); + + let mut stmt = self.0.prepare("INSERT INTO users (firstname, lastname, email, password, gender, date, day, month, year) VALUES(?,?,?,?,?,?,?,?,?) RETURNING *;")?; + let user = stmt.query_row( + ( + request.firstname, + request.lastname, + request.email, + request.password, + request.gender, + date, + request.day, + request.month, + request.year, + ), + |row| { + let row = Self::user_from_row(row, false)?; + Ok(row) + }, + )?; + Ok(user) + } } diff --git a/src/public/admin.rs b/src/public/admin.rs index bf0a155..6e32152 100644 --- a/src/public/admin.rs +++ b/src/public/admin.rs @@ -5,6 +5,7 @@ use tokio::sync::Mutex; use crate::{ console::sanatize, + database::Database, types::{ comment::Comment, http::ResponseCode, like::Like, post::Post, session::Session, user::User, }, @@ -36,8 +37,8 @@ pub async fn regen_secret() -> String { secret.clone() } -pub fn generate_users() -> Response { - let users = match User::reterieve_all() { +pub fn generate_users(db: &Database) -> Response { + let users = match User::reterieve_all(db) { Ok(users) => users, Err(err) => return err, }; @@ -70,8 +71,8 @@ pub fn generate_users() -> Response { ResponseCode::Success.text(&html) } -pub fn generate_posts() -> Response { - let posts = match Post::reterieve_all() { +pub fn generate_posts(db: &Database) -> Response { + let posts = match Post::reterieve_all(db) { Ok(posts) => posts, Err(err) => return err, }; @@ -99,8 +100,8 @@ pub fn generate_posts() -> Response { ResponseCode::Success.text(&html) } -pub fn generate_sessions() -> Response { - let sessions = match Session::reterieve_all() { +pub fn generate_sessions(db: &Database) -> Response { + let sessions = match Session::reterieve_all(db) { Ok(sessions) => sessions, Err(err) => return err, }; @@ -123,8 +124,8 @@ pub fn generate_sessions() -> Response { ResponseCode::Success.text(&html) } -pub fn generate_comments() -> Response { - let comments = match Comment::reterieve_all() { +pub fn generate_comments(db: &Database) -> Response { + let comments = match Comment::reterieve_all(db) { Ok(comments) => comments, Err(err) => return err, }; @@ -154,8 +155,8 @@ pub fn generate_comments() -> Response { ResponseCode::Success.text(&html) } -pub fn generate_likes() -> Response { - let likes = match Like::reterieve_all() { +pub fn generate_likes(db: &Database) -> Response { + let likes = match Like::reterieve_all(db) { Ok(likes) => likes, Err(err) => return err, }; diff --git a/src/types/comment.rs b/src/types/comment.rs index cf94bd3..0836950 100644 --- a/src/types/comment.rs +++ b/src/types/comment.rs @@ -2,7 +2,7 @@ use serde::Serialize; use tracing::instrument; use crate::{ - database::{self, comments}, + database::Database, types::http::{ResponseCode, Result}, }; @@ -16,9 +16,9 @@ pub struct Comment { } impl Comment { - #[instrument()] - pub fn new(user_id: u64, post_id: u64, content: &str) -> Result { - let Ok(comment) = comments::add_comment(user_id, post_id, content) else { + #[instrument(skip(db))] + pub fn new(db: &Database, user_id: u64, post_id: u64, content: &str) -> Result { + let Ok(comment) = db.add_comment(user_id, post_id, content) else { tracing::error!("Failed to create comment"); return Err(ResponseCode::InternalServerError.text("Failed to create post")) }; @@ -26,17 +26,17 @@ impl Comment { Ok(comment) } - #[instrument()] - pub fn from_comment_page(page: u64, post_id: u64) -> Result> { - let Ok(posts) = database::comments::get_comments_page(page, post_id) else { + #[instrument(skip(db))] + pub fn from_comment_page(db: &Database, page: u64, post_id: u64) -> Result> { + let Ok(posts) = db.get_comments_page(page, post_id) else { return Err(ResponseCode::BadRequest.text("Failed to fetch comments")) }; Ok(posts) } - #[instrument()] - pub fn reterieve_all() -> Result> { - let Ok(posts) = database::comments::get_all_comments() else { + #[instrument(skip(db))] + pub fn reterieve_all(db: &Database) -> Result> { + let Ok(posts) = db.get_all_comments() else { return Err(ResponseCode::InternalServerError.text("Failed to fetch comments")) }; Ok(posts) diff --git a/src/types/extract.rs b/src/types/extract.rs index 65d9f1a..f05215f 100644 --- a/src/types/extract.rs +++ b/src/types/extract.rs @@ -14,9 +14,11 @@ use axum::{ use bytes::Bytes; use image::{io::Reader, DynamicImage, ImageFormat}; use serde::de::DeserializeOwned; +use tokio::sync::Mutex; use tower_cookies::Cookies; use crate::{ + database, public::admin, public::console, types::{ @@ -97,11 +99,17 @@ where return Err(ResponseCode::Forbidden.text("No auth token provided")) }; - let Ok(session) = Session::from_token(token.value()) else { + let Some(db) = parts.extensions.get::() else { + return Err(ResponseCode::Forbidden.text("Could not connect to database")) + }; + + let db = db.0.lock().await; + + let Ok(session) = Session::from_token(&db, token.value()) else { return Err(ResponseCode::Unauthorized.text("Auth token invalid")) }; - let Ok(user) = User::from_user_id(session.user_id, true) else { + let Ok(user) = User::from_user_id(&db, session.user_id, true) else { tracing::error!("Valid token but no valid user"); return Err(ResponseCode::InternalServerError.text("Valid token but no valid user")) }; @@ -260,6 +268,26 @@ where } } +pub struct DatabaseExtention(pub Mutex); +pub struct Database(pub database::Database); + +#[async_trait] +impl FromRequestParts for Database +where + S: Send + Sync, +{ + type Rejection = Response; + + async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result { + let db = parts.extensions.remove::(); + let Some(db) = db else { + return Err(ResponseCode::InternalServerError.text("Database is not loaded")) + }; + + Ok(Self(db.0.into_inner())) + } +} + async fn read_body(mut req: Request, state: &S) -> Result> where B: HttpBody + Sync + Send + 'static, diff --git a/src/types/like.rs b/src/types/like.rs index 1c113c1..8eec941 100644 --- a/src/types/like.rs +++ b/src/types/like.rs @@ -1,7 +1,7 @@ use serde::Serialize; use tracing::instrument; -use crate::database; +use crate::database::Database; use crate::types::http::{ResponseCode, Result}; #[derive(Serialize)] @@ -11,9 +11,9 @@ pub struct Like { } impl Like { - #[instrument()] - pub fn add_liked(user_id: u64, post_id: u64) -> Result<()> { - let Ok(liked) = database::likes::add_liked(user_id, post_id) else { + #[instrument(skip(db))] + pub fn add_liked(db: &Database, user_id: u64, post_id: u64) -> Result<()> { + let Ok(liked) = db.add_liked(user_id, post_id) else { return Err(ResponseCode::BadRequest.text("Failed to add like status")) }; @@ -24,9 +24,9 @@ impl Like { Ok(()) } - #[instrument()] - pub fn remove_liked(user_id: u64, post_id: u64) -> Result<()> { - let Ok(liked) = database::likes::remove_liked(user_id, post_id) else { + #[instrument(skip(db))] + pub fn remove_liked(db: &Database, user_id: u64, post_id: u64) -> Result<()> { + let Ok(liked) = db.remove_liked(user_id, post_id) else { return Err(ResponseCode::BadRequest.text("Failed to remove like status")) }; @@ -37,9 +37,9 @@ impl Like { Ok(()) } - #[instrument()] - pub fn reterieve_all() -> Result> { - let Ok(likes) = database::likes::get_all_likes() else { + #[instrument(skip(db))] + pub fn reterieve_all(db: &Database) -> Result> { + let Ok(likes) = db.get_all_likes() else { return Err(ResponseCode::InternalServerError.text("Failed to fetch likes")) }; Ok(likes) diff --git a/src/types/post.rs b/src/types/post.rs index 09f2f50..bff68c7 100644 --- a/src/types/post.rs +++ b/src/types/post.rs @@ -2,7 +2,7 @@ use core::fmt; use serde::Serialize; use tracing::instrument; -use crate::database; +use crate::database::Database; use crate::types::http::{ResponseCode, Result}; use super::comment::Comment; @@ -27,57 +27,62 @@ impl fmt::Debug for Post { } impl Post { - #[instrument()] - pub fn from_post_id(self_id: u64, post_id: u64) -> Result { - let Ok(Some(mut post)) = database::posts::get_post(post_id) else { + #[instrument(skip(db))] + pub fn from_post_id(db: &Database, self_id: u64, post_id: u64) -> Result { + let Ok(Some(mut post)) = db.get_post(post_id) else { return Err(ResponseCode::BadRequest.text("Post does not exist")) }; - let liked = database::likes::get_liked(self_id, post.post_id).unwrap_or(false); + let liked = db.get_liked(self_id, post.post_id).unwrap_or(false); post.liked = liked; Ok(post) } - #[instrument()] - pub fn from_post_page(self_id: u64, page: u64) -> Result> { - let Ok(mut posts) = database::posts::get_post_page(page) else { + #[instrument(skip(db))] + pub fn from_post_page(db: &Database, self_id: u64, page: u64) -> Result> { + let Ok(mut posts) = db.get_post_page(page) else { return Err(ResponseCode::BadRequest.text("Failed to fetch posts")) }; for post in &mut posts { - let liked = database::likes::get_liked(self_id, post.post_id).unwrap_or(false); + let liked = db.get_liked(self_id, post.post_id).unwrap_or(false); post.liked = liked; } Ok(posts) } - #[instrument()] - pub fn from_user_post_page(self_id: u64, user_id: u64, page: u64) -> Result> { - let Ok(mut posts) = database::posts::get_users_post_page(user_id, page) else { + #[instrument(skip(db))] + pub fn from_user_post_page( + db: &Database, + self_id: u64, + user_id: u64, + page: u64, + ) -> Result> { + let Ok(mut posts) = db.get_users_post_page(user_id, page) else { return Err(ResponseCode::BadRequest.text("Failed to fetch posts")) }; for post in &mut posts { - let liked = database::likes::get_liked(self_id, post.post_id).unwrap_or(false); + let liked = db.get_liked(self_id, post.post_id).unwrap_or(false); post.liked = liked; } Ok(posts) } - #[instrument()] - pub fn reterieve_all() -> Result> { - let Ok(posts) = database::posts::get_all_posts() else { + #[instrument(skip(db))] + pub fn reterieve_all(db: &Database) -> Result> { + let Ok(posts) = db.get_all_posts() else { return Err(ResponseCode::InternalServerError.text("Failed to fetch posts")) }; Ok(posts) } - #[instrument()] - pub fn new(user_id: u64, content: String) -> Result { - let Ok(post) = database::posts::add_post(user_id, &content) else { + #[instrument(skip(db))] + pub fn new(db: &Database, user_id: u64, content: String) -> Result { + let Ok(post) = db.add_post(user_id, &content) else { tracing::error!("Failed to create post"); return Err(ResponseCode::InternalServerError.text("Failed to create post")) }; diff --git a/src/types/session.rs b/src/types/session.rs index a9073aa..27c5c66 100644 --- a/src/types/session.rs +++ b/src/types/session.rs @@ -2,7 +2,7 @@ use rand::{distributions::Alphanumeric, Rng}; use serde::Serialize; use tracing::instrument; -use crate::database; +use crate::database::Database; use crate::types::http::{ResponseCode, Result}; #[derive(Serialize)] @@ -12,39 +12,39 @@ pub struct Session { } impl Session { - #[instrument()] - pub fn from_token(token: &str) -> Result { - let Ok(Some(session)) = database::sessions::get_session(token) else { + #[instrument(skip(db))] + pub fn from_token(db: &Database, token: &str) -> Result { + let Ok(Some(session)) = db.get_session(token) else { return Err(ResponseCode::BadRequest.text("Invalid auth token")); }; Ok(session) } - #[instrument()] - pub fn reterieve_all() -> Result> { - let Ok(sessions) = database::sessions::get_all_sessions() else { + #[instrument(skip(db))] + pub fn reterieve_all(db: &Database) -> Result> { + let Ok(sessions) = db.get_all_sessions() else { return Err(ResponseCode::InternalServerError.text("Failed to fetch sessions")) }; Ok(sessions) } - #[instrument()] - pub fn new(user_id: u64) -> Result { + #[instrument(skip(db))] + pub fn new(db: &Database, user_id: u64) -> Result { let token: String = rand::thread_rng() .sample_iter(&Alphanumeric) .take(32) .map(char::from) .collect(); - match database::sessions::set_session(user_id, &token) { + match db.set_session(user_id, &token) { Err(_) => Err(ResponseCode::BadRequest.text("Failed to create session")), Ok(_) => Ok(Self { user_id, token }), } } - #[instrument()] - pub fn delete(user_id: u64) -> Result<()> { - if database::sessions::delete_session(user_id).is_err() { + #[instrument(skip(db))] + pub fn delete(db: &Database, user_id: u64) -> Result<()> { + if db.delete_session(user_id).is_err() { tracing::error!("Failed to logout user"); return Err(ResponseCode::InternalServerError.text("Failed to logout")); }; diff --git a/src/types/user.rs b/src/types/user.rs index 245e9b7..3c4cd6a 100644 --- a/src/types/user.rs +++ b/src/types/user.rs @@ -2,7 +2,7 @@ use serde::{Deserialize, Serialize}; use tracing::instrument; use crate::api::RegistrationRequet; -use crate::database; +use crate::database::Database; use crate::types::http::{ResponseCode, Result}; #[derive(Serialize, Deserialize, Debug)] @@ -24,21 +24,21 @@ pub const FOLLOWING: u8 = 1; pub const FOLLOWED: u8 = 2; impl User { - #[instrument()] - pub fn from_user_id(user_id: u64, hide_password: bool) -> Result { - let Ok(Some(user)) = database::users::get_user_by_id(user_id, hide_password) else { + #[instrument(skip(db))] + pub fn from_user_id(db: &Database, user_id: u64, hide_password: bool) -> Result { + let Ok(Some(user)) = db.get_user_by_id(user_id, hide_password) else { return Err(ResponseCode::BadRequest.text("User does not exist")) }; Ok(user) } - #[instrument()] - pub fn from_user_ids(user_ids: Vec) -> Vec { + #[instrument(skip(db))] + pub fn from_user_ids(db: &Database, user_ids: Vec) -> Vec { user_ids .iter() .filter_map(|user_id| { - let Ok(Some(user)) = database::users::get_user_by_id(*user_id, true) else { + let Ok(Some(user)) = db.get_user_by_id(*user_id, true) else { return None; }; Some(user) @@ -46,53 +46,53 @@ impl User { .collect() } - #[instrument()] - pub fn from_user_page(page: u64) -> Result> { - let Ok(users) = database::users::get_user_page(page, true) else { + #[instrument(skip(db))] + pub fn from_user_page(db: &Database, page: u64) -> Result> { + let Ok(users) = db.get_user_page(page, true) else { return Err(ResponseCode::BadRequest.text("Failed to fetch users")) }; Ok(users) } - #[instrument()] - pub fn from_email(email: &str) -> Result { - let Ok(Some(user)) = database::users::get_user_by_email(email, false) else { + #[instrument(skip(db))] + pub fn from_email(db: &Database, email: &str) -> Result { + let Ok(Some(user)) = db.get_user_by_email(email, false) else { return Err(ResponseCode::BadRequest.text("User does not exist")) }; Ok(user) } - #[instrument()] - pub fn from_password(password: &str) -> Result { - let Ok(Some(user)) = database::users::get_user_by_password(password, true) else { + #[instrument(skip(db))] + pub fn from_password(db: &Database, password: &str) -> Result { + let Ok(Some(user)) = db.get_user_by_password(password, true) else { return Err(ResponseCode::BadRequest.text("User does not exist")) }; Ok(user) } - #[instrument()] - pub fn reterieve_all() -> Result> { - let Ok(users) = database::users::get_all_users() else { + #[instrument(skip(db))] + pub fn reterieve_all(db: &Database) -> Result> { + let Ok(users) = db.get_all_users() else { return Err(ResponseCode::InternalServerError.text("Failed to fetch users")) }; Ok(users) } - #[instrument()] - pub fn new(request: RegistrationRequet) -> Result { - if Self::from_email(&request.email).is_ok() { + #[instrument(skip(db))] + pub fn new(db: &Database, request: RegistrationRequet) -> Result { + if Self::from_email(db, &request.email).is_ok() { return Err(ResponseCode::BadRequest .text(&format!("Email is already in use by {}", &request.email))); } - if let Ok(user) = Self::from_password(&request.password) { + if let Ok(user) = Self::from_password(db, &request.password) { return Err(ResponseCode::BadRequest .text(&format!("Password is already in use by {}", user.email))); } - let Ok(user) = database::users::add_user(request) else { + let Ok(user) = db.add_user(request) else { tracing::error!("Failed to create new user"); return Err(ResponseCode::InternalServerError.text("Failed to create new uesr")) }; @@ -100,8 +100,9 @@ impl User { Ok(user) } - pub fn add_following(user_id_1: u64, user_id_2: u64) -> Result<()> { - let Ok(followed) = database::friends::set_following(user_id_1, user_id_2) else { + #[instrument(skip(db))] + pub fn add_following(db: &Database, user_id_1: u64, user_id_2: u64) -> Result<()> { + let Ok(followed) = db.set_following(user_id_1, user_id_2) else { return Err(ResponseCode::BadRequest.text("Failed to add follow status")) }; @@ -112,8 +113,9 @@ impl User { Ok(()) } - pub fn remove_following(user_id_1: u64, user_id_2: u64) -> Result<()> { - let Ok(followed) = database::friends::remove_following(user_id_1, user_id_2) else { + #[instrument(skip(db))] + pub fn remove_following(db: &Database, user_id_1: u64, user_id_2: u64) -> Result<()> { + let Ok(followed) = db.remove_following(user_id_1, user_id_2) else { return Err(ResponseCode::BadRequest.text("Failed to remove follow status")) }; @@ -124,15 +126,17 @@ impl User { Ok(()) } - pub fn get_following(user_id_1: u64, user_id_2: u64) -> Result { - let Ok(followed) = database::friends::get_friend_status(user_id_1, user_id_2) else { + #[instrument(skip(db))] + pub fn get_following(db: &Database, user_id_1: u64, user_id_2: u64) -> Result { + let Ok(followed) = db.get_friend_status(user_id_1, user_id_2) else { return Err(ResponseCode::InternalServerError.text("Failed to get follow status")) }; Ok(followed) } - pub fn get_friends(user_id: u64) -> Result> { - let Ok(users) = database::friends::get_friends(user_id) else { + #[instrument(skip(db))] + pub fn get_friends(db: &Database, user_id: u64) -> Result> { + let Ok(users) = db.get_friends(user_id) else { return Err(ResponseCode::InternalServerError.text("Failed to fetch friends")) }; Ok(users)