use std::{io::Read, net::SocketAddr}; use axum::{extract::{FromRequestParts, FromRequest, ConnectInfo}, async_trait, response::Response, http::{request::Parts, Request}, TypedHeader, headers::Cookie, body::HttpBody, BoxError, RequestExt}; use bytes::Bytes; use serde::de::DeserializeOwned; use crate::{types::{user::User, response::{ResponseCode, Result}, session::Session}, console}; pub struct AuthorizedUser(pub User); #[async_trait] impl FromRequestParts for AuthorizedUser where S: Send + Sync { type Rejection = Response; async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { let Ok(Some(cookies)) = Option::>::from_request_parts(parts, state).await else { return Err(ResponseCode::Forbidden.text("No cookies provided")) }; let Some(token) = cookies.get("auth") else { return Err(ResponseCode::Forbidden.text("No auth token provided")) }; let Ok(session) = Session::from_token(&token) else { return Err(ResponseCode::Unauthorized.text("Auth token invalid")) }; let Ok(user) = User::from_user_id(session.user_id, true) else { return Err(ResponseCode::InternalServerError.text("Valid token but no valid user")) }; Ok(AuthorizedUser(user)) } } pub struct Log; #[async_trait] impl FromRequest for Log where B: HttpBody + Sync + Send + 'static, B::Data: Send, B::Error: Into, S: Send + Sync, { type Rejection = Response; async fn from_request(mut req: Request, state: &S) -> Result { let Ok(ConnectInfo(info)) = req.extract_parts::>().await else { return Ok(Log) }; let method = req.method().clone(); let path = req.extensions().get::().unwrap().0; let uri = req.uri().clone(); let Ok(bytes) = Bytes::from_request(req, state).await else { console::log(info.ip().clone(), method.clone(), uri.clone(), Some(path.to_string()), None).await; return Ok(Log) }; let Ok(body) = String::from_utf8(bytes.bytes().flatten().collect()) else { console::log(info.ip().clone(), method.clone(), uri.clone(), Some(path.to_string()), None).await; return Ok(Log) }; console::log(info.ip().clone(), method.clone(), uri.clone(), Some(path.to_string()), Some(body.to_string())).await; Ok(Log) } } pub struct Json(pub T); #[async_trait] impl FromRequest for Json where T: DeserializeOwned + Check, B: HttpBody + Sync + Send + 'static, B::Data: Send, B::Error: Into, S: Send + Sync, { type Rejection = Response; async fn from_request(mut req: Request, state: &S) -> Result { let Ok(ConnectInfo(info)) = req.extract_parts::>().await else { return Err(ResponseCode::InternalServerError.text("Failed to read connection info")); }; let method = req.method().clone(); let path = req.extensions().get::().unwrap().0; let uri = req.uri().clone(); let Ok(bytes) = Bytes::from_request(req, state).await else { return Err(ResponseCode::InternalServerError.text("Failed to read request body")); }; let Ok(body) = String::from_utf8(bytes.bytes().flatten().collect()) else { return Err(ResponseCode::BadRequest.text("Invalid utf8 body")) }; console::log(info.ip().clone(), method.clone(), uri.clone(), Some(path.to_string()), Some(body.to_string())).await; let Ok(value) = serde_json::from_str::(&body) else { return Err(ResponseCode::BadRequest.text("Invalid request body")) }; if let Err(msg) = value.check() { return Err(ResponseCode::BadRequest.text(&msg)); } Ok(Json(value)) } } pub type CheckResult = std::result::Result<(), String>; pub trait Check { fn check(&self) -> CheckResult; fn assert_length(string: &str, min: usize, max: usize, message: &str) -> CheckResult { if string.len() < min || string.len() > max { return Err(message.to_string()) } Ok(()) } fn assert_range(number: u64, min: u64, max: u64, message: &str) -> CheckResult { if number < min || number > max { return Err(message.to_string()) } Ok(()) } } #[derive(Clone)] pub struct RouterURI(pub &'static str);