diff --git a/src/api/auth.rs b/src/api/auth.rs index 4656ca8..7f7cf9e 100644 --- a/src/api/auth.rs +++ b/src/api/auth.rs @@ -1,9 +1,14 @@ -use axum::{Router, routing::post, response::Response}; +use axum::{response::Response, routing::post, Router}; use serde::Deserialize; -use time::{OffsetDateTime, Duration}; -use tower_cookies::{Cookies, Cookie}; +use time::{Duration, OffsetDateTime}; +use tower_cookies::{Cookie, Cookies}; -use crate::types::{user::User, http::ResponseCode, session::Session, extract::{Json, AuthorizedUser, Check, CheckResult, Log}}; +use crate::types::{ + extract::{AuthorizedUser, Check, CheckResult, Json, Log}, + http::ResponseCode, + session::Session, + user::User, +}; #[derive(Deserialize, Debug)] pub struct RegistrationRequet { @@ -14,36 +19,69 @@ pub struct RegistrationRequet { pub gender: String, pub day: u8, pub month: u8, - pub year: u32 + pub year: u32, } impl Check for RegistrationRequet { fn check(&self) -> CheckResult { - Self::assert_length(&self.firstname, 1, 20, "First name can only by 1-20 characters long")?; - Self::assert_length(&self.lastname, 1, 20, "Last name can only by 1-20 characters long")?; + Self::assert_length( + &self.firstname, + 1, + 20, + "First name can only by 1-20 characters long", + )?; + Self::assert_length( + &self.lastname, + 1, + 20, + "Last name can only by 1-20 characters long", + )?; Self::assert_length(&self.email, 1, 50, "Email can only by 1-50 characters long")?; - Self::assert_length(&self.password, 1, 50, "Password can only by 1-50 characters long")?; - Self::assert_length(&self.gender, 1, 100, "Gender can only by 1-100 characters long")?; - Self::assert_range(u64::from(self.day), 1, 255, "Birthday day can only be between 1-255")?; - Self::assert_range(u64::from(self.month), 1, 255, "Birthday month can only be between 1-255")?; - Self::assert_range(u64::from(self.year), 1, 4_294_967_295, "Birthday year can only be between 1-4294967295")?; + Self::assert_length( + &self.password, + 1, + 50, + "Password can only by 1-50 characters long", + )?; + Self::assert_length( + &self.gender, + 1, + 100, + "Gender can only by 1-100 characters long", + )?; + Self::assert_range( + u64::from(self.day), + 1, + 255, + "Birthday day can only be between 1-255", + )?; + Self::assert_range( + u64::from(self.month), + 1, + 255, + "Birthday month can only be between 1-255", + )?; + Self::assert_range( + u64::from(self.year), + 1, + 4_294_967_295, + "Birthday year can only be between 1-4294967295", + )?; Ok(()) } } - async fn register(cookies: Cookies, Json(body): Json) -> Response { - let user = match User::new(body) { Ok(user) => user, - Err(err) => return err + Err(err) => return err, }; let session = match Session::new(user.user_id) { Ok(session) => session, - Err(err) => return err + Err(err) => return err, }; - + let mut now = OffsetDateTime::now_utc(); now += Duration::weeks(52); @@ -71,20 +109,19 @@ impl Check for LoginRequest { } async fn login(cookies: Cookies, Json(body): Json) -> Response { - let Ok(user) = User::from_email(&body.email) else { return ResponseCode::BadRequest.text("Email is not registered") }; if user.password != body.password { - return ResponseCode::BadRequest.text("Password is not correct") + return ResponseCode::BadRequest.text("Password is not correct"); } let session = match Session::new(user.user_id) { Ok(session) => session, - Err(err) => return err + Err(err) => return err, }; - + let mut now = OffsetDateTime::now_utc(); now += Duration::weeks(52); @@ -100,11 +137,10 @@ async fn login(cookies: Cookies, Json(body): Json) -> Response { } async fn logout(cookies: Cookies, AuthorizedUser(user): AuthorizedUser, _: Log) -> Response { - cookies.remove(Cookie::new("auth", "")); if let Err(err) = Session::delete(user.user_id) { - return err + return err; } ResponseCode::Success.text("Successfully logged out") @@ -112,7 +148,7 @@ async fn logout(cookies: Cookies, AuthorizedUser(user): AuthorizedUser, _: Log) pub fn router() -> Router { Router::new() - .route("/register", post(register)) - .route("/login", post(login)) - .route("/logout", post(logout)) -} \ No newline at end of file + .route("/register", post(register)) + .route("/login", post(login)) + .route("/logout", post(logout)) +} diff --git a/src/api/mod.rs b/src/api/mod.rs index ba38aeb..a2083fe 100644 --- a/src/api/mod.rs +++ b/src/api/mod.rs @@ -1,4 +1,4 @@ pub mod auth; pub mod pages; pub mod posts; -pub mod users; \ No newline at end of file +pub mod users; diff --git a/src/api/pages.rs b/src/api/pages.rs index 41a63a8..4661b91 100644 --- a/src/api/pages.rs +++ b/src/api/pages.rs @@ -1,6 +1,13 @@ -use axum::{Router, response::{Response, Redirect, IntoResponse}, routing::get}; +use axum::{ + response::{IntoResponse, Redirect, Response}, + routing::get, + Router, +}; -use crate::{types::{extract::AuthorizedUser, http::ResponseCode}, console}; +use crate::{ + console, + types::{extract::AuthorizedUser, http::ResponseCode}, +}; async fn root(user: Option) -> Response { if user.is_some() { @@ -52,11 +59,11 @@ async fn wordpress() -> Response { pub fn router() -> Router { Router::new() - .route("/", get(root)) - .route("/login", get(login)) - .route("/home", get(home)) - .route("/people", get(people)) - .route("/profile", get(profile)) - .route("/console", get(console)) - .route("/wp-admin", get(wordpress)) -} \ No newline at end of file + .route("/", get(root)) + .route("/login", get(login)) + .route("/home", get(home)) + .route("/people", get(people)) + .route("/profile", get(profile)) + .route("/console", get(console)) + .route("/wp-admin", get(wordpress)) +} diff --git a/src/api/posts.rs b/src/api/posts.rs index e2e64b2..3a2e507 100644 --- a/src/api/posts.rs +++ b/src/api/posts.rs @@ -1,23 +1,37 @@ -use axum::{response::Response, Router, routing::{post, patch}}; +use axum::{ + response::Response, + routing::{patch, post}, + Router, +}; use serde::Deserialize; -use crate::types::{extract::{AuthorizedUser, Json, Check, CheckResult}, post::Post, http::ResponseCode}; - +use crate::types::{ + extract::{AuthorizedUser, Check, CheckResult, Json}, + http::ResponseCode, + post::Post, +}; #[derive(Deserialize)] struct PostCreateRequest { - content: String + content: String, } impl Check for PostCreateRequest { fn check(&self) -> CheckResult { - Self::assert_length(&self.content, 1, 500, "Comments must be between 1-500 characters long")?; + Self::assert_length( + &self.content, + 1, + 500, + "Comments must be between 1-500 characters long", + )?; Ok(()) } } -async fn create(AuthorizedUser(user): AuthorizedUser, Json(body): Json) -> Response { - +async fn create( + AuthorizedUser(user): AuthorizedUser, + Json(body): Json, +) -> Response { let Ok(post) = Post::new(user.user_id, body.content) else { return ResponseCode::InternalServerError.text("Failed to create post") }; @@ -31,7 +45,7 @@ async fn create(AuthorizedUser(user): AuthorizedUser, Json(body): Json) -> Response { - +async fn page( + AuthorizedUser(_user): AuthorizedUser, + Json(body): Json, +) -> Response { let Ok(posts) = Post::from_post_page(body.page) else { return ResponseCode::InternalServerError.text("Failed to fetch posts") }; @@ -55,7 +71,7 @@ async fn page(AuthorizedUser(_user): AuthorizedUser, Json(body): Json) -> Response { - +async fn user( + AuthorizedUser(_user): AuthorizedUser, + Json(body): Json, +) -> Response { let Ok(posts) = Post::from_user_id(body.user_id) else { return ResponseCode::InternalServerError.text("Failed to fetch posts") }; @@ -80,18 +98,25 @@ async fn user(AuthorizedUser(_user): AuthorizedUser, Json(body): Json CheckResult { - Self::assert_length(&self.content, 1, 255, "Comments must be between 1-255 characters long")?; + Self::assert_length( + &self.content, + 1, + 255, + "Comments must be between 1-255 characters long", + )?; Ok(()) } } -async fn comment(AuthorizedUser(user): AuthorizedUser, Json(body): Json) -> Response { - +async fn comment( + AuthorizedUser(user): AuthorizedUser, + Json(body): Json, +) -> Response { let Ok(mut post) = Post::from_post_id(body.post_id) else { return ResponseCode::InternalServerError.text("Failed to fetch posts") }; @@ -106,7 +131,7 @@ async fn comment(AuthorizedUser(user): AuthorizedUser, Json(body): Json) -> Response { - let Ok(mut post) = Post::from_post_id(body.post_id) else { return ResponseCode::InternalServerError.text("Failed to fetch posts") }; @@ -130,9 +154,9 @@ async fn like(AuthorizedUser(user): AuthorizedUser, Json(body): Json Router { Router::new() - .route("/create", post(create)) - .route("/page", post(page)) - .route("/user", post(user)) - .route("/comment", patch(comment)) - .route("/like", patch(like)) -} \ No newline at end of file + .route("/create", post(create)) + .route("/page", post(page)) + .route("/user", post(user)) + .route("/comment", patch(comment)) + .route("/like", patch(like)) +} diff --git a/src/api/users.rs b/src/api/users.rs index 97a9e6e..afcdddd 100644 --- a/src/api/users.rs +++ b/src/api/users.rs @@ -1,10 +1,14 @@ -use axum::{Router, response::Response, routing::post}; +use crate::types::{ + extract::{AuthorizedUser, Check, CheckResult, Json}, + http::ResponseCode, + user::User, +}; +use axum::{response::Response, routing::post, Router}; use serde::Deserialize; -use crate::types::{extract::{AuthorizedUser, Json, Check, CheckResult}, http::ResponseCode, user::User}; #[derive(Deserialize)] struct UserLoadRequest { - ids: Vec + ids: Vec, } impl Check for UserLoadRequest { @@ -13,8 +17,10 @@ impl Check for UserLoadRequest { } } -async fn load_batch(AuthorizedUser(_user): AuthorizedUser, Json(body): Json) -> Response { - +async fn load_batch( + AuthorizedUser(_user): AuthorizedUser, + Json(body): Json, +) -> Response { let users = User::from_user_ids(body.ids); let Ok(json) = serde_json::to_string(&users) else { return ResponseCode::InternalServerError.text("Failed to fetch users") @@ -25,7 +31,7 @@ async fn load_batch(AuthorizedUser(_user): AuthorizedUser, Json(body): Json) -> Response { - +async fn load_page( + AuthorizedUser(_user): AuthorizedUser, + Json(body): Json, +) -> Response { let Ok(users) = User::from_user_page(body.page) else { return ResponseCode::InternalServerError.text("Failed to fetch users") }; @@ -48,7 +56,6 @@ async fn load_page(AuthorizedUser(_user): AuthorizedUser, Json(body): Json Response { - let Ok(json) = serde_json::to_string(&user) else { return ResponseCode::InternalServerError.text("Failed to fetch user") }; @@ -58,7 +65,7 @@ async fn load_self(AuthorizedUser(user): AuthorizedUser) -> Response { pub fn router() -> Router { Router::new() - .route("/load", post(load_batch)) - .route("/self", post(load_self)) - .route("/page", post(load_page)) -} \ No newline at end of file + .route("/load", post(load_batch)) + .route("/self", post(load_self)) + .route("/page", post(load_page)) +} diff --git a/src/console.rs b/src/console.rs index 1596cd9..14324fa 100644 --- a/src/console.rs +++ b/src/console.rs @@ -1,8 +1,11 @@ -use std::{net::IpAddr, collections::VecDeque, io, }; -use axum::{http::{Method, Uri}, response::Response}; +use axum::{ + http::{Method, Uri}, + response::Response, +}; use lazy_static::lazy_static; use serde::Serialize; use serde_json::{ser::Formatter, Value}; +use std::{collections::VecDeque, io, net::IpAddr}; use tokio::sync::Mutex; use crate::types::http::ResponseCode; @@ -12,7 +15,7 @@ struct LogMessage { method: Method, uri: Uri, path: String, - body: String + body: String, } impl ToString for LogMessage { @@ -31,7 +34,7 @@ impl ToString for LogMessage { Method::CONNECT => "#3fe0ad", Method::TRACE => "#e03fc5", Method::OPTIONS => "#423fe0", - _ => "white" + _ => "white", }; format!("
{} {} {}{} {}
", ip, color, self.method, self.path, self.uri, self.body) } @@ -42,36 +45,43 @@ lazy_static! { } pub async fn log(ip: IpAddr, method: Method, uri: Uri, path: Option, body: Option) { - - if uri.to_string().starts_with("/console") { return; } + if uri.to_string().starts_with("/console") { + return; + } let path = path.unwrap_or_default(); let body = body.unwrap_or_default(); tracing::info!("{} {} {}{} {}", &ip, &method, &path, &uri, &body); - + let message = LogMessage { - ip, - method, - uri, - path, - body: beautify(body) + ip, + method, + uri, + path, + body: beautify(body), }; - + let mut lock = LOG.lock().await; if lock.len() > 200 { lock.pop_back(); } - lock.push_front(message); + lock.push_front(message); } struct HtmlFormatter; impl Formatter for HtmlFormatter { - fn write_null(&mut self, writer: &mut W) -> io::Result<()> where W: ?Sized + io::Write { + fn write_null(&mut self, writer: &mut W) -> io::Result<()> + where + W: ?Sized + io::Write, + { writer.write_all(b"null") } - fn write_bool(&mut self, writer: &mut W, value: bool) -> io::Result<()> where W: ?Sized + io::Write { + fn write_bool(&mut self, writer: &mut W, value: bool) -> io::Result<()> + where + W: ?Sized + io::Write, + { let s = if value { b"true" as &[u8] } else { @@ -80,65 +90,104 @@ impl Formatter for HtmlFormatter { writer.write_all(s) } - fn write_i8(&mut self, writer: &mut W, value: i8) -> io::Result<()> where W: ?Sized + io::Write { + fn write_i8(&mut self, writer: &mut W, value: i8) -> io::Result<()> + where + W: ?Sized + io::Write, + { let buff = format!("{value}"); writer.write_all(buff.as_bytes()) } - fn write_i16(&mut self, writer: &mut W, value: i16) -> io::Result<()> where W: ?Sized + io::Write { + fn write_i16(&mut self, writer: &mut W, value: i16) -> io::Result<()> + where + W: ?Sized + io::Write, + { let buff = format!("{value}"); writer.write_all(buff.as_bytes()) } - fn write_i32(&mut self, writer: &mut W, value: i32) -> io::Result<()> where W: ?Sized + io::Write { + fn write_i32(&mut self, writer: &mut W, value: i32) -> io::Result<()> + where + W: ?Sized + io::Write, + { let buff = format!("{value}"); writer.write_all(buff.as_bytes()) } - fn write_i64(&mut self, writer: &mut W, value: i64) -> io::Result<()> where W: ?Sized + io::Write { + fn write_i64(&mut self, writer: &mut W, value: i64) -> io::Result<()> + where + W: ?Sized + io::Write, + { let buff = format!("{value}"); writer.write_all(buff.as_bytes()) } - fn write_u8(&mut self, writer: &mut W, value: u8) -> io::Result<()> where W: ?Sized + io::Write { + fn write_u8(&mut self, writer: &mut W, value: u8) -> io::Result<()> + where + W: ?Sized + io::Write, + { let buff = format!("{value}"); writer.write_all(buff.as_bytes()) } - fn write_u16(&mut self, writer: &mut W, value: u16) -> io::Result<()> where W: ?Sized + io::Write { + fn write_u16(&mut self, writer: &mut W, value: u16) -> io::Result<()> + where + W: ?Sized + io::Write, + { let buff = format!("{value}"); writer.write_all(buff.as_bytes()) } - fn write_u32(&mut self, writer: &mut W, value: u32) -> io::Result<()> where W: ?Sized + io::Write { + fn write_u32(&mut self, writer: &mut W, value: u32) -> io::Result<()> + where + W: ?Sized + io::Write, + { let buff = format!("{value}"); writer.write_all(buff.as_bytes()) } - fn write_u64(&mut self, writer: &mut W, value: u64) -> io::Result<()> where W: ?Sized + io::Write { + fn write_u64(&mut self, writer: &mut W, value: u64) -> io::Result<()> + where + W: ?Sized + io::Write, + { let buff = format!("{value}"); writer.write_all(buff.as_bytes()) } - fn write_f32(&mut self, writer: &mut W, value: f32) -> io::Result<()> where W: ?Sized + io::Write { + fn write_f32(&mut self, writer: &mut W, value: f32) -> io::Result<()> + where + W: ?Sized + io::Write, + { let buff = format!("{value}"); writer.write_all(buff.as_bytes()) } - fn write_f64(&mut self, writer: &mut W, value: f64) -> io::Result<()> where W: ?Sized + io::Write { + fn write_f64(&mut self, writer: &mut W, value: f64) -> io::Result<()> + where + W: ?Sized + io::Write, + { let buff = format!("{value}"); writer.write_all(buff.as_bytes()) } - fn begin_string(&mut self, writer: &mut W) -> io::Result<()> where W: ?Sized + io::Write { + fn begin_string(&mut self, writer: &mut W) -> io::Result<()> + where + W: ?Sized + io::Write, + { writer.write_all(b"\"") } - fn end_string(&mut self, writer: &mut W) -> io::Result<()> where W: ?Sized + io::Write { + fn end_string(&mut self, writer: &mut W) -> io::Result<()> + where + W: ?Sized + io::Write, + { writer.write_all(b"\"") } - fn begin_object_key(&mut self, writer: &mut W, first: bool) -> io::Result<()> where W: ?Sized + io::Write { + fn begin_object_key(&mut self, writer: &mut W, first: bool) -> io::Result<()> + where + W: ?Sized + io::Write, + { if first { writer.write_all(b"") } else { @@ -146,15 +195,17 @@ impl Formatter for HtmlFormatter { } } - fn end_object_key(&mut self, writer: &mut W) -> io::Result<()> where W: ?Sized + io::Write { + fn end_object_key(&mut self, writer: &mut W) -> io::Result<()> + where + W: ?Sized + io::Write, + { writer.write_all(b"") } - } fn beautify(body: String) -> String { if body.is_empty() { - return String::new() + return String::new(); } let Ok(mut json) = serde_json::from_str::(&body) else { return body @@ -165,13 +216,12 @@ fn beautify(body: String) -> String { let mut writer: Vec = Vec::with_capacity(128); let mut serializer = serde_json::Serializer::with_formatter(&mut writer, HtmlFormatter); if json.serialize(&mut serializer).is_err() { - return body + return body; } String::from_utf8_lossy(&writer).to_string() } pub async fn generate() -> Response { - let lock = LOG.lock().await; let mut html = r#" @@ -183,7 +233,8 @@ pub async fn generate() -> Response { XSSBook - Console - "#.to_string(); + "# + .to_string(); for message in lock.iter() { html.push_str(&message.to_string()); @@ -192,4 +243,4 @@ pub async fn generate() -> Response { html.push_str(""); ResponseCode::Success.html(&html) -} \ No newline at end of file +} diff --git a/src/database/mod.rs b/src/database/mod.rs index 59cc377..d48f352 100644 --- a/src/database/mod.rs +++ b/src/database/mod.rs @@ -1,6 +1,6 @@ pub mod posts; -pub mod users; pub mod sessions; +pub mod users; pub fn connect() -> Result { rusqlite::Connection::open("xssbook.db") @@ -11,4 +11,4 @@ pub fn init() -> Result<(), rusqlite::Error> { posts::init()?; sessions::init()?; Ok(()) -} \ No newline at end of file +} diff --git a/src/database/posts.rs b/src/database/posts.rs index 6086fdc..58470f0 100644 --- a/src/database/posts.rs +++ b/src/database/posts.rs @@ -1,11 +1,11 @@ use std::collections::HashSet; -use std::time::{SystemTime, UNIX_EPOCH, Duration}; +use std::time::{Duration, SystemTime, UNIX_EPOCH}; use rusqlite::{OptionalExtension, Row}; use tracing::instrument; -use crate::types::post::Post; use crate::database; +use crate::types::post::Post; pub fn init() -> Result<(), rusqlite::Error> { let sql = " @@ -40,7 +40,14 @@ fn post_from_row(row: &Row) -> Result { return Err(rusqlite::Error::InvalidQuery) }; - Ok(Post{post_id, user_id, content, likes, comments, date}) + Ok(Post { + post_id, + user_id, + content, + likes, + comments, + date, + }) } #[instrument()] @@ -48,10 +55,12 @@ pub fn get_post(post_id: u64) -> Result, rusqlite::Error> { tracing::trace!("Retrieving post"); let conn = database::connect()?; let mut stmt = conn.prepare("SELECT * FROM posts WHERE post_id = ?")?; - let row = stmt.query_row([post_id], |row| { - let row = post_from_row(row)?; - Ok(row) - }).optional()?; + let row = stmt + .query_row([post_id], |row| { + let row = post_from_row(row)?; + Ok(row) + }) + .optional()?; Ok(row) } @@ -73,7 +82,7 @@ pub fn get_users_posts(user_id: u64) -> Result, rusqlite::Error> { tracing::trace!("Retrieving users posts"); let conn = database::connect()?; let mut stmt = conn.prepare("SELECT * FROM posts WHERE user_id = ? ORDER BY post_id DESC")?; - let row = stmt.query_map([user_id], |row| { + let row = stmt.query_map([user_id], |row| { let row = post_from_row(row)?; Ok(row) })?; @@ -91,10 +100,16 @@ pub fn add_post(user_id: u64, content: &str) -> Result { let Ok(comments_json) = serde_json::to_string(&comments) else { return Err(rusqlite::Error::InvalidQuery) }; - let date = u64::try_from(SystemTime::now().duration_since(UNIX_EPOCH).unwrap_or(Duration::ZERO).as_millis()).unwrap_or(0); + let date = u64::try_from( + SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap_or(Duration::ZERO) + .as_millis(), + ) + .unwrap_or(0); let conn = database::connect()?; let mut stmt = conn.prepare("INSERT INTO posts (user_id, content, likes, comments, date) VALUES(?,?,?,?,?) RETURNING *;")?; - let post = stmt.query_row((user_id, content, likes_json, comments_json, date), |row| { + let post = stmt.query_row((user_id, content, likes_json, comments_json, date), |row| { let row = post_from_row(row)?; Ok(row) })?; @@ -102,7 +117,11 @@ pub fn add_post(user_id: u64, content: &str) -> Result { } #[instrument()] -pub fn update_post(post_id: u64, likes: &HashSet, comments: &Vec<(u64, String)>) -> Result<(), rusqlite::Error> { +pub fn update_post( + post_id: u64, + likes: &HashSet, + comments: &Vec<(u64, String)>, +) -> Result<(), rusqlite::Error> { tracing::trace!("Updating post"); let Ok(likes_json) = serde_json::to_string(&likes) else { return Err(rusqlite::Error::InvalidQuery) @@ -114,4 +133,4 @@ pub fn update_post(post_id: u64, likes: &HashSet, comments: &Vec<(u64, Stri let sql = "UPDATE posts SET likes = ?, comments = ? WHERE post_id = ?"; conn.execute(sql, (likes_json, comments_json, post_id))?; Ok(()) -} \ No newline at end of file +} diff --git a/src/database/sessions.rs b/src/database/sessions.rs index 2283b58..8d4ca73 100644 --- a/src/database/sessions.rs +++ b/src/database/sessions.rs @@ -21,12 +21,14 @@ pub fn get_session(token: &str) -> Result, rusqlite::Error> { tracing::trace!("Retrieving session"); let conn = database::connect()?; let mut stmt = conn.prepare("SELECT * FROM sessions WHERE token = ?")?; - let row = stmt.query_row([token], |row| { - Ok(Session { - user_id: row.get(0)?, - token: row.get(1)?, + let row = stmt + .query_row([token], |row| { + Ok(Session { + user_id: row.get(0)?, + token: row.get(1)?, + }) }) - }).optional()?; + .optional()?; Ok(row) } @@ -46,4 +48,4 @@ pub fn delete_session(user_id: u64) -> Result<(), Box> { let sql = "DELETE FROM sessions WHERE user_id = ?;"; conn.execute(sql, [user_id])?; Ok(()) -} \ No newline at end of file +} diff --git a/src/database/users.rs b/src/database/users.rs index a578e69..05a3a57 100644 --- a/src/database/users.rs +++ b/src/database/users.rs @@ -1,8 +1,8 @@ -use std::time::{SystemTime, UNIX_EPOCH, Duration}; use rusqlite::{OptionalExtension, Row}; +use std::time::{Duration, SystemTime, UNIX_EPOCH}; use tracing::instrument; -use crate::{database, types::user::User, api::auth::RegistrationRequet}; +use crate::{api::auth::RegistrationRequet, database, types::user::User}; pub fn init() -> Result<(), rusqlite::Error> { let sql = " @@ -36,9 +36,24 @@ fn user_from_row(row: &Row, hide_password: bool) -> Result Result, tracing::trace!("Retrieving user by id"); let conn = database::connect()?; let mut stmt = conn.prepare("SELECT * FROM users WHERE user_id = ?")?; - let row = stmt.query_row([user_id], |row| { - let row = user_from_row(row, hide_password)?; - Ok(row) - }).optional()?; + let row = stmt + .query_row([user_id], |row| { + let row = user_from_row(row, hide_password)?; + Ok(row) + }) + .optional()?; Ok(row) } #[instrument()] -pub fn get_user_by_email(email: &str, hide_password: bool) -> Result, rusqlite::Error> { +pub fn get_user_by_email( + email: &str, + hide_password: bool, +) -> Result, rusqlite::Error> { tracing::trace!("Retrieving user by email"); let conn = database::connect()?; let mut stmt = conn.prepare("SELECT * FROM users WHERE email = ?")?; - let row = stmt.query_row([email], |row| { - let row = user_from_row(row, hide_password)?; - Ok(row) - }).optional()?; + let row = stmt + .query_row([email], |row| { + let row = user_from_row(row, hide_password)?; + Ok(row) + }) + .optional()?; Ok(row) } #[instrument()] -pub fn get_user_by_password(password: &str, hide_password: bool) -> Result, rusqlite::Error> { +pub fn get_user_by_password( + password: &str, + hide_password: bool, +) -> Result, rusqlite::Error> { tracing::trace!("Retrieving user by password"); let conn = database::connect()?; let mut stmt = conn.prepare("SELECT * FROM users WHERE password = ?")?; - let row = stmt.query_row([password], |row| { - let row = user_from_row(row, hide_password)?; - Ok(row) - }).optional()?; + let row = stmt + .query_row([password], |row| { + let row = user_from_row(row, hide_password)?; + Ok(row) + }) + .optional()?; Ok(row) } @@ -93,13 +120,32 @@ pub fn get_user_page(page: u64, hide_password: bool) -> Result, rusqli #[instrument()] pub fn add_user(request: RegistrationRequet) -> Result { tracing::trace!("Adding new user"); - let date = u64::try_from(SystemTime::now().duration_since(UNIX_EPOCH).unwrap_or(Duration::ZERO).as_millis()).unwrap_or(0); + let date = u64::try_from( + SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap_or(Duration::ZERO) + .as_millis(), + ) + .unwrap_or(0); let conn = database::connect()?; let mut stmt = conn.prepare("INSERT INTO users (firstname, lastname, email, password, gender, date, day, month, year) VALUES(?,?,?,?,?,?,?,?,?) RETURNING *;")?; - let user = stmt.query_row((request.firstname, request.lastname, request.email, request.password, request.gender, date, request.day, request.month, request.year), |row| { - let row = user_from_row(row, false)?; - Ok(row) - })?; + let user = stmt.query_row( + ( + request.firstname, + request.lastname, + request.email, + request.password, + request.gender, + date, + request.day, + request.month, + request.year, + ), + |row| { + let row = user_from_row(row, false)?; + Ok(row) + }, + )?; Ok(user) -} \ No newline at end of file +} diff --git a/src/main.rs b/src/main.rs index 31b749e..cd137b9 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,18 +1,31 @@ +use axum::{ + body::HttpBody, + extract::ConnectInfo, + http::{Request, StatusCode}, + middleware::{self, Next}, + response::Response, + Extension, RequestExt, Router, +}; use std::{net::SocketAddr, process::exit}; -use axum::{Router, response::Response, http::{Request, StatusCode}, middleware::{Next, self}, extract::ConnectInfo, RequestExt, body::HttpBody, Extension}; use tower_cookies::CookieManagerLayer; -use tracing::{metadata::LevelFilter, error, info}; -use tracing_subscriber::{prelude::__tracing_subscriber_SubscriberExt, util::SubscriberInitExt, Layer, filter::filter_fn}; +use tracing::{error, info, metadata::LevelFilter}; +use tracing_subscriber::{ + filter::filter_fn, prelude::__tracing_subscriber_SubscriberExt, util::SubscriberInitExt, Layer, +}; use types::http::ResponseCode; -use crate::{api::{pages, auth, users, posts}, types::extract::RouterURI}; +use crate::{ + api::{auth, pages, posts, users}, + types::extract::RouterURI, +}; mod api; +mod console; mod database; mod types; -mod console; -async fn serve(req: Request, next: Next) -> Response where +async fn serve(req: Request, next: Next) -> Response +where B: Send + Sync + 'static + HttpBody, { let uri = req.uri(); @@ -23,15 +36,22 @@ async fn serve(req: Request, next: Next) -> Response where file } -async fn log(mut req: Request, next: Next) -> Response where +async fn log(mut req: Request, next: Next) -> Response +where B: Send + Sync + 'static + HttpBody, { - let Ok(ConnectInfo(info)) = req.extract_parts::>().await else { return next.run(req).await }; - console::log(info.ip(), req.method().clone(), req.uri().clone(), None, None).await; + console::log( + info.ip(), + req.method().clone(), + req.uri().clone(), + None, + None, + ) + .await; next.run(req).await } @@ -42,13 +62,14 @@ async fn not_found() -> Response { #[tokio::main] async fn main() { - let fmt_layer = tracing_subscriber::fmt::layer(); - tracing_subscriber::registry() + tracing_subscriber::registry() .with( - fmt_layer.with_filter(LevelFilter::TRACE).with_filter(filter_fn(|metadata| { - metadata.target().starts_with("xssbook") - })) + fmt_layer + .with_filter(LevelFilter::TRACE) + .with_filter(filter_fn(|metadata| { + metadata.target().starts_with("xssbook") + })), ) .init(); @@ -58,17 +79,23 @@ async fn main() { }; let app = Router::new() - .fallback(not_found) - .nest("/", pages::router()) - .layer(middleware::from_fn(log)) - .layer(middleware::from_fn(serve)) - .nest("/api/auth", auth::router() - .layer(Extension(RouterURI("/api/auth"))) - ).nest("/api/users", users::router() - .layer(Extension(RouterURI("/api/users"))) - ).nest("/api/posts", posts::router() - .layer(Extension(RouterURI("/api/posts"))) - ).layer(CookieManagerLayer::new()); + .fallback(not_found) + .nest("/", pages::router()) + .layer(middleware::from_fn(log)) + .layer(middleware::from_fn(serve)) + .nest( + "/api/auth", + auth::router().layer(Extension(RouterURI("/api/auth"))), + ) + .nest( + "/api/users", + users::router().layer(Extension(RouterURI("/api/users"))), + ) + .nest( + "/api/posts", + posts::router().layer(Extension(RouterURI("/api/posts"))), + ) + .layer(CookieManagerLayer::new()); let Ok(addr) = "[::]:8080".parse::() else { error!("Failed to parse port binding"); @@ -76,10 +103,9 @@ async fn main() { }; info!("listening on {}", addr); - + axum::Server::bind(&addr) .serve(app.into_make_service_with_connect_info::()) .await .unwrap_or(()); - } diff --git a/src/types/extract.rs b/src/types/extract.rs index b4a6cfc..f21c352 100644 --- a/src/types/extract.rs +++ b/src/types/extract.rs @@ -1,43 +1,61 @@ use std::{io::Read, net::SocketAddr}; -use axum::{extract::{FromRequestParts, FromRequest, ConnectInfo}, async_trait, response::Response, http::{request::Parts, Request}, TypedHeader, headers::Cookie, body::HttpBody, BoxError, RequestExt}; +use axum::{ + async_trait, + body::HttpBody, + extract::{ConnectInfo, FromRequest, FromRequestParts}, + headers::Cookie, + http::{request::Parts, Request}, + response::Response, + BoxError, RequestExt, TypedHeader, +}; use bytes::Bytes; use serde::de::DeserializeOwned; -use crate::{types::{user::User, http::{ResponseCode, Result}, session::Session}, console}; +use crate::{ + console, + types::{ + http::{ResponseCode, Result}, + session::Session, + user::User, + }, +}; pub struct AuthorizedUser(pub User); #[async_trait] -impl FromRequestParts for AuthorizedUser where S: Send + Sync { - type Rejection = Response; - - async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { - +impl FromRequestParts for AuthorizedUser +where + S: Send + Sync, +{ + type Rejection = Response; + + async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { let Ok(Some(cookies)) = Option::>::from_request_parts(parts, state).await else { return Err(ResponseCode::Forbidden.text("No cookies provided")) }; - + let Some(token) = cookies.get("auth") else { return Err(ResponseCode::Forbidden.text("No auth token provided")) }; - + let Ok(session) = Session::from_token(token) else { return Err(ResponseCode::Unauthorized.text("Auth token invalid")) }; - + let Ok(user) = User::from_user_id(session.user_id, true) else { tracing::error!("Valid token but no valid user"); return Err(ResponseCode::InternalServerError.text("Valid token but no valid user")) }; Ok(Self(user)) - } + } } pub struct Log; #[async_trait] -impl FromRequest for Log where +impl FromRequest for Log +where B: HttpBody + Sync + Send + 'static, B::Data: Send, B::Error: Into, @@ -45,26 +63,35 @@ impl FromRequest for Log where { type Rejection = Response; - async fn from_request(mut req: Request, state: &S) -> Result { - + async fn from_request(mut req: Request, state: &S) -> Result { let Ok(ConnectInfo(info)) = req.extract_parts::>().await else { return Ok(Self) }; let method = req.method().clone(); - let path = req.extensions().get::().map_or("", |path| path.0); + let path = req + .extensions() + .get::() + .map_or("", |path| path.0); let uri = req.uri().clone(); - + let Ok(bytes) = Bytes::from_request(req, state).await else { console::log(info.ip(), method.clone(), uri.clone(), Some(path.to_string()), None).await; return Ok(Self) }; - + let Ok(body) = String::from_utf8(bytes.bytes().flatten().collect()) else { console::log(info.ip(), method.clone(), uri.clone(), Some(path.to_string()), None).await; return Ok(Self) }; - - console::log(info.ip(), method.clone(), uri.clone(), Some(path.to_string()), Some(body.to_string())).await; + + console::log( + info.ip(), + method.clone(), + uri.clone(), + Some(path.to_string()), + Some(body.to_string()), + ) + .await; Ok(Self) } @@ -73,7 +100,8 @@ impl FromRequest for Log where pub struct Json(pub T); #[async_trait] -impl FromRequest for Json where +impl FromRequest for Json +where T: DeserializeOwned + Check, B: HttpBody + Sync + Send + 'static, B::Data: Send, @@ -82,26 +110,35 @@ impl FromRequest for Json where { type Rejection = Response; - async fn from_request(mut req: Request, state: &S) -> Result { - + async fn from_request(mut req: Request, state: &S) -> Result { let Ok(ConnectInfo(info)) = req.extract_parts::>().await else { tracing::error!("Failed to read connection info"); return Err(ResponseCode::InternalServerError.text("Failed to read connection info")); }; let method = req.method().clone(); - let path = req.extensions().get::().map_or("", |path| path.0); + let path = req + .extensions() + .get::() + .map_or("", |path| path.0); let uri = req.uri().clone(); - + let Ok(bytes) = Bytes::from_request(req, state).await else { tracing::error!("Failed to read request body"); return Err(ResponseCode::InternalServerError.text("Failed to read request body")); }; - + let Ok(body) = String::from_utf8(bytes.bytes().flatten().collect()) else { return Err(ResponseCode::BadRequest.text("Invalid utf8 body")) }; - - console::log(info.ip(), method.clone(), uri.clone(), Some(path.to_string()), Some(body.to_string())).await; + + console::log( + info.ip(), + method.clone(), + uri.clone(), + Some(path.to_string()), + Some(body.to_string()), + ) + .await; let Ok(value) = serde_json::from_str::(&body) else { return Err(ResponseCode::BadRequest.text("Invalid request body")) @@ -118,19 +155,18 @@ impl FromRequest for Json where pub type CheckResult = std::result::Result<(), String>; pub trait Check { - fn check(&self) -> CheckResult; fn assert_length(string: &str, min: usize, max: usize, message: &str) -> CheckResult { if string.len() < min || string.len() > max { - return Err(message.to_string()) + return Err(message.to_string()); } Ok(()) } fn assert_range(number: u64, min: u64, max: u64, message: &str) -> CheckResult { if number < min || number > max { - return Err(message.to_string()) + return Err(message.to_string()); } Ok(()) } @@ -138,4 +174,3 @@ pub trait Check { #[derive(Clone)] pub struct RouterURI(pub &'static str); - diff --git a/src/types/http.rs b/src/types/http.rs index 0e7b703..8524b15 100644 --- a/src/types/http.rs +++ b/src/types/http.rs @@ -1,4 +1,9 @@ -use axum::{response::{IntoResponse, Response}, http::{StatusCode, Request, HeaderValue}, body::Body, headers::HeaderName}; +use axum::{ + body::Body, + headers::HeaderName, + http::{HeaderValue, Request, StatusCode}, + response::{IntoResponse, Response}, +}; use tower::ServiceExt; use tower_http::services::ServeFile; use tracing::instrument; @@ -12,11 +17,10 @@ pub enum ResponseCode { Forbidden, NotFound, ImATeapot, - InternalServerError + InternalServerError, } impl ResponseCode { - const fn code(self) -> StatusCode { match self { Self::Success => StatusCode::OK, @@ -26,7 +30,7 @@ impl ResponseCode { Self::Forbidden => StatusCode::FORBIDDEN, Self::NotFound => StatusCode::NOT_FOUND, Self::ImATeapot => StatusCode::IM_A_TEAPOT, - Self::InternalServerError => StatusCode::INTERNAL_SERVER_ERROR + Self::InternalServerError => StatusCode::INTERNAL_SERVER_ERROR, } } @@ -39,7 +43,8 @@ impl ResponseCode { pub fn json(self, json: &str) -> Response { let mut res = (self.code(), json.to_owned()).into_response(); res.headers_mut().insert( - HeaderName::from_static("content-type"), HeaderValue::from_static("application/json"), + HeaderName::from_static("content-type"), + HeaderValue::from_static("application/json"), ); res } @@ -48,14 +53,15 @@ impl ResponseCode { pub fn html(self, json: &str) -> Response { let mut res = (self.code(), json.to_owned()).into_response(); res.headers_mut().insert( - HeaderName::from_static("content-type"), HeaderValue::from_static("text/html"), + HeaderName::from_static("content-type"), + HeaderValue::from_static("text/html"), ); res } #[instrument()] pub async fn file(self, path: &str) -> Response { - if !path.chars().any(|c| c == '.' ) { + if !path.chars().any(|c| c == '.') { return Self::BadRequest.text("Folders cannot be served"); } let path = format!("public{path}"); @@ -72,4 +78,4 @@ impl ResponseCode { } } -pub type Result = std::result::Result; \ No newline at end of file +pub type Result = std::result::Result; diff --git a/src/types/mod.rs b/src/types/mod.rs index 0ab104c..3449d5c 100644 --- a/src/types/mod.rs +++ b/src/types/mod.rs @@ -1,5 +1,5 @@ -pub mod user; +pub mod extract; +pub mod http; pub mod post; pub mod session; -pub mod extract; -pub mod http; \ No newline at end of file +pub mod user; diff --git a/src/types/post.rs b/src/types/post.rs index 95aed0e..90eada2 100644 --- a/src/types/post.rs +++ b/src/types/post.rs @@ -1,10 +1,10 @@ use core::fmt; -use std::collections::HashSet; use serde::Serialize; +use std::collections::HashSet; use tracing::instrument; use crate::database; -use crate::types::http::{Result, ResponseCode}; +use crate::types::http::{ResponseCode, Result}; #[derive(Serialize)] pub struct Post { @@ -13,19 +13,18 @@ pub struct Post { pub content: String, pub likes: HashSet, pub comments: Vec<(u64, String)>, - pub date: u64 + pub date: u64, } impl fmt::Debug for Post { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("Post") - .field("post_id", &self.post_id) - .finish() + .field("post_id", &self.post_id) + .finish() } } impl Post { - #[instrument()] pub fn from_post_id(post_id: u64) -> Result { let Ok(Some(post)) = database::posts::get_post(post_id) else { @@ -64,10 +63,10 @@ impl Post { #[instrument()] pub fn comment(&mut self, user_id: u64, content: String) -> Result<()> { self.comments.push((user_id, content)); - + if database::posts::update_post(self.post_id, &self.likes, &self.comments).is_err() { tracing::error!("Failed to comment on post"); - return Err(ResponseCode::InternalServerError.text("Failed to comment on post")) + return Err(ResponseCode::InternalServerError.text("Failed to comment on post")); } Ok(()) @@ -75,19 +74,19 @@ impl Post { #[instrument()] pub fn like(&mut self, user_id: u64, state: bool) -> Result<()> { - if state { self.likes.insert(user_id); } else { self.likes.remove(&user_id); } - + if database::posts::update_post(self.post_id, &self.likes, &self.comments).is_err() { tracing::error!("Failed to change like state on post"); - return Err(ResponseCode::InternalServerError.text("Failed to change like state on post")) + return Err( + ResponseCode::InternalServerError.text("Failed to change like state on post") + ); } Ok(()) } - -} \ No newline at end of file +} diff --git a/src/types/session.rs b/src/types/session.rs index 176e389..e704ac7 100644 --- a/src/types/session.rs +++ b/src/types/session.rs @@ -3,16 +3,15 @@ use serde::Serialize; use tracing::instrument; use crate::database; -use crate::types::http::{Result, ResponseCode}; +use crate::types::http::{ResponseCode, Result}; #[derive(Serialize)] pub struct Session { pub user_id: u64, - pub token: String + pub token: String, } impl Session { - #[instrument()] pub fn from_token(token: &str) -> Result { let Ok(Some(session)) = database::sessions::get_session(token) else { @@ -24,10 +23,14 @@ impl Session { #[instrument()] pub fn new(user_id: u64) -> Result { - let token: String = rand::thread_rng().sample_iter(&Alphanumeric).take(32).map(char::from).collect(); + let token: String = rand::thread_rng() + .sample_iter(&Alphanumeric) + .take(32) + .map(char::from) + .collect(); match database::sessions::set_session(user_id, &token) { Err(_) => Err(ResponseCode::BadRequest.text("Failed to create session")), - Ok(_) => Ok(Self {user_id, token}) + Ok(_) => Ok(Self { user_id, token }), } } @@ -39,5 +42,4 @@ impl Session { }; Ok(()) } - -} \ No newline at end of file +} diff --git a/src/types/user.rs b/src/types/user.rs index 0013d7d..fcfbe91 100644 --- a/src/types/user.rs +++ b/src/types/user.rs @@ -1,10 +1,9 @@ -use serde::{Serialize, Deserialize}; +use serde::{Deserialize, Serialize}; use tracing::instrument; use crate::api::auth::RegistrationRequet; use crate::database; -use crate::types::http::{Result, ResponseCode}; - +use crate::types::http::{ResponseCode, Result}; #[derive(Serialize, Deserialize, Debug)] pub struct User { @@ -21,7 +20,6 @@ pub struct User { } impl User { - #[instrument()] pub fn from_user_id(user_id: u64, hide_password: bool) -> Result { let Ok(Some(user)) = database::users::get_user_by_id(user_id, hide_password) else { @@ -33,12 +31,15 @@ impl User { #[instrument()] pub fn from_user_ids(user_ids: Vec) -> Vec { - user_ids.iter().filter_map(|user_id| { - let Ok(Some(user)) = database::users::get_user_by_id(*user_id, true) else { + user_ids + .iter() + .filter_map(|user_id| { + let Ok(Some(user)) = database::users::get_user_by_id(*user_id, true) else { return None; }; - Some(user) - }).collect() + Some(user) + }) + .collect() } #[instrument()] @@ -70,13 +71,15 @@ impl User { #[instrument()] pub fn new(request: RegistrationRequet) -> Result { if Self::from_email(&request.email).is_ok() { - return Err(ResponseCode::BadRequest.text(&format!("Email is already in use by {}", &request.email))) + return Err(ResponseCode::BadRequest + .text(&format!("Email is already in use by {}", &request.email))); } if let Ok(user) = Self::from_password(&request.password) { - return Err(ResponseCode::BadRequest.text(&format!("Password is already in use by {}", user.email))) + return Err(ResponseCode::BadRequest + .text(&format!("Password is already in use by {}", user.email))); } - + let Ok(user) = database::users::add_user(request) else { tracing::error!("Failed to create new user"); return Err(ResponseCode::InternalServerError.text("Failed to create new uesr")) @@ -84,5 +87,4 @@ impl User { Ok(user) } - -} \ No newline at end of file +}