diff options
Diffstat (limited to 'src/types')
-rw-r--r-- | src/types/extract.rs | 60 |
1 files changed, 55 insertions, 5 deletions
diff --git a/src/types/extract.rs b/src/types/extract.rs index 54f250a..4d7ac51 100644 --- a/src/types/extract.rs +++ b/src/types/extract.rs @@ -1,14 +1,13 @@ -use std::io::{Read, Cursor}; +use std::{io::{Read, Cursor}, net::{IpAddr, SocketAddr}}; use axum::{ async_trait, body::HttpBody, - extract::{FromRequest, FromRequestParts}, + extract::{FromRequest, FromRequestParts, ConnectInfo}, http::{request::Parts, Request}, response::Response, BoxError, RequestExt, }; -use axum_client_ip::ClientIp; use bytes::Bytes; use image::{io::Reader, ImageFormat, DynamicImage}; use serde::de::DeserializeOwned; @@ -23,6 +22,57 @@ use crate::{ }, }; +pub struct RequestIp(pub IpAddr); + +#[async_trait] +impl<S> FromRequestParts<S> for RequestIp +where + S: Send + Sync, +{ + type Rejection = Response; + + async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self> { + + let headers = &parts.headers; + + let forwardedfor = headers.get("x-forwarded-for") + .and_then(|h| h.to_str().ok()) + .and_then(|h| + h.split(',') + .rev() + .find_map(|s| s.trim().parse::<IpAddr>().ok()) + ); + + if let Some(forwardedfor) = forwardedfor { + return Ok(RequestIp(forwardedfor)) + } + + let realip = headers.get("x-real-ip") + .and_then(|hv| hv.to_str().ok()) + .and_then(|s| s.parse::<IpAddr>().ok()); + + if let Some(realip) = realip { + return Ok(RequestIp(realip)) + } + + let realip = headers.get("x-real-ip") + .and_then(|hv| hv.to_str().ok()) + .and_then(|s| s.parse::<IpAddr>().ok()); + + if let Some(realip) = realip { + return Ok(RequestIp(realip)) + } + + let info = parts.extensions.get::<ConnectInfo<SocketAddr>>(); + + if let Some(info) = info { + return Ok(RequestIp(info.0.ip())) + } + + Err(ResponseCode::Forbidden.text("You have no ip")) + } +} + pub struct AuthorizedUser(pub User); #[async_trait] @@ -189,7 +239,7 @@ where S: Send + Sync, { - let Ok(ClientIp(ip)) = req.extract_parts::<ClientIp>().await else { + let Ok(RequestIp(ip)) = req.extract_parts::<RequestIp>().await else { tracing::error!("Failed to read client ip"); return Err(ResponseCode::InternalServerError.text("Failed to read client ip")); }; @@ -224,7 +274,7 @@ where B::Error: Into<BoxError>, S: Send + Sync, { - let Ok(ClientIp(ip)) = req.extract_parts::<ClientIp>().await else { + let Ok(RequestIp(ip)) = req.extract_parts::<RequestIp>().await else { tracing::error!("Failed to read client ip"); return Err(ResponseCode::InternalServerError.text("Failed to read client ip")); }; |