make database calls 1 conn

This commit is contained in:
Tyler Murphy 2023-02-15 00:01:44 -05:00
parent 192be95f84
commit aec4fdecc1
19 changed files with 829 additions and 695 deletions

View file

@ -5,13 +5,12 @@ use serde::Deserialize;
use tower_cookies::{Cookie, Cookies}; use tower_cookies::{Cookie, Cookies};
use crate::{ use crate::{
database,
public::{ public::{
admin, admin,
docs::{EndpointDocumentation, EndpointMethod}, docs::{EndpointDocumentation, EndpointMethod},
}, },
types::{ types::{
extract::{AdminUser, Check, CheckResult, Json}, extract::{AdminUser, Check, CheckResult, Database, Json},
http::ResponseCode, http::ResponseCode,
}, },
}; };
@ -92,8 +91,8 @@ impl Check for QueryRequest {
} }
} }
async fn query(_: AdminUser, Json(body): Json<QueryRequest>) -> Response { async fn query(_: AdminUser, Database(db): Database, Json(body): Json<QueryRequest>) -> Response {
match database::query(body.query) { match db.query(body.query) {
Ok(changes) => ResponseCode::Success.text(&format!( Ok(changes) => ResponseCode::Success.text(&format!(
"Query executed successfully. {changes} lines changed." "Query executed successfully. {changes} lines changed."
)), )),
@ -114,8 +113,8 @@ pub const ADMIN_POSTS: EndpointDocumentation = EndpointDocumentation {
cookie: Some("admin"), cookie: Some("admin"),
}; };
async fn posts(_: AdminUser) -> Response { async fn posts(_: AdminUser, Database(db): Database) -> Response {
admin::generate_posts() admin::generate_posts(&db)
} }
pub const ADMIN_USERS: EndpointDocumentation = EndpointDocumentation { pub const ADMIN_USERS: EndpointDocumentation = EndpointDocumentation {
@ -131,8 +130,8 @@ pub const ADMIN_USERS: EndpointDocumentation = EndpointDocumentation {
cookie: Some("admin"), cookie: Some("admin"),
}; };
async fn users(_: AdminUser) -> Response { async fn users(_: AdminUser, Database(db): Database) -> Response {
admin::generate_users() admin::generate_users(&db)
} }
pub const ADMIN_SESSIONS: EndpointDocumentation = EndpointDocumentation { pub const ADMIN_SESSIONS: EndpointDocumentation = EndpointDocumentation {
@ -148,8 +147,8 @@ pub const ADMIN_SESSIONS: EndpointDocumentation = EndpointDocumentation {
cookie: Some("admin"), cookie: Some("admin"),
}; };
async fn sessions(_: AdminUser) -> Response { async fn sessions(_: AdminUser, Database(db): Database) -> Response {
admin::generate_sessions() admin::generate_sessions(&db)
} }
pub const ADMIN_COMMENTS: EndpointDocumentation = EndpointDocumentation { pub const ADMIN_COMMENTS: EndpointDocumentation = EndpointDocumentation {
@ -165,8 +164,8 @@ pub const ADMIN_COMMENTS: EndpointDocumentation = EndpointDocumentation {
cookie: Some("admin"), cookie: Some("admin"),
}; };
async fn comments(_: AdminUser) -> Response { async fn comments(_: AdminUser, Database(db): Database) -> Response {
admin::generate_comments() admin::generate_comments(&db)
} }
pub const ADMIN_LIKES: EndpointDocumentation = EndpointDocumentation { pub const ADMIN_LIKES: EndpointDocumentation = EndpointDocumentation {
@ -182,8 +181,8 @@ pub const ADMIN_LIKES: EndpointDocumentation = EndpointDocumentation {
cookie: Some("admin"), cookie: Some("admin"),
}; };
async fn likes(_: AdminUser) -> Response { async fn likes(_: AdminUser, Database(db): Database) -> Response {
admin::generate_likes() admin::generate_likes(&db)
} }
async fn check(check: Option<AdminUser>) -> Response { async fn check(check: Option<AdminUser>) -> Response {

View file

@ -6,7 +6,7 @@ use tower_cookies::{Cookie, Cookies};
use crate::{ use crate::{
public::docs::{EndpointDocumentation, EndpointMethod}, public::docs::{EndpointDocumentation, EndpointMethod},
types::{ types::{
extract::{AuthorizedUser, Check, CheckResult, Json, Log}, extract::{AuthorizedUser, Check, CheckResult, Database, Json, Log},
http::ResponseCode, http::ResponseCode,
session::Session, session::Session,
user::User, user::User,
@ -99,13 +99,17 @@ impl Check for RegistrationRequet {
} }
} }
async fn register(cookies: Cookies, Json(body): Json<RegistrationRequet>) -> Response { async fn register(
let user = match User::new(body) { cookies: Cookies,
Database(db): Database,
Json(body): Json<RegistrationRequet>,
) -> Response {
let user = match User::new(&db, body) {
Ok(user) => user, Ok(user) => user,
Err(err) => return err, Err(err) => return err,
}; };
let session = match Session::new(user.user_id) { let session = match Session::new(&db, user.user_id) {
Ok(session) => session, Ok(session) => session,
Err(err) => return err, Err(err) => return err,
}; };
@ -158,8 +162,12 @@ impl Check for LoginRequest {
} }
} }
async fn login(cookies: Cookies, Json(body): Json<LoginRequest>) -> Response { async fn login(
let Ok(user) = User::from_email(&body.email) else { cookies: Cookies,
Database(db): Database,
Json(body): Json<LoginRequest>,
) -> Response {
let Ok(user) = User::from_email(&db, &body.email) else {
return ResponseCode::BadRequest.text("Email is not registered") return ResponseCode::BadRequest.text("Email is not registered")
}; };
@ -167,7 +175,7 @@ async fn login(cookies: Cookies, Json(body): Json<LoginRequest>) -> Response {
return ResponseCode::BadRequest.text("Password is not correct"); return ResponseCode::BadRequest.text("Password is not correct");
} }
let session = match Session::new(user.user_id) { let session = match Session::new(&db, user.user_id) {
Ok(session) => session, Ok(session) => session,
Err(err) => return err, Err(err) => return err,
}; };
@ -199,10 +207,15 @@ pub const AUTH_LOGOUT: EndpointDocumentation = EndpointDocumentation {
cookie: None, cookie: None,
}; };
async fn logout(cookies: Cookies, AuthorizedUser(user): AuthorizedUser, _: Log) -> Response { async fn logout(
cookies: Cookies,
AuthorizedUser(user): AuthorizedUser,
Database(db): Database,
_: Log,
) -> Response {
cookies.remove(Cookie::new("auth", "")); cookies.remove(Cookie::new("auth", ""));
if let Err(err) = Session::delete(user.user_id) { if let Err(err) = Session::delete(&db, user.user_id) {
return err; return err;
} }

View file

@ -1,5 +1,15 @@
use crate::types::extract::RouterURI; use crate::{
use axum::{error_handling::HandleErrorLayer, BoxError, Extension, Router}; database,
types::extract::{DatabaseExtention, RouterURI},
};
use axum::{
error_handling::HandleErrorLayer,
http::Request,
middleware::{self, Next},
response::Response,
BoxError, Extension, Router,
};
use tokio::sync::Mutex;
use tower::ServiceBuilder; use tower::ServiceBuilder;
use tower_governor::{ use tower_governor::{
errors::display_error, governor::GovernorConfigBuilder, key_extractor::SmartIpKeyExtractor, errors::display_error, governor::GovernorConfigBuilder, key_extractor::SmartIpKeyExtractor,
@ -13,6 +23,18 @@ pub mod users;
pub use auth::RegistrationRequet; pub use auth::RegistrationRequet;
async fn connect<B>(mut req: Request<B>, next: Next<B>) -> Response
where
B: Send,
{
if let Ok(db) = database::Database::connect() {
let ex = DatabaseExtention(Mutex::new(db));
req.extensions_mut().insert(ex);
}
next.run(req).await
}
pub fn router() -> Router { pub fn router() -> Router {
let governor_conf = Box::new( let governor_conf = Box::new(
GovernorConfigBuilder::default() GovernorConfigBuilder::default()
@ -49,4 +71,5 @@ pub fn router() -> Router {
config: Box::leak(governor_conf), config: Box::leak(governor_conf),
}), }),
) )
.layer(middleware::from_fn(connect))
} }

View file

@ -9,7 +9,7 @@ use crate::{
public::docs::{EndpointDocumentation, EndpointMethod}, public::docs::{EndpointDocumentation, EndpointMethod},
types::{ types::{
comment::Comment, comment::Comment,
extract::{AuthorizedUser, Check, CheckResult, Json}, extract::{AuthorizedUser, Check, CheckResult, Database, Json},
http::ResponseCode, http::ResponseCode,
like::Like, like::Like,
post::Post, post::Post,
@ -55,9 +55,10 @@ impl Check for PostCreateRequest {
async fn create( async fn create(
AuthorizedUser(user): AuthorizedUser, AuthorizedUser(user): AuthorizedUser,
Database(db): Database,
Json(body): Json<PostCreateRequest>, Json(body): Json<PostCreateRequest>,
) -> Response { ) -> Response {
let Ok(post) = Post::new(user.user_id, body.content) else { let Ok(post) = Post::new(&db, user.user_id, body.content) else {
return ResponseCode::InternalServerError.text("Failed to create post") return ResponseCode::InternalServerError.text("Failed to create post")
}; };
@ -101,9 +102,10 @@ impl Check for PostPageRequest {
async fn page( async fn page(
AuthorizedUser(user): AuthorizedUser, AuthorizedUser(user): AuthorizedUser,
Database(db): Database,
Json(body): Json<PostPageRequest>, Json(body): Json<PostPageRequest>,
) -> Response { ) -> Response {
let Ok(posts) = Post::from_post_page(user.user_id, body.page) else { let Ok(posts) = Post::from_post_page(&db, user.user_id, body.page) else {
return ResponseCode::InternalServerError.text("Failed to fetch posts") return ResponseCode::InternalServerError.text("Failed to fetch posts")
}; };
@ -149,9 +151,10 @@ impl Check for CommentsPageRequest {
async fn comments( async fn comments(
AuthorizedUser(_user): AuthorizedUser, AuthorizedUser(_user): AuthorizedUser,
Database(db): Database,
Json(body): Json<CommentsPageRequest>, Json(body): Json<CommentsPageRequest>,
) -> Response { ) -> Response {
let Ok(comments) = Comment::from_comment_page(body.page, body.post_id) else { let Ok(comments) = Comment::from_comment_page(&db, body.page, body.post_id) else {
return ResponseCode::InternalServerError.text("Failed to fetch comments") return ResponseCode::InternalServerError.text("Failed to fetch comments")
}; };
@ -197,9 +200,10 @@ impl Check for UsersPostsRequest {
async fn user( async fn user(
AuthorizedUser(user): AuthorizedUser, AuthorizedUser(user): AuthorizedUser,
Database(db): Database,
Json(body): Json<UsersPostsRequest>, Json(body): Json<UsersPostsRequest>,
) -> Response { ) -> Response {
let Ok(posts) = Post::from_user_post_page(user.user_id, body.user_id, body.page) else { let Ok(posts) = Post::from_user_post_page(&db, user.user_id, body.user_id, body.page) else {
return ResponseCode::InternalServerError.text("Failed to fetch posts") return ResponseCode::InternalServerError.text("Failed to fetch posts")
}; };
@ -251,9 +255,10 @@ impl Check for PostCommentRequest {
async fn comment( async fn comment(
AuthorizedUser(user): AuthorizedUser, AuthorizedUser(user): AuthorizedUser,
Database(db): Database,
Json(body): Json<PostCommentRequest>, Json(body): Json<PostCommentRequest>,
) -> Response { ) -> Response {
if let Err(err) = Comment::new(user.user_id, body.post_id, &body.content) { if let Err(err) = Comment::new(&db, user.user_id, body.post_id, &body.content) {
return err; return err;
} }
@ -293,12 +298,16 @@ impl Check for PostLikeRequest {
} }
} }
async fn like(AuthorizedUser(user): AuthorizedUser, Json(body): Json<PostLikeRequest>) -> Response { async fn like(
AuthorizedUser(user): AuthorizedUser,
Database(db): Database,
Json(body): Json<PostLikeRequest>,
) -> Response {
if body.state { if body.state {
if let Err(err) = Like::add_liked(user.user_id, body.post_id) { if let Err(err) = Like::add_liked(&db, user.user_id, body.post_id) {
return err; return err;
} }
} else if let Err(err) = Like::remove_liked(user.user_id, body.post_id) { } else if let Err(err) = Like::remove_liked(&db, user.user_id, body.post_id) {
return err; return err;
} }

View file

@ -1,7 +1,7 @@
use crate::{ use crate::{
public::docs::{EndpointDocumentation, EndpointMethod}, public::docs::{EndpointDocumentation, EndpointMethod},
types::{ types::{
extract::{AuthorizedUser, Check, CheckResult, Json, Log, Png}, extract::{AuthorizedUser, Check, CheckResult, Database, Json, Log, Png},
http::ResponseCode, http::ResponseCode,
user::User, user::User,
}, },
@ -46,9 +46,10 @@ impl Check for UserLoadRequest {
async fn load_batch( async fn load_batch(
AuthorizedUser(_user): AuthorizedUser, AuthorizedUser(_user): AuthorizedUser,
Database(db): Database,
Json(body): Json<UserLoadRequest>, Json(body): Json<UserLoadRequest>,
) -> Response { ) -> Response {
let users = User::from_user_ids(body.ids); let users = User::from_user_ids(&db, body.ids);
let Ok(json) = serde_json::to_string(&users) else { let Ok(json) = serde_json::to_string(&users) else {
return ResponseCode::InternalServerError.text("Failed to fetch users") return ResponseCode::InternalServerError.text("Failed to fetch users")
}; };
@ -90,9 +91,10 @@ impl Check for UserPageReqiest {
async fn load_page( async fn load_page(
AuthorizedUser(_user): AuthorizedUser, AuthorizedUser(_user): AuthorizedUser,
Database(db): Database,
Json(body): Json<UserPageReqiest>, Json(body): Json<UserPageReqiest>,
) -> Response { ) -> Response {
let Ok(users) = User::from_user_page(body.page) else { let Ok(users) = User::from_user_page(&db, body.page) else {
return ResponseCode::InternalServerError.text("Failed to fetch users") return ResponseCode::InternalServerError.text("Failed to fetch users")
}; };
@ -207,17 +209,18 @@ impl Check for UserFollowRequest {
async fn follow( async fn follow(
AuthorizedUser(user): AuthorizedUser, AuthorizedUser(user): AuthorizedUser,
Database(db): Database,
Json(body): Json<UserFollowRequest>, Json(body): Json<UserFollowRequest>,
) -> Response { ) -> Response {
if body.state { if body.state {
if let Err(err) = User::add_following(user.user_id, body.user_id) { if let Err(err) = User::add_following(&db, user.user_id, body.user_id) {
return err; return err;
} }
} else if let Err(err) = User::remove_following(user.user_id, body.user_id) { } else if let Err(err) = User::remove_following(&db, user.user_id, body.user_id) {
return err; return err;
} }
match User::get_following(user.user_id, body.user_id) { match User::get_following(&db, user.user_id, body.user_id) {
Ok(status) => ResponseCode::Success.text(&format!("{status}")), Ok(status) => ResponseCode::Success.text(&format!("{status}")),
Err(err) => err, Err(err) => err,
} }
@ -259,9 +262,10 @@ impl Check for UserFollowStatusRequest {
async fn follow_status( async fn follow_status(
AuthorizedUser(user): AuthorizedUser, AuthorizedUser(user): AuthorizedUser,
Database(db): Database,
Json(body): Json<UserFollowStatusRequest>, Json(body): Json<UserFollowStatusRequest>,
) -> Response { ) -> Response {
match User::get_following(user.user_id, body.user_id) { match User::get_following(&db, user.user_id, body.user_id) {
Ok(status) => ResponseCode::Success.text(&format!("{status}")), Ok(status) => ResponseCode::Success.text(&format!("{status}")),
Err(err) => err, Err(err) => err,
} }
@ -297,8 +301,12 @@ impl Check for UserFriendsRequest {
} }
} }
async fn friends(AuthorizedUser(_user): AuthorizedUser, Json(body): Json<UserFriendsRequest>) -> Response { async fn friends(
let Ok(users) = User::get_friends(body.user_id) else { AuthorizedUser(_user): AuthorizedUser,
Database(db): Database,
Json(body): Json<UserFriendsRequest>,
) -> Response {
let Ok(users) = User::get_friends(&db, body.user_id) else {
return ResponseCode::InternalServerError.text("Failed to fetch user") return ResponseCode::InternalServerError.text("Failed to fetch user")
}; };

View file

@ -3,9 +3,12 @@ use std::time::{Duration, SystemTime, UNIX_EPOCH};
use rusqlite::Row; use rusqlite::Row;
use tracing::instrument; use tracing::instrument;
use crate::{database, types::comment::Comment}; use crate::types::comment::Comment;
pub fn init() -> Result<(), rusqlite::Error> { use super::Database;
impl Database {
pub fn init_comments(&self) -> Result<(), rusqlite::Error> {
let sql = " let sql = "
CREATE TABLE IF NOT EXISTS comments ( CREATE TABLE IF NOT EXISTS comments (
comment_id INTEGER PRIMARY KEY AUTOINCREMENT, comment_id INTEGER PRIMARY KEY AUTOINCREMENT,
@ -17,16 +20,15 @@ pub fn init() -> Result<(), rusqlite::Error> {
FOREIGN KEY(post_id) REFERENCES posts(post_id) FOREIGN KEY(post_id) REFERENCES posts(post_id)
); );
"; ";
let conn = database::connect()?; self.0.execute(sql, ())?;
conn.execute(sql, ())?;
let sql2 = "CREATE INDEX IF NOT EXISTS post_ids on comments (post_id);"; let sql2 = "CREATE INDEX IF NOT EXISTS post_ids on comments (post_id);";
conn.execute(sql2, ())?; self.0.execute(sql2, ())?;
Ok(()) Ok(())
} }
fn comment_from_row(row: &Row) -> Result<Comment, rusqlite::Error> { fn comment_from_row(row: &Row) -> Result<Comment, rusqlite::Error> {
let comment_id = row.get(0)?; let comment_id = row.get(0)?;
let user_id = row.get(1)?; let user_id = row.get(1)?;
let post_id = row.get(2)?; let post_id = row.get(2)?;
@ -40,37 +42,46 @@ fn comment_from_row(row: &Row) -> Result<Comment, rusqlite::Error> {
date, date,
content, content,
}) })
} }
#[instrument()] #[instrument(skip(self))]
pub fn get_comments_page(page: u64, post_id: u64) -> Result<Vec<Comment>, rusqlite::Error> { pub fn get_comments_page(
&self,
page: u64,
post_id: u64,
) -> Result<Vec<Comment>, rusqlite::Error> {
tracing::trace!("Retrieving comments page"); tracing::trace!("Retrieving comments page");
let page_size = 5; let page_size = 5;
let conn = database::connect()?; let mut stmt = self.0.prepare(
let mut stmt = conn.prepare(
"SELECT * FROM comments WHERE post_id = ? ORDER BY comment_id ASC LIMIT ? OFFSET ?", "SELECT * FROM comments WHERE post_id = ? ORDER BY comment_id ASC LIMIT ? OFFSET ?",
)?; )?;
let row = stmt.query_map([post_id, page_size, page_size * page], |row| { let row = stmt.query_map([post_id, page_size, page_size * page], |row| {
let row = comment_from_row(row)?; let row = Self::comment_from_row(row)?;
Ok(row) Ok(row)
})?; })?;
Ok(row.into_iter().flatten().collect()) Ok(row.into_iter().flatten().collect())
} }
#[instrument()] #[instrument(skip(self))]
pub fn get_all_comments() -> Result<Vec<Comment>, rusqlite::Error> { pub fn get_all_comments(&self) -> Result<Vec<Comment>, rusqlite::Error> {
tracing::trace!("Retrieving comments page"); tracing::trace!("Retrieving comments page");
let conn = database::connect()?; let mut stmt = self
let mut stmt = conn.prepare("SELECT * FROM comments ORDER BY comment_id DESC")?; .0
.prepare("SELECT * FROM comments ORDER BY comment_id DESC")?;
let row = stmt.query_map([], |row| { let row = stmt.query_map([], |row| {
let row = comment_from_row(row)?; let row = Self::comment_from_row(row)?;
Ok(row) Ok(row)
})?; })?;
Ok(row.into_iter().flatten().collect()) Ok(row.into_iter().flatten().collect())
} }
#[instrument()] #[instrument(skip(self))]
pub fn add_comment(user_id: u64, post_id: u64, content: &str) -> Result<Comment, rusqlite::Error> { pub fn add_comment(
&self,
user_id: u64,
post_id: u64,
content: &str,
) -> Result<Comment, rusqlite::Error> {
tracing::trace!("Adding comment"); tracing::trace!("Adding comment");
let date = u64::try_from( let date = u64::try_from(
SystemTime::now() SystemTime::now()
@ -79,13 +90,13 @@ pub fn add_comment(user_id: u64, post_id: u64, content: &str) -> Result<Comment,
.as_millis(), .as_millis(),
) )
.unwrap_or(0); .unwrap_or(0);
let conn = database::connect()?; let mut stmt = self.0.prepare(
let mut stmt = conn.prepare(
"INSERT INTO comments (user_id, post_id, date, content) VALUES(?,?,?,?) RETURNING *;", "INSERT INTO comments (user_id, post_id, date, content) VALUES(?,?,?,?) RETURNING *;",
)?; )?;
let post = stmt.query_row((user_id, post_id, date, content), |row| { let post = stmt.query_row((user_id, post_id, date, content), |row| {
let row = comment_from_row(row)?; let row = Self::comment_from_row(row)?;
Ok(row) Ok(row)
})?; })?;
Ok(post) Ok(post)
}
} }

View file

@ -1,11 +1,11 @@
use tracing::instrument; use tracing::instrument;
use crate::{ use crate::types::user::{User, FOLLOWED, FOLLOWING, NO_RELATION};
database::{self, users::user_from_row},
types::user::{User, FOLLOWED, FOLLOWING, NO_RELATION},
};
pub fn init() -> Result<(), rusqlite::Error> { use super::Database;
impl Database {
pub fn init_friends(&self) -> Result<(), rusqlite::Error> {
let sql = " let sql = "
CREATE TABLE IF NOT EXISTS friends ( CREATE TABLE IF NOT EXISTS friends (
follower_id INTEGER NOT NULL, follower_id INTEGER NOT NULL,
@ -15,16 +15,14 @@ pub fn init() -> Result<(), rusqlite::Error> {
PRIMARY KEY (follower_id, followee_id) PRIMARY KEY (follower_id, followee_id)
); );
"; ";
let conn = database::connect()?; self.0.execute(sql, ())?;
conn.execute(sql, ())?;
Ok(()) Ok(())
} }
#[instrument()] #[instrument(skip(self))]
pub fn get_friend_status(user_id_1: u64, user_id_2: u64) -> Result<u8, rusqlite::Error> { pub fn get_friend_status(&self, user_id_1: u64, user_id_2: u64) -> Result<u8, rusqlite::Error> {
tracing::trace!("Retrieving friend status"); tracing::trace!("Retrieving friend status");
let conn = database::connect()?; let mut stmt = self.0.prepare("SELECT * FROM friends WHERE (follower_id = ? AND followee_id = ?) OR (follower_id = ? AND followee_id = ?);")?;
let mut stmt = conn.prepare("SELECT * FROM friends WHERE (follower_id = ? AND followee_id = ?) OR (follower_id = ? AND followee_id = ?);")?;
let mut status = NO_RELATION; let mut status = NO_RELATION;
let rows: Vec<u64> = stmt let rows: Vec<u64> = stmt
.query_map([user_id_1, user_id_2, user_id_2, user_id_1], |row| { .query_map([user_id_1, user_id_2, user_id_2, user_id_1], |row| {
@ -46,13 +44,12 @@ pub fn get_friend_status(user_id_1: u64, user_id_2: u64) -> Result<u8, rusqlite:
} }
Ok(status) Ok(status)
} }
#[instrument()] #[instrument(skip(self))]
pub fn get_friends(user_id: u64) -> Result<Vec<User>, rusqlite::Error> { pub fn get_friends(&self, user_id: u64) -> Result<Vec<User>, rusqlite::Error> {
tracing::trace!("Retrieving friends"); tracing::trace!("Retrieving friends");
let conn = database::connect()?; let mut stmt = self.0.prepare(
let mut stmt = conn.prepare(
" "
SELECT * SELECT *
FROM users u FROM users u
@ -71,27 +68,33 @@ pub fn get_friends(user_id: u64) -> Result<Vec<User>, rusqlite::Error> {
", ",
)?; )?;
let row = stmt.query_map([user_id, user_id], |row| { let row = stmt.query_map([user_id, user_id], |row| {
let row = user_from_row(row, true)?; let row = Self::user_from_row(row, true)?;
Ok(row) Ok(row)
})?; })?;
Ok(row.into_iter().flatten().collect()) Ok(row.into_iter().flatten().collect())
} }
#[instrument()] #[instrument(skip(self))]
pub fn set_following(user_id_1: u64, user_id_2: u64) -> Result<bool, rusqlite::Error> { pub fn set_following(&self, user_id_1: u64, user_id_2: u64) -> Result<bool, rusqlite::Error> {
tracing::trace!("Setting following"); tracing::trace!("Setting following");
let conn = database::connect()?; let mut stmt = self
let mut stmt = .0
conn.prepare("INSERT OR REPLACE INTO friends (follower_id, followee_id) VALUES (?,?)")?; .prepare("INSERT OR REPLACE INTO friends (follower_id, followee_id) VALUES (?,?)")?;
let changes = stmt.execute([user_id_1, user_id_2])?; let changes = stmt.execute([user_id_1, user_id_2])?;
Ok(changes == 1) Ok(changes == 1)
} }
#[instrument()] #[instrument(skip(self))]
pub fn remove_following(user_id_1: u64, user_id_2: u64) -> Result<bool, rusqlite::Error> { pub fn remove_following(
&self,
user_id_1: u64,
user_id_2: u64,
) -> Result<bool, rusqlite::Error> {
tracing::trace!("Removing following"); tracing::trace!("Removing following");
let conn = database::connect()?; let mut stmt = self
let mut stmt = conn.prepare("DELETE FROM friends WHERE follower_id = ? AND followee_id = ?")?; .0
.prepare("DELETE FROM friends WHERE follower_id = ? AND followee_id = ?")?;
let changes = stmt.execute([user_id_1, user_id_2])?; let changes = stmt.execute([user_id_1, user_id_2])?;
Ok(changes == 1) Ok(changes == 1)
}
} }

View file

@ -1,9 +1,12 @@
use rusqlite::OptionalExtension; use rusqlite::OptionalExtension;
use tracing::instrument; use tracing::instrument;
use crate::{database, types::like::Like}; use crate::types::like::Like;
pub fn init() -> Result<(), rusqlite::Error> { use super::Database;
impl Database {
pub fn init_likes(&self) -> Result<(), rusqlite::Error> {
let sql = " let sql = "
CREATE TABLE IF NOT EXISTS likes ( CREATE TABLE IF NOT EXISTS likes (
user_id INTEGER NOT NULL, user_id INTEGER NOT NULL,
@ -13,16 +16,16 @@ pub fn init() -> Result<(), rusqlite::Error> {
PRIMARY KEY (user_id, post_id) PRIMARY KEY (user_id, post_id)
); );
"; ";
let conn = database::connect()?; self.0.execute(sql, ())?;
conn.execute(sql, ())?;
Ok(()) Ok(())
} }
#[instrument()] #[instrument(skip(self))]
pub fn get_like_count(post_id: u64) -> Result<Option<u64>, rusqlite::Error> { pub fn get_like_count(&self, post_id: u64) -> Result<Option<u64>, rusqlite::Error> {
tracing::trace!("Retrieving like count"); tracing::trace!("Retrieving like count");
let conn = database::connect()?; let mut stmt = self
let mut stmt = conn.prepare("SELECT COUNT(post_id) FROM likes WHERE post_id = ?")?; .0
.prepare("SELECT COUNT(post_id) FROM likes WHERE post_id = ?")?;
let row = stmt let row = stmt
.query_row([post_id], |row| { .query_row([post_id], |row| {
let row = row.get(0)?; let row = row.get(0)?;
@ -30,40 +33,42 @@ pub fn get_like_count(post_id: u64) -> Result<Option<u64>, rusqlite::Error> {
}) })
.optional()?; .optional()?;
Ok(row) Ok(row)
} }
#[instrument()] #[instrument(skip(self))]
pub fn get_liked(user_id: u64, post_id: u64) -> Result<bool, rusqlite::Error> { pub fn get_liked(&self, user_id: u64, post_id: u64) -> Result<bool, rusqlite::Error> {
tracing::trace!("Retrieving if liked"); tracing::trace!("Retrieving if liked");
let conn = database::connect()?; let mut stmt = self
let mut stmt = conn.prepare("SELECT * FROM likes WHERE user_id = ? AND post_id = ?")?; .0
.prepare("SELECT * FROM likes WHERE user_id = ? AND post_id = ?")?;
let liked = stmt.query_row([user_id, post_id], |_| Ok(())).optional()?; let liked = stmt.query_row([user_id, post_id], |_| Ok(())).optional()?;
Ok(liked.is_some()) Ok(liked.is_some())
} }
#[instrument()] #[instrument(skip(self))]
pub fn add_liked(user_id: u64, post_id: u64) -> Result<bool, rusqlite::Error> { pub fn add_liked(&self, user_id: u64, post_id: u64) -> Result<bool, rusqlite::Error> {
tracing::trace!("Adding like"); tracing::trace!("Adding like");
let conn = database::connect()?; let mut stmt = self
let mut stmt = conn.prepare("INSERT OR REPLACE INTO likes (user_id, post_id) VALUES (?,?)")?; .0
.prepare("INSERT OR REPLACE INTO likes (user_id, post_id) VALUES (?,?)")?;
let changes = stmt.execute([user_id, post_id])?; let changes = stmt.execute([user_id, post_id])?;
Ok(changes == 1) Ok(changes == 1)
} }
#[instrument()] #[instrument(skip(self))]
pub fn remove_liked(user_id: u64, post_id: u64) -> Result<bool, rusqlite::Error> { pub fn remove_liked(&self, user_id: u64, post_id: u64) -> Result<bool, rusqlite::Error> {
tracing::trace!("Removing like"); tracing::trace!("Removing like");
let conn = database::connect()?; let mut stmt = self
let mut stmt = conn.prepare("DELETE FROM likes WHERE user_id = ? AND post_id = ?;")?; .0
.prepare("DELETE FROM likes WHERE user_id = ? AND post_id = ?;")?;
let changes = stmt.execute((user_id, post_id))?; let changes = stmt.execute((user_id, post_id))?;
Ok(changes == 1) Ok(changes == 1)
} }
#[instrument()] #[instrument(skip(self))]
pub fn get_all_likes() -> Result<Vec<Like>, rusqlite::Error> { pub fn get_all_likes(&self) -> Result<Vec<Like>, rusqlite::Error> {
tracing::trace!("Retrieving comments page"); tracing::trace!("Retrieving comments page");
let conn = database::connect()?; let mut stmt = self.0.prepare("SELECT * FROM likes")?;
let mut stmt = conn.prepare("SELECT * FROM likes")?;
let row = stmt.query_map([], |row| { let row = stmt.query_map([], |row| {
let like = Like { let like = Like {
user_id: row.get(0)?, user_id: row.get(0)?,
@ -72,4 +77,5 @@ pub fn get_all_likes() -> Result<Vec<Like>, rusqlite::Error> {
Ok(like) Ok(like)
})?; })?;
Ok(row.into_iter().flatten().collect()) Ok(row.into_iter().flatten().collect())
}
} }

View file

@ -1,3 +1,4 @@
use rusqlite::Connection;
use tracing::instrument; use tracing::instrument;
pub mod comments; pub mod comments;
@ -7,23 +8,29 @@ pub mod posts;
pub mod sessions; pub mod sessions;
pub mod users; pub mod users;
pub fn connect() -> Result<rusqlite::Connection, rusqlite::Error> { #[derive(Debug)]
rusqlite::Connection::open("xssbook.db") pub struct Database(Connection);
impl Database {
pub fn connect() -> Result<Self, rusqlite::Error> {
let conn = rusqlite::Connection::open("xssbook.db")?;
Ok(Self(conn))
}
#[instrument(skip(self))]
pub fn query(&self, query: String) -> Result<usize, rusqlite::Error> {
tracing::trace!("Running custom query");
self.0.execute(&query, [])
}
} }
pub fn init() -> Result<(), rusqlite::Error> { pub fn init() -> Result<(), rusqlite::Error> {
users::init()?; let db = Database::connect()?;
posts::init()?; db.init_users()?;
sessions::init()?; db.init_posts()?;
likes::init()?; db.init_sessions()?;
comments::init()?; db.init_likes()?;
friends::init()?; db.init_comments()?;
db.init_friends()?;
Ok(()) Ok(())
} }
#[instrument()]
pub fn query(query: String) -> Result<usize, rusqlite::Error> {
tracing::trace!("Running custom query");
let conn = connect()?;
conn.execute(&query, [])
}

View file

@ -3,12 +3,12 @@ use std::time::{Duration, SystemTime, UNIX_EPOCH};
use rusqlite::{OptionalExtension, Row}; use rusqlite::{OptionalExtension, Row};
use tracing::instrument; use tracing::instrument;
use crate::database;
use crate::types::post::Post; use crate::types::post::Post;
use super::{comments, likes}; use super::Database;
pub fn init() -> Result<(), rusqlite::Error> { impl Database {
pub fn init_posts(&self) -> Result<(), rusqlite::Error> {
let sql = " let sql = "
CREATE TABLE IF NOT EXISTS posts ( CREATE TABLE IF NOT EXISTS posts (
post_id INTEGER PRIMARY KEY AUTOINCREMENT, post_id INTEGER PRIMARY KEY AUTOINCREMENT,
@ -18,19 +18,20 @@ pub fn init() -> Result<(), rusqlite::Error> {
FOREIGN KEY(user_id) REFERENCES users(user_id) FOREIGN KEY(user_id) REFERENCES users(user_id)
); );
"; ";
let conn = database::connect()?; self.0.execute(sql, ())?;
conn.execute(sql, ())?;
Ok(()) Ok(())
} }
fn post_from_row(row: &Row) -> Result<Post, rusqlite::Error> { fn post_from_row(&self, row: &Row) -> Result<Post, rusqlite::Error> {
let post_id = row.get(0)?; let post_id = row.get(0)?;
let user_id = row.get(1)?; let user_id = row.get(1)?;
let content = row.get(2)?; let content = row.get(2)?;
let date = row.get(3)?; let date = row.get(3)?;
let comments = comments::get_comments_page(0, post_id).unwrap_or_else(|_| Vec::new()); let comments = self
let likes = likes::get_like_count(post_id).unwrap_or(None).unwrap_or(0); .get_comments_page(0, post_id)
.unwrap_or_else(|_| Vec::new());
let likes = self.get_like_count(post_id).unwrap_or(None).unwrap_or(0);
Ok(Post { Ok(Post {
post_id, post_id,
@ -41,63 +42,68 @@ fn post_from_row(row: &Row) -> Result<Post, rusqlite::Error> {
liked: false, liked: false,
comments, comments,
}) })
} }
#[instrument()] #[instrument(skip(self))]
pub fn get_post(post_id: u64) -> Result<Option<Post>, rusqlite::Error> { pub fn get_post(&self, post_id: u64) -> Result<Option<Post>, rusqlite::Error> {
tracing::trace!("Retrieving post"); tracing::trace!("Retrieving post");
let conn = database::connect()?; let mut stmt = self.0.prepare("SELECT * FROM posts WHERE post_id = ?")?;
let mut stmt = conn.prepare("SELECT * FROM posts WHERE post_id = ?")?;
let row = stmt let row = stmt
.query_row([post_id], |row| { .query_row([post_id], |row| {
let row = post_from_row(row)?; let row = self.post_from_row(row)?;
Ok(row) Ok(row)
}) })
.optional()?; .optional()?;
Ok(row) Ok(row)
} }
#[instrument()] #[instrument(skip(self))]
pub fn get_post_page(page: u64) -> Result<Vec<Post>, rusqlite::Error> { pub fn get_post_page(&self, page: u64) -> Result<Vec<Post>, rusqlite::Error> {
tracing::trace!("Retrieving posts page"); tracing::trace!("Retrieving posts page");
let page_size = 10; let page_size = 10;
let conn = database::connect()?; let mut stmt = self
let mut stmt = conn.prepare("SELECT * FROM posts ORDER BY post_id DESC LIMIT ? OFFSET ?")?; .0
.prepare("SELECT * FROM posts ORDER BY post_id DESC LIMIT ? OFFSET ?")?;
let row = stmt.query_map([page_size, page_size * page], |row| { let row = stmt.query_map([page_size, page_size * page], |row| {
let row = post_from_row(row)?; let row = self.post_from_row(row)?;
Ok(row) Ok(row)
})?; })?;
Ok(row.into_iter().flatten().collect()) Ok(row.into_iter().flatten().collect())
} }
#[instrument()] #[instrument(skip(self))]
pub fn get_all_posts() -> Result<Vec<Post>, rusqlite::Error> { pub fn get_all_posts(&self) -> Result<Vec<Post>, rusqlite::Error> {
tracing::trace!("Retrieving posts page"); tracing::trace!("Retrieving posts page");
let conn = database::connect()?; let mut stmt = self
let mut stmt = conn.prepare("SELECT * FROM posts ORDER BY post_id DESC")?; .0
.prepare("SELECT * FROM posts ORDER BY post_id DESC")?;
let row = stmt.query_map([], |row| { let row = stmt.query_map([], |row| {
let row = post_from_row(row)?; let row = self.post_from_row(row)?;
Ok(row) Ok(row)
})?; })?;
Ok(row.into_iter().flatten().collect()) Ok(row.into_iter().flatten().collect())
} }
#[instrument()] #[instrument(skip(self))]
pub fn get_users_post_page(user_id: u64, page: u64) -> Result<Vec<Post>, rusqlite::Error> { pub fn get_users_post_page(
&self,
user_id: u64,
page: u64,
) -> Result<Vec<Post>, rusqlite::Error> {
tracing::trace!("Retrieving users posts"); tracing::trace!("Retrieving users posts");
let page_size = 10; let page_size = 10;
let conn = database::connect()?; let mut stmt = self.0.prepare(
let mut stmt = conn "SELECT * FROM posts WHERE user_id = ? ORDER BY post_id DESC LIMIT ? OFFSET ?",
.prepare("SELECT * FROM posts WHERE user_id = ? ORDER BY post_id DESC LIMIT ? OFFSET ?")?; )?;
let row = stmt.query_map([user_id, page_size, page_size * page], |row| { let row = stmt.query_map([user_id, page_size, page_size * page], |row| {
let row = post_from_row(row)?; let row = self.post_from_row(row)?;
Ok(row) Ok(row)
})?; })?;
Ok(row.into_iter().flatten().collect()) Ok(row.into_iter().flatten().collect())
} }
#[instrument()] #[instrument(skip(self))]
pub fn add_post(user_id: u64, content: &str) -> Result<Post, rusqlite::Error> { pub fn add_post(&self, user_id: u64, content: &str) -> Result<Post, rusqlite::Error> {
tracing::trace!("Adding post"); tracing::trace!("Adding post");
let date = u64::try_from( let date = u64::try_from(
SystemTime::now() SystemTime::now()
@ -106,12 +112,13 @@ pub fn add_post(user_id: u64, content: &str) -> Result<Post, rusqlite::Error> {
.as_millis(), .as_millis(),
) )
.unwrap_or(0); .unwrap_or(0);
let conn = database::connect()?; let mut stmt = self
let mut stmt = .0
conn.prepare("INSERT INTO posts (user_id, content, date) VALUES(?,?,?) RETURNING *;")?; .prepare("INSERT INTO posts (user_id, content, date) VALUES(?,?,?) RETURNING *;")?;
let post = stmt.query_row((user_id, content, date), |row| { let post = stmt.query_row((user_id, content, date), |row| {
let row = post_from_row(row)?; let row = self.post_from_row(row)?;
Ok(row) Ok(row)
})?; })?;
Ok(post) Ok(post)
}
} }

View file

@ -1,9 +1,12 @@
use rusqlite::OptionalExtension; use rusqlite::OptionalExtension;
use tracing::instrument; use tracing::instrument;
use crate::{database, types::session::Session}; use crate::types::session::Session;
pub fn init() -> Result<(), rusqlite::Error> { use super::Database;
impl Database {
pub fn init_sessions(&self) -> Result<(), rusqlite::Error> {
let sql = " let sql = "
CREATE TABLE IF NOT EXISTS sessions ( CREATE TABLE IF NOT EXISTS sessions (
user_id INTEGER PRIMARY KEY NOT NULL, user_id INTEGER PRIMARY KEY NOT NULL,
@ -11,16 +14,14 @@ pub fn init() -> Result<(), rusqlite::Error> {
FOREIGN KEY(user_id) REFERENCES users(user_id) FOREIGN KEY(user_id) REFERENCES users(user_id)
); );
"; ";
let conn = database::connect()?; self.0.execute(sql, ())?;
conn.execute(sql, ())?;
Ok(()) Ok(())
} }
#[instrument()] #[instrument(skip(self))]
pub fn get_session(token: &str) -> Result<Option<Session>, rusqlite::Error> { pub fn get_session(&self, token: &str) -> Result<Option<Session>, rusqlite::Error> {
tracing::trace!("Retrieving session"); tracing::trace!("Retrieving session");
let conn = database::connect()?; let mut stmt = self.0.prepare("SELECT * FROM sessions WHERE token = ?")?;
let mut stmt = conn.prepare("SELECT * FROM sessions WHERE token = ?")?;
let row = stmt let row = stmt
.query_row([token], |row| { .query_row([token], |row| {
Ok(Session { Ok(Session {
@ -30,13 +31,12 @@ pub fn get_session(token: &str) -> Result<Option<Session>, rusqlite::Error> {
}) })
.optional()?; .optional()?;
Ok(row) Ok(row)
} }
#[instrument()] #[instrument(skip(self))]
pub fn get_all_sessions() -> Result<Vec<Session>, rusqlite::Error> { pub fn get_all_sessions(&self) -> Result<Vec<Session>, rusqlite::Error> {
tracing::trace!("Retrieving session"); tracing::trace!("Retrieving session");
let conn = database::connect()?; let mut stmt = self.0.prepare("SELECT * FROM sessions")?;
let mut stmt = conn.prepare("SELECT * FROM sessions")?;
let row = stmt.query_map([], |row| { let row = stmt.query_map([], |row| {
Ok(Session { Ok(Session {
user_id: row.get(0)?, user_id: row.get(0)?,
@ -44,22 +44,21 @@ pub fn get_all_sessions() -> Result<Vec<Session>, rusqlite::Error> {
}) })
})?; })?;
Ok(row.into_iter().flatten().collect()) Ok(row.into_iter().flatten().collect())
} }
#[instrument()] #[instrument(skip(self))]
pub fn set_session(user_id: u64, token: &str) -> Result<(), Box<dyn std::error::Error>> { pub fn set_session(&self, user_id: u64, token: &str) -> Result<(), Box<dyn std::error::Error>> {
tracing::trace!("Setting new session"); tracing::trace!("Setting new session");
let conn = database::connect()?;
let sql = "INSERT OR REPLACE INTO sessions (user_id, token) VALUES (?, ?);"; let sql = "INSERT OR REPLACE INTO sessions (user_id, token) VALUES (?, ?);";
conn.execute(sql, (user_id, token))?; self.0.execute(sql, (user_id, token))?;
Ok(()) Ok(())
} }
#[instrument()] #[instrument(skip(self))]
pub fn delete_session(user_id: u64) -> Result<(), Box<dyn std::error::Error>> { pub fn delete_session(&self, user_id: u64) -> Result<(), Box<dyn std::error::Error>> {
tracing::trace!("Deleting session"); tracing::trace!("Deleting session");
let conn = database::connect()?;
let sql = "DELETE FROM sessions WHERE user_id = ?;"; let sql = "DELETE FROM sessions WHERE user_id = ?;";
conn.execute(sql, [user_id])?; self.0.execute(sql, [user_id])?;
Ok(()) Ok(())
}
} }

View file

@ -2,9 +2,12 @@ use rusqlite::{OptionalExtension, Row};
use std::time::{Duration, SystemTime, UNIX_EPOCH}; use std::time::{Duration, SystemTime, UNIX_EPOCH};
use tracing::instrument; use tracing::instrument;
use crate::{api::RegistrationRequet, database, types::user::User}; use crate::{api::RegistrationRequet, types::user::User};
pub fn init() -> Result<(), rusqlite::Error> { use super::Database;
impl Database {
pub fn init_users(&self) -> Result<(), rusqlite::Error> {
let sql = " let sql = "
CREATE TABLE IF NOT EXISTS users ( CREATE TABLE IF NOT EXISTS users (
user_id INTEGER PRIMARY KEY AUTOINCREMENT, user_id INTEGER PRIMARY KEY AUTOINCREMENT,
@ -19,19 +22,18 @@ pub fn init() -> Result<(), rusqlite::Error> {
year INTEGER NOT NULL year INTEGER NOT NULL
); );
"; ";
let conn = database::connect()?; self.0.execute(sql, ())?;
conn.execute(sql, ())?;
let sql2 = "CREATE UNIQUE INDEX IF NOT EXISTS emails on users (email);"; let sql2 = "CREATE UNIQUE INDEX IF NOT EXISTS emails on users (email);";
conn.execute(sql2, ())?; self.0.execute(sql2, ())?;
let sql3 = "CREATE UNIQUE INDEX IF NOT EXISTS passwords on users (password);"; let sql3 = "CREATE UNIQUE INDEX IF NOT EXISTS passwords on users (password);";
conn.execute(sql3, ())?; self.0.execute(sql3, ())?;
Ok(()) Ok(())
} }
pub fn user_from_row(row: &Row, hide_password: bool) -> Result<User, rusqlite::Error> { pub fn user_from_row(row: &Row, hide_password: bool) -> Result<User, rusqlite::Error> {
let user_id = row.get(0)?; let user_id = row.get(0)?;
let firstname = row.get(1)?; let firstname = row.get(1)?;
let lastname = row.get(2)?; let lastname = row.get(2)?;
@ -61,83 +63,92 @@ pub fn user_from_row(row: &Row, hide_password: bool) -> Result<User, rusqlite::E
month, month,
year, year,
}) })
} }
#[instrument()] #[instrument(skip(self))]
pub fn get_user_by_id(user_id: u64, hide_password: bool) -> Result<Option<User>, rusqlite::Error> { pub fn get_user_by_id(
&self,
user_id: u64,
hide_password: bool,
) -> Result<Option<User>, rusqlite::Error> {
tracing::trace!("Retrieving user by id"); tracing::trace!("Retrieving user by id");
let conn = database::connect()?; let mut stmt = self.0.prepare("SELECT * FROM users WHERE user_id = ?")?;
let mut stmt = conn.prepare("SELECT * FROM users WHERE user_id = ?")?;
let row = stmt let row = stmt
.query_row([user_id], |row| { .query_row([user_id], |row| {
let row = user_from_row(row, hide_password)?; let row = Self::user_from_row(row, hide_password)?;
Ok(row) Ok(row)
}) })
.optional()?; .optional()?;
Ok(row) Ok(row)
} }
#[instrument()] #[instrument(skip(self))]
pub fn get_user_by_email( pub fn get_user_by_email(
&self,
email: &str, email: &str,
hide_password: bool, hide_password: bool,
) -> Result<Option<User>, rusqlite::Error> { ) -> Result<Option<User>, rusqlite::Error> {
tracing::trace!("Retrieving user by email"); tracing::trace!("Retrieving user by email");
let conn = database::connect()?; let mut stmt = self.0.prepare("SELECT * FROM users WHERE email = ?")?;
let mut stmt = conn.prepare("SELECT * FROM users WHERE email = ?")?;
let row = stmt let row = stmt
.query_row([email], |row| { .query_row([email], |row| {
let row = user_from_row(row, hide_password)?; let row = Self::user_from_row(row, hide_password)?;
Ok(row) Ok(row)
}) })
.optional()?; .optional()?;
Ok(row) Ok(row)
} }
#[instrument()] #[instrument(skip(self))]
pub fn get_user_by_password( pub fn get_user_by_password(
&self,
password: &str, password: &str,
hide_password: bool, hide_password: bool,
) -> Result<Option<User>, rusqlite::Error> { ) -> Result<Option<User>, rusqlite::Error> {
tracing::trace!("Retrieving user by password"); tracing::trace!("Retrieving user by password");
let conn = database::connect()?; let mut stmt = self.0.prepare("SELECT * FROM users WHERE password = ?")?;
let mut stmt = conn.prepare("SELECT * FROM users WHERE password = ?")?;
let row = stmt let row = stmt
.query_row([password], |row| { .query_row([password], |row| {
let row = user_from_row(row, hide_password)?; let row = Self::user_from_row(row, hide_password)?;
Ok(row) Ok(row)
}) })
.optional()?; .optional()?;
Ok(row) Ok(row)
} }
#[instrument()] #[instrument(skip(self))]
pub fn get_user_page(page: u64, hide_password: bool) -> Result<Vec<User>, rusqlite::Error> { pub fn get_user_page(
&self,
page: u64,
hide_password: bool,
) -> Result<Vec<User>, rusqlite::Error> {
tracing::trace!("Retrieving user page"); tracing::trace!("Retrieving user page");
let page_size = 5; let page_size = 5;
let conn = database::connect()?; let mut stmt = self
let mut stmt = conn.prepare("SELECT * FROM users ORDER BY user_id DESC LIMIT ? OFFSET ?")?; .0
.prepare("SELECT * FROM users ORDER BY user_id DESC LIMIT ? OFFSET ?")?;
let row = stmt.query_map([page_size, page_size * page], |row| { let row = stmt.query_map([page_size, page_size * page], |row| {
let row = user_from_row(row, hide_password)?; let row = Self::user_from_row(row, hide_password)?;
Ok(row) Ok(row)
})?; })?;
Ok(row.into_iter().flatten().collect()) Ok(row.into_iter().flatten().collect())
} }
#[instrument()] #[instrument(skip(self))]
pub fn get_all_users() -> Result<Vec<User>, rusqlite::Error> { pub fn get_all_users(&self) -> Result<Vec<User>, rusqlite::Error> {
tracing::trace!("Retrieving user page"); tracing::trace!("Retrieving user page");
let conn = database::connect()?; let mut stmt = self
let mut stmt = conn.prepare("SELECT * FROM users ORDER BY user_id DESC")?; .0
.prepare("SELECT * FROM users ORDER BY user_id DESC")?;
let row = stmt.query_map([], |row| { let row = stmt.query_map([], |row| {
let row = user_from_row(row, false)?; let row = Self::user_from_row(row, false)?;
Ok(row) Ok(row)
})?; })?;
Ok(row.into_iter().flatten().collect()) Ok(row.into_iter().flatten().collect())
} }
#[instrument()] #[instrument(skip(self))]
pub fn add_user(request: RegistrationRequet) -> Result<User, rusqlite::Error> { pub fn add_user(&self, request: RegistrationRequet) -> Result<User, rusqlite::Error> {
tracing::trace!("Adding new user"); tracing::trace!("Adding new user");
let date = u64::try_from( let date = u64::try_from(
SystemTime::now() SystemTime::now()
@ -147,8 +158,7 @@ pub fn add_user(request: RegistrationRequet) -> Result<User, rusqlite::Error> {
) )
.unwrap_or(0); .unwrap_or(0);
let conn = database::connect()?; let mut stmt = self.0.prepare("INSERT INTO users (firstname, lastname, email, password, gender, date, day, month, year) VALUES(?,?,?,?,?,?,?,?,?) RETURNING *;")?;
let mut stmt = conn.prepare("INSERT INTO users (firstname, lastname, email, password, gender, date, day, month, year) VALUES(?,?,?,?,?,?,?,?,?) RETURNING *;")?;
let user = stmt.query_row( let user = stmt.query_row(
( (
request.firstname, request.firstname,
@ -162,9 +172,10 @@ pub fn add_user(request: RegistrationRequet) -> Result<User, rusqlite::Error> {
request.year, request.year,
), ),
|row| { |row| {
let row = user_from_row(row, false)?; let row = Self::user_from_row(row, false)?;
Ok(row) Ok(row)
}, },
)?; )?;
Ok(user) Ok(user)
}
} }

View file

@ -5,6 +5,7 @@ use tokio::sync::Mutex;
use crate::{ use crate::{
console::sanatize, console::sanatize,
database::Database,
types::{ types::{
comment::Comment, http::ResponseCode, like::Like, post::Post, session::Session, user::User, comment::Comment, http::ResponseCode, like::Like, post::Post, session::Session, user::User,
}, },
@ -36,8 +37,8 @@ pub async fn regen_secret() -> String {
secret.clone() secret.clone()
} }
pub fn generate_users() -> Response { pub fn generate_users(db: &Database) -> Response {
let users = match User::reterieve_all() { let users = match User::reterieve_all(db) {
Ok(users) => users, Ok(users) => users,
Err(err) => return err, Err(err) => return err,
}; };
@ -70,8 +71,8 @@ pub fn generate_users() -> Response {
ResponseCode::Success.text(&html) ResponseCode::Success.text(&html)
} }
pub fn generate_posts() -> Response { pub fn generate_posts(db: &Database) -> Response {
let posts = match Post::reterieve_all() { let posts = match Post::reterieve_all(db) {
Ok(posts) => posts, Ok(posts) => posts,
Err(err) => return err, Err(err) => return err,
}; };
@ -99,8 +100,8 @@ pub fn generate_posts() -> Response {
ResponseCode::Success.text(&html) ResponseCode::Success.text(&html)
} }
pub fn generate_sessions() -> Response { pub fn generate_sessions(db: &Database) -> Response {
let sessions = match Session::reterieve_all() { let sessions = match Session::reterieve_all(db) {
Ok(sessions) => sessions, Ok(sessions) => sessions,
Err(err) => return err, Err(err) => return err,
}; };
@ -123,8 +124,8 @@ pub fn generate_sessions() -> Response {
ResponseCode::Success.text(&html) ResponseCode::Success.text(&html)
} }
pub fn generate_comments() -> Response { pub fn generate_comments(db: &Database) -> Response {
let comments = match Comment::reterieve_all() { let comments = match Comment::reterieve_all(db) {
Ok(comments) => comments, Ok(comments) => comments,
Err(err) => return err, Err(err) => return err,
}; };
@ -154,8 +155,8 @@ pub fn generate_comments() -> Response {
ResponseCode::Success.text(&html) ResponseCode::Success.text(&html)
} }
pub fn generate_likes() -> Response { pub fn generate_likes(db: &Database) -> Response {
let likes = match Like::reterieve_all() { let likes = match Like::reterieve_all(db) {
Ok(likes) => likes, Ok(likes) => likes,
Err(err) => return err, Err(err) => return err,
}; };

View file

@ -2,7 +2,7 @@ use serde::Serialize;
use tracing::instrument; use tracing::instrument;
use crate::{ use crate::{
database::{self, comments}, database::Database,
types::http::{ResponseCode, Result}, types::http::{ResponseCode, Result},
}; };
@ -16,9 +16,9 @@ pub struct Comment {
} }
impl Comment { impl Comment {
#[instrument()] #[instrument(skip(db))]
pub fn new(user_id: u64, post_id: u64, content: &str) -> Result<Self> { pub fn new(db: &Database, user_id: u64, post_id: u64, content: &str) -> Result<Self> {
let Ok(comment) = comments::add_comment(user_id, post_id, content) else { let Ok(comment) = db.add_comment(user_id, post_id, content) else {
tracing::error!("Failed to create comment"); tracing::error!("Failed to create comment");
return Err(ResponseCode::InternalServerError.text("Failed to create post")) return Err(ResponseCode::InternalServerError.text("Failed to create post"))
}; };
@ -26,17 +26,17 @@ impl Comment {
Ok(comment) Ok(comment)
} }
#[instrument()] #[instrument(skip(db))]
pub fn from_comment_page(page: u64, post_id: u64) -> Result<Vec<Self>> { pub fn from_comment_page(db: &Database, page: u64, post_id: u64) -> Result<Vec<Self>> {
let Ok(posts) = database::comments::get_comments_page(page, post_id) else { let Ok(posts) = db.get_comments_page(page, post_id) else {
return Err(ResponseCode::BadRequest.text("Failed to fetch comments")) return Err(ResponseCode::BadRequest.text("Failed to fetch comments"))
}; };
Ok(posts) Ok(posts)
} }
#[instrument()] #[instrument(skip(db))]
pub fn reterieve_all() -> Result<Vec<Self>> { pub fn reterieve_all(db: &Database) -> Result<Vec<Self>> {
let Ok(posts) = database::comments::get_all_comments() else { let Ok(posts) = db.get_all_comments() else {
return Err(ResponseCode::InternalServerError.text("Failed to fetch comments")) return Err(ResponseCode::InternalServerError.text("Failed to fetch comments"))
}; };
Ok(posts) Ok(posts)

View file

@ -14,9 +14,11 @@ use axum::{
use bytes::Bytes; use bytes::Bytes;
use image::{io::Reader, DynamicImage, ImageFormat}; use image::{io::Reader, DynamicImage, ImageFormat};
use serde::de::DeserializeOwned; use serde::de::DeserializeOwned;
use tokio::sync::Mutex;
use tower_cookies::Cookies; use tower_cookies::Cookies;
use crate::{ use crate::{
database,
public::admin, public::admin,
public::console, public::console,
types::{ types::{
@ -97,11 +99,17 @@ where
return Err(ResponseCode::Forbidden.text("No auth token provided")) return Err(ResponseCode::Forbidden.text("No auth token provided"))
}; };
let Ok(session) = Session::from_token(token.value()) else { let Some(db) = parts.extensions.get::<DatabaseExtention>() else {
return Err(ResponseCode::Forbidden.text("Could not connect to database"))
};
let db = db.0.lock().await;
let Ok(session) = Session::from_token(&db, token.value()) else {
return Err(ResponseCode::Unauthorized.text("Auth token invalid")) return Err(ResponseCode::Unauthorized.text("Auth token invalid"))
}; };
let Ok(user) = User::from_user_id(session.user_id, true) else { let Ok(user) = User::from_user_id(&db, session.user_id, true) else {
tracing::error!("Valid token but no valid user"); tracing::error!("Valid token but no valid user");
return Err(ResponseCode::InternalServerError.text("Valid token but no valid user")) return Err(ResponseCode::InternalServerError.text("Valid token but no valid user"))
}; };
@ -260,6 +268,26 @@ where
} }
} }
pub struct DatabaseExtention(pub Mutex<database::Database>);
pub struct Database(pub database::Database);
#[async_trait]
impl<S> FromRequestParts<S> for Database
where
S: Send + Sync,
{
type Rejection = Response;
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self> {
let db = parts.extensions.remove::<DatabaseExtention>();
let Some(db) = db else {
return Err(ResponseCode::InternalServerError.text("Database is not loaded"))
};
Ok(Self(db.0.into_inner()))
}
}
async fn read_body<S, B>(mut req: Request<B>, state: &S) -> Result<Vec<u8>> async fn read_body<S, B>(mut req: Request<B>, state: &S) -> Result<Vec<u8>>
where where
B: HttpBody + Sync + Send + 'static, B: HttpBody + Sync + Send + 'static,

View file

@ -1,7 +1,7 @@
use serde::Serialize; use serde::Serialize;
use tracing::instrument; use tracing::instrument;
use crate::database; use crate::database::Database;
use crate::types::http::{ResponseCode, Result}; use crate::types::http::{ResponseCode, Result};
#[derive(Serialize)] #[derive(Serialize)]
@ -11,9 +11,9 @@ pub struct Like {
} }
impl Like { impl Like {
#[instrument()] #[instrument(skip(db))]
pub fn add_liked(user_id: u64, post_id: u64) -> Result<()> { pub fn add_liked(db: &Database, user_id: u64, post_id: u64) -> Result<()> {
let Ok(liked) = database::likes::add_liked(user_id, post_id) else { let Ok(liked) = db.add_liked(user_id, post_id) else {
return Err(ResponseCode::BadRequest.text("Failed to add like status")) return Err(ResponseCode::BadRequest.text("Failed to add like status"))
}; };
@ -24,9 +24,9 @@ impl Like {
Ok(()) Ok(())
} }
#[instrument()] #[instrument(skip(db))]
pub fn remove_liked(user_id: u64, post_id: u64) -> Result<()> { pub fn remove_liked(db: &Database, user_id: u64, post_id: u64) -> Result<()> {
let Ok(liked) = database::likes::remove_liked(user_id, post_id) else { let Ok(liked) = db.remove_liked(user_id, post_id) else {
return Err(ResponseCode::BadRequest.text("Failed to remove like status")) return Err(ResponseCode::BadRequest.text("Failed to remove like status"))
}; };
@ -37,9 +37,9 @@ impl Like {
Ok(()) Ok(())
} }
#[instrument()] #[instrument(skip(db))]
pub fn reterieve_all() -> Result<Vec<Self>> { pub fn reterieve_all(db: &Database) -> Result<Vec<Self>> {
let Ok(likes) = database::likes::get_all_likes() else { let Ok(likes) = db.get_all_likes() else {
return Err(ResponseCode::InternalServerError.text("Failed to fetch likes")) return Err(ResponseCode::InternalServerError.text("Failed to fetch likes"))
}; };
Ok(likes) Ok(likes)

View file

@ -2,7 +2,7 @@ use core::fmt;
use serde::Serialize; use serde::Serialize;
use tracing::instrument; use tracing::instrument;
use crate::database; use crate::database::Database;
use crate::types::http::{ResponseCode, Result}; use crate::types::http::{ResponseCode, Result};
use super::comment::Comment; use super::comment::Comment;
@ -27,57 +27,62 @@ impl fmt::Debug for Post {
} }
impl Post { impl Post {
#[instrument()] #[instrument(skip(db))]
pub fn from_post_id(self_id: u64, post_id: u64) -> Result<Self> { pub fn from_post_id(db: &Database, self_id: u64, post_id: u64) -> Result<Self> {
let Ok(Some(mut post)) = database::posts::get_post(post_id) else { let Ok(Some(mut post)) = db.get_post(post_id) else {
return Err(ResponseCode::BadRequest.text("Post does not exist")) return Err(ResponseCode::BadRequest.text("Post does not exist"))
}; };
let liked = database::likes::get_liked(self_id, post.post_id).unwrap_or(false); let liked = db.get_liked(self_id, post.post_id).unwrap_or(false);
post.liked = liked; post.liked = liked;
Ok(post) Ok(post)
} }
#[instrument()] #[instrument(skip(db))]
pub fn from_post_page(self_id: u64, page: u64) -> Result<Vec<Self>> { pub fn from_post_page(db: &Database, self_id: u64, page: u64) -> Result<Vec<Self>> {
let Ok(mut posts) = database::posts::get_post_page(page) else { let Ok(mut posts) = db.get_post_page(page) else {
return Err(ResponseCode::BadRequest.text("Failed to fetch posts")) return Err(ResponseCode::BadRequest.text("Failed to fetch posts"))
}; };
for post in &mut posts { for post in &mut posts {
let liked = database::likes::get_liked(self_id, post.post_id).unwrap_or(false); let liked = db.get_liked(self_id, post.post_id).unwrap_or(false);
post.liked = liked; post.liked = liked;
} }
Ok(posts) Ok(posts)
} }
#[instrument()] #[instrument(skip(db))]
pub fn from_user_post_page(self_id: u64, user_id: u64, page: u64) -> Result<Vec<Self>> { pub fn from_user_post_page(
let Ok(mut posts) = database::posts::get_users_post_page(user_id, page) else { db: &Database,
self_id: u64,
user_id: u64,
page: u64,
) -> Result<Vec<Self>> {
let Ok(mut posts) = db.get_users_post_page(user_id, page) else {
return Err(ResponseCode::BadRequest.text("Failed to fetch posts")) return Err(ResponseCode::BadRequest.text("Failed to fetch posts"))
}; };
for post in &mut posts { for post in &mut posts {
let liked = database::likes::get_liked(self_id, post.post_id).unwrap_or(false); let liked = db.get_liked(self_id, post.post_id).unwrap_or(false);
post.liked = liked; post.liked = liked;
} }
Ok(posts) Ok(posts)
} }
#[instrument()] #[instrument(skip(db))]
pub fn reterieve_all() -> Result<Vec<Self>> { pub fn reterieve_all(db: &Database) -> Result<Vec<Self>> {
let Ok(posts) = database::posts::get_all_posts() else { let Ok(posts) = db.get_all_posts() else {
return Err(ResponseCode::InternalServerError.text("Failed to fetch posts")) return Err(ResponseCode::InternalServerError.text("Failed to fetch posts"))
}; };
Ok(posts) Ok(posts)
} }
#[instrument()] #[instrument(skip(db))]
pub fn new(user_id: u64, content: String) -> Result<Self> { pub fn new(db: &Database, user_id: u64, content: String) -> Result<Self> {
let Ok(post) = database::posts::add_post(user_id, &content) else { let Ok(post) = db.add_post(user_id, &content) else {
tracing::error!("Failed to create post"); tracing::error!("Failed to create post");
return Err(ResponseCode::InternalServerError.text("Failed to create post")) return Err(ResponseCode::InternalServerError.text("Failed to create post"))
}; };

View file

@ -2,7 +2,7 @@ use rand::{distributions::Alphanumeric, Rng};
use serde::Serialize; use serde::Serialize;
use tracing::instrument; use tracing::instrument;
use crate::database; use crate::database::Database;
use crate::types::http::{ResponseCode, Result}; use crate::types::http::{ResponseCode, Result};
#[derive(Serialize)] #[derive(Serialize)]
@ -12,39 +12,39 @@ pub struct Session {
} }
impl Session { impl Session {
#[instrument()] #[instrument(skip(db))]
pub fn from_token(token: &str) -> Result<Self> { pub fn from_token(db: &Database, token: &str) -> Result<Self> {
let Ok(Some(session)) = database::sessions::get_session(token) else { let Ok(Some(session)) = db.get_session(token) else {
return Err(ResponseCode::BadRequest.text("Invalid auth token")); return Err(ResponseCode::BadRequest.text("Invalid auth token"));
}; };
Ok(session) Ok(session)
} }
#[instrument()] #[instrument(skip(db))]
pub fn reterieve_all() -> Result<Vec<Self>> { pub fn reterieve_all(db: &Database) -> Result<Vec<Self>> {
let Ok(sessions) = database::sessions::get_all_sessions() else { let Ok(sessions) = db.get_all_sessions() else {
return Err(ResponseCode::InternalServerError.text("Failed to fetch sessions")) return Err(ResponseCode::InternalServerError.text("Failed to fetch sessions"))
}; };
Ok(sessions) Ok(sessions)
} }
#[instrument()] #[instrument(skip(db))]
pub fn new(user_id: u64) -> Result<Self> { pub fn new(db: &Database, user_id: u64) -> Result<Self> {
let token: String = rand::thread_rng() let token: String = rand::thread_rng()
.sample_iter(&Alphanumeric) .sample_iter(&Alphanumeric)
.take(32) .take(32)
.map(char::from) .map(char::from)
.collect(); .collect();
match database::sessions::set_session(user_id, &token) { match db.set_session(user_id, &token) {
Err(_) => Err(ResponseCode::BadRequest.text("Failed to create session")), Err(_) => Err(ResponseCode::BadRequest.text("Failed to create session")),
Ok(_) => Ok(Self { user_id, token }), Ok(_) => Ok(Self { user_id, token }),
} }
} }
#[instrument()] #[instrument(skip(db))]
pub fn delete(user_id: u64) -> Result<()> { pub fn delete(db: &Database, user_id: u64) -> Result<()> {
if database::sessions::delete_session(user_id).is_err() { if db.delete_session(user_id).is_err() {
tracing::error!("Failed to logout user"); tracing::error!("Failed to logout user");
return Err(ResponseCode::InternalServerError.text("Failed to logout")); return Err(ResponseCode::InternalServerError.text("Failed to logout"));
}; };

View file

@ -2,7 +2,7 @@ use serde::{Deserialize, Serialize};
use tracing::instrument; use tracing::instrument;
use crate::api::RegistrationRequet; use crate::api::RegistrationRequet;
use crate::database; use crate::database::Database;
use crate::types::http::{ResponseCode, Result}; use crate::types::http::{ResponseCode, Result};
#[derive(Serialize, Deserialize, Debug)] #[derive(Serialize, Deserialize, Debug)]
@ -24,21 +24,21 @@ pub const FOLLOWING: u8 = 1;
pub const FOLLOWED: u8 = 2; pub const FOLLOWED: u8 = 2;
impl User { impl User {
#[instrument()] #[instrument(skip(db))]
pub fn from_user_id(user_id: u64, hide_password: bool) -> Result<Self> { pub fn from_user_id(db: &Database, user_id: u64, hide_password: bool) -> Result<Self> {
let Ok(Some(user)) = database::users::get_user_by_id(user_id, hide_password) else { let Ok(Some(user)) = db.get_user_by_id(user_id, hide_password) else {
return Err(ResponseCode::BadRequest.text("User does not exist")) return Err(ResponseCode::BadRequest.text("User does not exist"))
}; };
Ok(user) Ok(user)
} }
#[instrument()] #[instrument(skip(db))]
pub fn from_user_ids(user_ids: Vec<u64>) -> Vec<Self> { pub fn from_user_ids(db: &Database, user_ids: Vec<u64>) -> Vec<Self> {
user_ids user_ids
.iter() .iter()
.filter_map(|user_id| { .filter_map(|user_id| {
let Ok(Some(user)) = database::users::get_user_by_id(*user_id, true) else { let Ok(Some(user)) = db.get_user_by_id(*user_id, true) else {
return None; return None;
}; };
Some(user) Some(user)
@ -46,53 +46,53 @@ impl User {
.collect() .collect()
} }
#[instrument()] #[instrument(skip(db))]
pub fn from_user_page(page: u64) -> Result<Vec<Self>> { pub fn from_user_page(db: &Database, page: u64) -> Result<Vec<Self>> {
let Ok(users) = database::users::get_user_page(page, true) else { let Ok(users) = db.get_user_page(page, true) else {
return Err(ResponseCode::BadRequest.text("Failed to fetch users")) return Err(ResponseCode::BadRequest.text("Failed to fetch users"))
}; };
Ok(users) Ok(users)
} }
#[instrument()] #[instrument(skip(db))]
pub fn from_email(email: &str) -> Result<Self> { pub fn from_email(db: &Database, email: &str) -> Result<Self> {
let Ok(Some(user)) = database::users::get_user_by_email(email, false) else { let Ok(Some(user)) = db.get_user_by_email(email, false) else {
return Err(ResponseCode::BadRequest.text("User does not exist")) return Err(ResponseCode::BadRequest.text("User does not exist"))
}; };
Ok(user) Ok(user)
} }
#[instrument()] #[instrument(skip(db))]
pub fn from_password(password: &str) -> Result<Self> { pub fn from_password(db: &Database, password: &str) -> Result<Self> {
let Ok(Some(user)) = database::users::get_user_by_password(password, true) else { let Ok(Some(user)) = db.get_user_by_password(password, true) else {
return Err(ResponseCode::BadRequest.text("User does not exist")) return Err(ResponseCode::BadRequest.text("User does not exist"))
}; };
Ok(user) Ok(user)
} }
#[instrument()] #[instrument(skip(db))]
pub fn reterieve_all() -> Result<Vec<Self>> { pub fn reterieve_all(db: &Database) -> Result<Vec<Self>> {
let Ok(users) = database::users::get_all_users() else { let Ok(users) = db.get_all_users() else {
return Err(ResponseCode::InternalServerError.text("Failed to fetch users")) return Err(ResponseCode::InternalServerError.text("Failed to fetch users"))
}; };
Ok(users) Ok(users)
} }
#[instrument()] #[instrument(skip(db))]
pub fn new(request: RegistrationRequet) -> Result<Self> { pub fn new(db: &Database, request: RegistrationRequet) -> Result<Self> {
if Self::from_email(&request.email).is_ok() { if Self::from_email(db, &request.email).is_ok() {
return Err(ResponseCode::BadRequest return Err(ResponseCode::BadRequest
.text(&format!("Email is already in use by {}", &request.email))); .text(&format!("Email is already in use by {}", &request.email)));
} }
if let Ok(user) = Self::from_password(&request.password) { if let Ok(user) = Self::from_password(db, &request.password) {
return Err(ResponseCode::BadRequest return Err(ResponseCode::BadRequest
.text(&format!("Password is already in use by {}", user.email))); .text(&format!("Password is already in use by {}", user.email)));
} }
let Ok(user) = database::users::add_user(request) else { let Ok(user) = db.add_user(request) else {
tracing::error!("Failed to create new user"); tracing::error!("Failed to create new user");
return Err(ResponseCode::InternalServerError.text("Failed to create new uesr")) return Err(ResponseCode::InternalServerError.text("Failed to create new uesr"))
}; };
@ -100,8 +100,9 @@ impl User {
Ok(user) Ok(user)
} }
pub fn add_following(user_id_1: u64, user_id_2: u64) -> Result<()> { #[instrument(skip(db))]
let Ok(followed) = database::friends::set_following(user_id_1, user_id_2) else { pub fn add_following(db: &Database, user_id_1: u64, user_id_2: u64) -> Result<()> {
let Ok(followed) = db.set_following(user_id_1, user_id_2) else {
return Err(ResponseCode::BadRequest.text("Failed to add follow status")) return Err(ResponseCode::BadRequest.text("Failed to add follow status"))
}; };
@ -112,8 +113,9 @@ impl User {
Ok(()) Ok(())
} }
pub fn remove_following(user_id_1: u64, user_id_2: u64) -> Result<()> { #[instrument(skip(db))]
let Ok(followed) = database::friends::remove_following(user_id_1, user_id_2) else { pub fn remove_following(db: &Database, user_id_1: u64, user_id_2: u64) -> Result<()> {
let Ok(followed) = db.remove_following(user_id_1, user_id_2) else {
return Err(ResponseCode::BadRequest.text("Failed to remove follow status")) return Err(ResponseCode::BadRequest.text("Failed to remove follow status"))
}; };
@ -124,15 +126,17 @@ impl User {
Ok(()) Ok(())
} }
pub fn get_following(user_id_1: u64, user_id_2: u64) -> Result<u8> { #[instrument(skip(db))]
let Ok(followed) = database::friends::get_friend_status(user_id_1, user_id_2) else { pub fn get_following(db: &Database, user_id_1: u64, user_id_2: u64) -> Result<u8> {
let Ok(followed) = db.get_friend_status(user_id_1, user_id_2) else {
return Err(ResponseCode::InternalServerError.text("Failed to get follow status")) return Err(ResponseCode::InternalServerError.text("Failed to get follow status"))
}; };
Ok(followed) Ok(followed)
} }
pub fn get_friends(user_id: u64) -> Result<Vec<Self>> { #[instrument(skip(db))]
let Ok(users) = database::friends::get_friends(user_id) else { pub fn get_friends(db: &Database, user_id: u64) -> Result<Vec<Self>> {
let Ok(users) = db.get_friends(user_id) else {
return Err(ResponseCode::InternalServerError.text("Failed to fetch friends")) return Err(ResponseCode::InternalServerError.text("Failed to fetch friends"))
}; };
Ok(users) Ok(users)