diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/api/mod.rs | 2 | ||||
-rw-r--r-- | src/api/posts.rs | 2 | ||||
-rw-r--r-- | src/api/users.rs | 143 | ||||
-rw-r--r-- | src/database/friends.rs | 97 | ||||
-rw-r--r-- | src/database/likes.rs | 8 | ||||
-rw-r--r-- | src/database/mod.rs | 2 | ||||
-rw-r--r-- | src/database/users.rs | 2 | ||||
-rw-r--r-- | src/public/admin.rs | 10 | ||||
-rw-r--r-- | src/public/console.rs | 33 | ||||
-rw-r--r-- | src/public/docs.rs | 6 | ||||
-rw-r--r-- | src/public/mod.rs | 2 | ||||
-rw-r--r-- | src/public/pages.rs | 6 | ||||
-rw-r--r-- | src/types/extract.rs | 6 | ||||
-rw-r--r-- | src/types/like.rs | 3 | ||||
-rw-r--r-- | src/types/user.rs | 42 |
15 files changed, 329 insertions, 35 deletions
diff --git a/src/api/mod.rs b/src/api/mod.rs index 12563e3..cd2190c 100644 --- a/src/api/mod.rs +++ b/src/api/mod.rs @@ -16,7 +16,7 @@ pub use auth::RegistrationRequet; pub fn router() -> Router { let governor_conf = Box::new( GovernorConfigBuilder::default() - .burst_size(10) + .burst_size(15) .per_second(1) .key_extractor(SmartIpKeyExtractor) .finish() diff --git a/src/api/posts.rs b/src/api/posts.rs index ca459cd..ee590ec 100644 --- a/src/api/posts.rs +++ b/src/api/posts.rs @@ -138,7 +138,7 @@ pub const COMMENTS_PAGE: EndpointDocumentation = EndpointDocumentation { #[derive(Deserialize)] struct CommentsPageRequest { page: u64, - post_id: u64 + post_id: u64, } impl Check for CommentsPageRequest { diff --git a/src/api/users.rs b/src/api/users.rs index 0ce9988..082926e 100644 --- a/src/api/users.rs +++ b/src/api/users.rs @@ -1,7 +1,7 @@ use crate::{ public::docs::{EndpointDocumentation, EndpointMethod}, types::{ - extract::{AuthorizedUser, Check, CheckResult, Json, Png}, + extract::{AuthorizedUser, Check, CheckResult, Json, Log, Png}, http::ResponseCode, user::User, }, @@ -116,7 +116,7 @@ pub const USERS_SELF: EndpointDocumentation = EndpointDocumentation { cookie: Some("auth"), }; -async fn load_self(AuthorizedUser(user): AuthorizedUser) -> Response { +async fn load_self(AuthorizedUser(user): AuthorizedUser, _: Log) -> Response { let Ok(json) = serde_json::to_string(&user) else { return ResponseCode::InternalServerError.text("Failed to fetch user") }; @@ -172,6 +172,143 @@ async fn banner(AuthorizedUser(user): AuthorizedUser, Png(img): Png) -> Response ResponseCode::Success.text("Successfully updated banner") } +pub const USERS_FOLLOW: EndpointDocumentation = EndpointDocumentation { + uri: "/api/users/follow", + method: EndpointMethod::Put, + description: "Set following status of another user", + body: Some( + r#" + { + "user_id": 13, + "status": false + } + "#, + ), + responses: &[ + (200, "Returns new follow status if successfull, see below"), + (400, "Body does not match parameters"), + (401, "Unauthorized"), + (500, "Failed to change follow status"), + ], + cookie: Some("auth"), +}; + +#[derive(Deserialize)] +struct UserFollowRequest { + user_id: u64, + state: bool, +} + +impl Check for UserFollowRequest { + fn check(&self) -> CheckResult { + Ok(()) + } +} + +async fn follow( + AuthorizedUser(user): AuthorizedUser, + Json(body): Json<UserFollowRequest>, +) -> Response { + if body.state { + if let Err(err) = User::add_following(user.user_id, body.user_id) { + return err; + } + } else if let Err(err) = User::remove_following(user.user_id, body.user_id) { + return err; + } + + match User::get_following(user.user_id, body.user_id) { + Ok(status) => ResponseCode::Success.text(&format!("{status}")), + Err(err) => err, + } +} + +pub const USERS_FOLLOW_STATUS: EndpointDocumentation = EndpointDocumentation { + uri: "/api/users/follow", + method: EndpointMethod::Post, + description: "Get following status of another user", + body: Some( + r#" + { + "user_id": 13 + } + "#, + ), + responses: &[ + ( + 200, + "Returns 0 if no relation, 1 if following, 2 if followed, 3 if both", + ), + (400, "Body does not match parameters"), + (401, "Unauthorized"), + (500, "Failed to retrieve follow status"), + ], + cookie: Some("auth"), +}; + +#[derive(Deserialize)] +struct UserFollowStatusRequest { + user_id: u64, +} + +impl Check for UserFollowStatusRequest { + fn check(&self) -> CheckResult { + Ok(()) + } +} + +async fn follow_status( + AuthorizedUser(user): AuthorizedUser, + Json(body): Json<UserFollowStatusRequest>, +) -> Response { + match User::get_following(user.user_id, body.user_id) { + Ok(status) => ResponseCode::Success.text(&format!("{status}")), + Err(err) => err, + } +} + +pub const USERS_FRIENDS: EndpointDocumentation = EndpointDocumentation { + uri: "/api/users/friends", + method: EndpointMethod::Post, + description: "Returns friends of a user", + body: Some( + r#" + { + "user_id": 13 + } + "#, + ), + responses: &[ + (200, "Returns users in <span>application/json<span>"), + (401, "Unauthorized"), + (500, "Failed to fetch friends"), + ], + cookie: Some("auth"), +}; + +#[derive(Deserialize)] +struct UserFriendsRequest { + user_id: u64, +} + +impl Check for UserFriendsRequest { + fn check(&self) -> CheckResult { + Ok(()) + } +} + +async fn friends(AuthorizedUser(_user): AuthorizedUser, Json(body): Json<UserFriendsRequest>) -> Response { + let Ok(users) = User::get_friends(body.user_id) else { + return ResponseCode::InternalServerError.text("Failed to fetch user") + }; + + let Ok(json) = serde_json::to_string(&users) else { + return ResponseCode::InternalServerError.text("Failed to fetch user") + }; + + ResponseCode::Success.json(&json) +} + pub fn router() -> Router { Router::new() .route("/load", post(load_batch)) @@ -179,4 +316,6 @@ pub fn router() -> Router { .route("/page", post(load_page)) .route("/avatar", put(avatar)) .route("/banner", put(banner)) + .route("/follow", put(follow).post(follow_status)) + .route("/friends", post(friends)) } diff --git a/src/database/friends.rs b/src/database/friends.rs new file mode 100644 index 0000000..0b78488 --- /dev/null +++ b/src/database/friends.rs @@ -0,0 +1,97 @@ +use tracing::instrument; + +use crate::{ + database::{self, users::user_from_row}, + types::user::{User, FOLLOWED, FOLLOWING, NO_RELATION}, +}; + +pub fn init() -> Result<(), rusqlite::Error> { + let sql = " + CREATE TABLE IF NOT EXISTS friends ( + follower_id INTEGER NOT NULL, + followee_id INTEGER NOT NULL, + FOREIGN KEY(follower_id) REFERENCES users(user_id), + FOREIGN KEY(followee_id) REFERENCES users(user_id), + PRIMARY KEY (follower_id, followee_id) + ); + "; + let conn = database::connect()?; + conn.execute(sql, ())?; + Ok(()) +} + +#[instrument()] +pub fn get_friend_status(user_id_1: u64, user_id_2: u64) -> Result<u8, rusqlite::Error> { + tracing::trace!("Retrieving friend status"); + let conn = database::connect()?; + let mut stmt = conn.prepare("SELECT * FROM friends WHERE (follower_id = ? AND followee_id = ?) OR (follower_id = ? AND followee_id = ?);")?; + let mut status = NO_RELATION; + let rows: Vec<u64> = stmt + .query_map([user_id_1, user_id_2, user_id_2, user_id_1], |row| { + let id: u64 = row.get(0)?; + Ok(id) + })? + .into_iter() + .flatten() + .collect(); + + for follower in rows { + if follower == user_id_1 { + status |= FOLLOWING; + } + + if follower == user_id_2 { + status |= FOLLOWED; + } + } + + Ok(status) +} + +#[instrument()] +pub fn get_friends(user_id: u64) -> Result<Vec<User>, rusqlite::Error> { + tracing::trace!("Retrieving friends"); + let conn = database::connect()?; + let mut stmt = conn.prepare( + " + SELECT * + FROM users u + WHERE EXISTS ( + SELECT NULL + FROM friends f + WHERE u.user_id = f.follower_id + AND f.followee_id = ? + ) + AND EXISTS ( + SELECT NULL + FROM friends f + WHERE u.user_id = f.followee_id + AND f.follower_id = ? + ) + ", + )?; + let row = stmt.query_map([user_id, user_id], |row| { + let row = user_from_row(row, true)?; + Ok(row) + })?; + Ok(row.into_iter().flatten().collect()) +} + +#[instrument()] +pub fn set_following(user_id_1: u64, user_id_2: u64) -> Result<bool, rusqlite::Error> { + tracing::trace!("Setting following"); + let conn = database::connect()?; + let mut stmt = + conn.prepare("INSERT OR REPLACE INTO friends (follower_id, followee_id) VALUES (?,?)")?; + let changes = stmt.execute([user_id_1, user_id_2])?; + Ok(changes == 1) +} + +#[instrument()] +pub fn remove_following(user_id_1: u64, user_id_2: u64) -> Result<bool, rusqlite::Error> { + tracing::trace!("Removing following"); + let conn = database::connect()?; + let mut stmt = conn.prepare("DELETE FROM friends WHERE follower_id = ? AND followee_id = ?")?; + let changes = stmt.execute([user_id_1, user_id_2])?; + Ok(changes == 1) +} diff --git a/src/database/likes.rs b/src/database/likes.rs index 6f6939e..f6a130b 100644 --- a/src/database/likes.rs +++ b/src/database/likes.rs @@ -37,9 +37,7 @@ pub fn get_liked(user_id: u64, post_id: u64) -> Result<bool, rusqlite::Error> { tracing::trace!("Retrieving if liked"); let conn = database::connect()?; let mut stmt = conn.prepare("SELECT * FROM likes WHERE user_id = ? AND post_id = ?")?; - let liked = stmt.query_row([user_id, post_id], |_| { - Ok(()) - }).optional()?; + let liked = stmt.query_row([user_id, post_id], |_| Ok(())).optional()?; Ok(liked.is_some()) } @@ -49,7 +47,7 @@ pub fn add_liked(user_id: u64, post_id: u64) -> Result<bool, rusqlite::Error> { let conn = database::connect()?; let mut stmt = conn.prepare("INSERT OR REPLACE INTO likes (user_id, post_id) VALUES (?,?)")?; let changes = stmt.execute([user_id, post_id])?; - Ok(changes == 1) + Ok(changes == 1) } #[instrument()] @@ -69,7 +67,7 @@ pub fn get_all_likes() -> Result<Vec<Like>, rusqlite::Error> { let row = stmt.query_map([], |row| { let like = Like { user_id: row.get(0)?, - post_id: row.get(1)? + post_id: row.get(1)?, }; Ok(like) })?; diff --git a/src/database/mod.rs b/src/database/mod.rs index 6d4853a..d22a350 100644 --- a/src/database/mod.rs +++ b/src/database/mod.rs @@ -1,6 +1,7 @@ use tracing::instrument; pub mod comments; +pub mod friends; pub mod likes; pub mod posts; pub mod sessions; @@ -16,6 +17,7 @@ pub fn init() -> Result<(), rusqlite::Error> { sessions::init()?; likes::init()?; comments::init()?; + friends::init()?; Ok(()) } diff --git a/src/database/users.rs b/src/database/users.rs index 15565f1..6062ea8 100644 --- a/src/database/users.rs +++ b/src/database/users.rs @@ -31,7 +31,7 @@ pub fn init() -> Result<(), rusqlite::Error> { Ok(()) } -fn user_from_row(row: &Row, hide_password: bool) -> Result<User, rusqlite::Error> { +pub fn user_from_row(row: &Row, hide_password: bool) -> Result<User, rusqlite::Error> { let user_id = row.get(0)?; let firstname = row.get(1)?; let lastname = row.get(2)?; diff --git a/src/public/admin.rs b/src/public/admin.rs index 25941f1..bf0a155 100644 --- a/src/public/admin.rs +++ b/src/public/admin.rs @@ -5,7 +5,9 @@ use tokio::sync::Mutex; use crate::{ console::sanatize, - types::{http::ResponseCode, post::Post, session::Session, user::User, comment::Comment, like::Like}, + types::{ + comment::Comment, http::ResponseCode, like::Like, post::Post, session::Session, user::User, + }, }; lazy_static! { @@ -141,7 +143,11 @@ pub fn generate_comments() -> Response { for comment in comments { html.push_str(&format!( "<tr><td>{}</td><td>{}</td><td>{}</td><td>{}</td><td>{}</td></tr>", - comment.comment_id, comment.user_id, comment.post_id, sanatize(&comment.content), comment.date + comment.comment_id, + comment.user_id, + comment.post_id, + sanatize(&comment.content), + comment.date )); } diff --git a/src/public/console.rs b/src/public/console.rs index 16bf4a3..251dbc1 100644 --- a/src/public/console.rs +++ b/src/public/console.rs @@ -84,9 +84,9 @@ impl Formatter for HtmlFormatter { W: ?Sized + io::Write, { let s = if value { - b"<span class='bool'>true</span>" as &[u8] + b"<span class='bool'> true </span>" as &[u8] } else { - b"<span class='bool'>false</span>" as &[u8] + b"<span class='bool'> false </span>" as &[u8] }; writer.write_all(s) } @@ -95,7 +95,7 @@ impl Formatter for HtmlFormatter { where W: ?Sized + io::Write, { - let buff = format!("<span class='number'>{value}</span>"); + let buff = format!("<span class='number'> {value} </span>"); writer.write_all(buff.as_bytes()) } @@ -103,7 +103,7 @@ impl Formatter for HtmlFormatter { where W: ?Sized + io::Write, { - let buff = format!("<span class='number'>{value}</span>"); + let buff = format!("<span class='number'> {value} </span>"); writer.write_all(buff.as_bytes()) } @@ -111,7 +111,7 @@ impl Formatter for HtmlFormatter { where W: ?Sized + io::Write, { - let buff = format!("<span class='number'>{value}</span>"); + let buff = format!("<span class='number'> {value} </span>"); writer.write_all(buff.as_bytes()) } @@ -119,7 +119,7 @@ impl Formatter for HtmlFormatter { where W: ?Sized + io::Write, { - let buff = format!("<span class='number'>{value}</span>"); + let buff = format!("<span class='number'> {value} </span>"); writer.write_all(buff.as_bytes()) } @@ -127,7 +127,7 @@ impl Formatter for HtmlFormatter { where W: ?Sized + io::Write, { - let buff = format!("<span class='number'>{value}</span>"); + let buff = format!("<span class='number'> {value} </span>"); writer.write_all(buff.as_bytes()) } @@ -135,7 +135,7 @@ impl Formatter for HtmlFormatter { where W: ?Sized + io::Write, { - let buff = format!("<span class='number'>{value}</span>"); + let buff = format!("<span class='number'> {value} </span>"); writer.write_all(buff.as_bytes()) } @@ -143,7 +143,7 @@ impl Formatter for HtmlFormatter { where W: ?Sized + io::Write, { - let buff = format!("<span class='number'>{value}</span>"); + let buff = format!("<span class='number'> {value} </span>"); writer.write_all(buff.as_bytes()) } @@ -151,7 +151,7 @@ impl Formatter for HtmlFormatter { where W: ?Sized + io::Write, { - let buff = format!("<span class='number'>{value}</span>"); + let buff = format!("<span class='number'> {value} </span>"); writer.write_all(buff.as_bytes()) } @@ -159,7 +159,7 @@ impl Formatter for HtmlFormatter { where W: ?Sized + io::Write, { - let buff = format!("<span class='number'>{value}</span>"); + let buff = format!("<span class='number'> {value} </span>"); writer.write_all(buff.as_bytes()) } @@ -167,7 +167,7 @@ impl Formatter for HtmlFormatter { where W: ?Sized + io::Write, { - let buff = format!("<span class='number'>{value}</span>"); + let buff = format!("<span class='number'> {value} </span>"); writer.write_all(buff.as_bytes()) } @@ -192,7 +192,7 @@ impl Formatter for HtmlFormatter { if first { writer.write_all(b"<span class='key'>") } else { - writer.write_all(b"<span class='key'>,") + writer.write_all(b",<span class='key'>") } } @@ -202,6 +202,13 @@ impl Formatter for HtmlFormatter { { writer.write_all(b"</span>") } + + fn begin_object_value<W>(&mut self, writer: &mut W) -> io::Result<()> + where + W: ?Sized + io::Write, + { + writer.write_all(b" : ") + } } pub fn sanatize(input: &str) -> String { diff --git a/src/public/docs.rs b/src/public/docs.rs index f4e26be..397e696 100644 --- a/src/public/docs.rs +++ b/src/public/docs.rs @@ -49,6 +49,7 @@ fn generate_body(body: Option<&'static str>) -> String { </div> "# .to_string(); + let body = body.trim(); if body.starts_with('{') { return html.replace( @@ -135,13 +136,16 @@ pub async fn init() { users::USERS_SELF, users::USERS_AVATAR, users::USERS_BANNER, + users::USERS_FOLLOW, + users::USERS_FOLLOW_STATUS, + users::USERS_FRIENDS, admin::ADMIN_AUTH, admin::ADMIN_QUERY, admin::ADMIN_POSTS, admin::ADMIN_USERS, admin::ADMIN_SESSIONS, admin::ADMIN_COMMENTS, - admin::ADMIN_LIKES + admin::ADMIN_LIKES, ]; let mut endpoints = ENDPOINTS.lock().await; for doc in docs { diff --git a/src/public/mod.rs b/src/public/mod.rs index 76796ea..bb75ef0 100644 --- a/src/public/mod.rs +++ b/src/public/mod.rs @@ -25,7 +25,7 @@ pub mod pages; pub fn router() -> Router { let governor_conf = Box::new( GovernorConfigBuilder::default() - .burst_size(20) + .burst_size(30) .per_second(1) .key_extractor(SmartIpKeyExtractor) .finish() diff --git a/src/public/pages.rs b/src/public/pages.rs index 6d5c0de..426727e 100644 --- a/src/public/pages.rs +++ b/src/public/pages.rs @@ -1,6 +1,7 @@ use axum::{ response::{IntoResponse, Redirect, Response}, - routing::get, Router + routing::get, + Router, }; use crate::{ @@ -58,9 +59,8 @@ async fn wordpress(_: Log) -> Response { } async fn forgot(UserAgent(agent): UserAgent, _: Log) -> Response { - if agent.starts_with("curl") { - return super::serve("/404.html").await + return super::serve("/404.html").await; } Redirect::to("https://www.youtube.com/watch?v=dQw4w9WgXcQ").into_response() diff --git a/src/types/extract.rs b/src/types/extract.rs index 6a01ad2..65d9f1a 100644 --- a/src/types/extract.rs +++ b/src/types/extract.rs @@ -7,7 +7,7 @@ use axum::{ async_trait, body::HttpBody, extract::{ConnectInfo, FromRequest, FromRequestParts}, - http::{request::Parts, Request, header::USER_AGENT}, + http::{header::USER_AGENT, request::Parts, Request}, response::Response, BoxError, RequestExt, }; @@ -205,7 +205,7 @@ where }; let Ok(value) = serde_json::from_str::<T>(&body) else { - return Err(ResponseCode::BadRequest.text("Invalid request body")) + return Err(ResponseCode::BadRequest.text("Body does not match paramaters")) }; if let Err(msg) = value.check() { @@ -256,7 +256,7 @@ where return Err(ResponseCode::BadRequest.text("Bad Request")); }; - Ok(UserAgent(agent.to_string())) + Ok(Self(agent.to_string())) } } diff --git a/src/types/like.rs b/src/types/like.rs index bf10b2d..1c113c1 100644 --- a/src/types/like.rs +++ b/src/types/like.rs @@ -7,13 +7,12 @@ use crate::types::http::{ResponseCode, Result}; #[derive(Serialize)] pub struct Like { pub user_id: u64, - pub post_id: u64 + pub post_id: u64, } impl Like { #[instrument()] pub fn add_liked(user_id: u64, post_id: u64) -> Result<()> { - let Ok(liked) = database::likes::add_liked(user_id, post_id) else { return Err(ResponseCode::BadRequest.text("Failed to add like status")) }; diff --git a/src/types/user.rs b/src/types/user.rs index 835b675..245e9b7 100644 --- a/src/types/user.rs +++ b/src/types/user.rs @@ -19,6 +19,10 @@ pub struct User { pub year: u32, } +pub const NO_RELATION: u8 = 0; +pub const FOLLOWING: u8 = 1; +pub const FOLLOWED: u8 = 2; + impl User { #[instrument()] pub fn from_user_id(user_id: u64, hide_password: bool) -> Result<Self> { @@ -95,4 +99,42 @@ impl User { Ok(user) } + + pub fn add_following(user_id_1: u64, user_id_2: u64) -> Result<()> { + let Ok(followed) = database::friends::set_following(user_id_1, user_id_2) else { + return Err(ResponseCode::BadRequest.text("Failed to add follow status")) + }; + + if !followed { + return Err(ResponseCode::InternalServerError.text("Failed to add follow status")); + } + + Ok(()) + } + + pub fn remove_following(user_id_1: u64, user_id_2: u64) -> Result<()> { + let Ok(followed) = database::friends::remove_following(user_id_1, user_id_2) else { + return Err(ResponseCode::BadRequest.text("Failed to remove follow status")) + }; + + if !followed { + return Err(ResponseCode::InternalServerError.text("Failed to remove follow status")); + } + + Ok(()) + } + + pub fn get_following(user_id_1: u64, user_id_2: u64) -> Result<u8> { + let Ok(followed) = database::friends::get_friend_status(user_id_1, user_id_2) else { + return Err(ResponseCode::InternalServerError.text("Failed to get follow status")) + }; + Ok(followed) + } + + pub fn get_friends(user_id: u64) -> Result<Vec<Self>> { + let Ok(users) = database::friends::get_friends(user_id) else { + return Err(ResponseCode::InternalServerError.text("Failed to fetch friends")) + }; + Ok(users) + } } |