input length and range checking

This commit is contained in:
Tyler Murphy 2023-01-26 21:29:06 -05:00
parent 88209d8823
commit 6bea3bf2ef
5 changed files with 96 additions and 7 deletions

View file

@ -3,7 +3,7 @@ use serde::Deserialize;
use time::{OffsetDateTime, Duration}; use time::{OffsetDateTime, Duration};
use tower_cookies::{Cookies, Cookie}; use tower_cookies::{Cookies, Cookie};
use crate::types::{user::User, response::ResponseCode, session::Session, extract::{Json, AuthorizedUser}}; use crate::types::{user::User, response::ResponseCode, session::Session, extract::{Json, AuthorizedUser, Check, CheckResult}};
#[derive(Deserialize)] #[derive(Deserialize)]
struct RegistrationRequet { struct RegistrationRequet {
@ -17,6 +17,20 @@ struct RegistrationRequet {
year: u32 year: u32
} }
impl Check for RegistrationRequet {
fn check(&self) -> CheckResult {
Self::assert_length(&self.firstname, 1, 20, "First name can only by 1-20 characters long")?;
Self::assert_length(&self.lastname, 1, 20, "Last name can only by 1-20 characters long")?;
Self::assert_length(&self.email, 1, 50, "Email can only by 1-50 characters long")?;
Self::assert_length(&self.password, 1, 50, "Password can only by 1-50 characters long")?;
Self::assert_length(&self.gender, 1, 100, "Gender can only by 1-100 characters long")?;
Self::assert_range(self.day as u64, 1, 255, "Birthday day can only be between 1-255")?;
Self::assert_range(self.month as u64, 1, 255, "Birthday month can only be between 1-255")?;
Self::assert_range(self.year as u64, 1, 2147483647, "Birthday year can only be between 1-2147483647")?;
Ok(())
}
}
async fn register(cookies: Cookies, Json(body): Json<RegistrationRequet>) -> Response { async fn register(cookies: Cookies, Json(body): Json<RegistrationRequet>) -> Response {
@ -50,6 +64,12 @@ struct LoginRequest {
password: String, password: String,
} }
impl Check for LoginRequest {
fn check(&self) -> CheckResult {
Ok(())
}
}
async fn login(cookies: Cookies, Json(body): Json<LoginRequest>) -> Response { async fn login(cookies: Cookies, Json(body): Json<LoginRequest>) -> Response {
let Ok(user) = User::from_email(&body.email) else { let Ok(user) = User::from_email(&body.email) else {

View file

@ -1,7 +1,7 @@
use axum::{response::Response, Router, routing::{post, patch}}; use axum::{response::Response, Router, routing::{post, patch}};
use serde::Deserialize; use serde::Deserialize;
use crate::types::{extract::{AuthorizedUser, Json}, post::Post, response::ResponseCode}; use crate::types::{extract::{AuthorizedUser, Json, Check, CheckResult}, post::Post, response::ResponseCode};
#[derive(Deserialize)] #[derive(Deserialize)]
@ -9,6 +9,13 @@ struct PostCreateRequest {
content: String content: String
} }
impl Check for PostCreateRequest {
fn check(&self) -> CheckResult {
Self::assert_length(&self.content, 1, 500, "Comments must be between 1-500 characters long")?;
Ok(())
}
}
async fn create(AuthorizedUser(user): AuthorizedUser, Json(body): Json<PostCreateRequest>) -> Response { async fn create(AuthorizedUser(user): AuthorizedUser, Json(body): Json<PostCreateRequest>) -> Response {
let Ok(_post) = Post::new(user.user_id, body.content) else { let Ok(_post) = Post::new(user.user_id, body.content) else {
@ -23,6 +30,12 @@ struct PostPageRequest {
page: u64 page: u64
} }
impl Check for PostPageRequest {
fn check(&self) -> CheckResult {
Ok(())
}
}
async fn page(AuthorizedUser(_user): AuthorizedUser, Json(body): Json<PostPageRequest>) -> Response { async fn page(AuthorizedUser(_user): AuthorizedUser, Json(body): Json<PostPageRequest>) -> Response {
let Ok(posts) = Post::from_post_page(body.page) else { let Ok(posts) = Post::from_post_page(body.page) else {
@ -41,6 +54,12 @@ struct UsersPostsRequest {
user_id: u64 user_id: u64
} }
impl Check for UsersPostsRequest {
fn check(&self) -> CheckResult {
Ok(())
}
}
async fn user(AuthorizedUser(_user): AuthorizedUser, Json(body): Json<UsersPostsRequest>) -> Response { async fn user(AuthorizedUser(_user): AuthorizedUser, Json(body): Json<UsersPostsRequest>) -> Response {
let Ok(posts) = Post::from_user_id(body.user_id) else { let Ok(posts) = Post::from_user_id(body.user_id) else {
@ -60,6 +79,13 @@ struct PostCommentRequest {
post_id: u64 post_id: u64
} }
impl Check for PostCommentRequest {
fn check(&self) -> CheckResult {
Self::assert_length(&self.content, 1, 255, "Comments must be between 1-255 characters long")?;
Ok(())
}
}
async fn comment(AuthorizedUser(user): AuthorizedUser, Json(body): Json<PostCommentRequest>) -> Response { async fn comment(AuthorizedUser(user): AuthorizedUser, Json(body): Json<PostCommentRequest>) -> Response {
let Ok(mut post) = Post::from_post_id(body.post_id) else { let Ok(mut post) = Post::from_post_id(body.post_id) else {
@ -79,6 +105,12 @@ struct PostLikeRequest {
post_id: u64 post_id: u64
} }
impl Check for PostLikeRequest {
fn check(&self) -> CheckResult {
Ok(())
}
}
async fn like(AuthorizedUser(user): AuthorizedUser, Json(body): Json<PostLikeRequest>) -> Response { async fn like(AuthorizedUser(user): AuthorizedUser, Json(body): Json<PostLikeRequest>) -> Response {
let Ok(mut post) = Post::from_post_id(body.post_id) else { let Ok(mut post) = Post::from_post_id(body.post_id) else {

View file

@ -1,12 +1,18 @@
use axum::{Router, response::Response, routing::post}; use axum::{Router, response::Response, routing::post};
use serde::Deserialize; use serde::Deserialize;
use crate::types::{extract::{AuthorizedUser, Json}, response::ResponseCode, user::User}; use crate::types::{extract::{AuthorizedUser, Json, Check, CheckResult}, response::ResponseCode, user::User};
#[derive(Deserialize)] #[derive(Deserialize)]
struct UserLoadRequest { struct UserLoadRequest {
ids: Vec<u64> ids: Vec<u64>
} }
impl Check for UserLoadRequest {
fn check(&self) -> CheckResult {
Ok(())
}
}
async fn load_batch(AuthorizedUser(_user): AuthorizedUser, Json(body): Json<UserLoadRequest>) -> Response { async fn load_batch(AuthorizedUser(_user): AuthorizedUser, Json(body): Json<UserLoadRequest>) -> Response {
let users = User::from_user_ids(body.ids); let users = User::from_user_ids(body.ids);
@ -22,6 +28,12 @@ struct UserPageReqiest {
page: u64 page: u64
} }
impl Check for UserPageReqiest {
fn check(&self) -> CheckResult {
Ok(())
}
}
async fn load_page(AuthorizedUser(_user): AuthorizedUser, Json(body): Json<UserPageReqiest>) -> Response { async fn load_page(AuthorizedUser(_user): AuthorizedUser, Json(body): Json<UserPageReqiest>) -> Response {
let Ok(users) = User::from_user_page(body.page) else { let Ok(users) = User::from_user_page(body.page) else {

View file

@ -11,7 +11,7 @@ pub fn init() -> Result<(), rusqlite::Error> {
CREATE TABLE IF NOT EXISTS posts ( CREATE TABLE IF NOT EXISTS posts (
post_id INTEGER PRIMARY KEY AUTOINCREMENT, post_id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id INTEGER NOT NULL, user_id INTEGER NOT NULL,
content TEXT NOT NULL, content VARCHAR(500) NOT NULL,
likes TEXT NOT NULL, likes TEXT NOT NULL,
comments TEXT NOT NULL, comments TEXT NOT NULL,
date INTEGER NOT NULL, date INTEGER NOT NULL,

View file

@ -38,7 +38,7 @@ pub struct Json<T>(pub T);
#[async_trait] #[async_trait]
impl<T, S, B> FromRequest<S, B> for Json<T> where impl<T, S, B> FromRequest<S, B> for Json<T> where
T: DeserializeOwned, T: DeserializeOwned + Check,
B: HttpBody + Send + 'static, B: HttpBody + Send + 'static,
B::Data: Send, B::Data: Send,
B::Error: Into<BoxError>, B::Error: Into<BoxError>,
@ -56,10 +56,35 @@ impl<T, S, B> FromRequest<S, B> for Json<T> where
return Err(ResponseCode::BadRequest.msg("Invalid utf8 body")) return Err(ResponseCode::BadRequest.msg("Invalid utf8 body"))
}; };
let Ok(value) = serde_json::from_str(&string) else { let Ok(value) = serde_json::from_str::<T>(&string) else {
return Err(ResponseCode::BadRequest.msg("Invalid request body")) return Err(ResponseCode::BadRequest.msg("Invalid request body"))
}; };
if let Err(msg) = value.check() {
return Err(ResponseCode::BadRequest.msg(&msg));
}
Ok(Json(value)) Ok(Json(value))
} }
} }
pub type CheckResult = std::result::Result<(), String>;
pub trait Check {
fn check(&self) -> CheckResult;
fn assert_length(string: &str, min: usize, max: usize, message: &str) -> CheckResult {
if string.len() < min || string.len() > max {
return Err(message.to_string())
}
Ok(())
}
fn assert_range(number: u64, min: u64, max: u64, message: &str) -> CheckResult {
if number < min || number > max {
return Err(message.to_string())
}
Ok(())
}
}