use std::io::Read; use axum::{extract::{FromRequestParts, FromRequest}, async_trait, response::Response, http::{request::Parts, Request}, TypedHeader, headers::Cookie, body::HttpBody, BoxError}; use bytes::Bytes; use serde::de::DeserializeOwned; use crate::types::{user::User, response::{ResponseCode, Result}, session::Session}; pub struct AuthorizedUser(pub User); #[async_trait] impl FromRequestParts for AuthorizedUser where S: Send + Sync { type Rejection = Response; async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { let Ok(Some(cookies)) = Option::>::from_request_parts(parts, state).await else { return Err(ResponseCode::Forbidden.msg("No cookies provided")) }; let Some(token) = cookies.get("auth") else { return Err(ResponseCode::Forbidden.msg("No auth token provided")) }; let Ok(session) = Session::from_token(&token) else { return Err(ResponseCode::Unauthorized.msg("Auth token invalid")) }; let Ok(user) = User::from_user_id(session.user_id, true) else { return Err(ResponseCode::InternalServerError.msg("Valid token but no valid user")) }; Ok(AuthorizedUser(user)) } } pub struct Json(pub T); #[async_trait] impl FromRequest for Json where T: DeserializeOwned + Check, B: HttpBody + Send + 'static, B::Data: Send, B::Error: Into, S: Send + Sync, { type Rejection = Response; async fn from_request(req: Request, state: &S) -> Result { let Ok(bytes) = Bytes::from_request(req, state).await else { return Err(ResponseCode::InternalServerError.msg("Failed to read request body")); }; let Ok(string) = String::from_utf8(bytes.bytes().flatten().collect()) else { return Err(ResponseCode::BadRequest.msg("Invalid utf8 body")) }; let Ok(value) = serde_json::from_str::(&string) else { return Err(ResponseCode::BadRequest.msg("Invalid request body")) }; if let Err(msg) = value.check() { return Err(ResponseCode::BadRequest.msg(&msg)); } Ok(Json(value)) } } pub type CheckResult = std::result::Result<(), String>; pub trait Check { fn check(&self) -> CheckResult; fn assert_length(string: &str, min: usize, max: usize, message: &str) -> CheckResult { if string.len() < min || string.len() > max { return Err(message.to_string()) } Ok(()) } fn assert_range(number: u64, min: u64, max: u64, message: &str) -> CheckResult { if number < min || number > max { return Err(message.to_string()) } Ok(()) } }