use std::io::Read; use axum::{ async_trait, body::HttpBody, extract::{FromRequest, FromRequestParts}, headers::Cookie, http::{request::Parts, Request}, response::Response, BoxError, RequestExt, TypedHeader, }; use axum_client_ip::ClientIp; use bytes::Bytes; use serde::de::DeserializeOwned; use crate::{ console, types::{ http::{ResponseCode, Result}, session::Session, user::User, }, }; pub struct AuthorizedUser(pub User); #[async_trait] impl FromRequestParts for AuthorizedUser where S: Send + Sync, { type Rejection = Response; async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { let Ok(Some(cookies)) = Option::>::from_request_parts(parts, state).await else { return Err(ResponseCode::Forbidden.text("No cookies provided")) }; let Some(token) = cookies.get("auth") else { return Err(ResponseCode::Forbidden.text("No auth token provided")) }; let Ok(session) = Session::from_token(token) else { return Err(ResponseCode::Unauthorized.text("Auth token invalid")) }; let Ok(user) = User::from_user_id(session.user_id, true) else { tracing::error!("Valid token but no valid user"); return Err(ResponseCode::InternalServerError.text("Valid token but no valid user")) }; Ok(Self(user)) } } pub struct Log; #[async_trait] impl FromRequest for Log where B: HttpBody + Sync + Send + 'static, B::Data: Send, B::Error: Into, S: Send + Sync, { type Rejection = Response; async fn from_request(req: Request, state: &S) -> Result { parse_body(req, state).await?; Ok(Self) } } pub struct Json(pub T); #[async_trait] impl FromRequest for Json where T: DeserializeOwned + Check, B: HttpBody + Sync + Send + 'static, B::Data: Send, B::Error: Into, S: Send + Sync, { type Rejection = Response; async fn from_request(req: Request, state: &S) -> Result { let body = match parse_body(req, state).await { Ok(body) => body, Err(err) => return Err(err) }; let Ok(value) = serde_json::from_str::(&body) else { return Err(ResponseCode::BadRequest.text("Invalid request body")) }; if let Err(msg) = value.check() { return Err(ResponseCode::BadRequest.text(&msg)); } Ok(Self(value)) } } pub type CheckResult = std::result::Result<(), String>; pub trait Check { fn check(&self) -> CheckResult; fn assert_length(string: &str, min: usize, max: usize, message: &str) -> CheckResult { if string.len() < min || string.len() > max { return Err(message.to_string()); } Ok(()) } fn assert_range(number: u64, min: u64, max: u64, message: &str) -> CheckResult { if number < min || number > max { return Err(message.to_string()); } Ok(()) } } pub async fn parse_body(mut req: Request, state: &S) -> Result where B: HttpBody + Sync + Send + 'static, B::Data: Send, B::Error: Into, S: Send + Sync { let Ok(ClientIp(ip)) = req.extract_parts::().await else { tracing::error!("Failed to read client ip"); return Err(ResponseCode::InternalServerError.text("Failed to read client ip")); }; let method = req.method().clone(); let uri = req.uri().clone(); let path = req .extensions() .get::() .map_or("", |path| path.0); let Ok(bytes) = Bytes::from_request(req, state).await else { tracing::error!("Failed to read request body"); return Err(ResponseCode::InternalServerError.text("Failed to read request body")); }; let Ok(body) = String::from_utf8(bytes.bytes().flatten().collect()) else { return Err(ResponseCode::BadRequest.text("Invalid utf8 body")) }; console::log( ip, method, uri, Some(path.to_string()), Some(body.to_string()), ) .await; Ok(body) } #[derive(Clone)] pub struct RouterURI(pub &'static str);