make database calls 1 conn

This commit is contained in:
Tyler Murphy 2023-02-15 00:01:44 -05:00
parent 192be95f84
commit aec4fdecc1
19 changed files with 829 additions and 695 deletions

View file

@ -5,13 +5,12 @@ use serde::Deserialize;
use tower_cookies::{Cookie, Cookies}; use tower_cookies::{Cookie, Cookies};
use crate::{ use crate::{
database,
public::{ public::{
admin, admin,
docs::{EndpointDocumentation, EndpointMethod}, docs::{EndpointDocumentation, EndpointMethod},
}, },
types::{ types::{
extract::{AdminUser, Check, CheckResult, Json}, extract::{AdminUser, Check, CheckResult, Database, Json},
http::ResponseCode, http::ResponseCode,
}, },
}; };
@ -92,8 +91,8 @@ impl Check for QueryRequest {
} }
} }
async fn query(_: AdminUser, Json(body): Json<QueryRequest>) -> Response { async fn query(_: AdminUser, Database(db): Database, Json(body): Json<QueryRequest>) -> Response {
match database::query(body.query) { match db.query(body.query) {
Ok(changes) => ResponseCode::Success.text(&format!( Ok(changes) => ResponseCode::Success.text(&format!(
"Query executed successfully. {changes} lines changed." "Query executed successfully. {changes} lines changed."
)), )),
@ -114,8 +113,8 @@ pub const ADMIN_POSTS: EndpointDocumentation = EndpointDocumentation {
cookie: Some("admin"), cookie: Some("admin"),
}; };
async fn posts(_: AdminUser) -> Response { async fn posts(_: AdminUser, Database(db): Database) -> Response {
admin::generate_posts() admin::generate_posts(&db)
} }
pub const ADMIN_USERS: EndpointDocumentation = EndpointDocumentation { pub const ADMIN_USERS: EndpointDocumentation = EndpointDocumentation {
@ -131,8 +130,8 @@ pub const ADMIN_USERS: EndpointDocumentation = EndpointDocumentation {
cookie: Some("admin"), cookie: Some("admin"),
}; };
async fn users(_: AdminUser) -> Response { async fn users(_: AdminUser, Database(db): Database) -> Response {
admin::generate_users() admin::generate_users(&db)
} }
pub const ADMIN_SESSIONS: EndpointDocumentation = EndpointDocumentation { pub const ADMIN_SESSIONS: EndpointDocumentation = EndpointDocumentation {
@ -148,8 +147,8 @@ pub const ADMIN_SESSIONS: EndpointDocumentation = EndpointDocumentation {
cookie: Some("admin"), cookie: Some("admin"),
}; };
async fn sessions(_: AdminUser) -> Response { async fn sessions(_: AdminUser, Database(db): Database) -> Response {
admin::generate_sessions() admin::generate_sessions(&db)
} }
pub const ADMIN_COMMENTS: EndpointDocumentation = EndpointDocumentation { pub const ADMIN_COMMENTS: EndpointDocumentation = EndpointDocumentation {
@ -165,8 +164,8 @@ pub const ADMIN_COMMENTS: EndpointDocumentation = EndpointDocumentation {
cookie: Some("admin"), cookie: Some("admin"),
}; };
async fn comments(_: AdminUser) -> Response { async fn comments(_: AdminUser, Database(db): Database) -> Response {
admin::generate_comments() admin::generate_comments(&db)
} }
pub const ADMIN_LIKES: EndpointDocumentation = EndpointDocumentation { pub const ADMIN_LIKES: EndpointDocumentation = EndpointDocumentation {
@ -182,8 +181,8 @@ pub const ADMIN_LIKES: EndpointDocumentation = EndpointDocumentation {
cookie: Some("admin"), cookie: Some("admin"),
}; };
async fn likes(_: AdminUser) -> Response { async fn likes(_: AdminUser, Database(db): Database) -> Response {
admin::generate_likes() admin::generate_likes(&db)
} }
async fn check(check: Option<AdminUser>) -> Response { async fn check(check: Option<AdminUser>) -> Response {

View file

@ -6,7 +6,7 @@ use tower_cookies::{Cookie, Cookies};
use crate::{ use crate::{
public::docs::{EndpointDocumentation, EndpointMethod}, public::docs::{EndpointDocumentation, EndpointMethod},
types::{ types::{
extract::{AuthorizedUser, Check, CheckResult, Json, Log}, extract::{AuthorizedUser, Check, CheckResult, Database, Json, Log},
http::ResponseCode, http::ResponseCode,
session::Session, session::Session,
user::User, user::User,
@ -99,13 +99,17 @@ impl Check for RegistrationRequet {
} }
} }
async fn register(cookies: Cookies, Json(body): Json<RegistrationRequet>) -> Response { async fn register(
let user = match User::new(body) { cookies: Cookies,
Database(db): Database,
Json(body): Json<RegistrationRequet>,
) -> Response {
let user = match User::new(&db, body) {
Ok(user) => user, Ok(user) => user,
Err(err) => return err, Err(err) => return err,
}; };
let session = match Session::new(user.user_id) { let session = match Session::new(&db, user.user_id) {
Ok(session) => session, Ok(session) => session,
Err(err) => return err, Err(err) => return err,
}; };
@ -158,8 +162,12 @@ impl Check for LoginRequest {
} }
} }
async fn login(cookies: Cookies, Json(body): Json<LoginRequest>) -> Response { async fn login(
let Ok(user) = User::from_email(&body.email) else { cookies: Cookies,
Database(db): Database,
Json(body): Json<LoginRequest>,
) -> Response {
let Ok(user) = User::from_email(&db, &body.email) else {
return ResponseCode::BadRequest.text("Email is not registered") return ResponseCode::BadRequest.text("Email is not registered")
}; };
@ -167,7 +175,7 @@ async fn login(cookies: Cookies, Json(body): Json<LoginRequest>) -> Response {
return ResponseCode::BadRequest.text("Password is not correct"); return ResponseCode::BadRequest.text("Password is not correct");
} }
let session = match Session::new(user.user_id) { let session = match Session::new(&db, user.user_id) {
Ok(session) => session, Ok(session) => session,
Err(err) => return err, Err(err) => return err,
}; };
@ -199,10 +207,15 @@ pub const AUTH_LOGOUT: EndpointDocumentation = EndpointDocumentation {
cookie: None, cookie: None,
}; };
async fn logout(cookies: Cookies, AuthorizedUser(user): AuthorizedUser, _: Log) -> Response { async fn logout(
cookies: Cookies,
AuthorizedUser(user): AuthorizedUser,
Database(db): Database,
_: Log,
) -> Response {
cookies.remove(Cookie::new("auth", "")); cookies.remove(Cookie::new("auth", ""));
if let Err(err) = Session::delete(user.user_id) { if let Err(err) = Session::delete(&db, user.user_id) {
return err; return err;
} }

View file

@ -1,5 +1,15 @@
use crate::types::extract::RouterURI; use crate::{
use axum::{error_handling::HandleErrorLayer, BoxError, Extension, Router}; database,
types::extract::{DatabaseExtention, RouterURI},
};
use axum::{
error_handling::HandleErrorLayer,
http::Request,
middleware::{self, Next},
response::Response,
BoxError, Extension, Router,
};
use tokio::sync::Mutex;
use tower::ServiceBuilder; use tower::ServiceBuilder;
use tower_governor::{ use tower_governor::{
errors::display_error, governor::GovernorConfigBuilder, key_extractor::SmartIpKeyExtractor, errors::display_error, governor::GovernorConfigBuilder, key_extractor::SmartIpKeyExtractor,
@ -13,6 +23,18 @@ pub mod users;
pub use auth::RegistrationRequet; pub use auth::RegistrationRequet;
async fn connect<B>(mut req: Request<B>, next: Next<B>) -> Response
where
B: Send,
{
if let Ok(db) = database::Database::connect() {
let ex = DatabaseExtention(Mutex::new(db));
req.extensions_mut().insert(ex);
}
next.run(req).await
}
pub fn router() -> Router { pub fn router() -> Router {
let governor_conf = Box::new( let governor_conf = Box::new(
GovernorConfigBuilder::default() GovernorConfigBuilder::default()
@ -49,4 +71,5 @@ pub fn router() -> Router {
config: Box::leak(governor_conf), config: Box::leak(governor_conf),
}), }),
) )
.layer(middleware::from_fn(connect))
} }

View file

@ -9,7 +9,7 @@ use crate::{
public::docs::{EndpointDocumentation, EndpointMethod}, public::docs::{EndpointDocumentation, EndpointMethod},
types::{ types::{
comment::Comment, comment::Comment,
extract::{AuthorizedUser, Check, CheckResult, Json}, extract::{AuthorizedUser, Check, CheckResult, Database, Json},
http::ResponseCode, http::ResponseCode,
like::Like, like::Like,
post::Post, post::Post,
@ -55,9 +55,10 @@ impl Check for PostCreateRequest {
async fn create( async fn create(
AuthorizedUser(user): AuthorizedUser, AuthorizedUser(user): AuthorizedUser,
Database(db): Database,
Json(body): Json<PostCreateRequest>, Json(body): Json<PostCreateRequest>,
) -> Response { ) -> Response {
let Ok(post) = Post::new(user.user_id, body.content) else { let Ok(post) = Post::new(&db, user.user_id, body.content) else {
return ResponseCode::InternalServerError.text("Failed to create post") return ResponseCode::InternalServerError.text("Failed to create post")
}; };
@ -101,9 +102,10 @@ impl Check for PostPageRequest {
async fn page( async fn page(
AuthorizedUser(user): AuthorizedUser, AuthorizedUser(user): AuthorizedUser,
Database(db): Database,
Json(body): Json<PostPageRequest>, Json(body): Json<PostPageRequest>,
) -> Response { ) -> Response {
let Ok(posts) = Post::from_post_page(user.user_id, body.page) else { let Ok(posts) = Post::from_post_page(&db, user.user_id, body.page) else {
return ResponseCode::InternalServerError.text("Failed to fetch posts") return ResponseCode::InternalServerError.text("Failed to fetch posts")
}; };
@ -149,9 +151,10 @@ impl Check for CommentsPageRequest {
async fn comments( async fn comments(
AuthorizedUser(_user): AuthorizedUser, AuthorizedUser(_user): AuthorizedUser,
Database(db): Database,
Json(body): Json<CommentsPageRequest>, Json(body): Json<CommentsPageRequest>,
) -> Response { ) -> Response {
let Ok(comments) = Comment::from_comment_page(body.page, body.post_id) else { let Ok(comments) = Comment::from_comment_page(&db, body.page, body.post_id) else {
return ResponseCode::InternalServerError.text("Failed to fetch comments") return ResponseCode::InternalServerError.text("Failed to fetch comments")
}; };
@ -197,9 +200,10 @@ impl Check for UsersPostsRequest {
async fn user( async fn user(
AuthorizedUser(user): AuthorizedUser, AuthorizedUser(user): AuthorizedUser,
Database(db): Database,
Json(body): Json<UsersPostsRequest>, Json(body): Json<UsersPostsRequest>,
) -> Response { ) -> Response {
let Ok(posts) = Post::from_user_post_page(user.user_id, body.user_id, body.page) else { let Ok(posts) = Post::from_user_post_page(&db, user.user_id, body.user_id, body.page) else {
return ResponseCode::InternalServerError.text("Failed to fetch posts") return ResponseCode::InternalServerError.text("Failed to fetch posts")
}; };
@ -251,9 +255,10 @@ impl Check for PostCommentRequest {
async fn comment( async fn comment(
AuthorizedUser(user): AuthorizedUser, AuthorizedUser(user): AuthorizedUser,
Database(db): Database,
Json(body): Json<PostCommentRequest>, Json(body): Json<PostCommentRequest>,
) -> Response { ) -> Response {
if let Err(err) = Comment::new(user.user_id, body.post_id, &body.content) { if let Err(err) = Comment::new(&db, user.user_id, body.post_id, &body.content) {
return err; return err;
} }
@ -293,12 +298,16 @@ impl Check for PostLikeRequest {
} }
} }
async fn like(AuthorizedUser(user): AuthorizedUser, Json(body): Json<PostLikeRequest>) -> Response { async fn like(
AuthorizedUser(user): AuthorizedUser,
Database(db): Database,
Json(body): Json<PostLikeRequest>,
) -> Response {
if body.state { if body.state {
if let Err(err) = Like::add_liked(user.user_id, body.post_id) { if let Err(err) = Like::add_liked(&db, user.user_id, body.post_id) {
return err; return err;
} }
} else if let Err(err) = Like::remove_liked(user.user_id, body.post_id) { } else if let Err(err) = Like::remove_liked(&db, user.user_id, body.post_id) {
return err; return err;
} }

View file

@ -1,7 +1,7 @@
use crate::{ use crate::{
public::docs::{EndpointDocumentation, EndpointMethod}, public::docs::{EndpointDocumentation, EndpointMethod},
types::{ types::{
extract::{AuthorizedUser, Check, CheckResult, Json, Log, Png}, extract::{AuthorizedUser, Check, CheckResult, Database, Json, Log, Png},
http::ResponseCode, http::ResponseCode,
user::User, user::User,
}, },
@ -46,9 +46,10 @@ impl Check for UserLoadRequest {
async fn load_batch( async fn load_batch(
AuthorizedUser(_user): AuthorizedUser, AuthorizedUser(_user): AuthorizedUser,
Database(db): Database,
Json(body): Json<UserLoadRequest>, Json(body): Json<UserLoadRequest>,
) -> Response { ) -> Response {
let users = User::from_user_ids(body.ids); let users = User::from_user_ids(&db, body.ids);
let Ok(json) = serde_json::to_string(&users) else { let Ok(json) = serde_json::to_string(&users) else {
return ResponseCode::InternalServerError.text("Failed to fetch users") return ResponseCode::InternalServerError.text("Failed to fetch users")
}; };
@ -90,9 +91,10 @@ impl Check for UserPageReqiest {
async fn load_page( async fn load_page(
AuthorizedUser(_user): AuthorizedUser, AuthorizedUser(_user): AuthorizedUser,
Database(db): Database,
Json(body): Json<UserPageReqiest>, Json(body): Json<UserPageReqiest>,
) -> Response { ) -> Response {
let Ok(users) = User::from_user_page(body.page) else { let Ok(users) = User::from_user_page(&db, body.page) else {
return ResponseCode::InternalServerError.text("Failed to fetch users") return ResponseCode::InternalServerError.text("Failed to fetch users")
}; };
@ -207,17 +209,18 @@ impl Check for UserFollowRequest {
async fn follow( async fn follow(
AuthorizedUser(user): AuthorizedUser, AuthorizedUser(user): AuthorizedUser,
Database(db): Database,
Json(body): Json<UserFollowRequest>, Json(body): Json<UserFollowRequest>,
) -> Response { ) -> Response {
if body.state { if body.state {
if let Err(err) = User::add_following(user.user_id, body.user_id) { if let Err(err) = User::add_following(&db, user.user_id, body.user_id) {
return err; return err;
} }
} else if let Err(err) = User::remove_following(user.user_id, body.user_id) { } else if let Err(err) = User::remove_following(&db, user.user_id, body.user_id) {
return err; return err;
} }
match User::get_following(user.user_id, body.user_id) { match User::get_following(&db, user.user_id, body.user_id) {
Ok(status) => ResponseCode::Success.text(&format!("{status}")), Ok(status) => ResponseCode::Success.text(&format!("{status}")),
Err(err) => err, Err(err) => err,
} }
@ -259,9 +262,10 @@ impl Check for UserFollowStatusRequest {
async fn follow_status( async fn follow_status(
AuthorizedUser(user): AuthorizedUser, AuthorizedUser(user): AuthorizedUser,
Database(db): Database,
Json(body): Json<UserFollowStatusRequest>, Json(body): Json<UserFollowStatusRequest>,
) -> Response { ) -> Response {
match User::get_following(user.user_id, body.user_id) { match User::get_following(&db, user.user_id, body.user_id) {
Ok(status) => ResponseCode::Success.text(&format!("{status}")), Ok(status) => ResponseCode::Success.text(&format!("{status}")),
Err(err) => err, Err(err) => err,
} }
@ -297,8 +301,12 @@ impl Check for UserFriendsRequest {
} }
} }
async fn friends(AuthorizedUser(_user): AuthorizedUser, Json(body): Json<UserFriendsRequest>) -> Response { async fn friends(
let Ok(users) = User::get_friends(body.user_id) else { AuthorizedUser(_user): AuthorizedUser,
Database(db): Database,
Json(body): Json<UserFriendsRequest>,
) -> Response {
let Ok(users) = User::get_friends(&db, body.user_id) else {
return ResponseCode::InternalServerError.text("Failed to fetch user") return ResponseCode::InternalServerError.text("Failed to fetch user")
}; };

View file

@ -3,89 +3,100 @@ use std::time::{Duration, SystemTime, UNIX_EPOCH};
use rusqlite::Row; use rusqlite::Row;
use tracing::instrument; use tracing::instrument;
use crate::{database, types::comment::Comment}; use crate::types::comment::Comment;
pub fn init() -> Result<(), rusqlite::Error> { use super::Database;
let sql = "
CREATE TABLE IF NOT EXISTS comments (
comment_id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id INTEGER NOT NULL,
post_id INTEGER NOT NULL,
date INTEGER NOT NULL,
content VARCHAR(255) NOT NULL,
FOREIGN KEY(user_id) REFERENCES users(user_id),
FOREIGN KEY(post_id) REFERENCES posts(post_id)
);
";
let conn = database::connect()?;
conn.execute(sql, ())?;
let sql2 = "CREATE INDEX IF NOT EXISTS post_ids on comments (post_id);"; impl Database {
conn.execute(sql2, ())?; pub fn init_comments(&self) -> Result<(), rusqlite::Error> {
let sql = "
CREATE TABLE IF NOT EXISTS comments (
comment_id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id INTEGER NOT NULL,
post_id INTEGER NOT NULL,
date INTEGER NOT NULL,
content VARCHAR(255) NOT NULL,
FOREIGN KEY(user_id) REFERENCES users(user_id),
FOREIGN KEY(post_id) REFERENCES posts(post_id)
);
";
self.0.execute(sql, ())?;
Ok(()) let sql2 = "CREATE INDEX IF NOT EXISTS post_ids on comments (post_id);";
} self.0.execute(sql2, ())?;
fn comment_from_row(row: &Row) -> Result<Comment, rusqlite::Error> { Ok(())
let comment_id = row.get(0)?; }
let user_id = row.get(1)?;
let post_id = row.get(2)?; fn comment_from_row(row: &Row) -> Result<Comment, rusqlite::Error> {
let date = row.get(3)?; let comment_id = row.get(0)?;
let content = row.get(4)?; let user_id = row.get(1)?;
let post_id = row.get(2)?;
Ok(Comment { let date = row.get(3)?;
comment_id, let content = row.get(4)?;
user_id,
post_id, Ok(Comment {
date, comment_id,
content, user_id,
}) post_id,
} date,
content,
#[instrument()] })
pub fn get_comments_page(page: u64, post_id: u64) -> Result<Vec<Comment>, rusqlite::Error> { }
tracing::trace!("Retrieving comments page");
let page_size = 5; #[instrument(skip(self))]
let conn = database::connect()?; pub fn get_comments_page(
let mut stmt = conn.prepare( &self,
"SELECT * FROM comments WHERE post_id = ? ORDER BY comment_id ASC LIMIT ? OFFSET ?", page: u64,
)?; post_id: u64,
let row = stmt.query_map([post_id, page_size, page_size * page], |row| { ) -> Result<Vec<Comment>, rusqlite::Error> {
let row = comment_from_row(row)?; tracing::trace!("Retrieving comments page");
Ok(row) let page_size = 5;
})?; let mut stmt = self.0.prepare(
Ok(row.into_iter().flatten().collect()) "SELECT * FROM comments WHERE post_id = ? ORDER BY comment_id ASC LIMIT ? OFFSET ?",
} )?;
let row = stmt.query_map([post_id, page_size, page_size * page], |row| {
#[instrument()] let row = Self::comment_from_row(row)?;
pub fn get_all_comments() -> Result<Vec<Comment>, rusqlite::Error> { Ok(row)
tracing::trace!("Retrieving comments page"); })?;
let conn = database::connect()?; Ok(row.into_iter().flatten().collect())
let mut stmt = conn.prepare("SELECT * FROM comments ORDER BY comment_id DESC")?; }
let row = stmt.query_map([], |row| {
let row = comment_from_row(row)?; #[instrument(skip(self))]
Ok(row) pub fn get_all_comments(&self) -> Result<Vec<Comment>, rusqlite::Error> {
})?; tracing::trace!("Retrieving comments page");
Ok(row.into_iter().flatten().collect()) let mut stmt = self
} .0
.prepare("SELECT * FROM comments ORDER BY comment_id DESC")?;
#[instrument()] let row = stmt.query_map([], |row| {
pub fn add_comment(user_id: u64, post_id: u64, content: &str) -> Result<Comment, rusqlite::Error> { let row = Self::comment_from_row(row)?;
tracing::trace!("Adding comment"); Ok(row)
let date = u64::try_from( })?;
SystemTime::now() Ok(row.into_iter().flatten().collect())
.duration_since(UNIX_EPOCH) }
.unwrap_or(Duration::ZERO)
.as_millis(), #[instrument(skip(self))]
) pub fn add_comment(
.unwrap_or(0); &self,
let conn = database::connect()?; user_id: u64,
let mut stmt = conn.prepare( post_id: u64,
"INSERT INTO comments (user_id, post_id, date, content) VALUES(?,?,?,?) RETURNING *;", content: &str,
)?; ) -> Result<Comment, rusqlite::Error> {
let post = stmt.query_row((user_id, post_id, date, content), |row| { tracing::trace!("Adding comment");
let row = comment_from_row(row)?; let date = u64::try_from(
Ok(row) SystemTime::now()
})?; .duration_since(UNIX_EPOCH)
Ok(post) .unwrap_or(Duration::ZERO)
.as_millis(),
)
.unwrap_or(0);
let mut stmt = self.0.prepare(
"INSERT INTO comments (user_id, post_id, date, content) VALUES(?,?,?,?) RETURNING *;",
)?;
let post = stmt.query_row((user_id, post_id, date, content), |row| {
let row = Self::comment_from_row(row)?;
Ok(row)
})?;
Ok(post)
}
} }

View file

@ -1,97 +1,100 @@
use tracing::instrument; use tracing::instrument;
use crate::{ use crate::types::user::{User, FOLLOWED, FOLLOWING, NO_RELATION};
database::{self, users::user_from_row},
types::user::{User, FOLLOWED, FOLLOWING, NO_RELATION},
};
pub fn init() -> Result<(), rusqlite::Error> { use super::Database;
let sql = "
CREATE TABLE IF NOT EXISTS friends (
follower_id INTEGER NOT NULL,
followee_id INTEGER NOT NULL,
FOREIGN KEY(follower_id) REFERENCES users(user_id),
FOREIGN KEY(followee_id) REFERENCES users(user_id),
PRIMARY KEY (follower_id, followee_id)
);
";
let conn = database::connect()?;
conn.execute(sql, ())?;
Ok(())
}
#[instrument()] impl Database {
pub fn get_friend_status(user_id_1: u64, user_id_2: u64) -> Result<u8, rusqlite::Error> { pub fn init_friends(&self) -> Result<(), rusqlite::Error> {
tracing::trace!("Retrieving friend status"); let sql = "
let conn = database::connect()?; CREATE TABLE IF NOT EXISTS friends (
let mut stmt = conn.prepare("SELECT * FROM friends WHERE (follower_id = ? AND followee_id = ?) OR (follower_id = ? AND followee_id = ?);")?; follower_id INTEGER NOT NULL,
let mut status = NO_RELATION; followee_id INTEGER NOT NULL,
let rows: Vec<u64> = stmt FOREIGN KEY(follower_id) REFERENCES users(user_id),
.query_map([user_id_1, user_id_2, user_id_2, user_id_1], |row| { FOREIGN KEY(followee_id) REFERENCES users(user_id),
let id: u64 = row.get(0)?; PRIMARY KEY (follower_id, followee_id)
Ok(id) );
})? ";
.into_iter() self.0.execute(sql, ())?;
.flatten() Ok(())
.collect();
for follower in rows {
if follower == user_id_1 {
status |= FOLLOWING;
}
if follower == user_id_2 {
status |= FOLLOWED;
}
} }
Ok(status) #[instrument(skip(self))]
} pub fn get_friend_status(&self, user_id_1: u64, user_id_2: u64) -> Result<u8, rusqlite::Error> {
tracing::trace!("Retrieving friend status");
let mut stmt = self.0.prepare("SELECT * FROM friends WHERE (follower_id = ? AND followee_id = ?) OR (follower_id = ? AND followee_id = ?);")?;
let mut status = NO_RELATION;
let rows: Vec<u64> = stmt
.query_map([user_id_1, user_id_2, user_id_2, user_id_1], |row| {
let id: u64 = row.get(0)?;
Ok(id)
})?
.into_iter()
.flatten()
.collect();
#[instrument()] for follower in rows {
pub fn get_friends(user_id: u64) -> Result<Vec<User>, rusqlite::Error> { if follower == user_id_1 {
tracing::trace!("Retrieving friends"); status |= FOLLOWING;
let conn = database::connect()?; }
let mut stmt = conn.prepare(
"
SELECT *
FROM users u
WHERE EXISTS (
SELECT NULL
FROM friends f
WHERE u.user_id = f.follower_id
AND f.followee_id = ?
)
AND EXISTS (
SELECT NULL
FROM friends f
WHERE u.user_id = f.followee_id
AND f.follower_id = ?
)
",
)?;
let row = stmt.query_map([user_id, user_id], |row| {
let row = user_from_row(row, true)?;
Ok(row)
})?;
Ok(row.into_iter().flatten().collect())
}
#[instrument()] if follower == user_id_2 {
pub fn set_following(user_id_1: u64, user_id_2: u64) -> Result<bool, rusqlite::Error> { status |= FOLLOWED;
tracing::trace!("Setting following"); }
let conn = database::connect()?; }
let mut stmt =
conn.prepare("INSERT OR REPLACE INTO friends (follower_id, followee_id) VALUES (?,?)")?;
let changes = stmt.execute([user_id_1, user_id_2])?;
Ok(changes == 1)
}
#[instrument()] Ok(status)
pub fn remove_following(user_id_1: u64, user_id_2: u64) -> Result<bool, rusqlite::Error> { }
tracing::trace!("Removing following");
let conn = database::connect()?; #[instrument(skip(self))]
let mut stmt = conn.prepare("DELETE FROM friends WHERE follower_id = ? AND followee_id = ?")?; pub fn get_friends(&self, user_id: u64) -> Result<Vec<User>, rusqlite::Error> {
let changes = stmt.execute([user_id_1, user_id_2])?; tracing::trace!("Retrieving friends");
Ok(changes == 1) let mut stmt = self.0.prepare(
"
SELECT *
FROM users u
WHERE EXISTS (
SELECT NULL
FROM friends f
WHERE u.user_id = f.follower_id
AND f.followee_id = ?
)
AND EXISTS (
SELECT NULL
FROM friends f
WHERE u.user_id = f.followee_id
AND f.follower_id = ?
)
",
)?;
let row = stmt.query_map([user_id, user_id], |row| {
let row = Self::user_from_row(row, true)?;
Ok(row)
})?;
Ok(row.into_iter().flatten().collect())
}
#[instrument(skip(self))]
pub fn set_following(&self, user_id_1: u64, user_id_2: u64) -> Result<bool, rusqlite::Error> {
tracing::trace!("Setting following");
let mut stmt = self
.0
.prepare("INSERT OR REPLACE INTO friends (follower_id, followee_id) VALUES (?,?)")?;
let changes = stmt.execute([user_id_1, user_id_2])?;
Ok(changes == 1)
}
#[instrument(skip(self))]
pub fn remove_following(
&self,
user_id_1: u64,
user_id_2: u64,
) -> Result<bool, rusqlite::Error> {
tracing::trace!("Removing following");
let mut stmt = self
.0
.prepare("DELETE FROM friends WHERE follower_id = ? AND followee_id = ?")?;
let changes = stmt.execute([user_id_1, user_id_2])?;
Ok(changes == 1)
}
} }

View file

@ -1,75 +1,81 @@
use rusqlite::OptionalExtension; use rusqlite::OptionalExtension;
use tracing::instrument; use tracing::instrument;
use crate::{database, types::like::Like}; use crate::types::like::Like;
pub fn init() -> Result<(), rusqlite::Error> { use super::Database;
let sql = "
CREATE TABLE IF NOT EXISTS likes (
user_id INTEGER NOT NULL,
post_id INTEGER NOT NULL,
FOREIGN KEY(user_id) REFERENCES users(user_id),
FOREIGN KEY(post_id) REFERENCES posts(post_id),
PRIMARY KEY (user_id, post_id)
);
";
let conn = database::connect()?;
conn.execute(sql, ())?;
Ok(())
}
#[instrument()] impl Database {
pub fn get_like_count(post_id: u64) -> Result<Option<u64>, rusqlite::Error> { pub fn init_likes(&self) -> Result<(), rusqlite::Error> {
tracing::trace!("Retrieving like count"); let sql = "
let conn = database::connect()?; CREATE TABLE IF NOT EXISTS likes (
let mut stmt = conn.prepare("SELECT COUNT(post_id) FROM likes WHERE post_id = ?")?; user_id INTEGER NOT NULL,
let row = stmt post_id INTEGER NOT NULL,
.query_row([post_id], |row| { FOREIGN KEY(user_id) REFERENCES users(user_id),
let row = row.get(0)?; FOREIGN KEY(post_id) REFERENCES posts(post_id),
Ok(row) PRIMARY KEY (user_id, post_id)
}) );
.optional()?; ";
Ok(row) self.0.execute(sql, ())?;
} Ok(())
}
#[instrument()] #[instrument(skip(self))]
pub fn get_liked(user_id: u64, post_id: u64) -> Result<bool, rusqlite::Error> { pub fn get_like_count(&self, post_id: u64) -> Result<Option<u64>, rusqlite::Error> {
tracing::trace!("Retrieving if liked"); tracing::trace!("Retrieving like count");
let conn = database::connect()?; let mut stmt = self
let mut stmt = conn.prepare("SELECT * FROM likes WHERE user_id = ? AND post_id = ?")?; .0
let liked = stmt.query_row([user_id, post_id], |_| Ok(())).optional()?; .prepare("SELECT COUNT(post_id) FROM likes WHERE post_id = ?")?;
Ok(liked.is_some()) let row = stmt
} .query_row([post_id], |row| {
let row = row.get(0)?;
Ok(row)
})
.optional()?;
Ok(row)
}
#[instrument()] #[instrument(skip(self))]
pub fn add_liked(user_id: u64, post_id: u64) -> Result<bool, rusqlite::Error> { pub fn get_liked(&self, user_id: u64, post_id: u64) -> Result<bool, rusqlite::Error> {
tracing::trace!("Adding like"); tracing::trace!("Retrieving if liked");
let conn = database::connect()?; let mut stmt = self
let mut stmt = conn.prepare("INSERT OR REPLACE INTO likes (user_id, post_id) VALUES (?,?)")?; .0
let changes = stmt.execute([user_id, post_id])?; .prepare("SELECT * FROM likes WHERE user_id = ? AND post_id = ?")?;
Ok(changes == 1) let liked = stmt.query_row([user_id, post_id], |_| Ok(())).optional()?;
} Ok(liked.is_some())
}
#[instrument()] #[instrument(skip(self))]
pub fn remove_liked(user_id: u64, post_id: u64) -> Result<bool, rusqlite::Error> { pub fn add_liked(&self, user_id: u64, post_id: u64) -> Result<bool, rusqlite::Error> {
tracing::trace!("Removing like"); tracing::trace!("Adding like");
let conn = database::connect()?; let mut stmt = self
let mut stmt = conn.prepare("DELETE FROM likes WHERE user_id = ? AND post_id = ?;")?; .0
let changes = stmt.execute((user_id, post_id))?; .prepare("INSERT OR REPLACE INTO likes (user_id, post_id) VALUES (?,?)")?;
Ok(changes == 1) let changes = stmt.execute([user_id, post_id])?;
} Ok(changes == 1)
}
#[instrument()] #[instrument(skip(self))]
pub fn get_all_likes() -> Result<Vec<Like>, rusqlite::Error> { pub fn remove_liked(&self, user_id: u64, post_id: u64) -> Result<bool, rusqlite::Error> {
tracing::trace!("Retrieving comments page"); tracing::trace!("Removing like");
let conn = database::connect()?; let mut stmt = self
let mut stmt = conn.prepare("SELECT * FROM likes")?; .0
let row = stmt.query_map([], |row| { .prepare("DELETE FROM likes WHERE user_id = ? AND post_id = ?;")?;
let like = Like { let changes = stmt.execute((user_id, post_id))?;
user_id: row.get(0)?, Ok(changes == 1)
post_id: row.get(1)?, }
};
Ok(like) #[instrument(skip(self))]
})?; pub fn get_all_likes(&self) -> Result<Vec<Like>, rusqlite::Error> {
Ok(row.into_iter().flatten().collect()) tracing::trace!("Retrieving comments page");
let mut stmt = self.0.prepare("SELECT * FROM likes")?;
let row = stmt.query_map([], |row| {
let like = Like {
user_id: row.get(0)?,
post_id: row.get(1)?,
};
Ok(like)
})?;
Ok(row.into_iter().flatten().collect())
}
} }

View file

@ -1,3 +1,4 @@
use rusqlite::Connection;
use tracing::instrument; use tracing::instrument;
pub mod comments; pub mod comments;
@ -7,23 +8,29 @@ pub mod posts;
pub mod sessions; pub mod sessions;
pub mod users; pub mod users;
pub fn connect() -> Result<rusqlite::Connection, rusqlite::Error> { #[derive(Debug)]
rusqlite::Connection::open("xssbook.db") pub struct Database(Connection);
impl Database {
pub fn connect() -> Result<Self, rusqlite::Error> {
let conn = rusqlite::Connection::open("xssbook.db")?;
Ok(Self(conn))
}
#[instrument(skip(self))]
pub fn query(&self, query: String) -> Result<usize, rusqlite::Error> {
tracing::trace!("Running custom query");
self.0.execute(&query, [])
}
} }
pub fn init() -> Result<(), rusqlite::Error> { pub fn init() -> Result<(), rusqlite::Error> {
users::init()?; let db = Database::connect()?;
posts::init()?; db.init_users()?;
sessions::init()?; db.init_posts()?;
likes::init()?; db.init_sessions()?;
comments::init()?; db.init_likes()?;
friends::init()?; db.init_comments()?;
db.init_friends()?;
Ok(()) Ok(())
} }
#[instrument()]
pub fn query(query: String) -> Result<usize, rusqlite::Error> {
tracing::trace!("Running custom query");
let conn = connect()?;
conn.execute(&query, [])
}

View file

@ -3,115 +3,122 @@ use std::time::{Duration, SystemTime, UNIX_EPOCH};
use rusqlite::{OptionalExtension, Row}; use rusqlite::{OptionalExtension, Row};
use tracing::instrument; use tracing::instrument;
use crate::database;
use crate::types::post::Post; use crate::types::post::Post;
use super::{comments, likes}; use super::Database;
pub fn init() -> Result<(), rusqlite::Error> { impl Database {
let sql = " pub fn init_posts(&self) -> Result<(), rusqlite::Error> {
CREATE TABLE IF NOT EXISTS posts ( let sql = "
post_id INTEGER PRIMARY KEY AUTOINCREMENT, CREATE TABLE IF NOT EXISTS posts (
user_id INTEGER NOT NULL, post_id INTEGER PRIMARY KEY AUTOINCREMENT,
content VARCHAR(500) NOT NULL, user_id INTEGER NOT NULL,
date INTEGER NOT NULL, content VARCHAR(500) NOT NULL,
FOREIGN KEY(user_id) REFERENCES users(user_id) date INTEGER NOT NULL,
); FOREIGN KEY(user_id) REFERENCES users(user_id)
"; );
let conn = database::connect()?; ";
conn.execute(sql, ())?; self.0.execute(sql, ())?;
Ok(()) Ok(())
} }
fn post_from_row(row: &Row) -> Result<Post, rusqlite::Error> { fn post_from_row(&self, row: &Row) -> Result<Post, rusqlite::Error> {
let post_id = row.get(0)?; let post_id = row.get(0)?;
let user_id = row.get(1)?; let user_id = row.get(1)?;
let content = row.get(2)?; let content = row.get(2)?;
let date = row.get(3)?; let date = row.get(3)?;
let comments = comments::get_comments_page(0, post_id).unwrap_or_else(|_| Vec::new()); let comments = self
let likes = likes::get_like_count(post_id).unwrap_or(None).unwrap_or(0); .get_comments_page(0, post_id)
.unwrap_or_else(|_| Vec::new());
let likes = self.get_like_count(post_id).unwrap_or(None).unwrap_or(0);
Ok(Post { Ok(Post {
post_id, post_id,
user_id, user_id,
content, content,
date, date,
likes, likes,
liked: false, liked: false,
comments, comments,
})
}
#[instrument()]
pub fn get_post(post_id: u64) -> Result<Option<Post>, rusqlite::Error> {
tracing::trace!("Retrieving post");
let conn = database::connect()?;
let mut stmt = conn.prepare("SELECT * FROM posts WHERE post_id = ?")?;
let row = stmt
.query_row([post_id], |row| {
let row = post_from_row(row)?;
Ok(row)
}) })
.optional()?; }
Ok(row)
}
#[instrument()] #[instrument(skip(self))]
pub fn get_post_page(page: u64) -> Result<Vec<Post>, rusqlite::Error> { pub fn get_post(&self, post_id: u64) -> Result<Option<Post>, rusqlite::Error> {
tracing::trace!("Retrieving posts page"); tracing::trace!("Retrieving post");
let page_size = 10; let mut stmt = self.0.prepare("SELECT * FROM posts WHERE post_id = ?")?;
let conn = database::connect()?; let row = stmt
let mut stmt = conn.prepare("SELECT * FROM posts ORDER BY post_id DESC LIMIT ? OFFSET ?")?; .query_row([post_id], |row| {
let row = stmt.query_map([page_size, page_size * page], |row| { let row = self.post_from_row(row)?;
let row = post_from_row(row)?; Ok(row)
})
.optional()?;
Ok(row) Ok(row)
})?; }
Ok(row.into_iter().flatten().collect())
}
#[instrument()] #[instrument(skip(self))]
pub fn get_all_posts() -> Result<Vec<Post>, rusqlite::Error> { pub fn get_post_page(&self, page: u64) -> Result<Vec<Post>, rusqlite::Error> {
tracing::trace!("Retrieving posts page"); tracing::trace!("Retrieving posts page");
let conn = database::connect()?; let page_size = 10;
let mut stmt = conn.prepare("SELECT * FROM posts ORDER BY post_id DESC")?; let mut stmt = self
let row = stmt.query_map([], |row| { .0
let row = post_from_row(row)?; .prepare("SELECT * FROM posts ORDER BY post_id DESC LIMIT ? OFFSET ?")?;
Ok(row) let row = stmt.query_map([page_size, page_size * page], |row| {
})?; let row = self.post_from_row(row)?;
Ok(row.into_iter().flatten().collect()) Ok(row)
} })?;
Ok(row.into_iter().flatten().collect())
}
#[instrument()] #[instrument(skip(self))]
pub fn get_users_post_page(user_id: u64, page: u64) -> Result<Vec<Post>, rusqlite::Error> { pub fn get_all_posts(&self) -> Result<Vec<Post>, rusqlite::Error> {
tracing::trace!("Retrieving users posts"); tracing::trace!("Retrieving posts page");
let page_size = 10; let mut stmt = self
let conn = database::connect()?; .0
let mut stmt = conn .prepare("SELECT * FROM posts ORDER BY post_id DESC")?;
.prepare("SELECT * FROM posts WHERE user_id = ? ORDER BY post_id DESC LIMIT ? OFFSET ?")?; let row = stmt.query_map([], |row| {
let row = stmt.query_map([user_id, page_size, page_size * page], |row| { let row = self.post_from_row(row)?;
let row = post_from_row(row)?; Ok(row)
Ok(row) })?;
})?; Ok(row.into_iter().flatten().collect())
Ok(row.into_iter().flatten().collect()) }
}
#[instrument()] #[instrument(skip(self))]
pub fn add_post(user_id: u64, content: &str) -> Result<Post, rusqlite::Error> { pub fn get_users_post_page(
tracing::trace!("Adding post"); &self,
let date = u64::try_from( user_id: u64,
SystemTime::now() page: u64,
.duration_since(UNIX_EPOCH) ) -> Result<Vec<Post>, rusqlite::Error> {
.unwrap_or(Duration::ZERO) tracing::trace!("Retrieving users posts");
.as_millis(), let page_size = 10;
) let mut stmt = self.0.prepare(
.unwrap_or(0); "SELECT * FROM posts WHERE user_id = ? ORDER BY post_id DESC LIMIT ? OFFSET ?",
let conn = database::connect()?; )?;
let mut stmt = let row = stmt.query_map([user_id, page_size, page_size * page], |row| {
conn.prepare("INSERT INTO posts (user_id, content, date) VALUES(?,?,?) RETURNING *;")?; let row = self.post_from_row(row)?;
let post = stmt.query_row((user_id, content, date), |row| { Ok(row)
let row = post_from_row(row)?; })?;
Ok(row) Ok(row.into_iter().flatten().collect())
})?; }
Ok(post)
#[instrument(skip(self))]
pub fn add_post(&self, user_id: u64, content: &str) -> Result<Post, rusqlite::Error> {
tracing::trace!("Adding post");
let date = u64::try_from(
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or(Duration::ZERO)
.as_millis(),
)
.unwrap_or(0);
let mut stmt = self
.0
.prepare("INSERT INTO posts (user_id, content, date) VALUES(?,?,?) RETURNING *;")?;
let post = stmt.query_row((user_id, content, date), |row| {
let row = self.post_from_row(row)?;
Ok(row)
})?;
Ok(post)
}
} }

View file

@ -1,65 +1,64 @@
use rusqlite::OptionalExtension; use rusqlite::OptionalExtension;
use tracing::instrument; use tracing::instrument;
use crate::{database, types::session::Session}; use crate::types::session::Session;
pub fn init() -> Result<(), rusqlite::Error> { use super::Database;
let sql = "
CREATE TABLE IF NOT EXISTS sessions (
user_id INTEGER PRIMARY KEY NOT NULL,
token TEXT NOT NULL,
FOREIGN KEY(user_id) REFERENCES users(user_id)
);
";
let conn = database::connect()?;
conn.execute(sql, ())?;
Ok(())
}
#[instrument()] impl Database {
pub fn get_session(token: &str) -> Result<Option<Session>, rusqlite::Error> { pub fn init_sessions(&self) -> Result<(), rusqlite::Error> {
tracing::trace!("Retrieving session"); let sql = "
let conn = database::connect()?; CREATE TABLE IF NOT EXISTS sessions (
let mut stmt = conn.prepare("SELECT * FROM sessions WHERE token = ?")?; user_id INTEGER PRIMARY KEY NOT NULL,
let row = stmt token TEXT NOT NULL,
.query_row([token], |row| { FOREIGN KEY(user_id) REFERENCES users(user_id)
);
";
self.0.execute(sql, ())?;
Ok(())
}
#[instrument(skip(self))]
pub fn get_session(&self, token: &str) -> Result<Option<Session>, rusqlite::Error> {
tracing::trace!("Retrieving session");
let mut stmt = self.0.prepare("SELECT * FROM sessions WHERE token = ?")?;
let row = stmt
.query_row([token], |row| {
Ok(Session {
user_id: row.get(0)?,
token: row.get(1)?,
})
})
.optional()?;
Ok(row)
}
#[instrument(skip(self))]
pub fn get_all_sessions(&self) -> Result<Vec<Session>, rusqlite::Error> {
tracing::trace!("Retrieving session");
let mut stmt = self.0.prepare("SELECT * FROM sessions")?;
let row = stmt.query_map([], |row| {
Ok(Session { Ok(Session {
user_id: row.get(0)?, user_id: row.get(0)?,
token: row.get(1)?, token: row.get(1)?,
}) })
}) })?;
.optional()?; Ok(row.into_iter().flatten().collect())
Ok(row) }
}
#[instrument()] #[instrument(skip(self))]
pub fn get_all_sessions() -> Result<Vec<Session>, rusqlite::Error> { pub fn set_session(&self, user_id: u64, token: &str) -> Result<(), Box<dyn std::error::Error>> {
tracing::trace!("Retrieving session"); tracing::trace!("Setting new session");
let conn = database::connect()?; let sql = "INSERT OR REPLACE INTO sessions (user_id, token) VALUES (?, ?);";
let mut stmt = conn.prepare("SELECT * FROM sessions")?; self.0.execute(sql, (user_id, token))?;
let row = stmt.query_map([], |row| { Ok(())
Ok(Session { }
user_id: row.get(0)?,
token: row.get(1)?,
})
})?;
Ok(row.into_iter().flatten().collect())
}
#[instrument()] #[instrument(skip(self))]
pub fn set_session(user_id: u64, token: &str) -> Result<(), Box<dyn std::error::Error>> { pub fn delete_session(&self, user_id: u64) -> Result<(), Box<dyn std::error::Error>> {
tracing::trace!("Setting new session"); tracing::trace!("Deleting session");
let conn = database::connect()?; let sql = "DELETE FROM sessions WHERE user_id = ?;";
let sql = "INSERT OR REPLACE INTO sessions (user_id, token) VALUES (?, ?);"; self.0.execute(sql, [user_id])?;
conn.execute(sql, (user_id, token))?; Ok(())
Ok(()) }
}
#[instrument()]
pub fn delete_session(user_id: u64) -> Result<(), Box<dyn std::error::Error>> {
tracing::trace!("Deleting session");
let conn = database::connect()?;
let sql = "DELETE FROM sessions WHERE user_id = ?;";
conn.execute(sql, [user_id])?;
Ok(())
} }

View file

@ -2,169 +2,180 @@ use rusqlite::{OptionalExtension, Row};
use std::time::{Duration, SystemTime, UNIX_EPOCH}; use std::time::{Duration, SystemTime, UNIX_EPOCH};
use tracing::instrument; use tracing::instrument;
use crate::{api::RegistrationRequet, database, types::user::User}; use crate::{api::RegistrationRequet, types::user::User};
pub fn init() -> Result<(), rusqlite::Error> { use super::Database;
let sql = "
CREATE TABLE IF NOT EXISTS users (
user_id INTEGER PRIMARY KEY AUTOINCREMENT,
firstname VARCHAR(20) NOT NULL,
lastname VARCHAR(20) NOT NULL,
email VARCHAR(50) NOT NULL,
password VARCHAR(50) NOT NULL,
gender VARCHAR(100) NOT NULL,
date BIGINT NOT NULL,
day TINYINT NOT NULL,
month TINYINT NOT NULL,
year INTEGER NOT NULL
);
";
let conn = database::connect()?;
conn.execute(sql, ())?;
let sql2 = "CREATE UNIQUE INDEX IF NOT EXISTS emails on users (email);"; impl Database {
conn.execute(sql2, ())?; pub fn init_users(&self) -> Result<(), rusqlite::Error> {
let sql = "
CREATE TABLE IF NOT EXISTS users (
user_id INTEGER PRIMARY KEY AUTOINCREMENT,
firstname VARCHAR(20) NOT NULL,
lastname VARCHAR(20) NOT NULL,
email VARCHAR(50) NOT NULL,
password VARCHAR(50) NOT NULL,
gender VARCHAR(100) NOT NULL,
date BIGINT NOT NULL,
day TINYINT NOT NULL,
month TINYINT NOT NULL,
year INTEGER NOT NULL
);
";
self.0.execute(sql, ())?;
let sql3 = "CREATE UNIQUE INDEX IF NOT EXISTS passwords on users (password);"; let sql2 = "CREATE UNIQUE INDEX IF NOT EXISTS emails on users (email);";
conn.execute(sql3, ())?; self.0.execute(sql2, ())?;
Ok(()) let sql3 = "CREATE UNIQUE INDEX IF NOT EXISTS passwords on users (password);";
} self.0.execute(sql3, ())?;
pub fn user_from_row(row: &Row, hide_password: bool) -> Result<User, rusqlite::Error> { Ok(())
let user_id = row.get(0)?; }
let firstname = row.get(1)?;
let lastname = row.get(2)?;
let email = row.get(3)?;
let password = row.get(4)?;
let gender = row.get(5)?;
let date = row.get(6)?;
let day = row.get(7)?;
let month = row.get(8)?;
let year = row.get(9)?;
let password = if hide_password { pub fn user_from_row(row: &Row, hide_password: bool) -> Result<User, rusqlite::Error> {
String::new() let user_id = row.get(0)?;
} else { let firstname = row.get(1)?;
password let lastname = row.get(2)?;
}; let email = row.get(3)?;
let password = row.get(4)?;
let gender = row.get(5)?;
let date = row.get(6)?;
let day = row.get(7)?;
let month = row.get(8)?;
let year = row.get(9)?;
Ok(User { let password = if hide_password {
user_id, String::new()
firstname, } else {
lastname, password
email, };
password,
gender,
date,
day,
month,
year,
})
}
#[instrument()] Ok(User {
pub fn get_user_by_id(user_id: u64, hide_password: bool) -> Result<Option<User>, rusqlite::Error> { user_id,
tracing::trace!("Retrieving user by id"); firstname,
let conn = database::connect()?; lastname,
let mut stmt = conn.prepare("SELECT * FROM users WHERE user_id = ?")?; email,
let row = stmt password,
.query_row([user_id], |row| { gender,
let row = user_from_row(row, hide_password)?;
Ok(row)
})
.optional()?;
Ok(row)
}
#[instrument()]
pub fn get_user_by_email(
email: &str,
hide_password: bool,
) -> Result<Option<User>, rusqlite::Error> {
tracing::trace!("Retrieving user by email");
let conn = database::connect()?;
let mut stmt = conn.prepare("SELECT * FROM users WHERE email = ?")?;
let row = stmt
.query_row([email], |row| {
let row = user_from_row(row, hide_password)?;
Ok(row)
})
.optional()?;
Ok(row)
}
#[instrument()]
pub fn get_user_by_password(
password: &str,
hide_password: bool,
) -> Result<Option<User>, rusqlite::Error> {
tracing::trace!("Retrieving user by password");
let conn = database::connect()?;
let mut stmt = conn.prepare("SELECT * FROM users WHERE password = ?")?;
let row = stmt
.query_row([password], |row| {
let row = user_from_row(row, hide_password)?;
Ok(row)
})
.optional()?;
Ok(row)
}
#[instrument()]
pub fn get_user_page(page: u64, hide_password: bool) -> Result<Vec<User>, rusqlite::Error> {
tracing::trace!("Retrieving user page");
let page_size = 5;
let conn = database::connect()?;
let mut stmt = conn.prepare("SELECT * FROM users ORDER BY user_id DESC LIMIT ? OFFSET ?")?;
let row = stmt.query_map([page_size, page_size * page], |row| {
let row = user_from_row(row, hide_password)?;
Ok(row)
})?;
Ok(row.into_iter().flatten().collect())
}
#[instrument()]
pub fn get_all_users() -> Result<Vec<User>, rusqlite::Error> {
tracing::trace!("Retrieving user page");
let conn = database::connect()?;
let mut stmt = conn.prepare("SELECT * FROM users ORDER BY user_id DESC")?;
let row = stmt.query_map([], |row| {
let row = user_from_row(row, false)?;
Ok(row)
})?;
Ok(row.into_iter().flatten().collect())
}
#[instrument()]
pub fn add_user(request: RegistrationRequet) -> Result<User, rusqlite::Error> {
tracing::trace!("Adding new user");
let date = u64::try_from(
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or(Duration::ZERO)
.as_millis(),
)
.unwrap_or(0);
let conn = database::connect()?;
let mut stmt = conn.prepare("INSERT INTO users (firstname, lastname, email, password, gender, date, day, month, year) VALUES(?,?,?,?,?,?,?,?,?) RETURNING *;")?;
let user = stmt.query_row(
(
request.firstname,
request.lastname,
request.email,
request.password,
request.gender,
date, date,
request.day, day,
request.month, month,
request.year, year,
), })
|row| { }
let row = user_from_row(row, false)?;
#[instrument(skip(self))]
pub fn get_user_by_id(
&self,
user_id: u64,
hide_password: bool,
) -> Result<Option<User>, rusqlite::Error> {
tracing::trace!("Retrieving user by id");
let mut stmt = self.0.prepare("SELECT * FROM users WHERE user_id = ?")?;
let row = stmt
.query_row([user_id], |row| {
let row = Self::user_from_row(row, hide_password)?;
Ok(row)
})
.optional()?;
Ok(row)
}
#[instrument(skip(self))]
pub fn get_user_by_email(
&self,
email: &str,
hide_password: bool,
) -> Result<Option<User>, rusqlite::Error> {
tracing::trace!("Retrieving user by email");
let mut stmt = self.0.prepare("SELECT * FROM users WHERE email = ?")?;
let row = stmt
.query_row([email], |row| {
let row = Self::user_from_row(row, hide_password)?;
Ok(row)
})
.optional()?;
Ok(row)
}
#[instrument(skip(self))]
pub fn get_user_by_password(
&self,
password: &str,
hide_password: bool,
) -> Result<Option<User>, rusqlite::Error> {
tracing::trace!("Retrieving user by password");
let mut stmt = self.0.prepare("SELECT * FROM users WHERE password = ?")?;
let row = stmt
.query_row([password], |row| {
let row = Self::user_from_row(row, hide_password)?;
Ok(row)
})
.optional()?;
Ok(row)
}
#[instrument(skip(self))]
pub fn get_user_page(
&self,
page: u64,
hide_password: bool,
) -> Result<Vec<User>, rusqlite::Error> {
tracing::trace!("Retrieving user page");
let page_size = 5;
let mut stmt = self
.0
.prepare("SELECT * FROM users ORDER BY user_id DESC LIMIT ? OFFSET ?")?;
let row = stmt.query_map([page_size, page_size * page], |row| {
let row = Self::user_from_row(row, hide_password)?;
Ok(row) Ok(row)
}, })?;
)?; Ok(row.into_iter().flatten().collect())
Ok(user) }
#[instrument(skip(self))]
pub fn get_all_users(&self) -> Result<Vec<User>, rusqlite::Error> {
tracing::trace!("Retrieving user page");
let mut stmt = self
.0
.prepare("SELECT * FROM users ORDER BY user_id DESC")?;
let row = stmt.query_map([], |row| {
let row = Self::user_from_row(row, false)?;
Ok(row)
})?;
Ok(row.into_iter().flatten().collect())
}
#[instrument(skip(self))]
pub fn add_user(&self, request: RegistrationRequet) -> Result<User, rusqlite::Error> {
tracing::trace!("Adding new user");
let date = u64::try_from(
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or(Duration::ZERO)
.as_millis(),
)
.unwrap_or(0);
let mut stmt = self.0.prepare("INSERT INTO users (firstname, lastname, email, password, gender, date, day, month, year) VALUES(?,?,?,?,?,?,?,?,?) RETURNING *;")?;
let user = stmt.query_row(
(
request.firstname,
request.lastname,
request.email,
request.password,
request.gender,
date,
request.day,
request.month,
request.year,
),
|row| {
let row = Self::user_from_row(row, false)?;
Ok(row)
},
)?;
Ok(user)
}
} }

View file

@ -5,6 +5,7 @@ use tokio::sync::Mutex;
use crate::{ use crate::{
console::sanatize, console::sanatize,
database::Database,
types::{ types::{
comment::Comment, http::ResponseCode, like::Like, post::Post, session::Session, user::User, comment::Comment, http::ResponseCode, like::Like, post::Post, session::Session, user::User,
}, },
@ -36,8 +37,8 @@ pub async fn regen_secret() -> String {
secret.clone() secret.clone()
} }
pub fn generate_users() -> Response { pub fn generate_users(db: &Database) -> Response {
let users = match User::reterieve_all() { let users = match User::reterieve_all(db) {
Ok(users) => users, Ok(users) => users,
Err(err) => return err, Err(err) => return err,
}; };
@ -70,8 +71,8 @@ pub fn generate_users() -> Response {
ResponseCode::Success.text(&html) ResponseCode::Success.text(&html)
} }
pub fn generate_posts() -> Response { pub fn generate_posts(db: &Database) -> Response {
let posts = match Post::reterieve_all() { let posts = match Post::reterieve_all(db) {
Ok(posts) => posts, Ok(posts) => posts,
Err(err) => return err, Err(err) => return err,
}; };
@ -99,8 +100,8 @@ pub fn generate_posts() -> Response {
ResponseCode::Success.text(&html) ResponseCode::Success.text(&html)
} }
pub fn generate_sessions() -> Response { pub fn generate_sessions(db: &Database) -> Response {
let sessions = match Session::reterieve_all() { let sessions = match Session::reterieve_all(db) {
Ok(sessions) => sessions, Ok(sessions) => sessions,
Err(err) => return err, Err(err) => return err,
}; };
@ -123,8 +124,8 @@ pub fn generate_sessions() -> Response {
ResponseCode::Success.text(&html) ResponseCode::Success.text(&html)
} }
pub fn generate_comments() -> Response { pub fn generate_comments(db: &Database) -> Response {
let comments = match Comment::reterieve_all() { let comments = match Comment::reterieve_all(db) {
Ok(comments) => comments, Ok(comments) => comments,
Err(err) => return err, Err(err) => return err,
}; };
@ -154,8 +155,8 @@ pub fn generate_comments() -> Response {
ResponseCode::Success.text(&html) ResponseCode::Success.text(&html)
} }
pub fn generate_likes() -> Response { pub fn generate_likes(db: &Database) -> Response {
let likes = match Like::reterieve_all() { let likes = match Like::reterieve_all(db) {
Ok(likes) => likes, Ok(likes) => likes,
Err(err) => return err, Err(err) => return err,
}; };

View file

@ -2,7 +2,7 @@ use serde::Serialize;
use tracing::instrument; use tracing::instrument;
use crate::{ use crate::{
database::{self, comments}, database::Database,
types::http::{ResponseCode, Result}, types::http::{ResponseCode, Result},
}; };
@ -16,9 +16,9 @@ pub struct Comment {
} }
impl Comment { impl Comment {
#[instrument()] #[instrument(skip(db))]
pub fn new(user_id: u64, post_id: u64, content: &str) -> Result<Self> { pub fn new(db: &Database, user_id: u64, post_id: u64, content: &str) -> Result<Self> {
let Ok(comment) = comments::add_comment(user_id, post_id, content) else { let Ok(comment) = db.add_comment(user_id, post_id, content) else {
tracing::error!("Failed to create comment"); tracing::error!("Failed to create comment");
return Err(ResponseCode::InternalServerError.text("Failed to create post")) return Err(ResponseCode::InternalServerError.text("Failed to create post"))
}; };
@ -26,17 +26,17 @@ impl Comment {
Ok(comment) Ok(comment)
} }
#[instrument()] #[instrument(skip(db))]
pub fn from_comment_page(page: u64, post_id: u64) -> Result<Vec<Self>> { pub fn from_comment_page(db: &Database, page: u64, post_id: u64) -> Result<Vec<Self>> {
let Ok(posts) = database::comments::get_comments_page(page, post_id) else { let Ok(posts) = db.get_comments_page(page, post_id) else {
return Err(ResponseCode::BadRequest.text("Failed to fetch comments")) return Err(ResponseCode::BadRequest.text("Failed to fetch comments"))
}; };
Ok(posts) Ok(posts)
} }
#[instrument()] #[instrument(skip(db))]
pub fn reterieve_all() -> Result<Vec<Self>> { pub fn reterieve_all(db: &Database) -> Result<Vec<Self>> {
let Ok(posts) = database::comments::get_all_comments() else { let Ok(posts) = db.get_all_comments() else {
return Err(ResponseCode::InternalServerError.text("Failed to fetch comments")) return Err(ResponseCode::InternalServerError.text("Failed to fetch comments"))
}; };
Ok(posts) Ok(posts)

View file

@ -14,9 +14,11 @@ use axum::{
use bytes::Bytes; use bytes::Bytes;
use image::{io::Reader, DynamicImage, ImageFormat}; use image::{io::Reader, DynamicImage, ImageFormat};
use serde::de::DeserializeOwned; use serde::de::DeserializeOwned;
use tokio::sync::Mutex;
use tower_cookies::Cookies; use tower_cookies::Cookies;
use crate::{ use crate::{
database,
public::admin, public::admin,
public::console, public::console,
types::{ types::{
@ -97,11 +99,17 @@ where
return Err(ResponseCode::Forbidden.text("No auth token provided")) return Err(ResponseCode::Forbidden.text("No auth token provided"))
}; };
let Ok(session) = Session::from_token(token.value()) else { let Some(db) = parts.extensions.get::<DatabaseExtention>() else {
return Err(ResponseCode::Forbidden.text("Could not connect to database"))
};
let db = db.0.lock().await;
let Ok(session) = Session::from_token(&db, token.value()) else {
return Err(ResponseCode::Unauthorized.text("Auth token invalid")) return Err(ResponseCode::Unauthorized.text("Auth token invalid"))
}; };
let Ok(user) = User::from_user_id(session.user_id, true) else { let Ok(user) = User::from_user_id(&db, session.user_id, true) else {
tracing::error!("Valid token but no valid user"); tracing::error!("Valid token but no valid user");
return Err(ResponseCode::InternalServerError.text("Valid token but no valid user")) return Err(ResponseCode::InternalServerError.text("Valid token but no valid user"))
}; };
@ -260,6 +268,26 @@ where
} }
} }
pub struct DatabaseExtention(pub Mutex<database::Database>);
pub struct Database(pub database::Database);
#[async_trait]
impl<S> FromRequestParts<S> for Database
where
S: Send + Sync,
{
type Rejection = Response;
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self> {
let db = parts.extensions.remove::<DatabaseExtention>();
let Some(db) = db else {
return Err(ResponseCode::InternalServerError.text("Database is not loaded"))
};
Ok(Self(db.0.into_inner()))
}
}
async fn read_body<S, B>(mut req: Request<B>, state: &S) -> Result<Vec<u8>> async fn read_body<S, B>(mut req: Request<B>, state: &S) -> Result<Vec<u8>>
where where
B: HttpBody + Sync + Send + 'static, B: HttpBody + Sync + Send + 'static,

View file

@ -1,7 +1,7 @@
use serde::Serialize; use serde::Serialize;
use tracing::instrument; use tracing::instrument;
use crate::database; use crate::database::Database;
use crate::types::http::{ResponseCode, Result}; use crate::types::http::{ResponseCode, Result};
#[derive(Serialize)] #[derive(Serialize)]
@ -11,9 +11,9 @@ pub struct Like {
} }
impl Like { impl Like {
#[instrument()] #[instrument(skip(db))]
pub fn add_liked(user_id: u64, post_id: u64) -> Result<()> { pub fn add_liked(db: &Database, user_id: u64, post_id: u64) -> Result<()> {
let Ok(liked) = database::likes::add_liked(user_id, post_id) else { let Ok(liked) = db.add_liked(user_id, post_id) else {
return Err(ResponseCode::BadRequest.text("Failed to add like status")) return Err(ResponseCode::BadRequest.text("Failed to add like status"))
}; };
@ -24,9 +24,9 @@ impl Like {
Ok(()) Ok(())
} }
#[instrument()] #[instrument(skip(db))]
pub fn remove_liked(user_id: u64, post_id: u64) -> Result<()> { pub fn remove_liked(db: &Database, user_id: u64, post_id: u64) -> Result<()> {
let Ok(liked) = database::likes::remove_liked(user_id, post_id) else { let Ok(liked) = db.remove_liked(user_id, post_id) else {
return Err(ResponseCode::BadRequest.text("Failed to remove like status")) return Err(ResponseCode::BadRequest.text("Failed to remove like status"))
}; };
@ -37,9 +37,9 @@ impl Like {
Ok(()) Ok(())
} }
#[instrument()] #[instrument(skip(db))]
pub fn reterieve_all() -> Result<Vec<Self>> { pub fn reterieve_all(db: &Database) -> Result<Vec<Self>> {
let Ok(likes) = database::likes::get_all_likes() else { let Ok(likes) = db.get_all_likes() else {
return Err(ResponseCode::InternalServerError.text("Failed to fetch likes")) return Err(ResponseCode::InternalServerError.text("Failed to fetch likes"))
}; };
Ok(likes) Ok(likes)

View file

@ -2,7 +2,7 @@ use core::fmt;
use serde::Serialize; use serde::Serialize;
use tracing::instrument; use tracing::instrument;
use crate::database; use crate::database::Database;
use crate::types::http::{ResponseCode, Result}; use crate::types::http::{ResponseCode, Result};
use super::comment::Comment; use super::comment::Comment;
@ -27,57 +27,62 @@ impl fmt::Debug for Post {
} }
impl Post { impl Post {
#[instrument()] #[instrument(skip(db))]
pub fn from_post_id(self_id: u64, post_id: u64) -> Result<Self> { pub fn from_post_id(db: &Database, self_id: u64, post_id: u64) -> Result<Self> {
let Ok(Some(mut post)) = database::posts::get_post(post_id) else { let Ok(Some(mut post)) = db.get_post(post_id) else {
return Err(ResponseCode::BadRequest.text("Post does not exist")) return Err(ResponseCode::BadRequest.text("Post does not exist"))
}; };
let liked = database::likes::get_liked(self_id, post.post_id).unwrap_or(false); let liked = db.get_liked(self_id, post.post_id).unwrap_or(false);
post.liked = liked; post.liked = liked;
Ok(post) Ok(post)
} }
#[instrument()] #[instrument(skip(db))]
pub fn from_post_page(self_id: u64, page: u64) -> Result<Vec<Self>> { pub fn from_post_page(db: &Database, self_id: u64, page: u64) -> Result<Vec<Self>> {
let Ok(mut posts) = database::posts::get_post_page(page) else { let Ok(mut posts) = db.get_post_page(page) else {
return Err(ResponseCode::BadRequest.text("Failed to fetch posts")) return Err(ResponseCode::BadRequest.text("Failed to fetch posts"))
}; };
for post in &mut posts { for post in &mut posts {
let liked = database::likes::get_liked(self_id, post.post_id).unwrap_or(false); let liked = db.get_liked(self_id, post.post_id).unwrap_or(false);
post.liked = liked; post.liked = liked;
} }
Ok(posts) Ok(posts)
} }
#[instrument()] #[instrument(skip(db))]
pub fn from_user_post_page(self_id: u64, user_id: u64, page: u64) -> Result<Vec<Self>> { pub fn from_user_post_page(
let Ok(mut posts) = database::posts::get_users_post_page(user_id, page) else { db: &Database,
self_id: u64,
user_id: u64,
page: u64,
) -> Result<Vec<Self>> {
let Ok(mut posts) = db.get_users_post_page(user_id, page) else {
return Err(ResponseCode::BadRequest.text("Failed to fetch posts")) return Err(ResponseCode::BadRequest.text("Failed to fetch posts"))
}; };
for post in &mut posts { for post in &mut posts {
let liked = database::likes::get_liked(self_id, post.post_id).unwrap_or(false); let liked = db.get_liked(self_id, post.post_id).unwrap_or(false);
post.liked = liked; post.liked = liked;
} }
Ok(posts) Ok(posts)
} }
#[instrument()] #[instrument(skip(db))]
pub fn reterieve_all() -> Result<Vec<Self>> { pub fn reterieve_all(db: &Database) -> Result<Vec<Self>> {
let Ok(posts) = database::posts::get_all_posts() else { let Ok(posts) = db.get_all_posts() else {
return Err(ResponseCode::InternalServerError.text("Failed to fetch posts")) return Err(ResponseCode::InternalServerError.text("Failed to fetch posts"))
}; };
Ok(posts) Ok(posts)
} }
#[instrument()] #[instrument(skip(db))]
pub fn new(user_id: u64, content: String) -> Result<Self> { pub fn new(db: &Database, user_id: u64, content: String) -> Result<Self> {
let Ok(post) = database::posts::add_post(user_id, &content) else { let Ok(post) = db.add_post(user_id, &content) else {
tracing::error!("Failed to create post"); tracing::error!("Failed to create post");
return Err(ResponseCode::InternalServerError.text("Failed to create post")) return Err(ResponseCode::InternalServerError.text("Failed to create post"))
}; };

View file

@ -2,7 +2,7 @@ use rand::{distributions::Alphanumeric, Rng};
use serde::Serialize; use serde::Serialize;
use tracing::instrument; use tracing::instrument;
use crate::database; use crate::database::Database;
use crate::types::http::{ResponseCode, Result}; use crate::types::http::{ResponseCode, Result};
#[derive(Serialize)] #[derive(Serialize)]
@ -12,39 +12,39 @@ pub struct Session {
} }
impl Session { impl Session {
#[instrument()] #[instrument(skip(db))]
pub fn from_token(token: &str) -> Result<Self> { pub fn from_token(db: &Database, token: &str) -> Result<Self> {
let Ok(Some(session)) = database::sessions::get_session(token) else { let Ok(Some(session)) = db.get_session(token) else {
return Err(ResponseCode::BadRequest.text("Invalid auth token")); return Err(ResponseCode::BadRequest.text("Invalid auth token"));
}; };
Ok(session) Ok(session)
} }
#[instrument()] #[instrument(skip(db))]
pub fn reterieve_all() -> Result<Vec<Self>> { pub fn reterieve_all(db: &Database) -> Result<Vec<Self>> {
let Ok(sessions) = database::sessions::get_all_sessions() else { let Ok(sessions) = db.get_all_sessions() else {
return Err(ResponseCode::InternalServerError.text("Failed to fetch sessions")) return Err(ResponseCode::InternalServerError.text("Failed to fetch sessions"))
}; };
Ok(sessions) Ok(sessions)
} }
#[instrument()] #[instrument(skip(db))]
pub fn new(user_id: u64) -> Result<Self> { pub fn new(db: &Database, user_id: u64) -> Result<Self> {
let token: String = rand::thread_rng() let token: String = rand::thread_rng()
.sample_iter(&Alphanumeric) .sample_iter(&Alphanumeric)
.take(32) .take(32)
.map(char::from) .map(char::from)
.collect(); .collect();
match database::sessions::set_session(user_id, &token) { match db.set_session(user_id, &token) {
Err(_) => Err(ResponseCode::BadRequest.text("Failed to create session")), Err(_) => Err(ResponseCode::BadRequest.text("Failed to create session")),
Ok(_) => Ok(Self { user_id, token }), Ok(_) => Ok(Self { user_id, token }),
} }
} }
#[instrument()] #[instrument(skip(db))]
pub fn delete(user_id: u64) -> Result<()> { pub fn delete(db: &Database, user_id: u64) -> Result<()> {
if database::sessions::delete_session(user_id).is_err() { if db.delete_session(user_id).is_err() {
tracing::error!("Failed to logout user"); tracing::error!("Failed to logout user");
return Err(ResponseCode::InternalServerError.text("Failed to logout")); return Err(ResponseCode::InternalServerError.text("Failed to logout"));
}; };

View file

@ -2,7 +2,7 @@ use serde::{Deserialize, Serialize};
use tracing::instrument; use tracing::instrument;
use crate::api::RegistrationRequet; use crate::api::RegistrationRequet;
use crate::database; use crate::database::Database;
use crate::types::http::{ResponseCode, Result}; use crate::types::http::{ResponseCode, Result};
#[derive(Serialize, Deserialize, Debug)] #[derive(Serialize, Deserialize, Debug)]
@ -24,21 +24,21 @@ pub const FOLLOWING: u8 = 1;
pub const FOLLOWED: u8 = 2; pub const FOLLOWED: u8 = 2;
impl User { impl User {
#[instrument()] #[instrument(skip(db))]
pub fn from_user_id(user_id: u64, hide_password: bool) -> Result<Self> { pub fn from_user_id(db: &Database, user_id: u64, hide_password: bool) -> Result<Self> {
let Ok(Some(user)) = database::users::get_user_by_id(user_id, hide_password) else { let Ok(Some(user)) = db.get_user_by_id(user_id, hide_password) else {
return Err(ResponseCode::BadRequest.text("User does not exist")) return Err(ResponseCode::BadRequest.text("User does not exist"))
}; };
Ok(user) Ok(user)
} }
#[instrument()] #[instrument(skip(db))]
pub fn from_user_ids(user_ids: Vec<u64>) -> Vec<Self> { pub fn from_user_ids(db: &Database, user_ids: Vec<u64>) -> Vec<Self> {
user_ids user_ids
.iter() .iter()
.filter_map(|user_id| { .filter_map(|user_id| {
let Ok(Some(user)) = database::users::get_user_by_id(*user_id, true) else { let Ok(Some(user)) = db.get_user_by_id(*user_id, true) else {
return None; return None;
}; };
Some(user) Some(user)
@ -46,53 +46,53 @@ impl User {
.collect() .collect()
} }
#[instrument()] #[instrument(skip(db))]
pub fn from_user_page(page: u64) -> Result<Vec<Self>> { pub fn from_user_page(db: &Database, page: u64) -> Result<Vec<Self>> {
let Ok(users) = database::users::get_user_page(page, true) else { let Ok(users) = db.get_user_page(page, true) else {
return Err(ResponseCode::BadRequest.text("Failed to fetch users")) return Err(ResponseCode::BadRequest.text("Failed to fetch users"))
}; };
Ok(users) Ok(users)
} }
#[instrument()] #[instrument(skip(db))]
pub fn from_email(email: &str) -> Result<Self> { pub fn from_email(db: &Database, email: &str) -> Result<Self> {
let Ok(Some(user)) = database::users::get_user_by_email(email, false) else { let Ok(Some(user)) = db.get_user_by_email(email, false) else {
return Err(ResponseCode::BadRequest.text("User does not exist")) return Err(ResponseCode::BadRequest.text("User does not exist"))
}; };
Ok(user) Ok(user)
} }
#[instrument()] #[instrument(skip(db))]
pub fn from_password(password: &str) -> Result<Self> { pub fn from_password(db: &Database, password: &str) -> Result<Self> {
let Ok(Some(user)) = database::users::get_user_by_password(password, true) else { let Ok(Some(user)) = db.get_user_by_password(password, true) else {
return Err(ResponseCode::BadRequest.text("User does not exist")) return Err(ResponseCode::BadRequest.text("User does not exist"))
}; };
Ok(user) Ok(user)
} }
#[instrument()] #[instrument(skip(db))]
pub fn reterieve_all() -> Result<Vec<Self>> { pub fn reterieve_all(db: &Database) -> Result<Vec<Self>> {
let Ok(users) = database::users::get_all_users() else { let Ok(users) = db.get_all_users() else {
return Err(ResponseCode::InternalServerError.text("Failed to fetch users")) return Err(ResponseCode::InternalServerError.text("Failed to fetch users"))
}; };
Ok(users) Ok(users)
} }
#[instrument()] #[instrument(skip(db))]
pub fn new(request: RegistrationRequet) -> Result<Self> { pub fn new(db: &Database, request: RegistrationRequet) -> Result<Self> {
if Self::from_email(&request.email).is_ok() { if Self::from_email(db, &request.email).is_ok() {
return Err(ResponseCode::BadRequest return Err(ResponseCode::BadRequest
.text(&format!("Email is already in use by {}", &request.email))); .text(&format!("Email is already in use by {}", &request.email)));
} }
if let Ok(user) = Self::from_password(&request.password) { if let Ok(user) = Self::from_password(db, &request.password) {
return Err(ResponseCode::BadRequest return Err(ResponseCode::BadRequest
.text(&format!("Password is already in use by {}", user.email))); .text(&format!("Password is already in use by {}", user.email)));
} }
let Ok(user) = database::users::add_user(request) else { let Ok(user) = db.add_user(request) else {
tracing::error!("Failed to create new user"); tracing::error!("Failed to create new user");
return Err(ResponseCode::InternalServerError.text("Failed to create new uesr")) return Err(ResponseCode::InternalServerError.text("Failed to create new uesr"))
}; };
@ -100,8 +100,9 @@ impl User {
Ok(user) Ok(user)
} }
pub fn add_following(user_id_1: u64, user_id_2: u64) -> Result<()> { #[instrument(skip(db))]
let Ok(followed) = database::friends::set_following(user_id_1, user_id_2) else { pub fn add_following(db: &Database, user_id_1: u64, user_id_2: u64) -> Result<()> {
let Ok(followed) = db.set_following(user_id_1, user_id_2) else {
return Err(ResponseCode::BadRequest.text("Failed to add follow status")) return Err(ResponseCode::BadRequest.text("Failed to add follow status"))
}; };
@ -112,8 +113,9 @@ impl User {
Ok(()) Ok(())
} }
pub fn remove_following(user_id_1: u64, user_id_2: u64) -> Result<()> { #[instrument(skip(db))]
let Ok(followed) = database::friends::remove_following(user_id_1, user_id_2) else { pub fn remove_following(db: &Database, user_id_1: u64, user_id_2: u64) -> Result<()> {
let Ok(followed) = db.remove_following(user_id_1, user_id_2) else {
return Err(ResponseCode::BadRequest.text("Failed to remove follow status")) return Err(ResponseCode::BadRequest.text("Failed to remove follow status"))
}; };
@ -124,15 +126,17 @@ impl User {
Ok(()) Ok(())
} }
pub fn get_following(user_id_1: u64, user_id_2: u64) -> Result<u8> { #[instrument(skip(db))]
let Ok(followed) = database::friends::get_friend_status(user_id_1, user_id_2) else { pub fn get_following(db: &Database, user_id_1: u64, user_id_2: u64) -> Result<u8> {
let Ok(followed) = db.get_friend_status(user_id_1, user_id_2) else {
return Err(ResponseCode::InternalServerError.text("Failed to get follow status")) return Err(ResponseCode::InternalServerError.text("Failed to get follow status"))
}; };
Ok(followed) Ok(followed)
} }
pub fn get_friends(user_id: u64) -> Result<Vec<Self>> { #[instrument(skip(db))]
let Ok(users) = database::friends::get_friends(user_id) else { pub fn get_friends(db: &Database, user_id: u64) -> Result<Vec<Self>> {
let Ok(users) = db.get_friends(user_id) else {
return Err(ResponseCode::InternalServerError.text("Failed to fetch friends")) return Err(ResponseCode::InternalServerError.text("Failed to fetch friends"))
}; };
Ok(users) Ok(users)