From 6bea3bf2ef31f978b98848a5f2a045dcab0cc2f0 Mon Sep 17 00:00:00 2001 From: Tyler Murphy Date: Thu, 26 Jan 2023 21:29:06 -0500 Subject: [PATCH] input length and range checking --- src/api/auth.rs | 22 +++++++++++++++++++++- src/api/posts.rs | 34 +++++++++++++++++++++++++++++++++- src/api/users.rs | 14 +++++++++++++- src/database/posts.rs | 2 +- src/types/extract.rs | 31 ++++++++++++++++++++++++++++--- 5 files changed, 96 insertions(+), 7 deletions(-) diff --git a/src/api/auth.rs b/src/api/auth.rs index d60483f..b469d4d 100644 --- a/src/api/auth.rs +++ b/src/api/auth.rs @@ -3,7 +3,7 @@ use serde::Deserialize; use time::{OffsetDateTime, Duration}; use tower_cookies::{Cookies, Cookie}; -use crate::types::{user::User, response::ResponseCode, session::Session, extract::{Json, AuthorizedUser}}; +use crate::types::{user::User, response::ResponseCode, session::Session, extract::{Json, AuthorizedUser, Check, CheckResult}}; #[derive(Deserialize)] struct RegistrationRequet { @@ -17,6 +17,20 @@ struct RegistrationRequet { year: u32 } +impl Check for RegistrationRequet { + fn check(&self) -> CheckResult { + Self::assert_length(&self.firstname, 1, 20, "First name can only by 1-20 characters long")?; + Self::assert_length(&self.lastname, 1, 20, "Last name can only by 1-20 characters long")?; + Self::assert_length(&self.email, 1, 50, "Email can only by 1-50 characters long")?; + Self::assert_length(&self.password, 1, 50, "Password can only by 1-50 characters long")?; + Self::assert_length(&self.gender, 1, 100, "Gender can only by 1-100 characters long")?; + Self::assert_range(self.day as u64, 1, 255, "Birthday day can only be between 1-255")?; + Self::assert_range(self.month as u64, 1, 255, "Birthday month can only be between 1-255")?; + Self::assert_range(self.year as u64, 1, 2147483647, "Birthday year can only be between 1-2147483647")?; + Ok(()) + } +} + async fn register(cookies: Cookies, Json(body): Json) -> Response { @@ -50,6 +64,12 @@ struct LoginRequest { password: String, } +impl Check for LoginRequest { + fn check(&self) -> CheckResult { + Ok(()) + } +} + async fn login(cookies: Cookies, Json(body): Json) -> Response { let Ok(user) = User::from_email(&body.email) else { diff --git a/src/api/posts.rs b/src/api/posts.rs index 405dfa6..85ff2b2 100644 --- a/src/api/posts.rs +++ b/src/api/posts.rs @@ -1,7 +1,7 @@ use axum::{response::Response, Router, routing::{post, patch}}; use serde::Deserialize; -use crate::types::{extract::{AuthorizedUser, Json}, post::Post, response::ResponseCode}; +use crate::types::{extract::{AuthorizedUser, Json, Check, CheckResult}, post::Post, response::ResponseCode}; #[derive(Deserialize)] @@ -9,6 +9,13 @@ struct PostCreateRequest { content: String } +impl Check for PostCreateRequest { + fn check(&self) -> CheckResult { + Self::assert_length(&self.content, 1, 500, "Comments must be between 1-500 characters long")?; + Ok(()) + } +} + async fn create(AuthorizedUser(user): AuthorizedUser, Json(body): Json) -> Response { let Ok(_post) = Post::new(user.user_id, body.content) else { @@ -23,6 +30,12 @@ struct PostPageRequest { page: u64 } +impl Check for PostPageRequest { + fn check(&self) -> CheckResult { + Ok(()) + } +} + async fn page(AuthorizedUser(_user): AuthorizedUser, Json(body): Json) -> Response { let Ok(posts) = Post::from_post_page(body.page) else { @@ -41,6 +54,12 @@ struct UsersPostsRequest { user_id: u64 } +impl Check for UsersPostsRequest { + fn check(&self) -> CheckResult { + Ok(()) + } +} + async fn user(AuthorizedUser(_user): AuthorizedUser, Json(body): Json) -> Response { let Ok(posts) = Post::from_user_id(body.user_id) else { @@ -60,6 +79,13 @@ struct PostCommentRequest { post_id: u64 } +impl Check for PostCommentRequest { + fn check(&self) -> CheckResult { + Self::assert_length(&self.content, 1, 255, "Comments must be between 1-255 characters long")?; + Ok(()) + } +} + async fn comment(AuthorizedUser(user): AuthorizedUser, Json(body): Json) -> Response { let Ok(mut post) = Post::from_post_id(body.post_id) else { @@ -79,6 +105,12 @@ struct PostLikeRequest { post_id: u64 } +impl Check for PostLikeRequest { + fn check(&self) -> CheckResult { + Ok(()) + } +} + async fn like(AuthorizedUser(user): AuthorizedUser, Json(body): Json) -> Response { let Ok(mut post) = Post::from_post_id(body.post_id) else { diff --git a/src/api/users.rs b/src/api/users.rs index 283ec96..45ed195 100644 --- a/src/api/users.rs +++ b/src/api/users.rs @@ -1,12 +1,18 @@ use axum::{Router, response::Response, routing::post}; use serde::Deserialize; -use crate::types::{extract::{AuthorizedUser, Json}, response::ResponseCode, user::User}; +use crate::types::{extract::{AuthorizedUser, Json, Check, CheckResult}, response::ResponseCode, user::User}; #[derive(Deserialize)] struct UserLoadRequest { ids: Vec } +impl Check for UserLoadRequest { + fn check(&self) -> CheckResult { + Ok(()) + } +} + async fn load_batch(AuthorizedUser(_user): AuthorizedUser, Json(body): Json) -> Response { let users = User::from_user_ids(body.ids); @@ -22,6 +28,12 @@ struct UserPageReqiest { page: u64 } +impl Check for UserPageReqiest { + fn check(&self) -> CheckResult { + Ok(()) + } +} + async fn load_page(AuthorizedUser(_user): AuthorizedUser, Json(body): Json) -> Response { let Ok(users) = User::from_user_page(body.page) else { diff --git a/src/database/posts.rs b/src/database/posts.rs index 77d2387..96cd18a 100644 --- a/src/database/posts.rs +++ b/src/database/posts.rs @@ -11,7 +11,7 @@ pub fn init() -> Result<(), rusqlite::Error> { CREATE TABLE IF NOT EXISTS posts ( post_id INTEGER PRIMARY KEY AUTOINCREMENT, user_id INTEGER NOT NULL, - content TEXT NOT NULL, + content VARCHAR(500) NOT NULL, likes TEXT NOT NULL, comments TEXT NOT NULL, date INTEGER NOT NULL, diff --git a/src/types/extract.rs b/src/types/extract.rs index 6518ca1..bb50aa7 100644 --- a/src/types/extract.rs +++ b/src/types/extract.rs @@ -38,7 +38,7 @@ pub struct Json(pub T); #[async_trait] impl FromRequest for Json where - T: DeserializeOwned, + T: DeserializeOwned + Check, B: HttpBody + Send + 'static, B::Data: Send, B::Error: Into, @@ -56,10 +56,35 @@ impl FromRequest for Json where return Err(ResponseCode::BadRequest.msg("Invalid utf8 body")) }; - let Ok(value) = serde_json::from_str(&string) else { + let Ok(value) = serde_json::from_str::(&string) else { return Err(ResponseCode::BadRequest.msg("Invalid request body")) }; + if let Err(msg) = value.check() { + return Err(ResponseCode::BadRequest.msg(&msg)); + } + Ok(Json(value)) } -} \ No newline at end of file +} + +pub type CheckResult = std::result::Result<(), String>; + +pub trait Check { + + fn check(&self) -> CheckResult; + + fn assert_length(string: &str, min: usize, max: usize, message: &str) -> CheckResult { + if string.len() < min || string.len() > max { + return Err(message.to_string()) + } + Ok(()) + } + + fn assert_range(number: u64, min: u64, max: u64, message: &str) -> CheckResult { + if number < min || number > max { + return Err(message.to_string()) + } + Ok(()) + } +} \ No newline at end of file