summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--src/api/admin.rs27
-rw-r--r--src/api/auth.rs31
-rw-r--r--src/api/mod.rs27
-rw-r--r--src/api/posts.rs27
-rw-r--r--src/api/users.rs26
-rw-r--r--src/database/comments.rs165
-rw-r--r--src/database/friends.rs173
-rw-r--r--src/database/likes.rs136
-rw-r--r--src/database/mod.rs37
-rw-r--r--src/database/posts.rs205
-rw-r--r--src/database/sessions.rs105
-rw-r--r--src/database/users.rs307
-rw-r--r--src/public/admin.rs21
-rw-r--r--src/types/comment.rs20
-rw-r--r--src/types/extract.rs32
-rw-r--r--src/types/like.rs20
-rw-r--r--src/types/post.rs43
-rw-r--r--src/types/session.rs26
-rw-r--r--src/types/user.rs68
19 files changed, 815 insertions, 681 deletions
diff --git a/src/api/admin.rs b/src/api/admin.rs
index 6030315..f412d75 100644
--- a/src/api/admin.rs
+++ b/src/api/admin.rs
@@ -5,13 +5,12 @@ use serde::Deserialize;
use tower_cookies::{Cookie, Cookies};
use crate::{
- database,
public::{
admin,
docs::{EndpointDocumentation, EndpointMethod},
},
types::{
- extract::{AdminUser, Check, CheckResult, Json},
+ extract::{AdminUser, Check, CheckResult, Database, Json},
http::ResponseCode,
},
};
@@ -92,8 +91,8 @@ impl Check for QueryRequest {
}
}
-async fn query(_: AdminUser, Json(body): Json<QueryRequest>) -> Response {
- match database::query(body.query) {
+async fn query(_: AdminUser, Database(db): Database, Json(body): Json<QueryRequest>) -> Response {
+ match db.query(body.query) {
Ok(changes) => ResponseCode::Success.text(&format!(
"Query executed successfully. {changes} lines changed."
)),
@@ -114,8 +113,8 @@ pub const ADMIN_POSTS: EndpointDocumentation = EndpointDocumentation {
cookie: Some("admin"),
};
-async fn posts(_: AdminUser) -> Response {
- admin::generate_posts()
+async fn posts(_: AdminUser, Database(db): Database) -> Response {
+ admin::generate_posts(&db)
}
pub const ADMIN_USERS: EndpointDocumentation = EndpointDocumentation {
@@ -131,8 +130,8 @@ pub const ADMIN_USERS: EndpointDocumentation = EndpointDocumentation {
cookie: Some("admin"),
};
-async fn users(_: AdminUser) -> Response {
- admin::generate_users()
+async fn users(_: AdminUser, Database(db): Database) -> Response {
+ admin::generate_users(&db)
}
pub const ADMIN_SESSIONS: EndpointDocumentation = EndpointDocumentation {
@@ -148,8 +147,8 @@ pub const ADMIN_SESSIONS: EndpointDocumentation = EndpointDocumentation {
cookie: Some("admin"),
};
-async fn sessions(_: AdminUser) -> Response {
- admin::generate_sessions()
+async fn sessions(_: AdminUser, Database(db): Database) -> Response {
+ admin::generate_sessions(&db)
}
pub const ADMIN_COMMENTS: EndpointDocumentation = EndpointDocumentation {
@@ -165,8 +164,8 @@ pub const ADMIN_COMMENTS: EndpointDocumentation = EndpointDocumentation {
cookie: Some("admin"),
};
-async fn comments(_: AdminUser) -> Response {
- admin::generate_comments()
+async fn comments(_: AdminUser, Database(db): Database) -> Response {
+ admin::generate_comments(&db)
}
pub const ADMIN_LIKES: EndpointDocumentation = EndpointDocumentation {
@@ -182,8 +181,8 @@ pub const ADMIN_LIKES: EndpointDocumentation = EndpointDocumentation {
cookie: Some("admin"),
};
-async fn likes(_: AdminUser) -> Response {
- admin::generate_likes()
+async fn likes(_: AdminUser, Database(db): Database) -> Response {
+ admin::generate_likes(&db)
}
async fn check(check: Option<AdminUser>) -> Response {
diff --git a/src/api/auth.rs b/src/api/auth.rs
index 60ddc80..c48b583 100644
--- a/src/api/auth.rs
+++ b/src/api/auth.rs
@@ -6,7 +6,7 @@ use tower_cookies::{Cookie, Cookies};
use crate::{
public::docs::{EndpointDocumentation, EndpointMethod},
types::{
- extract::{AuthorizedUser, Check, CheckResult, Json, Log},
+ extract::{AuthorizedUser, Check, CheckResult, Database, Json, Log},
http::ResponseCode,
session::Session,
user::User,
@@ -99,13 +99,17 @@ impl Check for RegistrationRequet {
}
}
-async fn register(cookies: Cookies, Json(body): Json<RegistrationRequet>) -> Response {
- let user = match User::new(body) {
+async fn register(
+ cookies: Cookies,
+ Database(db): Database,
+ Json(body): Json<RegistrationRequet>,
+) -> Response {
+ let user = match User::new(&db, body) {
Ok(user) => user,
Err(err) => return err,
};
- let session = match Session::new(user.user_id) {
+ let session = match Session::new(&db, user.user_id) {
Ok(session) => session,
Err(err) => return err,
};
@@ -158,8 +162,12 @@ impl Check for LoginRequest {
}
}
-async fn login(cookies: Cookies, Json(body): Json<LoginRequest>) -> Response {
- let Ok(user) = User::from_email(&body.email) else {
+async fn login(
+ cookies: Cookies,
+ Database(db): Database,
+ Json(body): Json<LoginRequest>,
+) -> Response {
+ let Ok(user) = User::from_email(&db, &body.email) else {
return ResponseCode::BadRequest.text("Email is not registered")
};
@@ -167,7 +175,7 @@ async fn login(cookies: Cookies, Json(body): Json<LoginRequest>) -> Response {
return ResponseCode::BadRequest.text("Password is not correct");
}
- let session = match Session::new(user.user_id) {
+ let session = match Session::new(&db, user.user_id) {
Ok(session) => session,
Err(err) => return err,
};
@@ -199,10 +207,15 @@ pub const AUTH_LOGOUT: EndpointDocumentation = EndpointDocumentation {
cookie: None,
};
-async fn logout(cookies: Cookies, AuthorizedUser(user): AuthorizedUser, _: Log) -> Response {
+async fn logout(
+ cookies: Cookies,
+ AuthorizedUser(user): AuthorizedUser,
+ Database(db): Database,
+ _: Log,
+) -> Response {
cookies.remove(Cookie::new("auth", ""));
- if let Err(err) = Session::delete(user.user_id) {
+ if let Err(err) = Session::delete(&db, user.user_id) {
return err;
}
diff --git a/src/api/mod.rs b/src/api/mod.rs
index cd2190c..eeaaa0a 100644
--- a/src/api/mod.rs
+++ b/src/api/mod.rs
@@ -1,5 +1,15 @@
-use crate::types::extract::RouterURI;
-use axum::{error_handling::HandleErrorLayer, BoxError, Extension, Router};
+use crate::{
+ database,
+ types::extract::{DatabaseExtention, RouterURI},
+};
+use axum::{
+ error_handling::HandleErrorLayer,
+ http::Request,
+ middleware::{self, Next},
+ response::Response,
+ BoxError, Extension, Router,
+};
+use tokio::sync::Mutex;
use tower::ServiceBuilder;
use tower_governor::{
errors::display_error, governor::GovernorConfigBuilder, key_extractor::SmartIpKeyExtractor,
@@ -13,6 +23,18 @@ pub mod users;
pub use auth::RegistrationRequet;
+async fn connect<B>(mut req: Request<B>, next: Next<B>) -> Response
+where
+ B: Send,
+{
+ if let Ok(db) = database::Database::connect() {
+ let ex = DatabaseExtention(Mutex::new(db));
+ req.extensions_mut().insert(ex);
+ }
+
+ next.run(req).await
+}
+
pub fn router() -> Router {
let governor_conf = Box::new(
GovernorConfigBuilder::default()
@@ -49,4 +71,5 @@ pub fn router() -> Router {
config: Box::leak(governor_conf),
}),
)
+ .layer(middleware::from_fn(connect))
}
diff --git a/src/api/posts.rs b/src/api/posts.rs
index 57b2ca8..bd7e665 100644
--- a/src/api/posts.rs
+++ b/src/api/posts.rs
@@ -9,7 +9,7 @@ use crate::{
public::docs::{EndpointDocumentation, EndpointMethod},
types::{
comment::Comment,
- extract::{AuthorizedUser, Check, CheckResult, Json},
+ extract::{AuthorizedUser, Check, CheckResult, Database, Json},
http::ResponseCode,
like::Like,
post::Post,
@@ -55,9 +55,10 @@ impl Check for PostCreateRequest {
async fn create(
AuthorizedUser(user): AuthorizedUser,
+ Database(db): Database,
Json(body): Json<PostCreateRequest>,
) -> Response {
- let Ok(post) = Post::new(user.user_id, body.content) else {
+ let Ok(post) = Post::new(&db, user.user_id, body.content) else {
return ResponseCode::InternalServerError.text("Failed to create post")
};
@@ -101,9 +102,10 @@ impl Check for PostPageRequest {
async fn page(
AuthorizedUser(user): AuthorizedUser,
+ Database(db): Database,
Json(body): Json<PostPageRequest>,
) -> Response {
- let Ok(posts) = Post::from_post_page(user.user_id, body.page) else {
+ let Ok(posts) = Post::from_post_page(&db, user.user_id, body.page) else {
return ResponseCode::InternalServerError.text("Failed to fetch posts")
};
@@ -149,9 +151,10 @@ impl Check for CommentsPageRequest {
async fn comments(
AuthorizedUser(_user): AuthorizedUser,
+ Database(db): Database,
Json(body): Json<CommentsPageRequest>,
) -> Response {
- let Ok(comments) = Comment::from_comment_page(body.page, body.post_id) else {
+ let Ok(comments) = Comment::from_comment_page(&db, body.page, body.post_id) else {
return ResponseCode::InternalServerError.text("Failed to fetch comments")
};
@@ -197,9 +200,10 @@ impl Check for UsersPostsRequest {
async fn user(
AuthorizedUser(user): AuthorizedUser,
+ Database(db): Database,
Json(body): Json<UsersPostsRequest>,
) -> Response {
- let Ok(posts) = Post::from_user_post_page(user.user_id, body.user_id, body.page) else {
+ let Ok(posts) = Post::from_user_post_page(&db, user.user_id, body.user_id, body.page) else {
return ResponseCode::InternalServerError.text("Failed to fetch posts")
};
@@ -251,9 +255,10 @@ impl Check for PostCommentRequest {
async fn comment(
AuthorizedUser(user): AuthorizedUser,
+ Database(db): Database,
Json(body): Json<PostCommentRequest>,
) -> Response {
- if let Err(err) = Comment::new(user.user_id, body.post_id, &body.content) {
+ if let Err(err) = Comment::new(&db, user.user_id, body.post_id, &body.content) {
return err;
}
@@ -293,12 +298,16 @@ impl Check for PostLikeRequest {
}
}
-async fn like(AuthorizedUser(user): AuthorizedUser, Json(body): Json<PostLikeRequest>) -> Response {
+async fn like(
+ AuthorizedUser(user): AuthorizedUser,
+ Database(db): Database,
+ Json(body): Json<PostLikeRequest>,
+) -> Response {
if body.state {
- if let Err(err) = Like::add_liked(user.user_id, body.post_id) {
+ if let Err(err) = Like::add_liked(&db, user.user_id, body.post_id) {
return err;
}
- } else if let Err(err) = Like::remove_liked(user.user_id, body.post_id) {
+ } else if let Err(err) = Like::remove_liked(&db, user.user_id, body.post_id) {
return err;
}
diff --git a/src/api/users.rs b/src/api/users.rs
index 082926e..71305c5 100644
--- a/src/api/users.rs
+++ b/src/api/users.rs
@@ -1,7 +1,7 @@
use crate::{
public::docs::{EndpointDocumentation, EndpointMethod},
types::{
- extract::{AuthorizedUser, Check, CheckResult, Json, Log, Png},
+ extract::{AuthorizedUser, Check, CheckResult, Database, Json, Log, Png},
http::ResponseCode,
user::User,
},
@@ -46,9 +46,10 @@ impl Check for UserLoadRequest {
async fn load_batch(
AuthorizedUser(_user): AuthorizedUser,
+ Database(db): Database,
Json(body): Json<UserLoadRequest>,
) -> Response {
- let users = User::from_user_ids(body.ids);
+ let users = User::from_user_ids(&db, body.ids);
let Ok(json) = serde_json::to_string(&users) else {
return ResponseCode::InternalServerError.text("Failed to fetch users")
};
@@ -90,9 +91,10 @@ impl Check for UserPageReqiest {
async fn load_page(
AuthorizedUser(_user): AuthorizedUser,
+ Database(db): Database,
Json(body): Json<UserPageReqiest>,
) -> Response {
- let Ok(users) = User::from_user_page(body.page) else {
+ let Ok(users) = User::from_user_page(&db, body.page) else {
return ResponseCode::InternalServerError.text("Failed to fetch users")
};
@@ -207,17 +209,18 @@ impl Check for UserFollowRequest {
async fn follow(
AuthorizedUser(user): AuthorizedUser,
+ Database(db): Database,
Json(body): Json<UserFollowRequest>,
) -> Response {
if body.state {
- if let Err(err) = User::add_following(user.user_id, body.user_id) {
+ if let Err(err) = User::add_following(&db, user.user_id, body.user_id) {
return err;
}
- } else if let Err(err) = User::remove_following(user.user_id, body.user_id) {
+ } else if let Err(err) = User::remove_following(&db, user.user_id, body.user_id) {
return err;
}
- match User::get_following(user.user_id, body.user_id) {
+ match User::get_following(&db, user.user_id, body.user_id) {
Ok(status) => ResponseCode::Success.text(&format!("{status}")),
Err(err) => err,
}
@@ -259,9 +262,10 @@ impl Check for UserFollowStatusRequest {
async fn follow_status(
AuthorizedUser(user): AuthorizedUser,
+ Database(db): Database,
Json(body): Json<UserFollowStatusRequest>,
) -> Response {
- match User::get_following(user.user_id, body.user_id) {
+ match User::get_following(&db, user.user_id, body.user_id) {
Ok(status) => ResponseCode::Success.text(&format!("{status}")),
Err(err) => err,
}
@@ -297,8 +301,12 @@ impl Check for UserFriendsRequest {
}
}
-async fn friends(AuthorizedUser(_user): AuthorizedUser, Json(body): Json<UserFriendsRequest>) -> Response {
- let Ok(users) = User::get_friends(body.user_id) else {
+async fn friends(
+ AuthorizedUser(_user): AuthorizedUser,
+ Database(db): Database,
+ Json(body): Json<UserFriendsRequest>,
+) -> Response {
+ let Ok(users) = User::get_friends(&db, body.user_id) else {
return ResponseCode::InternalServerError.text("Failed to fetch user")
};
diff --git a/src/database/comments.rs b/src/database/comments.rs
index 9e0eaf9..5a1b39d 100644
--- a/src/database/comments.rs
+++ b/src/database/comments.rs
@@ -3,89 +3,100 @@ use std::time::{Duration, SystemTime, UNIX_EPOCH};
use rusqlite::Row;
use tracing::instrument;
-use crate::{database, types::comment::Comment};
+use crate::types::comment::Comment;
-pub fn init() -> Result<(), rusqlite::Error> {
- let sql = "
- CREATE TABLE IF NOT EXISTS comments (
- comment_id INTEGER PRIMARY KEY AUTOINCREMENT,
- user_id INTEGER NOT NULL,
- post_id INTEGER NOT NULL,
- date INTEGER NOT NULL,
- content VARCHAR(255) NOT NULL,
- FOREIGN KEY(user_id) REFERENCES users(user_id),
- FOREIGN KEY(post_id) REFERENCES posts(post_id)
- );
- ";
- let conn = database::connect()?;
- conn.execute(sql, ())?;
+use super::Database;
- let sql2 = "CREATE INDEX IF NOT EXISTS post_ids on comments (post_id);";
- conn.execute(sql2, ())?;
+impl Database {
+ pub fn init_comments(&self) -> Result<(), rusqlite::Error> {
+ let sql = "
+ CREATE TABLE IF NOT EXISTS comments (
+ comment_id INTEGER PRIMARY KEY AUTOINCREMENT,
+ user_id INTEGER NOT NULL,
+ post_id INTEGER NOT NULL,
+ date INTEGER NOT NULL,
+ content VARCHAR(255) NOT NULL,
+ FOREIGN KEY(user_id) REFERENCES users(user_id),
+ FOREIGN KEY(post_id) REFERENCES posts(post_id)
+ );
+ ";
+ self.0.execute(sql, ())?;
- Ok(())
-}
+ let sql2 = "CREATE INDEX IF NOT EXISTS post_ids on comments (post_id);";
+ self.0.execute(sql2, ())?;
-fn comment_from_row(row: &Row) -> Result<Comment, rusqlite::Error> {
- let comment_id = row.get(0)?;
- let user_id = row.get(1)?;
- let post_id = row.get(2)?;
- let date = row.get(3)?;
- let content = row.get(4)?;
+ Ok(())
+ }
- Ok(Comment {
- comment_id,
- user_id,
- post_id,
- date,
- content,
- })
-}
+ fn comment_from_row(row: &Row) -> Result<Comment, rusqlite::Error> {
+ let comment_id = row.get(0)?;
+ let user_id = row.get(1)?;
+ let post_id = row.get(2)?;
+ let date = row.get(3)?;
+ let content = row.get(4)?;
-#[instrument()]
-pub fn get_comments_page(page: u64, post_id: u64) -> Result<Vec<Comment>, rusqlite::Error> {
- tracing::trace!("Retrieving comments page");
- let page_size = 5;
- let conn = database::connect()?;
- let mut stmt = conn.prepare(
- "SELECT * FROM comments WHERE post_id = ? ORDER BY comment_id ASC LIMIT ? OFFSET ?",
- )?;
- let row = stmt.query_map([post_id, page_size, page_size * page], |row| {
- let row = comment_from_row(row)?;
- Ok(row)
- })?;
- Ok(row.into_iter().flatten().collect())
-}
+ Ok(Comment {
+ comment_id,
+ user_id,
+ post_id,
+ date,
+ content,
+ })
+ }
-#[instrument()]
-pub fn get_all_comments() -> Result<Vec<Comment>, rusqlite::Error> {
- tracing::trace!("Retrieving comments page");
- let conn = database::connect()?;
- let mut stmt = conn.prepare("SELECT * FROM comments ORDER BY comment_id DESC")?;
- let row = stmt.query_map([], |row| {
- let row = comment_from_row(row)?;
- Ok(row)
- })?;
- Ok(row.into_iter().flatten().collect())
-}
+ #[instrument(skip(self))]
+ pub fn get_comments_page(
+ &self,
+ page: u64,
+ post_id: u64,
+ ) -> Result<Vec<Comment>, rusqlite::Error> {
+ tracing::trace!("Retrieving comments page");
+ let page_size = 5;
+ let mut stmt = self.0.prepare(
+ "SELECT * FROM comments WHERE post_id = ? ORDER BY comment_id ASC LIMIT ? OFFSET ?",
+ )?;
+ let row = stmt.query_map([post_id, page_size, page_size * page], |row| {
+ let row = Self::comment_from_row(row)?;
+ Ok(row)
+ })?;
+ Ok(row.into_iter().flatten().collect())
+ }
+
+ #[instrument(skip(self))]
+ pub fn get_all_comments(&self) -> Result<Vec<Comment>, rusqlite::Error> {
+ tracing::trace!("Retrieving comments page");
+ let mut stmt = self
+ .0
+ .prepare("SELECT * FROM comments ORDER BY comment_id DESC")?;
+ let row = stmt.query_map([], |row| {
+ let row = Self::comment_from_row(row)?;
+ Ok(row)
+ })?;
+ Ok(row.into_iter().flatten().collect())
+ }
-#[instrument()]
-pub fn add_comment(user_id: u64, post_id: u64, content: &str) -> Result<Comment, rusqlite::Error> {
- tracing::trace!("Adding comment");
- let date = u64::try_from(
- SystemTime::now()
- .duration_since(UNIX_EPOCH)
- .unwrap_or(Duration::ZERO)
- .as_millis(),
- )
- .unwrap_or(0);
- let conn = database::connect()?;
- let mut stmt = conn.prepare(
- "INSERT INTO comments (user_id, post_id, date, content) VALUES(?,?,?,?) RETURNING *;",
- )?;
- let post = stmt.query_row((user_id, post_id, date, content), |row| {
- let row = comment_from_row(row)?;
- Ok(row)
- })?;
- Ok(post)
+ #[instrument(skip(self))]
+ pub fn add_comment(
+ &self,
+ user_id: u64,
+ post_id: u64,
+ content: &str,
+ ) -> Result<Comment, rusqlite::Error> {
+ tracing::trace!("Adding comment");
+ let date = u64::try_from(
+ SystemTime::now()
+ .duration_since(UNIX_EPOCH)
+ .unwrap_or(Duration::ZERO)
+ .as_millis(),
+ )
+ .unwrap_or(0);
+ let mut stmt = self.0.prepare(
+ "INSERT INTO comments (user_id, post_id, date, content) VALUES(?,?,?,?) RETURNING *;",
+ )?;
+ let post = stmt.query_row((user_id, post_id, date, content), |row| {
+ let row = Self::comment_from_row(row)?;
+ Ok(row)
+ })?;
+ Ok(post)
+ }
}
diff --git a/src/database/friends.rs b/src/database/friends.rs
index 0b78488..31434d4 100644
--- a/src/database/friends.rs
+++ b/src/database/friends.rs
@@ -1,97 +1,100 @@
use tracing::instrument;
-use crate::{
- database::{self, users::user_from_row},
- types::user::{User, FOLLOWED, FOLLOWING, NO_RELATION},
-};
+use crate::types::user::{User, FOLLOWED, FOLLOWING, NO_RELATION};
-pub fn init() -> Result<(), rusqlite::Error> {
- let sql = "
- CREATE TABLE IF NOT EXISTS friends (
- follower_id INTEGER NOT NULL,
- followee_id INTEGER NOT NULL,
- FOREIGN KEY(follower_id) REFERENCES users(user_id),
- FOREIGN KEY(followee_id) REFERENCES users(user_id),
- PRIMARY KEY (follower_id, followee_id)
- );
- ";
- let conn = database::connect()?;
- conn.execute(sql, ())?;
- Ok(())
-}
+use super::Database;
-#[instrument()]
-pub fn get_friend_status(user_id_1: u64, user_id_2: u64) -> Result<u8, rusqlite::Error> {
- tracing::trace!("Retrieving friend status");
- let conn = database::connect()?;
- let mut stmt = conn.prepare("SELECT * FROM friends WHERE (follower_id = ? AND followee_id = ?) OR (follower_id = ? AND followee_id = ?);")?;
- let mut status = NO_RELATION;
- let rows: Vec<u64> = stmt
- .query_map([user_id_1, user_id_2, user_id_2, user_id_1], |row| {
- let id: u64 = row.get(0)?;
- Ok(id)
- })?
- .into_iter()
- .flatten()
- .collect();
+impl Database {
+ pub fn init_friends(&self) -> Result<(), rusqlite::Error> {
+ let sql = "
+ CREATE TABLE IF NOT EXISTS friends (
+ follower_id INTEGER NOT NULL,
+ followee_id INTEGER NOT NULL,
+ FOREIGN KEY(follower_id) REFERENCES users(user_id),
+ FOREIGN KEY(followee_id) REFERENCES users(user_id),
+ PRIMARY KEY (follower_id, followee_id)
+ );
+ ";
+ self.0.execute(sql, ())?;
+ Ok(())
+ }
- for follower in rows {
- if follower == user_id_1 {
- status |= FOLLOWING;
- }
+ #[instrument(skip(self))]
+ pub fn get_friend_status(&self, user_id_1: u64, user_id_2: u64) -> Result<u8, rusqlite::Error> {
+ tracing::trace!("Retrieving friend status");
+ let mut stmt = self.0.prepare("SELECT * FROM friends WHERE (follower_id = ? AND followee_id = ?) OR (follower_id = ? AND followee_id = ?);")?;
+ let mut status = NO_RELATION;
+ let rows: Vec<u64> = stmt
+ .query_map([user_id_1, user_id_2, user_id_2, user_id_1], |row| {
+ let id: u64 = row.get(0)?;
+ Ok(id)
+ })?
+ .into_iter()
+ .flatten()
+ .collect();
- if follower == user_id_2 {
- status |= FOLLOWED;
+ for follower in rows {
+ if follower == user_id_1 {
+ status |= FOLLOWING;
+ }
+
+ if follower == user_id_2 {
+ status |= FOLLOWED;
+ }
}
- }
- Ok(status)
-}
+ Ok(status)
+ }
-#[instrument()]
-pub fn get_friends(user_id: u64) -> Result<Vec<User>, rusqlite::Error> {
- tracing::trace!("Retrieving friends");
- let conn = database::connect()?;
- let mut stmt = conn.prepare(
- "
- SELECT *
- FROM users u
- WHERE EXISTS (
- SELECT NULL
- FROM friends f
- WHERE u.user_id = f.follower_id
- AND f.followee_id = ?
- )
- AND EXISTS (
- SELECT NULL
- FROM friends f
- WHERE u.user_id = f.followee_id
- AND f.follower_id = ?
- )
- ",
- )?;
- let row = stmt.query_map([user_id, user_id], |row| {
- let row = user_from_row(row, true)?;
- Ok(row)
- })?;
- Ok(row.into_iter().flatten().collect())
-}
+ #[instrument(skip(self))]
+ pub fn get_friends(&self, user_id: u64) -> Result<Vec<User>, rusqlite::Error> {
+ tracing::trace!("Retrieving friends");
+ let mut stmt = self.0.prepare(
+ "
+ SELECT *
+ FROM users u
+ WHERE EXISTS (
+ SELECT NULL
+ FROM friends f
+ WHERE u.user_id = f.follower_id
+ AND f.followee_id = ?
+ )
+ AND EXISTS (
+ SELECT NULL
+ FROM friends f
+ WHERE u.user_id = f.followee_id
+ AND f.follower_id = ?
+ )
+ ",
+ )?;
+ let row = stmt.query_map([user_id, user_id], |row| {
+ let row = Self::user_from_row(row, true)?;
+ Ok(row)
+ })?;
+ Ok(row.into_iter().flatten().collect())
+ }
-#[instrument()]
-pub fn set_following(user_id_1: u64, user_id_2: u64) -> Result<bool, rusqlite::Error> {
- tracing::trace!("Setting following");
- let conn = database::connect()?;
- let mut stmt =
- conn.prepare("INSERT OR REPLACE INTO friends (follower_id, followee_id) VALUES (?,?)")?;
- let changes = stmt.execute([user_id_1, user_id_2])?;
- Ok(changes == 1)
-}
+ #[instrument(skip(self))]
+ pub fn set_following(&self, user_id_1: u64, user_id_2: u64) -> Result<bool, rusqlite::Error> {
+ tracing::trace!("Setting following");
+ let mut stmt = self
+ .0
+ .prepare("INSERT OR REPLACE INTO friends (follower_id, followee_id) VALUES (?,?)")?;
+ let changes = stmt.execute([user_id_1, user_id_2])?;
+ Ok(changes == 1)
+ }
-#[instrument()]
-pub fn remove_following(user_id_1: u64, user_id_2: u64) -> Result<bool, rusqlite::Error> {
- tracing::trace!("Removing following");
- let conn = database::connect()?;
- let mut stmt = conn.prepare("DELETE FROM friends WHERE follower_id = ? AND followee_id = ?")?;
- let changes = stmt.execute([user_id_1, user_id_2])?;
- Ok(changes == 1)
+ #[instrument(skip(self))]
+ pub fn remove_following(
+ &self,
+ user_id_1: u64,
+ user_id_2: u64,
+ ) -> Result<bool, rusqlite::Error> {
+ tracing::trace!("Removing following");
+ let mut stmt = self
+ .0
+ .prepare("DELETE FROM friends WHERE follower_id = ? AND followee_id = ?")?;
+ let changes = stmt.execute([user_id_1, user_id_2])?;
+ Ok(changes == 1)
+ }
}
diff --git a/src/database/likes.rs b/src/database/likes.rs
index f6a130b..b313c97 100644
--- a/src/database/likes.rs
+++ b/src/database/likes.rs
@@ -1,75 +1,81 @@
use rusqlite::OptionalExtension;
use tracing::instrument;
-use crate::{database, types::like::Like};
+use crate::types::like::Like;
-pub fn init() -> Result<(), rusqlite::Error> {
- let sql = "
- CREATE TABLE IF NOT EXISTS likes (
- user_id INTEGER NOT NULL,
- post_id INTEGER NOT NULL,
- FOREIGN KEY(user_id) REFERENCES users(user_id),
- FOREIGN KEY(post_id) REFERENCES posts(post_id),
- PRIMARY KEY (user_id, post_id)
- );
- ";
- let conn = database::connect()?;
- conn.execute(sql, ())?;
- Ok(())
-}
+use super::Database;
-#[instrument()]
-pub fn get_like_count(post_id: u64) -> Result<Option<u64>, rusqlite::Error> {
- tracing::trace!("Retrieving like count");
- let conn = database::connect()?;
- let mut stmt = conn.prepare("SELECT COUNT(post_id) FROM likes WHERE post_id = ?")?;
- let row = stmt
- .query_row([post_id], |row| {
- let row = row.get(0)?;
- Ok(row)
- })
- .optional()?;
- Ok(row)
-}
+impl Database {
+ pub fn init_likes(&self) -> Result<(), rusqlite::Error> {
+ let sql = "
+ CREATE TABLE IF NOT EXISTS likes (
+ user_id INTEGER NOT NULL,
+ post_id INTEGER NOT NULL,
+ FOREIGN KEY(user_id) REFERENCES users(user_id),
+ FOREIGN KEY(post_id) REFERENCES posts(post_id),
+ PRIMARY KEY (user_id, post_id)
+ );
+ ";
+ self.0.execute(sql, ())?;
+ Ok(())
+ }
-#[instrument()]
-pub fn get_liked(user_id: u64, post_id: u64) -> Result<bool, rusqlite::Error> {
- tracing::trace!("Retrieving if liked");
- let conn = database::connect()?;
- let mut stmt = conn.prepare("SELECT * FROM likes WHERE user_id = ? AND post_id = ?")?;
- let liked = stmt.query_row([user_id, post_id], |_| Ok(())).optional()?;
- Ok(liked.is_some())
-}
+ #[instrument(skip(self))]
+ pub fn get_like_count(&self, post_id: u64) -> Result<Option<u64>, rusqlite::Error> {
+ tracing::trace!("Retrieving like count");
+ let mut stmt = self
+ .0
+ .prepare("SELECT COUNT(post_id) FROM likes WHERE post_id = ?")?;
+ let row = stmt
+ .query_row([post_id], |row| {
+ let row = row.get(0)?;
+ Ok(row)
+ })
+ .optional()?;
+ Ok(row)
+ }
-#[instrument()]
-pub fn add_liked(user_id: u64, post_id: u64) -> Result<bool, rusqlite::Error> {
- tracing::trace!("Adding like");
- let conn = database::connect()?;
- let mut stmt = conn.prepare("INSERT OR REPLACE INTO likes (user_id, post_id) VALUES (?,?)")?;
- let changes = stmt.execute([user_id, post_id])?;
- Ok(changes == 1)
-}
+ #[instrument(skip(self))]
+ pub fn get_liked(&self, user_id: u64, post_id: u64) -> Result<bool, rusqlite::Error> {
+ tracing::trace!("Retrieving if liked");
+ let mut stmt = self
+ .0
+ .prepare("SELECT * FROM likes WHERE user_id = ? AND post_id = ?")?;
+ let liked = stmt.query_row([user_id, post_id], |_| Ok(())).optional()?;
+ Ok(liked.is_some())
+ }
-#[instrument()]
-pub fn remove_liked(user_id: u64, post_id: u64) -> Result<bool, rusqlite::Error> {
- tracing::trace!("Removing like");
- let conn = database::connect()?;
- let mut stmt = conn.prepare("DELETE FROM likes WHERE user_id = ? AND post_id = ?;")?;
- let changes = stmt.execute((user_id, post_id))?;
- Ok(changes == 1)
-}
+ #[instrument(skip(self))]
+ pub fn add_liked(&self, user_id: u64, post_id: u64) -> Result<bool, rusqlite::Error> {
+ tracing::trace!("Adding like");
+ let mut stmt = self
+ .0
+ .prepare("INSERT OR REPLACE INTO likes (user_id, post_id) VALUES (?,?)")?;
+ let changes = stmt.execute([user_id, post_id])?;
+ Ok(changes == 1)
+ }
+
+ #[instrument(skip(self))]
+ pub fn remove_liked(&self, user_id: u64, post_id: u64) -> Result<bool, rusqlite::Error> {
+ tracing::trace!("Removing like");
+ let mut stmt = self
+ .0
+ .prepare("DELETE FROM likes WHERE user_id = ? AND post_id = ?;")?;
+ let changes = stmt.execute((user_id, post_id))?;
+ Ok(changes == 1)
+ }
-#[instrument()]
-pub fn get_all_likes() -> Result<Vec<Like>, rusqlite::Error> {
- tracing::trace!("Retrieving comments page");
- let conn = database::connect()?;
- let mut stmt = conn.prepare("SELECT * FROM likes")?;
- let row = stmt.query_map([], |row| {
- let like = Like {
- user_id: row.get(0)?,
- post_id: row.get(1)?,
- };
- Ok(like)
- })?;
- Ok(row.into_iter().flatten().collect())
+ #[instrument(skip(self))]
+ pub fn get_all_likes(&self) -> Result<Vec<Like>, rusqlite::Error> {
+ tracing::trace!("Retrieving comments page");
+ let mut stmt = self.0.prepare("SELECT * FROM likes")?;
+ let row = stmt.query_map([], |row| {
+ let like = Like {
+ user_id: row.get(0)?,
+ post_id: row.get(1)?,
+ };
+ Ok(like)
+ })?;
+ Ok(row.into_iter().flatten().collect())
+ }
}
diff --git a/src/database/mod.rs b/src/database/mod.rs
index d22a350..67e05c6 100644
--- a/src/database/mod.rs
+++ b/src/database/mod.rs
@@ -1,3 +1,4 @@
+use rusqlite::Connection;
use tracing::instrument;
pub mod comments;
@@ -7,23 +8,29 @@ pub mod posts;
pub mod sessions;
pub mod users;
-pub fn connect() -> Result<rusqlite::Connection, rusqlite::Error> {
- rusqlite::Connection::open("xssbook.db")
+#[derive(Debug)]
+pub struct Database(Connection);
+
+impl Database {
+ pub fn connect() -> Result<Self, rusqlite::Error> {
+ let conn = rusqlite::Connection::open("xssbook.db")?;
+ Ok(Self(conn))
+ }
+
+ #[instrument(skip(self))]
+ pub fn query(&self, query: String) -> Result<usize, rusqlite::Error> {
+ tracing::trace!("Running custom query");
+ self.0.execute(&query, [])
+ }
}
pub fn init() -> Result<(), rusqlite::Error> {
- users::init()?;
- posts::init()?;
- sessions::init()?;
- likes::init()?;
- comments::init()?;
- friends::init()?;
+ let db = Database::connect()?;
+ db.init_users()?;
+ db.init_posts()?;
+ db.init_sessions()?;
+ db.init_likes()?;
+ db.init_comments()?;
+ db.init_friends()?;
Ok(())
}
-
-#[instrument()]
-pub fn query(query: String) -> Result<usize, rusqlite::Error> {
- tracing::trace!("Running custom query");
- let conn = connect()?;
- conn.execute(&query, [])
-}
diff --git a/src/database/posts.rs b/src/database/posts.rs
index c33e7e7..fa0fd3c 100644
--- a/src/database/posts.rs
+++ b/src/database/posts.rs
@@ -3,115 +3,122 @@ use std::time::{Duration, SystemTime, UNIX_EPOCH};
use rusqlite::{OptionalExtension, Row};
use tracing::instrument;
-use crate::database;
use crate::types::post::Post;
-use super::{comments, likes};
+use super::Database;
-pub fn init() -> Result<(), rusqlite::Error> {
- let sql = "
- CREATE TABLE IF NOT EXISTS posts (
- post_id INTEGER PRIMARY KEY AUTOINCREMENT,
- user_id INTEGER NOT NULL,
- content VARCHAR(500) NOT NULL,
- date INTEGER NOT NULL,
- FOREIGN KEY(user_id) REFERENCES users(user_id)
- );
- ";
- let conn = database::connect()?;
- conn.execute(sql, ())?;
- Ok(())
-}
-
-fn post_from_row(row: &Row) -> Result<Post, rusqlite::Error> {
- let post_id = row.get(0)?;
- let user_id = row.get(1)?;
- let content = row.get(2)?;
- let date = row.get(3)?;
+impl Database {
+ pub fn init_posts(&self) -> Result<(), rusqlite::Error> {
+ let sql = "
+ CREATE TABLE IF NOT EXISTS posts (
+ post_id INTEGER PRIMARY KEY AUTOINCREMENT,
+ user_id INTEGER NOT NULL,
+ content VARCHAR(500) NOT NULL,
+ date INTEGER NOT NULL,
+ FOREIGN KEY(user_id) REFERENCES users(user_id)
+ );
+ ";
+ self.0.execute(sql, ())?;
+ Ok(())
+ }
- let comments = comments::get_comments_page(0, post_id).unwrap_or_else(|_| Vec::new());
- let likes = likes::get_like_count(post_id).unwrap_or(None).unwrap_or(0);
+ fn post_from_row(&self, row: &Row) -> Result<Post, rusqlite::Error> {
+ let post_id = row.get(0)?;
+ let user_id = row.get(1)?;
+ let content = row.get(2)?;
+ let date = row.get(3)?;
- Ok(Post {
- post_id,
- user_id,
- content,
- date,
- likes,
- liked: false,
- comments,
- })
-}
+ let comments = self
+ .get_comments_page(0, post_id)
+ .unwrap_or_else(|_| Vec::new());
+ let likes = self.get_like_count(post_id).unwrap_or(None).unwrap_or(0);
-#[instrument()]
-pub fn get_post(post_id: u64) -> Result<Option<Post>, rusqlite::Error> {
- tracing::trace!("Retrieving post");
- let conn = database::connect()?;
- let mut stmt = conn.prepare("SELECT * FROM posts WHERE post_id = ?")?;
- let row = stmt
- .query_row([post_id], |row| {
- let row = post_from_row(row)?;
- Ok(row)
+ Ok(Post {
+ post_id,
+ user_id,
+ content,
+ date,
+ likes,
+ liked: false,
+ comments,
})
- .optional()?;
- Ok(row)
-}
+ }
-#[instrument()]
-pub fn get_post_page(page: u64) -> Result<Vec<Post>, rusqlite::Error> {
- tracing::trace!("Retrieving posts page");
- let page_size = 10;
- let conn = database::connect()?;
- let mut stmt = conn.prepare("SELECT * FROM posts ORDER BY post_id DESC LIMIT ? OFFSET ?")?;
- let row = stmt.query_map([page_size, page_size * page], |row| {
- let row = post_from_row(row)?;
+ #[instrument(skip(self))]
+ pub fn get_post(&self, post_id: u64) -> Result<Option<Post>, rusqlite::Error> {
+ tracing::trace!("Retrieving post");
+ let mut stmt = self.0.prepare("SELECT * FROM posts WHERE post_id = ?")?;
+ let row = stmt
+ .query_row([post_id], |row| {
+ let row = self.post_from_row(row)?;
+ Ok(row)
+ })
+ .optional()?;
Ok(row)
- })?;
- Ok(row.into_iter().flatten().collect())
-}
+ }
-#[instrument()]
-pub fn get_all_posts() -> Result<Vec<Post>, rusqlite::Error> {
- tracing::trace!("Retrieving posts page");
- let conn = database::connect()?;
- let mut stmt = conn.prepare("SELECT * FROM posts ORDER BY post_id DESC")?;
- let row = stmt.query_map([], |row| {
- let row = post_from_row(row)?;
- Ok(row)
- })?;
- Ok(row.into_iter().flatten().collect())
-}
+ #[instrument(skip(self))]
+ pub fn get_post_page(&self, page: u64) -> Result<Vec<Post>, rusqlite::Error> {
+ tracing::trace!("Retrieving posts page");
+ let page_size = 10;
+ let mut stmt = self
+ .0
+ .prepare("SELECT * FROM posts ORDER BY post_id DESC LIMIT ? OFFSET ?")?;
+ let row = stmt.query_map([page_size, page_size * page], |row| {
+ let row = self.post_from_row(row)?;
+ Ok(row)
+ })?;
+ Ok(row.into_iter().flatten().collect())
+ }
-#[instrument()]
-pub fn get_users_post_page(user_id: u64, page: u64) -> Result<Vec<Post>, rusqlite::Error> {
- tracing::trace!("Retrieving users posts");
- let page_size = 10;
- let conn = database::connect()?;
- let mut stmt = conn
- .prepare("SELECT * FROM posts WHERE user_id = ? ORDER BY post_id DESC LIMIT ? OFFSET ?")?;
- let row = stmt.query_map([user_id, page_size, page_size * page], |row| {
- let row = post_from_row(row)?;
- Ok(row)
- })?;
- Ok(row.into_iter().flatten().collect())
-}
+ #[instrument(skip(self))]
+ pub fn get_all_posts(&self) -> Result<Vec<Post>, rusqlite::Error> {
+ tracing::trace!("Retrieving posts page");
+ let mut stmt = self
+ .0
+ .prepare("SELECT * FROM posts ORDER BY post_id DESC")?;
+ let row = stmt.query_map([], |row| {
+ let row = self.post_from_row(row)?;
+ Ok(row)
+ })?;
+ Ok(row.into_iter().flatten().collect())
+ }
-#[instrument()]
-pub fn add_post(user_id: u64, content: &str) -> Result<Post, rusqlite::Error> {
- tracing::trace!("Adding post");
- let date = u64::try_from(
- SystemTime::now()
- .duration_since(UNIX_EPOCH)
- .unwrap_or(Duration::ZERO)
- .as_millis(),
- )
- .unwrap_or(0);
- let conn = database::connect()?;
- let mut stmt =
- conn.prepare("INSERT INTO posts (user_id, content, date) VALUES(?,?,?) RETURNING *;")?;
- let post = stmt.query_row((user_id, content, date), |row| {
- let row = post_from_row(row)?;
- Ok(row)
- })?;
- Ok(post)
+ #[instrument(skip(self))]
+ pub fn get_users_post_page(
+ &self,
+ user_id: u64,
+ page: u64,
+ ) -> Result<Vec<Post>, rusqlite::Error> {
+ tracing::trace!("Retrieving users posts");
+ let page_size = 10;
+ let mut stmt = self.0.prepare(
+ "SELECT * FROM posts WHERE user_id = ? ORDER BY post_id DESC LIMIT ? OFFSET ?",
+ )?;
+ let row = stmt.query_map([user_id, page_size, page_size * page], |row| {
+ let row = self.post_from_row(row)?;
+ Ok(row)
+ })?;
+ Ok(row.into_iter().flatten().collect())
+ }
+
+ #[instrument(skip(self))]
+ pub fn add_post(&self, user_id: u64, content: &str) -> Result<Post, rusqlite::Error> {
+ tracing::trace!("Adding post");
+ let date = u64::try_from(
+ SystemTime::now()
+ .duration_since(UNIX_EPOCH)
+ .unwrap_or(Duration::ZERO)
+ .as_millis(),
+ )
+ .unwrap_or(0);
+ let mut stmt = self
+ .0
+ .prepare("INSERT INTO posts (user_id, content, date) VALUES(?,?,?) RETURNING *;")?;
+ let post = stmt.query_row((user_id, content, date), |row| {
+ let row = self.post_from_row(row)?;
+ Ok(row)
+ })?;
+ Ok(post)
+ }
}
diff --git a/src/database/sessions.rs b/src/database/sessions.rs
index 9adccd4..a50bb51 100644
--- a/src/database/sessions.rs
+++ b/src/database/sessions.rs
@@ -1,65 +1,64 @@
use rusqlite::OptionalExtension;
use tracing::instrument;
-use crate::{database, types::session::Session};
+use crate::types::session::Session;
-pub fn init() -> Result<(), rusqlite::Error> {
- let sql = "
- CREATE TABLE IF NOT EXISTS sessions (
- user_id INTEGER PRIMARY KEY NOT NULL,
- token TEXT NOT NULL,
- FOREIGN KEY(user_id) REFERENCES users(user_id)
- );
- ";
- let conn = database::connect()?;
- conn.execute(sql, ())?;
- Ok(())
-}
+use super::Database;
+
+impl Database {
+ pub fn init_sessions(&self) -> Result<(), rusqlite::Error> {
+ let sql = "
+ CREATE TABLE IF NOT EXISTS sessions (
+ user_id INTEGER PRIMARY KEY NOT NULL,
+ token TEXT NOT NULL,
+ FOREIGN KEY(user_id) REFERENCES users(user_id)
+ );
+ ";
+ self.0.execute(sql, ())?;
+ Ok(())
+ }
+
+ #[instrument(skip(self))]
+ pub fn get_session(&self, token: &str) -> Result<Option<Session>, rusqlite::Error> {
+ tracing::trace!("Retrieving session");
+ let mut stmt = self.0.prepare("SELECT * FROM sessions WHERE token = ?")?;
+ let row = stmt
+ .query_row([token], |row| {
+ Ok(Session {
+ user_id: row.get(0)?,
+ token: row.get(1)?,
+ })
+ })
+ .optional()?;
+ Ok(row)
+ }
-#[instrument()]
-pub fn get_session(token: &str) -> Result<Option<Session>, rusqlite::Error> {
- tracing::trace!("Retrieving session");
- let conn = database::connect()?;
- let mut stmt = conn.prepare("SELECT * FROM sessions WHERE token = ?")?;
- let row = stmt
- .query_row([token], |row| {
+ #[instrument(skip(self))]
+ pub fn get_all_sessions(&self) -> Result<Vec<Session>, rusqlite::Error> {
+ tracing::trace!("Retrieving session");
+ let mut stmt = self.0.prepare("SELECT * FROM sessions")?;
+ let row = stmt.query_map([], |row| {
Ok(Session {
user_id: row.get(0)?,
token: row.get(1)?,
})
- })
- .optional()?;
- Ok(row)
-}
-
-#[instrument()]
-pub fn get_all_sessions() -> Result<Vec<Session>, rusqlite::Error> {
- tracing::trace!("Retrieving session");
- let conn = database::connect()?;
- let mut stmt = conn.prepare("SELECT * FROM sessions")?;
- let row = stmt.query_map([], |row| {
- Ok(Session {
- user_id: row.get(0)?,
- token: row.get(1)?,
- })
- })?;
- Ok(row.into_iter().flatten().collect())
-}
+ })?;
+ Ok(row.into_iter().flatten().collect())
+ }
-#[instrument()]
-pub fn set_session(user_id: u64, token: &str) -> Result<(), Box<dyn std::error::Error>> {
- tracing::trace!("Setting new session");
- let conn = database::connect()?;
- let sql = "INSERT OR REPLACE INTO sessions (user_id, token) VALUES (?, ?);";
- conn.execute(sql, (user_id, token))?;
- Ok(())
-}
+ #[instrument(skip(self))]
+ pub fn set_session(&self, user_id: u64, token: &str) -> Result<(), Box<dyn std::error::Error>> {
+ tracing::trace!("Setting new session");
+ let sql = "INSERT OR REPLACE INTO sessions (user_id, token) VALUES (?, ?);";
+ self.0.execute(sql, (user_id, token))?;
+ Ok(())
+ }
-#[instrument()]
-pub fn delete_session(user_id: u64) -> Result<(), Box<dyn std::error::Error>> {
- tracing::trace!("Deleting session");
- let conn = database::connect()?;
- let sql = "DELETE FROM sessions WHERE user_id = ?;";
- conn.execute(sql, [user_id])?;
- Ok(())
+ #[instrument(skip(self))]
+ pub fn delete_session(&self, user_id: u64) -> Result<(), Box<dyn std::error::Error>> {
+ tracing::trace!("Deleting session");
+ let sql = "DELETE FROM sessions WHERE user_id = ?;";
+ self.0.execute(sql, [user_id])?;
+ Ok(())
+ }
}
diff --git a/src/database/users.rs b/src/database/users.rs
index 6062ea8..9df69ee 100644
--- a/src/database/users.rs
+++ b/src/database/users.rs
@@ -2,169 +2,180 @@ use rusqlite::{OptionalExtension, Row};
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use tracing::instrument;
-use crate::{api::RegistrationRequet, database, types::user::User};
+use crate::{api::RegistrationRequet, types::user::User};
-pub fn init() -> Result<(), rusqlite::Error> {
- let sql = "
- CREATE TABLE IF NOT EXISTS users (
- user_id INTEGER PRIMARY KEY AUTOINCREMENT,
- firstname VARCHAR(20) NOT NULL,
- lastname VARCHAR(20) NOT NULL,
- email VARCHAR(50) NOT NULL,
- password VARCHAR(50) NOT NULL,
- gender VARCHAR(100) NOT NULL,
- date BIGINT NOT NULL,
- day TINYINT NOT NULL,
- month TINYINT NOT NULL,
- year INTEGER NOT NULL
- );
- ";
- let conn = database::connect()?;
- conn.execute(sql, ())?;
+use super::Database;
- let sql2 = "CREATE UNIQUE INDEX IF NOT EXISTS emails on users (email);";
- conn.execute(sql2, ())?;
+impl Database {
+ pub fn init_users(&self) -> Result<(), rusqlite::Error> {
+ let sql = "
+ CREATE TABLE IF NOT EXISTS users (
+ user_id INTEGER PRIMARY KEY AUTOINCREMENT,
+ firstname VARCHAR(20) NOT NULL,
+ lastname VARCHAR(20) NOT NULL,
+ email VARCHAR(50) NOT NULL,
+ password VARCHAR(50) NOT NULL,
+ gender VARCHAR(100) NOT NULL,
+ date BIGINT NOT NULL,
+ day TINYINT NOT NULL,
+ month TINYINT NOT NULL,
+ year INTEGER NOT NULL
+ );
+ ";
+ self.0.execute(sql, ())?;
- let sql3 = "CREATE UNIQUE INDEX IF NOT EXISTS passwords on users (password);";
- conn.execute(sql3, ())?;
+ let sql2 = "CREATE UNIQUE INDEX IF NOT EXISTS emails on users (email);";
+ self.0.execute(sql2, ())?;
- Ok(())
-}
-
-pub fn user_from_row(row: &Row, hide_password: bool) -> Result<User, rusqlite::Error> {
- let user_id = row.get(0)?;
- let firstname = row.get(1)?;
- let lastname = row.get(2)?;
- let email = row.get(3)?;
- let password = row.get(4)?;
- let gender = row.get(5)?;
- let date = row.get(6)?;
- let day = row.get(7)?;
- let month = row.get(8)?;
- let year = row.get(9)?;
+ let sql3 = "CREATE UNIQUE INDEX IF NOT EXISTS passwords on users (password);";
+ self.0.execute(sql3, ())?;
- let password = if hide_password {
- String::new()
- } else {
- password
- };
+ Ok(())
+ }
- Ok(User {
- user_id,
- firstname,
- lastname,
- email,
- password,
- gender,
- date,
- day,
- month,
- year,
- })
-}
+ pub fn user_from_row(row: &Row, hide_password: bool) -> Result<User, rusqlite::Error> {
+ let user_id = row.get(0)?;
+ let firstname = row.get(1)?;
+ let lastname = row.get(2)?;
+ let email = row.get(3)?;
+ let password = row.get(4)?;
+ let gender = row.get(5)?;
+ let date = row.get(6)?;
+ let day = row.get(7)?;
+ let month = row.get(8)?;
+ let year = row.get(9)?;
-#[instrument()]
-pub fn get_user_by_id(user_id: u64, hide_password: bool) -> Result<Option<User>, rusqlite::Error> {
- tracing::trace!("Retrieving user by id");
- let conn = database::connect()?;
- let mut stmt = conn.prepare("SELECT * FROM users WHERE user_id = ?")?;
- let row = stmt
- .query_row([user_id], |row| {
- let row = user_from_row(row, hide_password)?;
- Ok(row)
- })
- .optional()?;
- Ok(row)
-}
+ let password = if hide_password {
+ String::new()
+ } else {
+ password
+ };
-#[instrument()]
-pub fn get_user_by_email(
- email: &str,
- hide_password: bool,
-) -> Result<Option<User>, rusqlite::Error> {
- tracing::trace!("Retrieving user by email");
- let conn = database::connect()?;
- let mut stmt = conn.prepare("SELECT * FROM users WHERE email = ?")?;
- let row = stmt
- .query_row([email], |row| {
- let row = user_from_row(row, hide_password)?;
- Ok(row)
+ Ok(User {
+ user_id,
+ firstname,
+ lastname,
+ email,
+ password,
+ gender,
+ date,
+ day,
+ month,
+ year,
})
- .optional()?;
- Ok(row)
-}
+ }
-#[instrument()]
-pub fn get_user_by_password(
- password: &str,
- hide_password: bool,
-) -> Result<Option<User>, rusqlite::Error> {
- tracing::trace!("Retrieving user by password");
- let conn = database::connect()?;
- let mut stmt = conn.prepare("SELECT * FROM users WHERE password = ?")?;
- let row = stmt
- .query_row([password], |row| {
- let row = user_from_row(row, hide_password)?;
- Ok(row)
- })
- .optional()?;
- Ok(row)
-}
+ #[instrument(skip(self))]
+ pub fn get_user_by_id(
+ &self,
+ user_id: u64,
+ hide_password: bool,
+ ) -> Result<Option<User>, rusqlite::Error> {
+ tracing::trace!("Retrieving user by id");
+ let mut stmt = self.0.prepare("SELECT * FROM users WHERE user_id = ?")?;
+ let row = stmt
+ .query_row([user_id], |row| {
+ let row = Self::user_from_row(row, hide_password)?;
+ Ok(row)
+ })
+ .optional()?;
+ Ok(row)
+ }
-#[instrument()]
-pub fn get_user_page(page: u64, hide_password: bool) -> Result<Vec<User>, rusqlite::Error> {
- tracing::trace!("Retrieving user page");
- let page_size = 5;
- let conn = database::connect()?;
- let mut stmt = conn.prepare("SELECT * FROM users ORDER BY user_id DESC LIMIT ? OFFSET ?")?;
- let row = stmt.query_map([page_size, page_size * page], |row| {
- let row = user_from_row(row, hide_password)?;
+ #[instrument(skip(self))]
+ pub fn get_user_by_email(
+ &self,
+ email: &str,
+ hide_password: bool,
+ ) -> Result<Option<User>, rusqlite::Error> {
+ tracing::trace!("Retrieving user by email");
+ let mut stmt = self.0.prepare("SELECT * FROM users WHERE email = ?")?;
+ let row = stmt
+ .query_row([email], |row| {
+ let row = Self::user_from_row(row, hide_password)?;
+ Ok(row)
+ })
+ .optional()?;
Ok(row)
- })?;
- Ok(row.into_iter().flatten().collect())
-}
+ }
-#[instrument()]
-pub fn get_all_users() -> Result<Vec<User>, rusqlite::Error> {
- tracing::trace!("Retrieving user page");
- let conn = database::connect()?;
- let mut stmt = conn.prepare("SELECT * FROM users ORDER BY user_id DESC")?;
- let row = stmt.query_map([], |row| {
- let row = user_from_row(row, false)?;
+ #[instrument(skip(self))]
+ pub fn get_user_by_password(
+ &self,
+ password: &str,
+ hide_password: bool,
+ ) -> Result<Option<User>, rusqlite::Error> {
+ tracing::trace!("Retrieving user by password");
+ let mut stmt = self.0.prepare("SELECT * FROM users WHERE password = ?")?;
+ let row = stmt
+ .query_row([password], |row| {
+ let row = Self::user_from_row(row, hide_password)?;
+ Ok(row)
+ })
+ .optional()?;
Ok(row)
- })?;
- Ok(row.into_iter().flatten().collect())
-}
+ }
-#[instrument()]
-pub fn add_user(request: RegistrationRequet) -> Result<User, rusqlite::Error> {
- tracing::trace!("Adding new user");
- let date = u64::try_from(
- SystemTime::now()
- .duration_since(UNIX_EPOCH)
- .unwrap_or(Duration::ZERO)
- .as_millis(),
- )
- .unwrap_or(0);
+ #[instrument(skip(self))]
+ pub fn get_user_page(
+ &self,
+ page: u64,
+ hide_password: bool,
+ ) -> Result<Vec<User>, rusqlite::Error> {
+ tracing::trace!("Retrieving user page");
+ let page_size = 5;
+ let mut stmt = self
+ .0
+ .prepare("SELECT * FROM users ORDER BY user_id DESC LIMIT ? OFFSET ?")?;
+ let row = stmt.query_map([page_size, page_size * page], |row| {
+ let row = Self::user_from_row(row, hide_password)?;
+ Ok(row)
+ })?;
+ Ok(row.into_iter().flatten().collect())
+ }
- let conn = database::connect()?;
- let mut stmt = conn.prepare("INSERT INTO users (firstname, lastname, email, password, gender, date, day, month, year) VALUES(?,?,?,?,?,?,?,?,?) RETURNING *;")?;
- let user = stmt.query_row(
- (
- request.firstname,
- request.lastname,
- request.email,
- request.password,
- request.gender,
- date,
- request.day,
- request.month,
- request.year,
- ),
- |row| {
- let row = user_from_row(row, false)?;
+ #[instrument(skip(self))]
+ pub fn get_all_users(&self) -> Result<Vec<User>, rusqlite::Error> {
+ tracing::trace!("Retrieving user page");
+ let mut stmt = self
+ .0
+ .prepare("SELECT * FROM users ORDER BY user_id DESC")?;
+ let row = stmt.query_map([], |row| {
+ let row = Self::user_from_row(row, false)?;
Ok(row)
- },
- )?;
- Ok(user)
+ })?;
+ Ok(row.into_iter().flatten().collect())
+ }
+
+ #[instrument(skip(self))]
+ pub fn add_user(&self, request: RegistrationRequet) -> Result<User, rusqlite::Error> {
+ tracing::trace!("Adding new user");
+ let date = u64::try_from(
+ SystemTime::now()
+ .duration_since(UNIX_EPOCH)
+ .unwrap_or(Duration::ZERO)
+ .as_millis(),
+ )
+ .unwrap_or(0);
+
+ let mut stmt = self.0.prepare("INSERT INTO users (firstname, lastname, email, password, gender, date, day, month, year) VALUES(?,?,?,?,?,?,?,?,?) RETURNING *;")?;
+ let user = stmt.query_row(
+ (
+ request.firstname,
+ request.lastname,
+ request.email,
+ request.password,
+ request.gender,
+ date,
+ request.day,
+ request.month,
+ request.year,
+ ),
+ |row| {
+ let row = Self::user_from_row(row, false)?;
+ Ok(row)
+ },
+ )?;
+ Ok(user)
+ }
}
diff --git a/src/public/admin.rs b/src/public/admin.rs
index bf0a155..6e32152 100644
--- a/src/public/admin.rs
+++ b/src/public/admin.rs
@@ -5,6 +5,7 @@ use tokio::sync::Mutex;
use crate::{
console::sanatize,
+ database::Database,
types::{
comment::Comment, http::ResponseCode, like::Like, post::Post, session::Session, user::User,
},
@@ -36,8 +37,8 @@ pub async fn regen_secret() -> String {
secret.clone()
}
-pub fn generate_users() -> Response {
- let users = match User::reterieve_all() {
+pub fn generate_users(db: &Database) -> Response {
+ let users = match User::reterieve_all(db) {
Ok(users) => users,
Err(err) => return err,
};
@@ -70,8 +71,8 @@ pub fn generate_users() -> Response {
ResponseCode::Success.text(&html)
}
-pub fn generate_posts() -> Response {
- let posts = match Post::reterieve_all() {
+pub fn generate_posts(db: &Database) -> Response {
+ let posts = match Post::reterieve_all(db) {
Ok(posts) => posts,
Err(err) => return err,
};
@@ -99,8 +100,8 @@ pub fn generate_posts() -> Response {
ResponseCode::Success.text(&html)
}
-pub fn generate_sessions() -> Response {
- let sessions = match Session::reterieve_all() {
+pub fn generate_sessions(db: &Database) -> Response {
+ let sessions = match Session::reterieve_all(db) {
Ok(sessions) => sessions,
Err(err) => return err,
};
@@ -123,8 +124,8 @@ pub fn generate_sessions() -> Response {
ResponseCode::Success.text(&html)
}
-pub fn generate_comments() -> Response {
- let comments = match Comment::reterieve_all() {
+pub fn generate_comments(db: &Database) -> Response {
+ let comments = match Comment::reterieve_all(db) {
Ok(comments) => comments,
Err(err) => return err,
};
@@ -154,8 +155,8 @@ pub fn generate_comments() -> Response {
ResponseCode::Success.text(&html)
}
-pub fn generate_likes() -> Response {
- let likes = match Like::reterieve_all() {
+pub fn generate_likes(db: &Database) -> Response {
+ let likes = match Like::reterieve_all(db) {
Ok(likes) => likes,
Err(err) => return err,
};
diff --git a/src/types/comment.rs b/src/types/comment.rs
index cf94bd3..0836950 100644
--- a/src/types/comment.rs
+++ b/src/types/comment.rs
@@ -2,7 +2,7 @@ use serde::Serialize;
use tracing::instrument;
use crate::{
- database::{self, comments},
+ database::Database,
types::http::{ResponseCode, Result},
};
@@ -16,9 +16,9 @@ pub struct Comment {
}
impl Comment {
- #[instrument()]
- pub fn new(user_id: u64, post_id: u64, content: &str) -> Result<Self> {
- let Ok(comment) = comments::add_comment(user_id, post_id, content) else {
+ #[instrument(skip(db))]
+ pub fn new(db: &Database, user_id: u64, post_id: u64, content: &str) -> Result<Self> {
+ let Ok(comment) = db.add_comment(user_id, post_id, content) else {
tracing::error!("Failed to create comment");
return Err(ResponseCode::InternalServerError.text("Failed to create post"))
};
@@ -26,17 +26,17 @@ impl Comment {
Ok(comment)
}
- #[instrument()]
- pub fn from_comment_page(page: u64, post_id: u64) -> Result<Vec<Self>> {
- let Ok(posts) = database::comments::get_comments_page(page, post_id) else {
+ #[instrument(skip(db))]
+ pub fn from_comment_page(db: &Database, page: u64, post_id: u64) -> Result<Vec<Self>> {
+ let Ok(posts) = db.get_comments_page(page, post_id) else {
return Err(ResponseCode::BadRequest.text("Failed to fetch comments"))
};
Ok(posts)
}
- #[instrument()]
- pub fn reterieve_all() -> Result<Vec<Self>> {
- let Ok(posts) = database::comments::get_all_comments() else {
+ #[instrument(skip(db))]
+ pub fn reterieve_all(db: &Database) -> Result<Vec<Self>> {
+ let Ok(posts) = db.get_all_comments() else {
return Err(ResponseCode::InternalServerError.text("Failed to fetch comments"))
};
Ok(posts)
diff --git a/src/types/extract.rs b/src/types/extract.rs
index 65d9f1a..f05215f 100644
--- a/src/types/extract.rs
+++ b/src/types/extract.rs
@@ -14,9 +14,11 @@ use axum::{
use bytes::Bytes;
use image::{io::Reader, DynamicImage, ImageFormat};
use serde::de::DeserializeOwned;
+use tokio::sync::Mutex;
use tower_cookies::Cookies;
use crate::{
+ database,
public::admin,
public::console,
types::{
@@ -97,11 +99,17 @@ where
return Err(ResponseCode::Forbidden.text("No auth token provided"))
};
- let Ok(session) = Session::from_token(token.value()) else {
+ let Some(db) = parts.extensions.get::<DatabaseExtention>() else {
+ return Err(ResponseCode::Forbidden.text("Could not connect to database"))
+ };
+
+ let db = db.0.lock().await;
+
+ let Ok(session) = Session::from_token(&db, token.value()) else {
return Err(ResponseCode::Unauthorized.text("Auth token invalid"))
};
- let Ok(user) = User::from_user_id(session.user_id, true) else {
+ let Ok(user) = User::from_user_id(&db, session.user_id, true) else {
tracing::error!("Valid token but no valid user");
return Err(ResponseCode::InternalServerError.text("Valid token but no valid user"))
};
@@ -260,6 +268,26 @@ where
}
}
+pub struct DatabaseExtention(pub Mutex<database::Database>);
+pub struct Database(pub database::Database);
+
+#[async_trait]
+impl<S> FromRequestParts<S> for Database
+where
+ S: Send + Sync,
+{
+ type Rejection = Response;
+
+ async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self> {
+ let db = parts.extensions.remove::<DatabaseExtention>();
+ let Some(db) = db else {
+ return Err(ResponseCode::InternalServerError.text("Database is not loaded"))
+ };
+
+ Ok(Self(db.0.into_inner()))
+ }
+}
+
async fn read_body<S, B>(mut req: Request<B>, state: &S) -> Result<Vec<u8>>
where
B: HttpBody + Sync + Send + 'static,
diff --git a/src/types/like.rs b/src/types/like.rs
index 1c113c1..8eec941 100644
--- a/src/types/like.rs
+++ b/src/types/like.rs
@@ -1,7 +1,7 @@
use serde::Serialize;
use tracing::instrument;
-use crate::database;
+use crate::database::Database;
use crate::types::http::{ResponseCode, Result};
#[derive(Serialize)]
@@ -11,9 +11,9 @@ pub struct Like {
}
impl Like {
- #[instrument()]
- pub fn add_liked(user_id: u64, post_id: u64) -> Result<()> {
- let Ok(liked) = database::likes::add_liked(user_id, post_id) else {
+ #[instrument(skip(db))]
+ pub fn add_liked(db: &Database, user_id: u64, post_id: u64) -> Result<()> {
+ let Ok(liked) = db.add_liked(user_id, post_id) else {
return Err(ResponseCode::BadRequest.text("Failed to add like status"))
};
@@ -24,9 +24,9 @@ impl Like {
Ok(())
}
- #[instrument()]
- pub fn remove_liked(user_id: u64, post_id: u64) -> Result<()> {
- let Ok(liked) = database::likes::remove_liked(user_id, post_id) else {
+ #[instrument(skip(db))]
+ pub fn remove_liked(db: &Database, user_id: u64, post_id: u64) -> Result<()> {
+ let Ok(liked) = db.remove_liked(user_id, post_id) else {
return Err(ResponseCode::BadRequest.text("Failed to remove like status"))
};
@@ -37,9 +37,9 @@ impl Like {
Ok(())
}
- #[instrument()]
- pub fn reterieve_all() -> Result<Vec<Self>> {
- let Ok(likes) = database::likes::get_all_likes() else {
+ #[instrument(skip(db))]
+ pub fn reterieve_all(db: &Database) -> Result<Vec<Self>> {
+ let Ok(likes) = db.get_all_likes() else {
return Err(ResponseCode::InternalServerError.text("Failed to fetch likes"))
};
Ok(likes)
diff --git a/src/types/post.rs b/src/types/post.rs
index 09f2f50..bff68c7 100644
--- a/src/types/post.rs
+++ b/src/types/post.rs
@@ -2,7 +2,7 @@ use core::fmt;
use serde::Serialize;
use tracing::instrument;
-use crate::database;
+use crate::database::Database;
use crate::types::http::{ResponseCode, Result};
use super::comment::Comment;
@@ -27,57 +27,62 @@ impl fmt::Debug for Post {
}
impl Post {
- #[instrument()]
- pub fn from_post_id(self_id: u64, post_id: u64) -> Result<Self> {
- let Ok(Some(mut post)) = database::posts::get_post(post_id) else {
+ #[instrument(skip(db))]
+ pub fn from_post_id(db: &Database, self_id: u64, post_id: u64) -> Result<Self> {
+ let Ok(Some(mut post)) = db.get_post(post_id) else {
return Err(ResponseCode::BadRequest.text("Post does not exist"))
};
- let liked = database::likes::get_liked(self_id, post.post_id).unwrap_or(false);
+ let liked = db.get_liked(self_id, post.post_id).unwrap_or(false);
post.liked = liked;
Ok(post)
}
- #[instrument()]
- pub fn from_post_page(self_id: u64, page: u64) -> Result<Vec<Self>> {
- let Ok(mut posts) = database::posts::get_post_page(page) else {
+ #[instrument(skip(db))]
+ pub fn from_post_page(db: &Database, self_id: u64, page: u64) -> Result<Vec<Self>> {
+ let Ok(mut posts) = db.get_post_page(page) else {
return Err(ResponseCode::BadRequest.text("Failed to fetch posts"))
};
for post in &mut posts {
- let liked = database::likes::get_liked(self_id, post.post_id).unwrap_or(false);
+ let liked = db.get_liked(self_id, post.post_id).unwrap_or(false);
post.liked = liked;
}
Ok(posts)
}
- #[instrument()]
- pub fn from_user_post_page(self_id: u64, user_id: u64, page: u64) -> Result<Vec<Self>> {
- let Ok(mut posts) = database::posts::get_users_post_page(user_id, page) else {
+ #[instrument(skip(db))]
+ pub fn from_user_post_page(
+ db: &Database,
+ self_id: u64,
+ user_id: u64,
+ page: u64,
+ ) -> Result<Vec<Self>> {
+ let Ok(mut posts) = db.get_users_post_page(user_id, page) else {
return Err(ResponseCode::BadRequest.text("Failed to fetch posts"))
};
for post in &mut posts {
- let liked = database::likes::get_liked(self_id, post.post_id).unwrap_or(false);
+ let liked = db.get_liked(self_id, post.post_id).unwrap_or(false);
post.liked = liked;
}
Ok(posts)
}
- #[instrument()]
- pub fn reterieve_all() -> Result<Vec<Self>> {
- let Ok(posts) = database::posts::get_all_posts() else {
+ #[instrument(skip(db))]
+ pub fn reterieve_all(db: &Database) -> Result<Vec<Self>> {
+ let Ok(posts) = db.get_all_posts() else {
return Err(ResponseCode::InternalServerError.text("Failed to fetch posts"))
};
Ok(posts)
}
- #[instrument()]
- pub fn new(user_id: u64, content: String) -> Result<Self> {
- let Ok(post) = database::posts::add_post(user_id, &content) else {
+ #[instrument(skip(db))]
+ pub fn new(db: &Database, user_id: u64, content: String) -> Result<Self> {
+ let Ok(post) = db.add_post(user_id, &content) else {
tracing::error!("Failed to create post");
return Err(ResponseCode::InternalServerError.text("Failed to create post"))
};
diff --git a/src/types/session.rs b/src/types/session.rs
index a9073aa..27c5c66 100644
--- a/src/types/session.rs
+++ b/src/types/session.rs
@@ -2,7 +2,7 @@ use rand::{distributions::Alphanumeric, Rng};
use serde::Serialize;
use tracing::instrument;
-use crate::database;
+use crate::database::Database;
use crate::types::http::{ResponseCode, Result};
#[derive(Serialize)]
@@ -12,39 +12,39 @@ pub struct Session {
}
impl Session {
- #[instrument()]
- pub fn from_token(token: &str) -> Result<Self> {
- let Ok(Some(session)) = database::sessions::get_session(token) else {
+ #[instrument(skip(db))]
+ pub fn from_token(db: &Database, token: &str) -> Result<Self> {
+ let Ok(Some(session)) = db.get_session(token) else {
return Err(ResponseCode::BadRequest.text("Invalid auth token"));
};
Ok(session)
}
- #[instrument()]
- pub fn reterieve_all() -> Result<Vec<Self>> {
- let Ok(sessions) = database::sessions::get_all_sessions() else {
+ #[instrument(skip(db))]
+ pub fn reterieve_all(db: &Database) -> Result<Vec<Self>> {
+ let Ok(sessions) = db.get_all_sessions() else {
return Err(ResponseCode::InternalServerError.text("Failed to fetch sessions"))
};
Ok(sessions)
}
- #[instrument()]
- pub fn new(user_id: u64) -> Result<Self> {
+ #[instrument(skip(db))]
+ pub fn new(db: &Database, user_id: u64) -> Result<Self> {
let token: String = rand::thread_rng()
.sample_iter(&Alphanumeric)
.take(32)
.map(char::from)
.collect();
- match database::sessions::set_session(user_id, &token) {
+ match db.set_session(user_id, &token) {
Err(_) => Err(ResponseCode::BadRequest.text("Failed to create session")),
Ok(_) => Ok(Self { user_id, token }),
}
}
- #[instrument()]
- pub fn delete(user_id: u64) -> Result<()> {
- if database::sessions::delete_session(user_id).is_err() {
+ #[instrument(skip(db))]
+ pub fn delete(db: &Database, user_id: u64) -> Result<()> {
+ if db.delete_session(user_id).is_err() {
tracing::error!("Failed to logout user");
return Err(ResponseCode::InternalServerError.text("Failed to logout"));
};
diff --git a/src/types/user.rs b/src/types/user.rs
index 245e9b7..3c4cd6a 100644
--- a/src/types/user.rs
+++ b/src/types/user.rs
@@ -2,7 +2,7 @@ use serde::{Deserialize, Serialize};
use tracing::instrument;
use crate::api::RegistrationRequet;
-use crate::database;
+use crate::database::Database;
use crate::types::http::{ResponseCode, Result};
#[derive(Serialize, Deserialize, Debug)]
@@ -24,21 +24,21 @@ pub const FOLLOWING: u8 = 1;
pub const FOLLOWED: u8 = 2;
impl User {
- #[instrument()]
- pub fn from_user_id(user_id: u64, hide_password: bool) -> Result<Self> {
- let Ok(Some(user)) = database::users::get_user_by_id(user_id, hide_password) else {
+ #[instrument(skip(db))]
+ pub fn from_user_id(db: &Database, user_id: u64, hide_password: bool) -> Result<Self> {
+ let Ok(Some(user)) = db.get_user_by_id(user_id, hide_password) else {
return Err(ResponseCode::BadRequest.text("User does not exist"))
};
Ok(user)
}
- #[instrument()]
- pub fn from_user_ids(user_ids: Vec<u64>) -> Vec<Self> {
+ #[instrument(skip(db))]
+ pub fn from_user_ids(db: &Database, user_ids: Vec<u64>) -> Vec<Self> {
user_ids
.iter()
.filter_map(|user_id| {
- let Ok(Some(user)) = database::users::get_user_by_id(*user_id, true) else {
+ let Ok(Some(user)) = db.get_user_by_id(*user_id, true) else {
return None;
};
Some(user)
@@ -46,53 +46,53 @@ impl User {
.collect()
}
- #[instrument()]
- pub fn from_user_page(page: u64) -> Result<Vec<Self>> {
- let Ok(users) = database::users::get_user_page(page, true) else {
+ #[instrument(skip(db))]
+ pub fn from_user_page(db: &Database, page: u64) -> Result<Vec<Self>> {
+ let Ok(users) = db.get_user_page(page, true) else {
return Err(ResponseCode::BadRequest.text("Failed to fetch users"))
};
Ok(users)
}
- #[instrument()]
- pub fn from_email(email: &str) -> Result<Self> {
- let Ok(Some(user)) = database::users::get_user_by_email(email, false) else {
+ #[instrument(skip(db))]
+ pub fn from_email(db: &Database, email: &str) -> Result<Self> {
+ let Ok(Some(user)) = db.get_user_by_email(email, false) else {
return Err(ResponseCode::BadRequest.text("User does not exist"))
};
Ok(user)
}
- #[instrument()]
- pub fn from_password(password: &str) -> Result<Self> {
- let Ok(Some(user)) = database::users::get_user_by_password(password, true) else {
+ #[instrument(skip(db))]
+ pub fn from_password(db: &Database, password: &str) -> Result<Self> {
+ let Ok(Some(user)) = db.get_user_by_password(password, true) else {
return Err(ResponseCode::BadRequest.text("User does not exist"))
};
Ok(user)
}
- #[instrument()]
- pub fn reterieve_all() -> Result<Vec<Self>> {
- let Ok(users) = database::users::get_all_users() else {
+ #[instrument(skip(db))]
+ pub fn reterieve_all(db: &Database) -> Result<Vec<Self>> {
+ let Ok(users) = db.get_all_users() else {
return Err(ResponseCode::InternalServerError.text("Failed to fetch users"))
};
Ok(users)
}
- #[instrument()]
- pub fn new(request: RegistrationRequet) -> Result<Self> {
- if Self::from_email(&request.email).is_ok() {
+ #[instrument(skip(db))]
+ pub fn new(db: &Database, request: RegistrationRequet) -> Result<Self> {
+ if Self::from_email(db, &request.email).is_ok() {
return Err(ResponseCode::BadRequest
.text(&format!("Email is already in use by {}", &request.email)));
}
- if let Ok(user) = Self::from_password(&request.password) {
+ if let Ok(user) = Self::from_password(db, &request.password) {
return Err(ResponseCode::BadRequest
.text(&format!("Password is already in use by {}", user.email)));
}
- let Ok(user) = database::users::add_user(request) else {
+ let Ok(user) = db.add_user(request) else {
tracing::error!("Failed to create new user");
return Err(ResponseCode::InternalServerError.text("Failed to create new uesr"))
};
@@ -100,8 +100,9 @@ impl User {
Ok(user)
}
- pub fn add_following(user_id_1: u64, user_id_2: u64) -> Result<()> {
- let Ok(followed) = database::friends::set_following(user_id_1, user_id_2) else {
+ #[instrument(skip(db))]
+ pub fn add_following(db: &Database, user_id_1: u64, user_id_2: u64) -> Result<()> {
+ let Ok(followed) = db.set_following(user_id_1, user_id_2) else {
return Err(ResponseCode::BadRequest.text("Failed to add follow status"))
};
@@ -112,8 +113,9 @@ impl User {
Ok(())
}
- pub fn remove_following(user_id_1: u64, user_id_2: u64) -> Result<()> {
- let Ok(followed) = database::friends::remove_following(user_id_1, user_id_2) else {
+ #[instrument(skip(db))]
+ pub fn remove_following(db: &Database, user_id_1: u64, user_id_2: u64) -> Result<()> {
+ let Ok(followed) = db.remove_following(user_id_1, user_id_2) else {
return Err(ResponseCode::BadRequest.text("Failed to remove follow status"))
};
@@ -124,15 +126,17 @@ impl User {
Ok(())
}
- pub fn get_following(user_id_1: u64, user_id_2: u64) -> Result<u8> {
- let Ok(followed) = database::friends::get_friend_status(user_id_1, user_id_2) else {
+ #[instrument(skip(db))]
+ pub fn get_following(db: &Database, user_id_1: u64, user_id_2: u64) -> Result<u8> {
+ let Ok(followed) = db.get_friend_status(user_id_1, user_id_2) else {
return Err(ResponseCode::InternalServerError.text("Failed to get follow status"))
};
Ok(followed)
}
- pub fn get_friends(user_id: u64) -> Result<Vec<Self>> {
- let Ok(users) = database::friends::get_friends(user_id) else {
+ #[instrument(skip(db))]
+ pub fn get_friends(db: &Database, user_id: u64) -> Result<Vec<Self>> {
+ let Ok(users) = db.get_friends(user_id) else {
return Err(ResponseCode::InternalServerError.text("Failed to fetch friends"))
};
Ok(users)