use std::{ io::{Cursor, Read}, net::{IpAddr, SocketAddr}, }; use axum::{ async_trait, body::HttpBody, extract::{ConnectInfo, FromRequest, FromRequestParts}, http::{header::USER_AGENT, request::Parts, Request}, response::Response, BoxError, RequestExt, }; use bytes::Bytes; use image::{io::Reader, DynamicImage, ImageFormat}; use serde::de::DeserializeOwned; use tower_cookies::Cookies; use crate::{ public::admin, public::console, types::{ http::{ResponseCode, Result}, session::Session, user::User, }, }; pub struct RequestIp(pub IpAddr); #[async_trait] impl FromRequestParts for RequestIp where S: Send + Sync, { type Rejection = Response; async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result { let headers = &parts.headers; let forwardedfor = headers .get("x-forwarded-for") .and_then(|h| h.to_str().ok()) .and_then(|h| { h.split(',') .rev() .find_map(|s| s.trim().parse::().ok()) }); if let Some(forwardedfor) = forwardedfor { return Ok(Self(forwardedfor)); } let realip = headers .get("x-real-ip") .and_then(|hv| hv.to_str().ok()) .and_then(|s| s.parse::().ok()); if let Some(realip) = realip { return Ok(Self(realip)); } let realip = headers .get("x-real-ip") .and_then(|hv| hv.to_str().ok()) .and_then(|s| s.parse::().ok()); if let Some(realip) = realip { return Ok(Self(realip)); } let info = parts.extensions.get::>(); if let Some(info) = info { return Ok(Self(info.0.ip())); } Err(ResponseCode::Forbidden.text("You have no ip")) } } pub struct AuthorizedUser(pub User); #[async_trait] impl FromRequestParts for AuthorizedUser where S: Send + Sync, { type Rejection = Response; async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { let Ok(Some(cookies)) = Option::::from_request_parts(parts, state).await else { return Err(ResponseCode::Forbidden.text("No cookies provided")) }; let Some(token) = cookies.get("auth") else { return Err(ResponseCode::Forbidden.text("No auth token provided")) }; let Ok(session) = Session::from_token(token.value()) else { return Err(ResponseCode::Unauthorized.text("Auth token invalid")) }; let Ok(user) = User::from_user_id(session.user_id, true) else { tracing::error!("Valid token but no valid user"); return Err(ResponseCode::InternalServerError.text("Valid token but no valid user")) }; Ok(Self(user)) } } pub struct AdminUser; #[async_trait] impl FromRequestParts for AdminUser where S: Send + Sync, { type Rejection = Response; async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { let Ok(Some(cookies)) = Option::::from_request_parts(parts, state).await else { return Err(ResponseCode::Forbidden.text("No cookies provided")) }; let Some(secret) = cookies.get("admin") else { return Err(ResponseCode::Forbidden.text("No admin secret provided")) }; let check = admin::get_secret().await; if check != secret.value() { return Err(ResponseCode::Unauthorized.text("Auth token invalid")); } Ok(Self) } } pub struct Log; #[async_trait] impl FromRequest for Log where B: HttpBody + Sync + Send + 'static, B::Data: Send, B::Error: Into, S: Send + Sync, { type Rejection = Response; async fn from_request(req: Request, state: &S) -> Result { parse_body(req, state).await?; Ok(Self) } } pub struct Png(pub DynamicImage); #[async_trait] impl FromRequest for Png where B: HttpBody + Sync + Send + 'static, B::Data: Send, B::Error: Into, S: Send + Sync, { type Rejection = Response; async fn from_request(req: Request, state: &S) -> Result { let bytes = match read_body(req, state).await { Ok(body) => body, Err(err) => return Err(err), }; let mut reader = Reader::new(Cursor::new(bytes)); reader.set_format(ImageFormat::Png); let Ok(img) = reader.decode() else { return Err(ResponseCode::BadRequest.text("Failed to decode png image")) }; Ok(Self(img)) } } pub struct Json(pub T); #[async_trait] impl FromRequest for Json where T: DeserializeOwned + Check, B: HttpBody + Sync + Send + 'static, B::Data: Send, B::Error: Into, S: Send + Sync, { type Rejection = Response; async fn from_request(req: Request, state: &S) -> Result { let body = match parse_body(req, state).await { Ok(body) => body, Err(err) => return Err(err), }; let Ok(value) = serde_json::from_str::(&body) else { return Err(ResponseCode::BadRequest.text("Body does not match paramaters")) }; if let Err(msg) = value.check() { return Err(ResponseCode::BadRequest.text(&msg)); } Ok(Self(value)) } } pub type CheckResult = std::result::Result<(), String>; pub trait Check { fn check(&self) -> CheckResult; fn assert_length(string: &str, min: usize, max: usize, message: &str) -> CheckResult { if string.len() < min || string.len() > max { return Err(message.to_string()); } Ok(()) } fn assert_range(number: u64, min: u64, max: u64, message: &str) -> CheckResult { if number < min || number > max { return Err(message.to_string()); } Ok(()) } } pub struct UserAgent(pub String); #[async_trait] impl FromRequestParts for UserAgent where S: Send + Sync, { type Rejection = Response; async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result { let agent = parts.headers.get(USER_AGENT); let Some(agent) = agent else { return Err(ResponseCode::BadRequest.text("Bad Request")); }; let Ok(agent) = agent.to_str() else { return Err(ResponseCode::BadRequest.text("Bad Request")); }; Ok(Self(agent.to_string())) } } async fn read_body(mut req: Request, state: &S) -> Result> where B: HttpBody + Sync + Send + 'static, B::Data: Send, B::Error: Into, S: Send + Sync, { let Ok(RequestIp(ip)) = req.extract_parts::().await else { tracing::error!("Failed to read client ip"); return Err(ResponseCode::InternalServerError.text("Failed to read client ip")); }; let method = req.method().clone(); let uri = req.uri().clone(); let path = req .extensions() .get::() .map_or("", |path| path.0); let Ok(bytes) = Bytes::from_request(req, state).await else { return Err(ResponseCode::BadRequest.text("Request can be at most 512kb")); }; console::log(ip, method, uri, Some(path.to_string()), None).await; Ok(bytes.bytes().flatten().collect()) } async fn parse_body(mut req: Request, state: &S) -> Result where B: HttpBody + Sync + Send + 'static, B::Data: Send, B::Error: Into, S: Send + Sync, { let Ok(RequestIp(ip)) = req.extract_parts::().await else { tracing::error!("Failed to read client ip"); return Err(ResponseCode::InternalServerError.text("Failed to read client ip")); }; let method = req.method().clone(); let uri = req.uri().clone(); let path = req .extensions() .get::() .map_or("", |path| path.0); let Ok(bytes) = Bytes::from_request(req, state).await else { return Err(ResponseCode::BadRequest.text("Request can be at most 512kb")); }; let Ok(body) = String::from_utf8(bytes.bytes().flatten().collect()) else { return Err(ResponseCode::BadRequest.text("Invalid utf8 body")) }; console::log( ip, method, uri, Some(path.to_string()), Some(body.to_string()), ) .await; Ok(body) } #[derive(Clone)] pub struct RouterURI(pub &'static str);