diff --git a/src/api/auth.rs b/src/api/auth.rs index 410b643..4656ca8 100644 --- a/src/api/auth.rs +++ b/src/api/auth.rs @@ -3,7 +3,7 @@ use serde::Deserialize; use time::{OffsetDateTime, Duration}; use tower_cookies::{Cookies, Cookie}; -use crate::types::{user::User, response::ResponseCode, session::Session, extract::{Json, AuthorizedUser, Check, CheckResult, Log}}; +use crate::types::{user::User, http::ResponseCode, session::Session, extract::{Json, AuthorizedUser, Check, CheckResult, Log}}; #[derive(Deserialize, Debug)] pub struct RegistrationRequet { @@ -24,9 +24,9 @@ impl Check for RegistrationRequet { Self::assert_length(&self.email, 1, 50, "Email can only by 1-50 characters long")?; Self::assert_length(&self.password, 1, 50, "Password can only by 1-50 characters long")?; Self::assert_length(&self.gender, 1, 100, "Gender can only by 1-100 characters long")?; - Self::assert_range(self.day as u64, 1, 255, "Birthday day can only be between 1-255")?; - Self::assert_range(self.month as u64, 1, 255, "Birthday month can only be between 1-255")?; - Self::assert_range(self.year as u64, 1, 4294967295, "Birthday year can only be between 1-4294967295")?; + Self::assert_range(u64::from(self.day), 1, 255, "Birthday day can only be between 1-255")?; + Self::assert_range(u64::from(self.month), 1, 255, "Birthday month can only be between 1-255")?; + Self::assert_range(u64::from(self.year), 1, 4_294_967_295, "Birthday year can only be between 1-4294967295")?; Ok(()) } } diff --git a/src/api/pages.rs b/src/api/pages.rs index f4f0f42..41a63a8 100644 --- a/src/api/pages.rs +++ b/src/api/pages.rs @@ -1,6 +1,6 @@ use axum::{Router, response::{Response, Redirect, IntoResponse}, routing::get}; -use crate::{types::{extract::AuthorizedUser, response::ResponseCode}, console}; +use crate::{types::{extract::AuthorizedUser, http::ResponseCode}, console}; async fn root(user: Option) -> Response { if user.is_some() { diff --git a/src/api/posts.rs b/src/api/posts.rs index fda1fb1..e2e64b2 100644 --- a/src/api/posts.rs +++ b/src/api/posts.rs @@ -1,7 +1,7 @@ use axum::{response::Response, Router, routing::{post, patch}}; use serde::Deserialize; -use crate::types::{extract::{AuthorizedUser, Json, Check, CheckResult}, post::Post, response::ResponseCode}; +use crate::types::{extract::{AuthorizedUser, Json, Check, CheckResult}, post::Post, http::ResponseCode}; #[derive(Deserialize)] diff --git a/src/api/users.rs b/src/api/users.rs index 7bea200..97a9e6e 100644 --- a/src/api/users.rs +++ b/src/api/users.rs @@ -1,6 +1,6 @@ use axum::{Router, response::Response, routing::post}; use serde::Deserialize; -use crate::types::{extract::{AuthorizedUser, Json, Check, CheckResult}, response::ResponseCode, user::User}; +use crate::types::{extract::{AuthorizedUser, Json, Check, CheckResult}, http::ResponseCode, user::User}; #[derive(Deserialize)] struct UserLoadRequest { diff --git a/src/console.rs b/src/console.rs index 019ae2c..1596cd9 100644 --- a/src/console.rs +++ b/src/console.rs @@ -5,7 +5,7 @@ use serde::Serialize; use serde_json::{ser::Formatter, Value}; use tokio::sync::Mutex; -use crate::types::response::ResponseCode; +use crate::types::http::ResponseCode; struct LogMessage { ip: IpAddr, @@ -19,7 +19,7 @@ impl ToString for LogMessage { fn to_string(&self) -> String { let mut ip = self.ip.to_string(); if ip.contains("::ffff:") { - ip = ip.as_str()[7..].to_string() + ip = ip.as_str()[7..].to_string(); } let color = match self.method { Method::GET => "#3fe04f", @@ -81,52 +81,52 @@ impl Formatter for HtmlFormatter { } fn write_i8(&mut self, writer: &mut W, value: i8) -> io::Result<()> where W: ?Sized + io::Write { - let buff = format!("{}", value); + let buff = format!("{value}"); writer.write_all(buff.as_bytes()) } fn write_i16(&mut self, writer: &mut W, value: i16) -> io::Result<()> where W: ?Sized + io::Write { - let buff = format!("{}", value); + let buff = format!("{value}"); writer.write_all(buff.as_bytes()) } fn write_i32(&mut self, writer: &mut W, value: i32) -> io::Result<()> where W: ?Sized + io::Write { - let buff = format!("{}", value); + let buff = format!("{value}"); writer.write_all(buff.as_bytes()) } fn write_i64(&mut self, writer: &mut W, value: i64) -> io::Result<()> where W: ?Sized + io::Write { - let buff = format!("{}", value); + let buff = format!("{value}"); writer.write_all(buff.as_bytes()) } fn write_u8(&mut self, writer: &mut W, value: u8) -> io::Result<()> where W: ?Sized + io::Write { - let buff = format!("{}", value); + let buff = format!("{value}"); writer.write_all(buff.as_bytes()) } fn write_u16(&mut self, writer: &mut W, value: u16) -> io::Result<()> where W: ?Sized + io::Write { - let buff = format!("{}", value); + let buff = format!("{value}"); writer.write_all(buff.as_bytes()) } fn write_u32(&mut self, writer: &mut W, value: u32) -> io::Result<()> where W: ?Sized + io::Write { - let buff = format!("{}", value); + let buff = format!("{value}"); writer.write_all(buff.as_bytes()) } fn write_u64(&mut self, writer: &mut W, value: u64) -> io::Result<()> where W: ?Sized + io::Write { - let buff = format!("{}", value); + let buff = format!("{value}"); writer.write_all(buff.as_bytes()) } fn write_f32(&mut self, writer: &mut W, value: f32) -> io::Result<()> where W: ?Sized + io::Write { - let buff = format!("{}", value); + let buff = format!("{value}"); writer.write_all(buff.as_bytes()) } fn write_f64(&mut self, writer: &mut W, value: f64) -> io::Result<()> where W: ?Sized + io::Write { - let buff = format!("{}", value); + let buff = format!("{value}"); writer.write_all(buff.as_bytes()) } @@ -154,7 +154,7 @@ impl Formatter for HtmlFormatter { fn beautify(body: String) -> String { if body.is_empty() { - return "".to_string() + return String::new() } let Ok(mut json) = serde_json::from_str::(&body) else { return body diff --git a/src/database/mod.rs b/src/database/mod.rs index 19e4203..59cc377 100644 --- a/src/database/mod.rs +++ b/src/database/mod.rs @@ -1,5 +1,3 @@ -use rusqlite::Result; - pub mod posts; pub mod users; pub mod sessions; @@ -8,7 +6,7 @@ pub fn connect() -> Result { rusqlite::Connection::open("xssbook.db") } -pub fn init() -> Result<()> { +pub fn init() -> Result<(), rusqlite::Error> { users::init()?; posts::init()?; sessions::init()?; diff --git a/src/database/posts.rs b/src/database/posts.rs index 7892683..6086fdc 100644 --- a/src/database/posts.rs +++ b/src/database/posts.rs @@ -1,5 +1,5 @@ use std::collections::HashSet; -use std::time::{SystemTime, UNIX_EPOCH}; +use std::time::{SystemTime, UNIX_EPOCH, Duration}; use rusqlite::{OptionalExtension, Row}; use tracing::instrument; @@ -91,7 +91,7 @@ pub fn add_post(user_id: u64, content: &str) -> Result { let Ok(comments_json) = serde_json::to_string(&comments) else { return Err(rusqlite::Error::InvalidQuery) }; - let date = SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_millis() as u64; + let date = u64::try_from(SystemTime::now().duration_since(UNIX_EPOCH).unwrap_or(Duration::ZERO).as_millis()).unwrap_or(0); let conn = database::connect()?; let mut stmt = conn.prepare("INSERT INTO posts (user_id, content, likes, comments, date) VALUES(?,?,?,?,?) RETURNING *;")?; let post = stmt.query_row((user_id, content, likes_json, comments_json, date), |row| { diff --git a/src/database/sessions.rs b/src/database/sessions.rs index a2a2f6e..2283b58 100644 --- a/src/database/sessions.rs +++ b/src/database/sessions.rs @@ -4,7 +4,6 @@ use tracing::instrument; use crate::{database, types::session::Session}; pub fn init() -> Result<(), rusqlite::Error> { - tracing::trace!("Retrieving posts page"); let sql = " CREATE TABLE IF NOT EXISTS sessions ( user_id INTEGER PRIMARY KEY NOT NULL, diff --git a/src/database/users.rs b/src/database/users.rs index d9e35b1..a578e69 100644 --- a/src/database/users.rs +++ b/src/database/users.rs @@ -1,4 +1,4 @@ -use std::time::{SystemTime, UNIX_EPOCH}; +use std::time::{SystemTime, UNIX_EPOCH, Duration}; use rusqlite::{OptionalExtension, Row}; use tracing::instrument; @@ -36,7 +36,7 @@ fn user_from_row(row: &Row, hide_password: bool) -> Result Result, rusqli #[instrument()] pub fn add_user(request: RegistrationRequet) -> Result { tracing::trace!("Adding new user"); - let date = SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_millis() as u64; + let date = u64::try_from(SystemTime::now().duration_since(UNIX_EPOCH).unwrap_or(Duration::ZERO).as_millis()).unwrap_or(0); let conn = database::connect()?; let mut stmt = conn.prepare("INSERT INTO users (firstname, lastname, email, password, gender, date, day, month, year) VALUES(?,?,?,?,?,?,?,?,?) RETURNING *;")?; diff --git a/src/main.rs b/src/main.rs index 9ccc45b..31b749e 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,9 +1,9 @@ -use std::net::SocketAddr; +use std::{net::SocketAddr, process::exit}; use axum::{Router, response::Response, http::{Request, StatusCode}, middleware::{Next, self}, extract::ConnectInfo, RequestExt, body::HttpBody, Extension}; use tower_cookies::CookieManagerLayer; -use tracing::metadata::LevelFilter; +use tracing::{metadata::LevelFilter, error, info}; use tracing_subscriber::{prelude::__tracing_subscriber_SubscriberExt, util::SubscriberInitExt, Layer, filter::filter_fn}; -use types::response::ResponseCode; +use types::http::ResponseCode; use crate::{api::{pages, auth, users, posts}, types::extract::RouterURI}; @@ -12,8 +12,11 @@ mod database; mod types; mod console; -async fn serve(req: Request, next: Next) -> Response { - let file = ResponseCode::Success.file(&req.uri().to_string()).await; +async fn serve(req: Request, next: Next) -> Response where + B: Send + Sync + 'static + HttpBody, +{ + let uri = req.uri(); + let file = ResponseCode::Success.file(&uri.to_string()).await; if file.status() != StatusCode::OK { return next.run(req).await; } @@ -40,8 +43,6 @@ async fn not_found() -> Response { #[tokio::main] async fn main() { - database::init().unwrap(); - let fmt_layer = tracing_subscriber::fmt::layer(); tracing_subscriber::registry() .with( @@ -51,6 +52,11 @@ async fn main() { ) .init(); + if database::init().is_err() { + error!("Failed to connect to the sqlite database"); + exit(1) + }; + let app = Router::new() .fallback(not_found) .nest("/", pages::router()) @@ -64,12 +70,16 @@ async fn main() { .layer(Extension(RouterURI("/api/posts"))) ).layer(CookieManagerLayer::new()); - let addr = "[::]:8080".parse::().unwrap(); - tracing::info!("listening on {}", addr); + let Ok(addr) = "[::]:8080".parse::() else { + error!("Failed to parse port binding"); + exit(1) + }; + + info!("listening on {}", addr); axum::Server::bind(&addr) .serve(app.into_make_service_with_connect_info::()) .await - .unwrap(); + .unwrap_or(()); } diff --git a/src/types/extract.rs b/src/types/extract.rs index 7dbf386..b4a6cfc 100644 --- a/src/types/extract.rs +++ b/src/types/extract.rs @@ -4,7 +4,7 @@ use axum::{extract::{FromRequestParts, FromRequest, ConnectInfo}, async_trait, r use bytes::Bytes; use serde::de::DeserializeOwned; -use crate::{types::{user::User, response::{ResponseCode, Result}, session::Session}, console}; +use crate::{types::{user::User, http::{ResponseCode, Result}, session::Session}, console}; pub struct AuthorizedUser(pub User); @@ -31,7 +31,7 @@ impl FromRequestParts for AuthorizedUser where S: Send + Sync { return Err(ResponseCode::InternalServerError.text("Valid token but no valid user")) }; - Ok(AuthorizedUser(user)) + Ok(Self(user)) } } @@ -48,25 +48,25 @@ impl FromRequest for Log where async fn from_request(mut req: Request, state: &S) -> Result { let Ok(ConnectInfo(info)) = req.extract_parts::>().await else { - return Ok(Log) + return Ok(Self) }; let method = req.method().clone(); - let path = req.extensions().get::().unwrap().0; + let path = req.extensions().get::().map_or("", |path| path.0); let uri = req.uri().clone(); let Ok(bytes) = Bytes::from_request(req, state).await else { console::log(info.ip(), method.clone(), uri.clone(), Some(path.to_string()), None).await; - return Ok(Log) + return Ok(Self) }; let Ok(body) = String::from_utf8(bytes.bytes().flatten().collect()) else { console::log(info.ip(), method.clone(), uri.clone(), Some(path.to_string()), None).await; - return Ok(Log) + return Ok(Self) }; console::log(info.ip(), method.clone(), uri.clone(), Some(path.to_string()), Some(body.to_string())).await; - Ok(Log) + Ok(Self) } } @@ -89,7 +89,7 @@ impl FromRequest for Json where return Err(ResponseCode::InternalServerError.text("Failed to read connection info")); }; let method = req.method().clone(); - let path = req.extensions().get::().unwrap().0; + let path = req.extensions().get::().map_or("", |path| path.0); let uri = req.uri().clone(); let Ok(bytes) = Bytes::from_request(req, state).await else { @@ -111,7 +111,7 @@ impl FromRequest for Json where return Err(ResponseCode::BadRequest.text(&msg)); } - Ok(Json(value)) + Ok(Self(value)) } } diff --git a/src/types/response.rs b/src/types/http.rs similarity index 87% rename from src/types/response.rs rename to src/types/http.rs index 0c5a78c..0e7b703 100644 --- a/src/types/response.rs +++ b/src/types/http.rs @@ -17,7 +17,7 @@ pub enum ResponseCode { impl ResponseCode { - pub fn code(self) -> StatusCode { + const fn code(self) -> StatusCode { match self { Self::Success => StatusCode::OK, Self::Created => StatusCode::CREATED, @@ -56,16 +56,16 @@ impl ResponseCode { #[instrument()] pub async fn file(self, path: &str) -> Response { if !path.chars().any(|c| c == '.' ) { - return ResponseCode::BadRequest.text("Folders cannot be served"); + return Self::BadRequest.text("Folders cannot be served"); } - let path = format!("public{}", path); + let path = format!("public{path}"); let svc = ServeFile::new(path); let Ok(mut res) = svc.oneshot(Request::new(Body::empty())).await else { tracing::error!("Error while fetching file"); - return ResponseCode::InternalServerError.text("Error while fetching file"); + return Self::InternalServerError.text("Error while fetching file"); }; if res.status() != StatusCode::OK { - return ResponseCode::NotFound.text("File not found"); + return Self::NotFound.text("File not found"); } *res.status_mut() = self.code(); res.into_response() diff --git a/src/types/mod.rs b/src/types/mod.rs index 089885e..0ab104c 100644 --- a/src/types/mod.rs +++ b/src/types/mod.rs @@ -2,4 +2,4 @@ pub mod user; pub mod post; pub mod session; pub mod extract; -pub mod response; \ No newline at end of file +pub mod http; \ No newline at end of file diff --git a/src/types/post.rs b/src/types/post.rs index 7ca0a3c..95aed0e 100644 --- a/src/types/post.rs +++ b/src/types/post.rs @@ -4,7 +4,7 @@ use serde::Serialize; use tracing::instrument; use crate::database; -use crate::types::response::{Result, ResponseCode}; +use crate::types::http::{Result, ResponseCode}; #[derive(Serialize)] pub struct Post { diff --git a/src/types/session.rs b/src/types/session.rs index 30e430e..176e389 100644 --- a/src/types/session.rs +++ b/src/types/session.rs @@ -3,7 +3,7 @@ use serde::Serialize; use tracing::instrument; use crate::database; -use crate::types::response::{Result, ResponseCode}; +use crate::types::http::{Result, ResponseCode}; #[derive(Serialize)] pub struct Session { @@ -27,7 +27,7 @@ impl Session { let token: String = rand::thread_rng().sample_iter(&Alphanumeric).take(32).map(char::from).collect(); match database::sessions::set_session(user_id, &token) { Err(_) => Err(ResponseCode::BadRequest.text("Failed to create session")), - Ok(_) => Ok(Session {user_id, token}) + Ok(_) => Ok(Self {user_id, token}) } } diff --git a/src/types/user.rs b/src/types/user.rs index b26cfd7..0013d7d 100644 --- a/src/types/user.rs +++ b/src/types/user.rs @@ -3,7 +3,7 @@ use tracing::instrument; use crate::api::auth::RegistrationRequet; use crate::database; -use crate::types::response::{Result, ResponseCode}; +use crate::types::http::{Result, ResponseCode}; #[derive(Serialize, Deserialize, Debug)] @@ -69,11 +69,11 @@ impl User { #[instrument()] pub fn new(request: RegistrationRequet) -> Result { - if User::from_email(&request.email).is_ok() { + if Self::from_email(&request.email).is_ok() { return Err(ResponseCode::BadRequest.text(&format!("Email is already in use by {}", &request.email))) } - if let Ok(user) = User::from_password(&request.password) { + if let Ok(user) = Self::from_password(&request.password) { return Err(ResponseCode::BadRequest.text(&format!("Password is already in use by {}", user.email))) }