diff options
Diffstat (limited to 'src/types/extract.rs')
-rw-r--r-- | src/types/extract.rs | 97 |
1 files changed, 66 insertions, 31 deletions
diff --git a/src/types/extract.rs b/src/types/extract.rs index b4a6cfc..f21c352 100644 --- a/src/types/extract.rs +++ b/src/types/extract.rs @@ -1,43 +1,61 @@ use std::{io::Read, net::SocketAddr}; -use axum::{extract::{FromRequestParts, FromRequest, ConnectInfo}, async_trait, response::Response, http::{request::Parts, Request}, TypedHeader, headers::Cookie, body::HttpBody, BoxError, RequestExt}; +use axum::{ + async_trait, + body::HttpBody, + extract::{ConnectInfo, FromRequest, FromRequestParts}, + headers::Cookie, + http::{request::Parts, Request}, + response::Response, + BoxError, RequestExt, TypedHeader, +}; use bytes::Bytes; use serde::de::DeserializeOwned; -use crate::{types::{user::User, http::{ResponseCode, Result}, session::Session}, console}; +use crate::{ + console, + types::{ + http::{ResponseCode, Result}, + session::Session, + user::User, + }, +}; pub struct AuthorizedUser(pub User); #[async_trait] -impl<S> FromRequestParts<S> for AuthorizedUser where S: Send + Sync { - type Rejection = Response; - - async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self> { - +impl<S> FromRequestParts<S> for AuthorizedUser +where + S: Send + Sync, +{ + type Rejection = Response; + + async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self> { let Ok(Some(cookies)) = Option::<TypedHeader<Cookie>>::from_request_parts(parts, state).await else { return Err(ResponseCode::Forbidden.text("No cookies provided")) }; - + let Some(token) = cookies.get("auth") else { return Err(ResponseCode::Forbidden.text("No auth token provided")) }; - + let Ok(session) = Session::from_token(token) else { return Err(ResponseCode::Unauthorized.text("Auth token invalid")) }; - + let Ok(user) = User::from_user_id(session.user_id, true) else { tracing::error!("Valid token but no valid user"); return Err(ResponseCode::InternalServerError.text("Valid token but no valid user")) }; Ok(Self(user)) - } + } } pub struct Log; #[async_trait] -impl<S, B> FromRequest<S, B> for Log where +impl<S, B> FromRequest<S, B> for Log +where B: HttpBody + Sync + Send + 'static, B::Data: Send, B::Error: Into<BoxError>, @@ -45,26 +63,35 @@ impl<S, B> FromRequest<S, B> for Log where { type Rejection = Response; - async fn from_request(mut req: Request<B>, state: &S) -> Result<Self> { - + async fn from_request(mut req: Request<B>, state: &S) -> Result<Self> { let Ok(ConnectInfo(info)) = req.extract_parts::<ConnectInfo<SocketAddr>>().await else { return Ok(Self) }; let method = req.method().clone(); - let path = req.extensions().get::<RouterURI>().map_or("", |path| path.0); + let path = req + .extensions() + .get::<RouterURI>() + .map_or("", |path| path.0); let uri = req.uri().clone(); - + let Ok(bytes) = Bytes::from_request(req, state).await else { console::log(info.ip(), method.clone(), uri.clone(), Some(path.to_string()), None).await; return Ok(Self) }; - + let Ok(body) = String::from_utf8(bytes.bytes().flatten().collect()) else { console::log(info.ip(), method.clone(), uri.clone(), Some(path.to_string()), None).await; return Ok(Self) }; - - console::log(info.ip(), method.clone(), uri.clone(), Some(path.to_string()), Some(body.to_string())).await; + + console::log( + info.ip(), + method.clone(), + uri.clone(), + Some(path.to_string()), + Some(body.to_string()), + ) + .await; Ok(Self) } @@ -73,7 +100,8 @@ impl<S, B> FromRequest<S, B> for Log where pub struct Json<T>(pub T); #[async_trait] -impl<T, S, B> FromRequest<S, B> for Json<T> where +impl<T, S, B> FromRequest<S, B> for Json<T> +where T: DeserializeOwned + Check, B: HttpBody + Sync + Send + 'static, B::Data: Send, @@ -82,26 +110,35 @@ impl<T, S, B> FromRequest<S, B> for Json<T> where { type Rejection = Response; - async fn from_request(mut req: Request<B>, state: &S) -> Result<Self> { - + async fn from_request(mut req: Request<B>, state: &S) -> Result<Self> { let Ok(ConnectInfo(info)) = req.extract_parts::<ConnectInfo<SocketAddr>>().await else { tracing::error!("Failed to read connection info"); return Err(ResponseCode::InternalServerError.text("Failed to read connection info")); }; let method = req.method().clone(); - let path = req.extensions().get::<RouterURI>().map_or("", |path| path.0); + let path = req + .extensions() + .get::<RouterURI>() + .map_or("", |path| path.0); let uri = req.uri().clone(); - + let Ok(bytes) = Bytes::from_request(req, state).await else { tracing::error!("Failed to read request body"); return Err(ResponseCode::InternalServerError.text("Failed to read request body")); }; - + let Ok(body) = String::from_utf8(bytes.bytes().flatten().collect()) else { return Err(ResponseCode::BadRequest.text("Invalid utf8 body")) }; - - console::log(info.ip(), method.clone(), uri.clone(), Some(path.to_string()), Some(body.to_string())).await; + + console::log( + info.ip(), + method.clone(), + uri.clone(), + Some(path.to_string()), + Some(body.to_string()), + ) + .await; let Ok(value) = serde_json::from_str::<T>(&body) else { return Err(ResponseCode::BadRequest.text("Invalid request body")) @@ -118,19 +155,18 @@ impl<T, S, B> FromRequest<S, B> for Json<T> where pub type CheckResult = std::result::Result<(), String>; pub trait Check { - fn check(&self) -> CheckResult; fn assert_length(string: &str, min: usize, max: usize, message: &str) -> CheckResult { if string.len() < min || string.len() > max { - return Err(message.to_string()) + return Err(message.to_string()); } Ok(()) } fn assert_range(number: u64, min: u64, max: u64, message: &str) -> CheckResult { if number < min || number > max { - return Err(message.to_string()) + return Err(message.to_string()); } Ok(()) } @@ -138,4 +174,3 @@ pub trait Check { #[derive(Clone)] pub struct RouterURI(pub &'static str); - |