make database calls 1 conn

This commit is contained in:
Tyler Murphy 2023-02-15 00:01:44 -05:00
parent 192be95f84
commit aec4fdecc1
19 changed files with 829 additions and 695 deletions

View file

@ -5,13 +5,12 @@ use serde::Deserialize;
use tower_cookies::{Cookie, Cookies};
use crate::{
database,
public::{
admin,
docs::{EndpointDocumentation, EndpointMethod},
},
types::{
extract::{AdminUser, Check, CheckResult, Json},
extract::{AdminUser, Check, CheckResult, Database, Json},
http::ResponseCode,
},
};
@ -92,8 +91,8 @@ impl Check for QueryRequest {
}
}
async fn query(_: AdminUser, Json(body): Json<QueryRequest>) -> Response {
match database::query(body.query) {
async fn query(_: AdminUser, Database(db): Database, Json(body): Json<QueryRequest>) -> Response {
match db.query(body.query) {
Ok(changes) => ResponseCode::Success.text(&format!(
"Query executed successfully. {changes} lines changed."
)),
@ -114,8 +113,8 @@ pub const ADMIN_POSTS: EndpointDocumentation = EndpointDocumentation {
cookie: Some("admin"),
};
async fn posts(_: AdminUser) -> Response {
admin::generate_posts()
async fn posts(_: AdminUser, Database(db): Database) -> Response {
admin::generate_posts(&db)
}
pub const ADMIN_USERS: EndpointDocumentation = EndpointDocumentation {
@ -131,8 +130,8 @@ pub const ADMIN_USERS: EndpointDocumentation = EndpointDocumentation {
cookie: Some("admin"),
};
async fn users(_: AdminUser) -> Response {
admin::generate_users()
async fn users(_: AdminUser, Database(db): Database) -> Response {
admin::generate_users(&db)
}
pub const ADMIN_SESSIONS: EndpointDocumentation = EndpointDocumentation {
@ -148,8 +147,8 @@ pub const ADMIN_SESSIONS: EndpointDocumentation = EndpointDocumentation {
cookie: Some("admin"),
};
async fn sessions(_: AdminUser) -> Response {
admin::generate_sessions()
async fn sessions(_: AdminUser, Database(db): Database) -> Response {
admin::generate_sessions(&db)
}
pub const ADMIN_COMMENTS: EndpointDocumentation = EndpointDocumentation {
@ -165,8 +164,8 @@ pub const ADMIN_COMMENTS: EndpointDocumentation = EndpointDocumentation {
cookie: Some("admin"),
};
async fn comments(_: AdminUser) -> Response {
admin::generate_comments()
async fn comments(_: AdminUser, Database(db): Database) -> Response {
admin::generate_comments(&db)
}
pub const ADMIN_LIKES: EndpointDocumentation = EndpointDocumentation {
@ -182,8 +181,8 @@ pub const ADMIN_LIKES: EndpointDocumentation = EndpointDocumentation {
cookie: Some("admin"),
};
async fn likes(_: AdminUser) -> Response {
admin::generate_likes()
async fn likes(_: AdminUser, Database(db): Database) -> Response {
admin::generate_likes(&db)
}
async fn check(check: Option<AdminUser>) -> Response {

View file

@ -6,7 +6,7 @@ use tower_cookies::{Cookie, Cookies};
use crate::{
public::docs::{EndpointDocumentation, EndpointMethod},
types::{
extract::{AuthorizedUser, Check, CheckResult, Json, Log},
extract::{AuthorizedUser, Check, CheckResult, Database, Json, Log},
http::ResponseCode,
session::Session,
user::User,
@ -99,13 +99,17 @@ impl Check for RegistrationRequet {
}
}
async fn register(cookies: Cookies, Json(body): Json<RegistrationRequet>) -> Response {
let user = match User::new(body) {
async fn register(
cookies: Cookies,
Database(db): Database,
Json(body): Json<RegistrationRequet>,
) -> Response {
let user = match User::new(&db, body) {
Ok(user) => user,
Err(err) => return err,
};
let session = match Session::new(user.user_id) {
let session = match Session::new(&db, user.user_id) {
Ok(session) => session,
Err(err) => return err,
};
@ -158,8 +162,12 @@ impl Check for LoginRequest {
}
}
async fn login(cookies: Cookies, Json(body): Json<LoginRequest>) -> Response {
let Ok(user) = User::from_email(&body.email) else {
async fn login(
cookies: Cookies,
Database(db): Database,
Json(body): Json<LoginRequest>,
) -> Response {
let Ok(user) = User::from_email(&db, &body.email) else {
return ResponseCode::BadRequest.text("Email is not registered")
};
@ -167,7 +175,7 @@ async fn login(cookies: Cookies, Json(body): Json<LoginRequest>) -> Response {
return ResponseCode::BadRequest.text("Password is not correct");
}
let session = match Session::new(user.user_id) {
let session = match Session::new(&db, user.user_id) {
Ok(session) => session,
Err(err) => return err,
};
@ -199,10 +207,15 @@ pub const AUTH_LOGOUT: EndpointDocumentation = EndpointDocumentation {
cookie: None,
};
async fn logout(cookies: Cookies, AuthorizedUser(user): AuthorizedUser, _: Log) -> Response {
async fn logout(
cookies: Cookies,
AuthorizedUser(user): AuthorizedUser,
Database(db): Database,
_: Log,
) -> Response {
cookies.remove(Cookie::new("auth", ""));
if let Err(err) = Session::delete(user.user_id) {
if let Err(err) = Session::delete(&db, user.user_id) {
return err;
}

View file

@ -1,5 +1,15 @@
use crate::types::extract::RouterURI;
use axum::{error_handling::HandleErrorLayer, BoxError, Extension, Router};
use crate::{
database,
types::extract::{DatabaseExtention, RouterURI},
};
use axum::{
error_handling::HandleErrorLayer,
http::Request,
middleware::{self, Next},
response::Response,
BoxError, Extension, Router,
};
use tokio::sync::Mutex;
use tower::ServiceBuilder;
use tower_governor::{
errors::display_error, governor::GovernorConfigBuilder, key_extractor::SmartIpKeyExtractor,
@ -13,6 +23,18 @@ pub mod users;
pub use auth::RegistrationRequet;
async fn connect<B>(mut req: Request<B>, next: Next<B>) -> Response
where
B: Send,
{
if let Ok(db) = database::Database::connect() {
let ex = DatabaseExtention(Mutex::new(db));
req.extensions_mut().insert(ex);
}
next.run(req).await
}
pub fn router() -> Router {
let governor_conf = Box::new(
GovernorConfigBuilder::default()
@ -49,4 +71,5 @@ pub fn router() -> Router {
config: Box::leak(governor_conf),
}),
)
.layer(middleware::from_fn(connect))
}

View file

@ -9,7 +9,7 @@ use crate::{
public::docs::{EndpointDocumentation, EndpointMethod},
types::{
comment::Comment,
extract::{AuthorizedUser, Check, CheckResult, Json},
extract::{AuthorizedUser, Check, CheckResult, Database, Json},
http::ResponseCode,
like::Like,
post::Post,
@ -55,9 +55,10 @@ impl Check for PostCreateRequest {
async fn create(
AuthorizedUser(user): AuthorizedUser,
Database(db): Database,
Json(body): Json<PostCreateRequest>,
) -> Response {
let Ok(post) = Post::new(user.user_id, body.content) else {
let Ok(post) = Post::new(&db, user.user_id, body.content) else {
return ResponseCode::InternalServerError.text("Failed to create post")
};
@ -101,9 +102,10 @@ impl Check for PostPageRequest {
async fn page(
AuthorizedUser(user): AuthorizedUser,
Database(db): Database,
Json(body): Json<PostPageRequest>,
) -> Response {
let Ok(posts) = Post::from_post_page(user.user_id, body.page) else {
let Ok(posts) = Post::from_post_page(&db, user.user_id, body.page) else {
return ResponseCode::InternalServerError.text("Failed to fetch posts")
};
@ -149,9 +151,10 @@ impl Check for CommentsPageRequest {
async fn comments(
AuthorizedUser(_user): AuthorizedUser,
Database(db): Database,
Json(body): Json<CommentsPageRequest>,
) -> Response {
let Ok(comments) = Comment::from_comment_page(body.page, body.post_id) else {
let Ok(comments) = Comment::from_comment_page(&db, body.page, body.post_id) else {
return ResponseCode::InternalServerError.text("Failed to fetch comments")
};
@ -197,9 +200,10 @@ impl Check for UsersPostsRequest {
async fn user(
AuthorizedUser(user): AuthorizedUser,
Database(db): Database,
Json(body): Json<UsersPostsRequest>,
) -> Response {
let Ok(posts) = Post::from_user_post_page(user.user_id, body.user_id, body.page) else {
let Ok(posts) = Post::from_user_post_page(&db, user.user_id, body.user_id, body.page) else {
return ResponseCode::InternalServerError.text("Failed to fetch posts")
};
@ -251,9 +255,10 @@ impl Check for PostCommentRequest {
async fn comment(
AuthorizedUser(user): AuthorizedUser,
Database(db): Database,
Json(body): Json<PostCommentRequest>,
) -> Response {
if let Err(err) = Comment::new(user.user_id, body.post_id, &body.content) {
if let Err(err) = Comment::new(&db, user.user_id, body.post_id, &body.content) {
return err;
}
@ -293,12 +298,16 @@ impl Check for PostLikeRequest {
}
}
async fn like(AuthorizedUser(user): AuthorizedUser, Json(body): Json<PostLikeRequest>) -> Response {
async fn like(
AuthorizedUser(user): AuthorizedUser,
Database(db): Database,
Json(body): Json<PostLikeRequest>,
) -> Response {
if body.state {
if let Err(err) = Like::add_liked(user.user_id, body.post_id) {
if let Err(err) = Like::add_liked(&db, user.user_id, body.post_id) {
return err;
}
} else if let Err(err) = Like::remove_liked(user.user_id, body.post_id) {
} else if let Err(err) = Like::remove_liked(&db, user.user_id, body.post_id) {
return err;
}

View file

@ -1,7 +1,7 @@
use crate::{
public::docs::{EndpointDocumentation, EndpointMethod},
types::{
extract::{AuthorizedUser, Check, CheckResult, Json, Log, Png},
extract::{AuthorizedUser, Check, CheckResult, Database, Json, Log, Png},
http::ResponseCode,
user::User,
},
@ -46,9 +46,10 @@ impl Check for UserLoadRequest {
async fn load_batch(
AuthorizedUser(_user): AuthorizedUser,
Database(db): Database,
Json(body): Json<UserLoadRequest>,
) -> Response {
let users = User::from_user_ids(body.ids);
let users = User::from_user_ids(&db, body.ids);
let Ok(json) = serde_json::to_string(&users) else {
return ResponseCode::InternalServerError.text("Failed to fetch users")
};
@ -90,9 +91,10 @@ impl Check for UserPageReqiest {
async fn load_page(
AuthorizedUser(_user): AuthorizedUser,
Database(db): Database,
Json(body): Json<UserPageReqiest>,
) -> Response {
let Ok(users) = User::from_user_page(body.page) else {
let Ok(users) = User::from_user_page(&db, body.page) else {
return ResponseCode::InternalServerError.text("Failed to fetch users")
};
@ -207,17 +209,18 @@ impl Check for UserFollowRequest {
async fn follow(
AuthorizedUser(user): AuthorizedUser,
Database(db): Database,
Json(body): Json<UserFollowRequest>,
) -> Response {
if body.state {
if let Err(err) = User::add_following(user.user_id, body.user_id) {
if let Err(err) = User::add_following(&db, user.user_id, body.user_id) {
return err;
}
} else if let Err(err) = User::remove_following(user.user_id, body.user_id) {
} else if let Err(err) = User::remove_following(&db, user.user_id, body.user_id) {
return err;
}
match User::get_following(user.user_id, body.user_id) {
match User::get_following(&db, user.user_id, body.user_id) {
Ok(status) => ResponseCode::Success.text(&format!("{status}")),
Err(err) => err,
}
@ -259,9 +262,10 @@ impl Check for UserFollowStatusRequest {
async fn follow_status(
AuthorizedUser(user): AuthorizedUser,
Database(db): Database,
Json(body): Json<UserFollowStatusRequest>,
) -> Response {
match User::get_following(user.user_id, body.user_id) {
match User::get_following(&db, user.user_id, body.user_id) {
Ok(status) => ResponseCode::Success.text(&format!("{status}")),
Err(err) => err,
}
@ -297,8 +301,12 @@ impl Check for UserFriendsRequest {
}
}
async fn friends(AuthorizedUser(_user): AuthorizedUser, Json(body): Json<UserFriendsRequest>) -> Response {
let Ok(users) = User::get_friends(body.user_id) else {
async fn friends(
AuthorizedUser(_user): AuthorizedUser,
Database(db): Database,
Json(body): Json<UserFriendsRequest>,
) -> Response {
let Ok(users) = User::get_friends(&db, body.user_id) else {
return ResponseCode::InternalServerError.text("Failed to fetch user")
};

View file

@ -3,89 +3,100 @@ use std::time::{Duration, SystemTime, UNIX_EPOCH};
use rusqlite::Row;
use tracing::instrument;
use crate::{database, types::comment::Comment};
use crate::types::comment::Comment;
pub fn init() -> Result<(), rusqlite::Error> {
let sql = "
CREATE TABLE IF NOT EXISTS comments (
comment_id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id INTEGER NOT NULL,
post_id INTEGER NOT NULL,
date INTEGER NOT NULL,
content VARCHAR(255) NOT NULL,
FOREIGN KEY(user_id) REFERENCES users(user_id),
FOREIGN KEY(post_id) REFERENCES posts(post_id)
);
";
let conn = database::connect()?;
conn.execute(sql, ())?;
use super::Database;
let sql2 = "CREATE INDEX IF NOT EXISTS post_ids on comments (post_id);";
conn.execute(sql2, ())?;
impl Database {
pub fn init_comments(&self) -> Result<(), rusqlite::Error> {
let sql = "
CREATE TABLE IF NOT EXISTS comments (
comment_id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id INTEGER NOT NULL,
post_id INTEGER NOT NULL,
date INTEGER NOT NULL,
content VARCHAR(255) NOT NULL,
FOREIGN KEY(user_id) REFERENCES users(user_id),
FOREIGN KEY(post_id) REFERENCES posts(post_id)
);
";
self.0.execute(sql, ())?;
Ok(())
}
fn comment_from_row(row: &Row) -> Result<Comment, rusqlite::Error> {
let comment_id = row.get(0)?;
let user_id = row.get(1)?;
let post_id = row.get(2)?;
let date = row.get(3)?;
let content = row.get(4)?;
Ok(Comment {
comment_id,
user_id,
post_id,
date,
content,
})
}
#[instrument()]
pub fn get_comments_page(page: u64, post_id: u64) -> Result<Vec<Comment>, rusqlite::Error> {
tracing::trace!("Retrieving comments page");
let page_size = 5;
let conn = database::connect()?;
let mut stmt = conn.prepare(
"SELECT * FROM comments WHERE post_id = ? ORDER BY comment_id ASC LIMIT ? OFFSET ?",
)?;
let row = stmt.query_map([post_id, page_size, page_size * page], |row| {
let row = comment_from_row(row)?;
Ok(row)
})?;
Ok(row.into_iter().flatten().collect())
}
#[instrument()]
pub fn get_all_comments() -> Result<Vec<Comment>, rusqlite::Error> {
tracing::trace!("Retrieving comments page");
let conn = database::connect()?;
let mut stmt = conn.prepare("SELECT * FROM comments ORDER BY comment_id DESC")?;
let row = stmt.query_map([], |row| {
let row = comment_from_row(row)?;
Ok(row)
})?;
Ok(row.into_iter().flatten().collect())
}
#[instrument()]
pub fn add_comment(user_id: u64, post_id: u64, content: &str) -> Result<Comment, rusqlite::Error> {
tracing::trace!("Adding comment");
let date = u64::try_from(
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or(Duration::ZERO)
.as_millis(),
)
.unwrap_or(0);
let conn = database::connect()?;
let mut stmt = conn.prepare(
"INSERT INTO comments (user_id, post_id, date, content) VALUES(?,?,?,?) RETURNING *;",
)?;
let post = stmt.query_row((user_id, post_id, date, content), |row| {
let row = comment_from_row(row)?;
Ok(row)
})?;
Ok(post)
let sql2 = "CREATE INDEX IF NOT EXISTS post_ids on comments (post_id);";
self.0.execute(sql2, ())?;
Ok(())
}
fn comment_from_row(row: &Row) -> Result<Comment, rusqlite::Error> {
let comment_id = row.get(0)?;
let user_id = row.get(1)?;
let post_id = row.get(2)?;
let date = row.get(3)?;
let content = row.get(4)?;
Ok(Comment {
comment_id,
user_id,
post_id,
date,
content,
})
}
#[instrument(skip(self))]
pub fn get_comments_page(
&self,
page: u64,
post_id: u64,
) -> Result<Vec<Comment>, rusqlite::Error> {
tracing::trace!("Retrieving comments page");
let page_size = 5;
let mut stmt = self.0.prepare(
"SELECT * FROM comments WHERE post_id = ? ORDER BY comment_id ASC LIMIT ? OFFSET ?",
)?;
let row = stmt.query_map([post_id, page_size, page_size * page], |row| {
let row = Self::comment_from_row(row)?;
Ok(row)
})?;
Ok(row.into_iter().flatten().collect())
}
#[instrument(skip(self))]
pub fn get_all_comments(&self) -> Result<Vec<Comment>, rusqlite::Error> {
tracing::trace!("Retrieving comments page");
let mut stmt = self
.0
.prepare("SELECT * FROM comments ORDER BY comment_id DESC")?;
let row = stmt.query_map([], |row| {
let row = Self::comment_from_row(row)?;
Ok(row)
})?;
Ok(row.into_iter().flatten().collect())
}
#[instrument(skip(self))]
pub fn add_comment(
&self,
user_id: u64,
post_id: u64,
content: &str,
) -> Result<Comment, rusqlite::Error> {
tracing::trace!("Adding comment");
let date = u64::try_from(
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or(Duration::ZERO)
.as_millis(),
)
.unwrap_or(0);
let mut stmt = self.0.prepare(
"INSERT INTO comments (user_id, post_id, date, content) VALUES(?,?,?,?) RETURNING *;",
)?;
let post = stmt.query_row((user_id, post_id, date, content), |row| {
let row = Self::comment_from_row(row)?;
Ok(row)
})?;
Ok(post)
}
}

View file

@ -1,97 +1,100 @@
use tracing::instrument;
use crate::{
database::{self, users::user_from_row},
types::user::{User, FOLLOWED, FOLLOWING, NO_RELATION},
};
use crate::types::user::{User, FOLLOWED, FOLLOWING, NO_RELATION};
pub fn init() -> Result<(), rusqlite::Error> {
let sql = "
CREATE TABLE IF NOT EXISTS friends (
follower_id INTEGER NOT NULL,
followee_id INTEGER NOT NULL,
FOREIGN KEY(follower_id) REFERENCES users(user_id),
FOREIGN KEY(followee_id) REFERENCES users(user_id),
PRIMARY KEY (follower_id, followee_id)
);
";
let conn = database::connect()?;
conn.execute(sql, ())?;
Ok(())
}
use super::Database;
#[instrument()]
pub fn get_friend_status(user_id_1: u64, user_id_2: u64) -> Result<u8, rusqlite::Error> {
tracing::trace!("Retrieving friend status");
let conn = database::connect()?;
let mut stmt = conn.prepare("SELECT * FROM friends WHERE (follower_id = ? AND followee_id = ?) OR (follower_id = ? AND followee_id = ?);")?;
let mut status = NO_RELATION;
let rows: Vec<u64> = stmt
.query_map([user_id_1, user_id_2, user_id_2, user_id_1], |row| {
let id: u64 = row.get(0)?;
Ok(id)
})?
.into_iter()
.flatten()
.collect();
for follower in rows {
if follower == user_id_1 {
status |= FOLLOWING;
}
if follower == user_id_2 {
status |= FOLLOWED;
}
impl Database {
pub fn init_friends(&self) -> Result<(), rusqlite::Error> {
let sql = "
CREATE TABLE IF NOT EXISTS friends (
follower_id INTEGER NOT NULL,
followee_id INTEGER NOT NULL,
FOREIGN KEY(follower_id) REFERENCES users(user_id),
FOREIGN KEY(followee_id) REFERENCES users(user_id),
PRIMARY KEY (follower_id, followee_id)
);
";
self.0.execute(sql, ())?;
Ok(())
}
Ok(status)
}
#[instrument(skip(self))]
pub fn get_friend_status(&self, user_id_1: u64, user_id_2: u64) -> Result<u8, rusqlite::Error> {
tracing::trace!("Retrieving friend status");
let mut stmt = self.0.prepare("SELECT * FROM friends WHERE (follower_id = ? AND followee_id = ?) OR (follower_id = ? AND followee_id = ?);")?;
let mut status = NO_RELATION;
let rows: Vec<u64> = stmt
.query_map([user_id_1, user_id_2, user_id_2, user_id_1], |row| {
let id: u64 = row.get(0)?;
Ok(id)
})?
.into_iter()
.flatten()
.collect();
#[instrument()]
pub fn get_friends(user_id: u64) -> Result<Vec<User>, rusqlite::Error> {
tracing::trace!("Retrieving friends");
let conn = database::connect()?;
let mut stmt = conn.prepare(
"
SELECT *
FROM users u
WHERE EXISTS (
SELECT NULL
FROM friends f
WHERE u.user_id = f.follower_id
AND f.followee_id = ?
)
AND EXISTS (
SELECT NULL
FROM friends f
WHERE u.user_id = f.followee_id
AND f.follower_id = ?
)
",
)?;
let row = stmt.query_map([user_id, user_id], |row| {
let row = user_from_row(row, true)?;
Ok(row)
})?;
Ok(row.into_iter().flatten().collect())
}
for follower in rows {
if follower == user_id_1 {
status |= FOLLOWING;
}
#[instrument()]
pub fn set_following(user_id_1: u64, user_id_2: u64) -> Result<bool, rusqlite::Error> {
tracing::trace!("Setting following");
let conn = database::connect()?;
let mut stmt =
conn.prepare("INSERT OR REPLACE INTO friends (follower_id, followee_id) VALUES (?,?)")?;
let changes = stmt.execute([user_id_1, user_id_2])?;
Ok(changes == 1)
}
if follower == user_id_2 {
status |= FOLLOWED;
}
}
#[instrument()]
pub fn remove_following(user_id_1: u64, user_id_2: u64) -> Result<bool, rusqlite::Error> {
tracing::trace!("Removing following");
let conn = database::connect()?;
let mut stmt = conn.prepare("DELETE FROM friends WHERE follower_id = ? AND followee_id = ?")?;
let changes = stmt.execute([user_id_1, user_id_2])?;
Ok(changes == 1)
Ok(status)
}
#[instrument(skip(self))]
pub fn get_friends(&self, user_id: u64) -> Result<Vec<User>, rusqlite::Error> {
tracing::trace!("Retrieving friends");
let mut stmt = self.0.prepare(
"
SELECT *
FROM users u
WHERE EXISTS (
SELECT NULL
FROM friends f
WHERE u.user_id = f.follower_id
AND f.followee_id = ?
)
AND EXISTS (
SELECT NULL
FROM friends f
WHERE u.user_id = f.followee_id
AND f.follower_id = ?
)
",
)?;
let row = stmt.query_map([user_id, user_id], |row| {
let row = Self::user_from_row(row, true)?;
Ok(row)
})?;
Ok(row.into_iter().flatten().collect())
}
#[instrument(skip(self))]
pub fn set_following(&self, user_id_1: u64, user_id_2: u64) -> Result<bool, rusqlite::Error> {
tracing::trace!("Setting following");
let mut stmt = self
.0
.prepare("INSERT OR REPLACE INTO friends (follower_id, followee_id) VALUES (?,?)")?;
let changes = stmt.execute([user_id_1, user_id_2])?;
Ok(changes == 1)
}
#[instrument(skip(self))]
pub fn remove_following(
&self,
user_id_1: u64,
user_id_2: u64,
) -> Result<bool, rusqlite::Error> {
tracing::trace!("Removing following");
let mut stmt = self
.0
.prepare("DELETE FROM friends WHERE follower_id = ? AND followee_id = ?")?;
let changes = stmt.execute([user_id_1, user_id_2])?;
Ok(changes == 1)
}
}

View file

@ -1,75 +1,81 @@
use rusqlite::OptionalExtension;
use tracing::instrument;
use crate::{database, types::like::Like};
use crate::types::like::Like;
pub fn init() -> Result<(), rusqlite::Error> {
let sql = "
CREATE TABLE IF NOT EXISTS likes (
user_id INTEGER NOT NULL,
post_id INTEGER NOT NULL,
FOREIGN KEY(user_id) REFERENCES users(user_id),
FOREIGN KEY(post_id) REFERENCES posts(post_id),
PRIMARY KEY (user_id, post_id)
);
";
let conn = database::connect()?;
conn.execute(sql, ())?;
Ok(())
}
use super::Database;
#[instrument()]
pub fn get_like_count(post_id: u64) -> Result<Option<u64>, rusqlite::Error> {
tracing::trace!("Retrieving like count");
let conn = database::connect()?;
let mut stmt = conn.prepare("SELECT COUNT(post_id) FROM likes WHERE post_id = ?")?;
let row = stmt
.query_row([post_id], |row| {
let row = row.get(0)?;
Ok(row)
})
.optional()?;
Ok(row)
}
impl Database {
pub fn init_likes(&self) -> Result<(), rusqlite::Error> {
let sql = "
CREATE TABLE IF NOT EXISTS likes (
user_id INTEGER NOT NULL,
post_id INTEGER NOT NULL,
FOREIGN KEY(user_id) REFERENCES users(user_id),
FOREIGN KEY(post_id) REFERENCES posts(post_id),
PRIMARY KEY (user_id, post_id)
);
";
self.0.execute(sql, ())?;
Ok(())
}
#[instrument()]
pub fn get_liked(user_id: u64, post_id: u64) -> Result<bool, rusqlite::Error> {
tracing::trace!("Retrieving if liked");
let conn = database::connect()?;
let mut stmt = conn.prepare("SELECT * FROM likes WHERE user_id = ? AND post_id = ?")?;
let liked = stmt.query_row([user_id, post_id], |_| Ok(())).optional()?;
Ok(liked.is_some())
}
#[instrument(skip(self))]
pub fn get_like_count(&self, post_id: u64) -> Result<Option<u64>, rusqlite::Error> {
tracing::trace!("Retrieving like count");
let mut stmt = self
.0
.prepare("SELECT COUNT(post_id) FROM likes WHERE post_id = ?")?;
let row = stmt
.query_row([post_id], |row| {
let row = row.get(0)?;
Ok(row)
})
.optional()?;
Ok(row)
}
#[instrument()]
pub fn add_liked(user_id: u64, post_id: u64) -> Result<bool, rusqlite::Error> {
tracing::trace!("Adding like");
let conn = database::connect()?;
let mut stmt = conn.prepare("INSERT OR REPLACE INTO likes (user_id, post_id) VALUES (?,?)")?;
let changes = stmt.execute([user_id, post_id])?;
Ok(changes == 1)
}
#[instrument(skip(self))]
pub fn get_liked(&self, user_id: u64, post_id: u64) -> Result<bool, rusqlite::Error> {
tracing::trace!("Retrieving if liked");
let mut stmt = self
.0
.prepare("SELECT * FROM likes WHERE user_id = ? AND post_id = ?")?;
let liked = stmt.query_row([user_id, post_id], |_| Ok(())).optional()?;
Ok(liked.is_some())
}
#[instrument()]
pub fn remove_liked(user_id: u64, post_id: u64) -> Result<bool, rusqlite::Error> {
tracing::trace!("Removing like");
let conn = database::connect()?;
let mut stmt = conn.prepare("DELETE FROM likes WHERE user_id = ? AND post_id = ?;")?;
let changes = stmt.execute((user_id, post_id))?;
Ok(changes == 1)
}
#[instrument(skip(self))]
pub fn add_liked(&self, user_id: u64, post_id: u64) -> Result<bool, rusqlite::Error> {
tracing::trace!("Adding like");
let mut stmt = self
.0
.prepare("INSERT OR REPLACE INTO likes (user_id, post_id) VALUES (?,?)")?;
let changes = stmt.execute([user_id, post_id])?;
Ok(changes == 1)
}
#[instrument()]
pub fn get_all_likes() -> Result<Vec<Like>, rusqlite::Error> {
tracing::trace!("Retrieving comments page");
let conn = database::connect()?;
let mut stmt = conn.prepare("SELECT * FROM likes")?;
let row = stmt.query_map([], |row| {
let like = Like {
user_id: row.get(0)?,
post_id: row.get(1)?,
};
Ok(like)
})?;
Ok(row.into_iter().flatten().collect())
#[instrument(skip(self))]
pub fn remove_liked(&self, user_id: u64, post_id: u64) -> Result<bool, rusqlite::Error> {
tracing::trace!("Removing like");
let mut stmt = self
.0
.prepare("DELETE FROM likes WHERE user_id = ? AND post_id = ?;")?;
let changes = stmt.execute((user_id, post_id))?;
Ok(changes == 1)
}
#[instrument(skip(self))]
pub fn get_all_likes(&self) -> Result<Vec<Like>, rusqlite::Error> {
tracing::trace!("Retrieving comments page");
let mut stmt = self.0.prepare("SELECT * FROM likes")?;
let row = stmt.query_map([], |row| {
let like = Like {
user_id: row.get(0)?,
post_id: row.get(1)?,
};
Ok(like)
})?;
Ok(row.into_iter().flatten().collect())
}
}

View file

@ -1,3 +1,4 @@
use rusqlite::Connection;
use tracing::instrument;
pub mod comments;
@ -7,23 +8,29 @@ pub mod posts;
pub mod sessions;
pub mod users;
pub fn connect() -> Result<rusqlite::Connection, rusqlite::Error> {
rusqlite::Connection::open("xssbook.db")
#[derive(Debug)]
pub struct Database(Connection);
impl Database {
pub fn connect() -> Result<Self, rusqlite::Error> {
let conn = rusqlite::Connection::open("xssbook.db")?;
Ok(Self(conn))
}
#[instrument(skip(self))]
pub fn query(&self, query: String) -> Result<usize, rusqlite::Error> {
tracing::trace!("Running custom query");
self.0.execute(&query, [])
}
}
pub fn init() -> Result<(), rusqlite::Error> {
users::init()?;
posts::init()?;
sessions::init()?;
likes::init()?;
comments::init()?;
friends::init()?;
let db = Database::connect()?;
db.init_users()?;
db.init_posts()?;
db.init_sessions()?;
db.init_likes()?;
db.init_comments()?;
db.init_friends()?;
Ok(())
}
#[instrument()]
pub fn query(query: String) -> Result<usize, rusqlite::Error> {
tracing::trace!("Running custom query");
let conn = connect()?;
conn.execute(&query, [])
}

View file

@ -3,115 +3,122 @@ use std::time::{Duration, SystemTime, UNIX_EPOCH};
use rusqlite::{OptionalExtension, Row};
use tracing::instrument;
use crate::database;
use crate::types::post::Post;
use super::{comments, likes};
use super::Database;
pub fn init() -> Result<(), rusqlite::Error> {
let sql = "
CREATE TABLE IF NOT EXISTS posts (
post_id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id INTEGER NOT NULL,
content VARCHAR(500) NOT NULL,
date INTEGER NOT NULL,
FOREIGN KEY(user_id) REFERENCES users(user_id)
);
";
let conn = database::connect()?;
conn.execute(sql, ())?;
Ok(())
}
impl Database {
pub fn init_posts(&self) -> Result<(), rusqlite::Error> {
let sql = "
CREATE TABLE IF NOT EXISTS posts (
post_id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id INTEGER NOT NULL,
content VARCHAR(500) NOT NULL,
date INTEGER NOT NULL,
FOREIGN KEY(user_id) REFERENCES users(user_id)
);
";
self.0.execute(sql, ())?;
Ok(())
}
fn post_from_row(row: &Row) -> Result<Post, rusqlite::Error> {
let post_id = row.get(0)?;
let user_id = row.get(1)?;
let content = row.get(2)?;
let date = row.get(3)?;
fn post_from_row(&self, row: &Row) -> Result<Post, rusqlite::Error> {
let post_id = row.get(0)?;
let user_id = row.get(1)?;
let content = row.get(2)?;
let date = row.get(3)?;
let comments = comments::get_comments_page(0, post_id).unwrap_or_else(|_| Vec::new());
let likes = likes::get_like_count(post_id).unwrap_or(None).unwrap_or(0);
let comments = self
.get_comments_page(0, post_id)
.unwrap_or_else(|_| Vec::new());
let likes = self.get_like_count(post_id).unwrap_or(None).unwrap_or(0);
Ok(Post {
post_id,
user_id,
content,
date,
likes,
liked: false,
comments,
})
}
#[instrument()]
pub fn get_post(post_id: u64) -> Result<Option<Post>, rusqlite::Error> {
tracing::trace!("Retrieving post");
let conn = database::connect()?;
let mut stmt = conn.prepare("SELECT * FROM posts WHERE post_id = ?")?;
let row = stmt
.query_row([post_id], |row| {
let row = post_from_row(row)?;
Ok(row)
Ok(Post {
post_id,
user_id,
content,
date,
likes,
liked: false,
comments,
})
.optional()?;
Ok(row)
}
}
#[instrument()]
pub fn get_post_page(page: u64) -> Result<Vec<Post>, rusqlite::Error> {
tracing::trace!("Retrieving posts page");
let page_size = 10;
let conn = database::connect()?;
let mut stmt = conn.prepare("SELECT * FROM posts ORDER BY post_id DESC LIMIT ? OFFSET ?")?;
let row = stmt.query_map([page_size, page_size * page], |row| {
let row = post_from_row(row)?;
#[instrument(skip(self))]
pub fn get_post(&self, post_id: u64) -> Result<Option<Post>, rusqlite::Error> {
tracing::trace!("Retrieving post");
let mut stmt = self.0.prepare("SELECT * FROM posts WHERE post_id = ?")?;
let row = stmt
.query_row([post_id], |row| {
let row = self.post_from_row(row)?;
Ok(row)
})
.optional()?;
Ok(row)
})?;
Ok(row.into_iter().flatten().collect())
}
}
#[instrument()]
pub fn get_all_posts() -> Result<Vec<Post>, rusqlite::Error> {
tracing::trace!("Retrieving posts page");
let conn = database::connect()?;
let mut stmt = conn.prepare("SELECT * FROM posts ORDER BY post_id DESC")?;
let row = stmt.query_map([], |row| {
let row = post_from_row(row)?;
Ok(row)
})?;
Ok(row.into_iter().flatten().collect())
}
#[instrument(skip(self))]
pub fn get_post_page(&self, page: u64) -> Result<Vec<Post>, rusqlite::Error> {
tracing::trace!("Retrieving posts page");
let page_size = 10;
let mut stmt = self
.0
.prepare("SELECT * FROM posts ORDER BY post_id DESC LIMIT ? OFFSET ?")?;
let row = stmt.query_map([page_size, page_size * page], |row| {
let row = self.post_from_row(row)?;
Ok(row)
})?;
Ok(row.into_iter().flatten().collect())
}
#[instrument()]
pub fn get_users_post_page(user_id: u64, page: u64) -> Result<Vec<Post>, rusqlite::Error> {
tracing::trace!("Retrieving users posts");
let page_size = 10;
let conn = database::connect()?;
let mut stmt = conn
.prepare("SELECT * FROM posts WHERE user_id = ? ORDER BY post_id DESC LIMIT ? OFFSET ?")?;
let row = stmt.query_map([user_id, page_size, page_size * page], |row| {
let row = post_from_row(row)?;
Ok(row)
})?;
Ok(row.into_iter().flatten().collect())
}
#[instrument(skip(self))]
pub fn get_all_posts(&self) -> Result<Vec<Post>, rusqlite::Error> {
tracing::trace!("Retrieving posts page");
let mut stmt = self
.0
.prepare("SELECT * FROM posts ORDER BY post_id DESC")?;
let row = stmt.query_map([], |row| {
let row = self.post_from_row(row)?;
Ok(row)
})?;
Ok(row.into_iter().flatten().collect())
}
#[instrument()]
pub fn add_post(user_id: u64, content: &str) -> Result<Post, rusqlite::Error> {
tracing::trace!("Adding post");
let date = u64::try_from(
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or(Duration::ZERO)
.as_millis(),
)
.unwrap_or(0);
let conn = database::connect()?;
let mut stmt =
conn.prepare("INSERT INTO posts (user_id, content, date) VALUES(?,?,?) RETURNING *;")?;
let post = stmt.query_row((user_id, content, date), |row| {
let row = post_from_row(row)?;
Ok(row)
})?;
Ok(post)
#[instrument(skip(self))]
pub fn get_users_post_page(
&self,
user_id: u64,
page: u64,
) -> Result<Vec<Post>, rusqlite::Error> {
tracing::trace!("Retrieving users posts");
let page_size = 10;
let mut stmt = self.0.prepare(
"SELECT * FROM posts WHERE user_id = ? ORDER BY post_id DESC LIMIT ? OFFSET ?",
)?;
let row = stmt.query_map([user_id, page_size, page_size * page], |row| {
let row = self.post_from_row(row)?;
Ok(row)
})?;
Ok(row.into_iter().flatten().collect())
}
#[instrument(skip(self))]
pub fn add_post(&self, user_id: u64, content: &str) -> Result<Post, rusqlite::Error> {
tracing::trace!("Adding post");
let date = u64::try_from(
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or(Duration::ZERO)
.as_millis(),
)
.unwrap_or(0);
let mut stmt = self
.0
.prepare("INSERT INTO posts (user_id, content, date) VALUES(?,?,?) RETURNING *;")?;
let post = stmt.query_row((user_id, content, date), |row| {
let row = self.post_from_row(row)?;
Ok(row)
})?;
Ok(post)
}
}

View file

@ -1,65 +1,64 @@
use rusqlite::OptionalExtension;
use tracing::instrument;
use crate::{database, types::session::Session};
use crate::types::session::Session;
pub fn init() -> Result<(), rusqlite::Error> {
let sql = "
CREATE TABLE IF NOT EXISTS sessions (
user_id INTEGER PRIMARY KEY NOT NULL,
token TEXT NOT NULL,
FOREIGN KEY(user_id) REFERENCES users(user_id)
);
";
let conn = database::connect()?;
conn.execute(sql, ())?;
Ok(())
}
use super::Database;
#[instrument()]
pub fn get_session(token: &str) -> Result<Option<Session>, rusqlite::Error> {
tracing::trace!("Retrieving session");
let conn = database::connect()?;
let mut stmt = conn.prepare("SELECT * FROM sessions WHERE token = ?")?;
let row = stmt
.query_row([token], |row| {
impl Database {
pub fn init_sessions(&self) -> Result<(), rusqlite::Error> {
let sql = "
CREATE TABLE IF NOT EXISTS sessions (
user_id INTEGER PRIMARY KEY NOT NULL,
token TEXT NOT NULL,
FOREIGN KEY(user_id) REFERENCES users(user_id)
);
";
self.0.execute(sql, ())?;
Ok(())
}
#[instrument(skip(self))]
pub fn get_session(&self, token: &str) -> Result<Option<Session>, rusqlite::Error> {
tracing::trace!("Retrieving session");
let mut stmt = self.0.prepare("SELECT * FROM sessions WHERE token = ?")?;
let row = stmt
.query_row([token], |row| {
Ok(Session {
user_id: row.get(0)?,
token: row.get(1)?,
})
})
.optional()?;
Ok(row)
}
#[instrument(skip(self))]
pub fn get_all_sessions(&self) -> Result<Vec<Session>, rusqlite::Error> {
tracing::trace!("Retrieving session");
let mut stmt = self.0.prepare("SELECT * FROM sessions")?;
let row = stmt.query_map([], |row| {
Ok(Session {
user_id: row.get(0)?,
token: row.get(1)?,
})
})
.optional()?;
Ok(row)
}
})?;
Ok(row.into_iter().flatten().collect())
}
#[instrument()]
pub fn get_all_sessions() -> Result<Vec<Session>, rusqlite::Error> {
tracing::trace!("Retrieving session");
let conn = database::connect()?;
let mut stmt = conn.prepare("SELECT * FROM sessions")?;
let row = stmt.query_map([], |row| {
Ok(Session {
user_id: row.get(0)?,
token: row.get(1)?,
})
})?;
Ok(row.into_iter().flatten().collect())
}
#[instrument(skip(self))]
pub fn set_session(&self, user_id: u64, token: &str) -> Result<(), Box<dyn std::error::Error>> {
tracing::trace!("Setting new session");
let sql = "INSERT OR REPLACE INTO sessions (user_id, token) VALUES (?, ?);";
self.0.execute(sql, (user_id, token))?;
Ok(())
}
#[instrument()]
pub fn set_session(user_id: u64, token: &str) -> Result<(), Box<dyn std::error::Error>> {
tracing::trace!("Setting new session");
let conn = database::connect()?;
let sql = "INSERT OR REPLACE INTO sessions (user_id, token) VALUES (?, ?);";
conn.execute(sql, (user_id, token))?;
Ok(())
}
#[instrument()]
pub fn delete_session(user_id: u64) -> Result<(), Box<dyn std::error::Error>> {
tracing::trace!("Deleting session");
let conn = database::connect()?;
let sql = "DELETE FROM sessions WHERE user_id = ?;";
conn.execute(sql, [user_id])?;
Ok(())
#[instrument(skip(self))]
pub fn delete_session(&self, user_id: u64) -> Result<(), Box<dyn std::error::Error>> {
tracing::trace!("Deleting session");
let sql = "DELETE FROM sessions WHERE user_id = ?;";
self.0.execute(sql, [user_id])?;
Ok(())
}
}

View file

@ -2,169 +2,180 @@ use rusqlite::{OptionalExtension, Row};
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use tracing::instrument;
use crate::{api::RegistrationRequet, database, types::user::User};
use crate::{api::RegistrationRequet, types::user::User};
pub fn init() -> Result<(), rusqlite::Error> {
let sql = "
CREATE TABLE IF NOT EXISTS users (
user_id INTEGER PRIMARY KEY AUTOINCREMENT,
firstname VARCHAR(20) NOT NULL,
lastname VARCHAR(20) NOT NULL,
email VARCHAR(50) NOT NULL,
password VARCHAR(50) NOT NULL,
gender VARCHAR(100) NOT NULL,
date BIGINT NOT NULL,
day TINYINT NOT NULL,
month TINYINT NOT NULL,
year INTEGER NOT NULL
);
";
let conn = database::connect()?;
conn.execute(sql, ())?;
use super::Database;
let sql2 = "CREATE UNIQUE INDEX IF NOT EXISTS emails on users (email);";
conn.execute(sql2, ())?;
impl Database {
pub fn init_users(&self) -> Result<(), rusqlite::Error> {
let sql = "
CREATE TABLE IF NOT EXISTS users (
user_id INTEGER PRIMARY KEY AUTOINCREMENT,
firstname VARCHAR(20) NOT NULL,
lastname VARCHAR(20) NOT NULL,
email VARCHAR(50) NOT NULL,
password VARCHAR(50) NOT NULL,
gender VARCHAR(100) NOT NULL,
date BIGINT NOT NULL,
day TINYINT NOT NULL,
month TINYINT NOT NULL,
year INTEGER NOT NULL
);
";
self.0.execute(sql, ())?;
let sql3 = "CREATE UNIQUE INDEX IF NOT EXISTS passwords on users (password);";
conn.execute(sql3, ())?;
let sql2 = "CREATE UNIQUE INDEX IF NOT EXISTS emails on users (email);";
self.0.execute(sql2, ())?;
Ok(())
}
let sql3 = "CREATE UNIQUE INDEX IF NOT EXISTS passwords on users (password);";
self.0.execute(sql3, ())?;
pub fn user_from_row(row: &Row, hide_password: bool) -> Result<User, rusqlite::Error> {
let user_id = row.get(0)?;
let firstname = row.get(1)?;
let lastname = row.get(2)?;
let email = row.get(3)?;
let password = row.get(4)?;
let gender = row.get(5)?;
let date = row.get(6)?;
let day = row.get(7)?;
let month = row.get(8)?;
let year = row.get(9)?;
Ok(())
}
let password = if hide_password {
String::new()
} else {
password
};
pub fn user_from_row(row: &Row, hide_password: bool) -> Result<User, rusqlite::Error> {
let user_id = row.get(0)?;
let firstname = row.get(1)?;
let lastname = row.get(2)?;
let email = row.get(3)?;
let password = row.get(4)?;
let gender = row.get(5)?;
let date = row.get(6)?;
let day = row.get(7)?;
let month = row.get(8)?;
let year = row.get(9)?;
Ok(User {
user_id,
firstname,
lastname,
email,
password,
gender,
date,
day,
month,
year,
})
}
let password = if hide_password {
String::new()
} else {
password
};
#[instrument()]
pub fn get_user_by_id(user_id: u64, hide_password: bool) -> Result<Option<User>, rusqlite::Error> {
tracing::trace!("Retrieving user by id");
let conn = database::connect()?;
let mut stmt = conn.prepare("SELECT * FROM users WHERE user_id = ?")?;
let row = stmt
.query_row([user_id], |row| {
let row = user_from_row(row, hide_password)?;
Ok(row)
})
.optional()?;
Ok(row)
}
#[instrument()]
pub fn get_user_by_email(
email: &str,
hide_password: bool,
) -> Result<Option<User>, rusqlite::Error> {
tracing::trace!("Retrieving user by email");
let conn = database::connect()?;
let mut stmt = conn.prepare("SELECT * FROM users WHERE email = ?")?;
let row = stmt
.query_row([email], |row| {
let row = user_from_row(row, hide_password)?;
Ok(row)
})
.optional()?;
Ok(row)
}
#[instrument()]
pub fn get_user_by_password(
password: &str,
hide_password: bool,
) -> Result<Option<User>, rusqlite::Error> {
tracing::trace!("Retrieving user by password");
let conn = database::connect()?;
let mut stmt = conn.prepare("SELECT * FROM users WHERE password = ?")?;
let row = stmt
.query_row([password], |row| {
let row = user_from_row(row, hide_password)?;
Ok(row)
})
.optional()?;
Ok(row)
}
#[instrument()]
pub fn get_user_page(page: u64, hide_password: bool) -> Result<Vec<User>, rusqlite::Error> {
tracing::trace!("Retrieving user page");
let page_size = 5;
let conn = database::connect()?;
let mut stmt = conn.prepare("SELECT * FROM users ORDER BY user_id DESC LIMIT ? OFFSET ?")?;
let row = stmt.query_map([page_size, page_size * page], |row| {
let row = user_from_row(row, hide_password)?;
Ok(row)
})?;
Ok(row.into_iter().flatten().collect())
}
#[instrument()]
pub fn get_all_users() -> Result<Vec<User>, rusqlite::Error> {
tracing::trace!("Retrieving user page");
let conn = database::connect()?;
let mut stmt = conn.prepare("SELECT * FROM users ORDER BY user_id DESC")?;
let row = stmt.query_map([], |row| {
let row = user_from_row(row, false)?;
Ok(row)
})?;
Ok(row.into_iter().flatten().collect())
}
#[instrument()]
pub fn add_user(request: RegistrationRequet) -> Result<User, rusqlite::Error> {
tracing::trace!("Adding new user");
let date = u64::try_from(
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or(Duration::ZERO)
.as_millis(),
)
.unwrap_or(0);
let conn = database::connect()?;
let mut stmt = conn.prepare("INSERT INTO users (firstname, lastname, email, password, gender, date, day, month, year) VALUES(?,?,?,?,?,?,?,?,?) RETURNING *;")?;
let user = stmt.query_row(
(
request.firstname,
request.lastname,
request.email,
request.password,
request.gender,
Ok(User {
user_id,
firstname,
lastname,
email,
password,
gender,
date,
request.day,
request.month,
request.year,
),
|row| {
let row = user_from_row(row, false)?;
day,
month,
year,
})
}
#[instrument(skip(self))]
pub fn get_user_by_id(
&self,
user_id: u64,
hide_password: bool,
) -> Result<Option<User>, rusqlite::Error> {
tracing::trace!("Retrieving user by id");
let mut stmt = self.0.prepare("SELECT * FROM users WHERE user_id = ?")?;
let row = stmt
.query_row([user_id], |row| {
let row = Self::user_from_row(row, hide_password)?;
Ok(row)
})
.optional()?;
Ok(row)
}
#[instrument(skip(self))]
pub fn get_user_by_email(
&self,
email: &str,
hide_password: bool,
) -> Result<Option<User>, rusqlite::Error> {
tracing::trace!("Retrieving user by email");
let mut stmt = self.0.prepare("SELECT * FROM users WHERE email = ?")?;
let row = stmt
.query_row([email], |row| {
let row = Self::user_from_row(row, hide_password)?;
Ok(row)
})
.optional()?;
Ok(row)
}
#[instrument(skip(self))]
pub fn get_user_by_password(
&self,
password: &str,
hide_password: bool,
) -> Result<Option<User>, rusqlite::Error> {
tracing::trace!("Retrieving user by password");
let mut stmt = self.0.prepare("SELECT * FROM users WHERE password = ?")?;
let row = stmt
.query_row([password], |row| {
let row = Self::user_from_row(row, hide_password)?;
Ok(row)
})
.optional()?;
Ok(row)
}
#[instrument(skip(self))]
pub fn get_user_page(
&self,
page: u64,
hide_password: bool,
) -> Result<Vec<User>, rusqlite::Error> {
tracing::trace!("Retrieving user page");
let page_size = 5;
let mut stmt = self
.0
.prepare("SELECT * FROM users ORDER BY user_id DESC LIMIT ? OFFSET ?")?;
let row = stmt.query_map([page_size, page_size * page], |row| {
let row = Self::user_from_row(row, hide_password)?;
Ok(row)
},
)?;
Ok(user)
})?;
Ok(row.into_iter().flatten().collect())
}
#[instrument(skip(self))]
pub fn get_all_users(&self) -> Result<Vec<User>, rusqlite::Error> {
tracing::trace!("Retrieving user page");
let mut stmt = self
.0
.prepare("SELECT * FROM users ORDER BY user_id DESC")?;
let row = stmt.query_map([], |row| {
let row = Self::user_from_row(row, false)?;
Ok(row)
})?;
Ok(row.into_iter().flatten().collect())
}
#[instrument(skip(self))]
pub fn add_user(&self, request: RegistrationRequet) -> Result<User, rusqlite::Error> {
tracing::trace!("Adding new user");
let date = u64::try_from(
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or(Duration::ZERO)
.as_millis(),
)
.unwrap_or(0);
let mut stmt = self.0.prepare("INSERT INTO users (firstname, lastname, email, password, gender, date, day, month, year) VALUES(?,?,?,?,?,?,?,?,?) RETURNING *;")?;
let user = stmt.query_row(
(
request.firstname,
request.lastname,
request.email,
request.password,
request.gender,
date,
request.day,
request.month,
request.year,
),
|row| {
let row = Self::user_from_row(row, false)?;
Ok(row)
},
)?;
Ok(user)
}
}

View file

@ -5,6 +5,7 @@ use tokio::sync::Mutex;
use crate::{
console::sanatize,
database::Database,
types::{
comment::Comment, http::ResponseCode, like::Like, post::Post, session::Session, user::User,
},
@ -36,8 +37,8 @@ pub async fn regen_secret() -> String {
secret.clone()
}
pub fn generate_users() -> Response {
let users = match User::reterieve_all() {
pub fn generate_users(db: &Database) -> Response {
let users = match User::reterieve_all(db) {
Ok(users) => users,
Err(err) => return err,
};
@ -70,8 +71,8 @@ pub fn generate_users() -> Response {
ResponseCode::Success.text(&html)
}
pub fn generate_posts() -> Response {
let posts = match Post::reterieve_all() {
pub fn generate_posts(db: &Database) -> Response {
let posts = match Post::reterieve_all(db) {
Ok(posts) => posts,
Err(err) => return err,
};
@ -99,8 +100,8 @@ pub fn generate_posts() -> Response {
ResponseCode::Success.text(&html)
}
pub fn generate_sessions() -> Response {
let sessions = match Session::reterieve_all() {
pub fn generate_sessions(db: &Database) -> Response {
let sessions = match Session::reterieve_all(db) {
Ok(sessions) => sessions,
Err(err) => return err,
};
@ -123,8 +124,8 @@ pub fn generate_sessions() -> Response {
ResponseCode::Success.text(&html)
}
pub fn generate_comments() -> Response {
let comments = match Comment::reterieve_all() {
pub fn generate_comments(db: &Database) -> Response {
let comments = match Comment::reterieve_all(db) {
Ok(comments) => comments,
Err(err) => return err,
};
@ -154,8 +155,8 @@ pub fn generate_comments() -> Response {
ResponseCode::Success.text(&html)
}
pub fn generate_likes() -> Response {
let likes = match Like::reterieve_all() {
pub fn generate_likes(db: &Database) -> Response {
let likes = match Like::reterieve_all(db) {
Ok(likes) => likes,
Err(err) => return err,
};

View file

@ -2,7 +2,7 @@ use serde::Serialize;
use tracing::instrument;
use crate::{
database::{self, comments},
database::Database,
types::http::{ResponseCode, Result},
};
@ -16,9 +16,9 @@ pub struct Comment {
}
impl Comment {
#[instrument()]
pub fn new(user_id: u64, post_id: u64, content: &str) -> Result<Self> {
let Ok(comment) = comments::add_comment(user_id, post_id, content) else {
#[instrument(skip(db))]
pub fn new(db: &Database, user_id: u64, post_id: u64, content: &str) -> Result<Self> {
let Ok(comment) = db.add_comment(user_id, post_id, content) else {
tracing::error!("Failed to create comment");
return Err(ResponseCode::InternalServerError.text("Failed to create post"))
};
@ -26,17 +26,17 @@ impl Comment {
Ok(comment)
}
#[instrument()]
pub fn from_comment_page(page: u64, post_id: u64) -> Result<Vec<Self>> {
let Ok(posts) = database::comments::get_comments_page(page, post_id) else {
#[instrument(skip(db))]
pub fn from_comment_page(db: &Database, page: u64, post_id: u64) -> Result<Vec<Self>> {
let Ok(posts) = db.get_comments_page(page, post_id) else {
return Err(ResponseCode::BadRequest.text("Failed to fetch comments"))
};
Ok(posts)
}
#[instrument()]
pub fn reterieve_all() -> Result<Vec<Self>> {
let Ok(posts) = database::comments::get_all_comments() else {
#[instrument(skip(db))]
pub fn reterieve_all(db: &Database) -> Result<Vec<Self>> {
let Ok(posts) = db.get_all_comments() else {
return Err(ResponseCode::InternalServerError.text("Failed to fetch comments"))
};
Ok(posts)

View file

@ -14,9 +14,11 @@ use axum::{
use bytes::Bytes;
use image::{io::Reader, DynamicImage, ImageFormat};
use serde::de::DeserializeOwned;
use tokio::sync::Mutex;
use tower_cookies::Cookies;
use crate::{
database,
public::admin,
public::console,
types::{
@ -97,11 +99,17 @@ where
return Err(ResponseCode::Forbidden.text("No auth token provided"))
};
let Ok(session) = Session::from_token(token.value()) else {
let Some(db) = parts.extensions.get::<DatabaseExtention>() else {
return Err(ResponseCode::Forbidden.text("Could not connect to database"))
};
let db = db.0.lock().await;
let Ok(session) = Session::from_token(&db, token.value()) else {
return Err(ResponseCode::Unauthorized.text("Auth token invalid"))
};
let Ok(user) = User::from_user_id(session.user_id, true) else {
let Ok(user) = User::from_user_id(&db, session.user_id, true) else {
tracing::error!("Valid token but no valid user");
return Err(ResponseCode::InternalServerError.text("Valid token but no valid user"))
};
@ -260,6 +268,26 @@ where
}
}
pub struct DatabaseExtention(pub Mutex<database::Database>);
pub struct Database(pub database::Database);
#[async_trait]
impl<S> FromRequestParts<S> for Database
where
S: Send + Sync,
{
type Rejection = Response;
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self> {
let db = parts.extensions.remove::<DatabaseExtention>();
let Some(db) = db else {
return Err(ResponseCode::InternalServerError.text("Database is not loaded"))
};
Ok(Self(db.0.into_inner()))
}
}
async fn read_body<S, B>(mut req: Request<B>, state: &S) -> Result<Vec<u8>>
where
B: HttpBody + Sync + Send + 'static,

View file

@ -1,7 +1,7 @@
use serde::Serialize;
use tracing::instrument;
use crate::database;
use crate::database::Database;
use crate::types::http::{ResponseCode, Result};
#[derive(Serialize)]
@ -11,9 +11,9 @@ pub struct Like {
}
impl Like {
#[instrument()]
pub fn add_liked(user_id: u64, post_id: u64) -> Result<()> {
let Ok(liked) = database::likes::add_liked(user_id, post_id) else {
#[instrument(skip(db))]
pub fn add_liked(db: &Database, user_id: u64, post_id: u64) -> Result<()> {
let Ok(liked) = db.add_liked(user_id, post_id) else {
return Err(ResponseCode::BadRequest.text("Failed to add like status"))
};
@ -24,9 +24,9 @@ impl Like {
Ok(())
}
#[instrument()]
pub fn remove_liked(user_id: u64, post_id: u64) -> Result<()> {
let Ok(liked) = database::likes::remove_liked(user_id, post_id) else {
#[instrument(skip(db))]
pub fn remove_liked(db: &Database, user_id: u64, post_id: u64) -> Result<()> {
let Ok(liked) = db.remove_liked(user_id, post_id) else {
return Err(ResponseCode::BadRequest.text("Failed to remove like status"))
};
@ -37,9 +37,9 @@ impl Like {
Ok(())
}
#[instrument()]
pub fn reterieve_all() -> Result<Vec<Self>> {
let Ok(likes) = database::likes::get_all_likes() else {
#[instrument(skip(db))]
pub fn reterieve_all(db: &Database) -> Result<Vec<Self>> {
let Ok(likes) = db.get_all_likes() else {
return Err(ResponseCode::InternalServerError.text("Failed to fetch likes"))
};
Ok(likes)

View file

@ -2,7 +2,7 @@ use core::fmt;
use serde::Serialize;
use tracing::instrument;
use crate::database;
use crate::database::Database;
use crate::types::http::{ResponseCode, Result};
use super::comment::Comment;
@ -27,57 +27,62 @@ impl fmt::Debug for Post {
}
impl Post {
#[instrument()]
pub fn from_post_id(self_id: u64, post_id: u64) -> Result<Self> {
let Ok(Some(mut post)) = database::posts::get_post(post_id) else {
#[instrument(skip(db))]
pub fn from_post_id(db: &Database, self_id: u64, post_id: u64) -> Result<Self> {
let Ok(Some(mut post)) = db.get_post(post_id) else {
return Err(ResponseCode::BadRequest.text("Post does not exist"))
};
let liked = database::likes::get_liked(self_id, post.post_id).unwrap_or(false);
let liked = db.get_liked(self_id, post.post_id).unwrap_or(false);
post.liked = liked;
Ok(post)
}
#[instrument()]
pub fn from_post_page(self_id: u64, page: u64) -> Result<Vec<Self>> {
let Ok(mut posts) = database::posts::get_post_page(page) else {
#[instrument(skip(db))]
pub fn from_post_page(db: &Database, self_id: u64, page: u64) -> Result<Vec<Self>> {
let Ok(mut posts) = db.get_post_page(page) else {
return Err(ResponseCode::BadRequest.text("Failed to fetch posts"))
};
for post in &mut posts {
let liked = database::likes::get_liked(self_id, post.post_id).unwrap_or(false);
let liked = db.get_liked(self_id, post.post_id).unwrap_or(false);
post.liked = liked;
}
Ok(posts)
}
#[instrument()]
pub fn from_user_post_page(self_id: u64, user_id: u64, page: u64) -> Result<Vec<Self>> {
let Ok(mut posts) = database::posts::get_users_post_page(user_id, page) else {
#[instrument(skip(db))]
pub fn from_user_post_page(
db: &Database,
self_id: u64,
user_id: u64,
page: u64,
) -> Result<Vec<Self>> {
let Ok(mut posts) = db.get_users_post_page(user_id, page) else {
return Err(ResponseCode::BadRequest.text("Failed to fetch posts"))
};
for post in &mut posts {
let liked = database::likes::get_liked(self_id, post.post_id).unwrap_or(false);
let liked = db.get_liked(self_id, post.post_id).unwrap_or(false);
post.liked = liked;
}
Ok(posts)
}
#[instrument()]
pub fn reterieve_all() -> Result<Vec<Self>> {
let Ok(posts) = database::posts::get_all_posts() else {
#[instrument(skip(db))]
pub fn reterieve_all(db: &Database) -> Result<Vec<Self>> {
let Ok(posts) = db.get_all_posts() else {
return Err(ResponseCode::InternalServerError.text("Failed to fetch posts"))
};
Ok(posts)
}
#[instrument()]
pub fn new(user_id: u64, content: String) -> Result<Self> {
let Ok(post) = database::posts::add_post(user_id, &content) else {
#[instrument(skip(db))]
pub fn new(db: &Database, user_id: u64, content: String) -> Result<Self> {
let Ok(post) = db.add_post(user_id, &content) else {
tracing::error!("Failed to create post");
return Err(ResponseCode::InternalServerError.text("Failed to create post"))
};

View file

@ -2,7 +2,7 @@ use rand::{distributions::Alphanumeric, Rng};
use serde::Serialize;
use tracing::instrument;
use crate::database;
use crate::database::Database;
use crate::types::http::{ResponseCode, Result};
#[derive(Serialize)]
@ -12,39 +12,39 @@ pub struct Session {
}
impl Session {
#[instrument()]
pub fn from_token(token: &str) -> Result<Self> {
let Ok(Some(session)) = database::sessions::get_session(token) else {
#[instrument(skip(db))]
pub fn from_token(db: &Database, token: &str) -> Result<Self> {
let Ok(Some(session)) = db.get_session(token) else {
return Err(ResponseCode::BadRequest.text("Invalid auth token"));
};
Ok(session)
}
#[instrument()]
pub fn reterieve_all() -> Result<Vec<Self>> {
let Ok(sessions) = database::sessions::get_all_sessions() else {
#[instrument(skip(db))]
pub fn reterieve_all(db: &Database) -> Result<Vec<Self>> {
let Ok(sessions) = db.get_all_sessions() else {
return Err(ResponseCode::InternalServerError.text("Failed to fetch sessions"))
};
Ok(sessions)
}
#[instrument()]
pub fn new(user_id: u64) -> Result<Self> {
#[instrument(skip(db))]
pub fn new(db: &Database, user_id: u64) -> Result<Self> {
let token: String = rand::thread_rng()
.sample_iter(&Alphanumeric)
.take(32)
.map(char::from)
.collect();
match database::sessions::set_session(user_id, &token) {
match db.set_session(user_id, &token) {
Err(_) => Err(ResponseCode::BadRequest.text("Failed to create session")),
Ok(_) => Ok(Self { user_id, token }),
}
}
#[instrument()]
pub fn delete(user_id: u64) -> Result<()> {
if database::sessions::delete_session(user_id).is_err() {
#[instrument(skip(db))]
pub fn delete(db: &Database, user_id: u64) -> Result<()> {
if db.delete_session(user_id).is_err() {
tracing::error!("Failed to logout user");
return Err(ResponseCode::InternalServerError.text("Failed to logout"));
};

View file

@ -2,7 +2,7 @@ use serde::{Deserialize, Serialize};
use tracing::instrument;
use crate::api::RegistrationRequet;
use crate::database;
use crate::database::Database;
use crate::types::http::{ResponseCode, Result};
#[derive(Serialize, Deserialize, Debug)]
@ -24,21 +24,21 @@ pub const FOLLOWING: u8 = 1;
pub const FOLLOWED: u8 = 2;
impl User {
#[instrument()]
pub fn from_user_id(user_id: u64, hide_password: bool) -> Result<Self> {
let Ok(Some(user)) = database::users::get_user_by_id(user_id, hide_password) else {
#[instrument(skip(db))]
pub fn from_user_id(db: &Database, user_id: u64, hide_password: bool) -> Result<Self> {
let Ok(Some(user)) = db.get_user_by_id(user_id, hide_password) else {
return Err(ResponseCode::BadRequest.text("User does not exist"))
};
Ok(user)
}
#[instrument()]
pub fn from_user_ids(user_ids: Vec<u64>) -> Vec<Self> {
#[instrument(skip(db))]
pub fn from_user_ids(db: &Database, user_ids: Vec<u64>) -> Vec<Self> {
user_ids
.iter()
.filter_map(|user_id| {
let Ok(Some(user)) = database::users::get_user_by_id(*user_id, true) else {
let Ok(Some(user)) = db.get_user_by_id(*user_id, true) else {
return None;
};
Some(user)
@ -46,53 +46,53 @@ impl User {
.collect()
}
#[instrument()]
pub fn from_user_page(page: u64) -> Result<Vec<Self>> {
let Ok(users) = database::users::get_user_page(page, true) else {
#[instrument(skip(db))]
pub fn from_user_page(db: &Database, page: u64) -> Result<Vec<Self>> {
let Ok(users) = db.get_user_page(page, true) else {
return Err(ResponseCode::BadRequest.text("Failed to fetch users"))
};
Ok(users)
}
#[instrument()]
pub fn from_email(email: &str) -> Result<Self> {
let Ok(Some(user)) = database::users::get_user_by_email(email, false) else {
#[instrument(skip(db))]
pub fn from_email(db: &Database, email: &str) -> Result<Self> {
let Ok(Some(user)) = db.get_user_by_email(email, false) else {
return Err(ResponseCode::BadRequest.text("User does not exist"))
};
Ok(user)
}
#[instrument()]
pub fn from_password(password: &str) -> Result<Self> {
let Ok(Some(user)) = database::users::get_user_by_password(password, true) else {
#[instrument(skip(db))]
pub fn from_password(db: &Database, password: &str) -> Result<Self> {
let Ok(Some(user)) = db.get_user_by_password(password, true) else {
return Err(ResponseCode::BadRequest.text("User does not exist"))
};
Ok(user)
}
#[instrument()]
pub fn reterieve_all() -> Result<Vec<Self>> {
let Ok(users) = database::users::get_all_users() else {
#[instrument(skip(db))]
pub fn reterieve_all(db: &Database) -> Result<Vec<Self>> {
let Ok(users) = db.get_all_users() else {
return Err(ResponseCode::InternalServerError.text("Failed to fetch users"))
};
Ok(users)
}
#[instrument()]
pub fn new(request: RegistrationRequet) -> Result<Self> {
if Self::from_email(&request.email).is_ok() {
#[instrument(skip(db))]
pub fn new(db: &Database, request: RegistrationRequet) -> Result<Self> {
if Self::from_email(db, &request.email).is_ok() {
return Err(ResponseCode::BadRequest
.text(&format!("Email is already in use by {}", &request.email)));
}
if let Ok(user) = Self::from_password(&request.password) {
if let Ok(user) = Self::from_password(db, &request.password) {
return Err(ResponseCode::BadRequest
.text(&format!("Password is already in use by {}", user.email)));
}
let Ok(user) = database::users::add_user(request) else {
let Ok(user) = db.add_user(request) else {
tracing::error!("Failed to create new user");
return Err(ResponseCode::InternalServerError.text("Failed to create new uesr"))
};
@ -100,8 +100,9 @@ impl User {
Ok(user)
}
pub fn add_following(user_id_1: u64, user_id_2: u64) -> Result<()> {
let Ok(followed) = database::friends::set_following(user_id_1, user_id_2) else {
#[instrument(skip(db))]
pub fn add_following(db: &Database, user_id_1: u64, user_id_2: u64) -> Result<()> {
let Ok(followed) = db.set_following(user_id_1, user_id_2) else {
return Err(ResponseCode::BadRequest.text("Failed to add follow status"))
};
@ -112,8 +113,9 @@ impl User {
Ok(())
}
pub fn remove_following(user_id_1: u64, user_id_2: u64) -> Result<()> {
let Ok(followed) = database::friends::remove_following(user_id_1, user_id_2) else {
#[instrument(skip(db))]
pub fn remove_following(db: &Database, user_id_1: u64, user_id_2: u64) -> Result<()> {
let Ok(followed) = db.remove_following(user_id_1, user_id_2) else {
return Err(ResponseCode::BadRequest.text("Failed to remove follow status"))
};
@ -124,15 +126,17 @@ impl User {
Ok(())
}
pub fn get_following(user_id_1: u64, user_id_2: u64) -> Result<u8> {
let Ok(followed) = database::friends::get_friend_status(user_id_1, user_id_2) else {
#[instrument(skip(db))]
pub fn get_following(db: &Database, user_id_1: u64, user_id_2: u64) -> Result<u8> {
let Ok(followed) = db.get_friend_status(user_id_1, user_id_2) else {
return Err(ResponseCode::InternalServerError.text("Failed to get follow status"))
};
Ok(followed)
}
pub fn get_friends(user_id: u64) -> Result<Vec<Self>> {
let Ok(users) = database::friends::get_friends(user_id) else {
#[instrument(skip(db))]
pub fn get_friends(db: &Database, user_id: u64) -> Result<Vec<Self>> {
let Ok(users) = db.get_friends(user_id) else {
return Err(ResponseCode::InternalServerError.text("Failed to fetch friends"))
};
Ok(users)