summaryrefslogtreecommitdiff
path: root/src/types/extract.rs
diff options
context:
space:
mode:
authorTyler Murphy <tylermurphy534@gmail.com>2023-01-28 18:04:00 -0500
committerTyler Murphy <tylermurphy534@gmail.com>2023-01-28 18:04:00 -0500
commitb58654fd70958d89b344a6f7acac204f67ae9879 (patch)
tree60a1960d0d265c9f661e633022164f33e099c81c /src/types/extract.rs
parentnew rust, clippy (diff)
downloadxssbook-b58654fd70958d89b344a6f7acac204f67ae9879.tar.gz
xssbook-b58654fd70958d89b344a6f7acac204f67ae9879.tar.bz2
xssbook-b58654fd70958d89b344a6f7acac204f67ae9879.zip
fmt
Diffstat (limited to 'src/types/extract.rs')
-rw-r--r--src/types/extract.rs97
1 files changed, 66 insertions, 31 deletions
diff --git a/src/types/extract.rs b/src/types/extract.rs
index b4a6cfc..f21c352 100644
--- a/src/types/extract.rs
+++ b/src/types/extract.rs
@@ -1,43 +1,61 @@
use std::{io::Read, net::SocketAddr};
-use axum::{extract::{FromRequestParts, FromRequest, ConnectInfo}, async_trait, response::Response, http::{request::Parts, Request}, TypedHeader, headers::Cookie, body::HttpBody, BoxError, RequestExt};
+use axum::{
+ async_trait,
+ body::HttpBody,
+ extract::{ConnectInfo, FromRequest, FromRequestParts},
+ headers::Cookie,
+ http::{request::Parts, Request},
+ response::Response,
+ BoxError, RequestExt, TypedHeader,
+};
use bytes::Bytes;
use serde::de::DeserializeOwned;
-use crate::{types::{user::User, http::{ResponseCode, Result}, session::Session}, console};
+use crate::{
+ console,
+ types::{
+ http::{ResponseCode, Result},
+ session::Session,
+ user::User,
+ },
+};
pub struct AuthorizedUser(pub User);
#[async_trait]
-impl<S> FromRequestParts<S> for AuthorizedUser where S: Send + Sync {
- type Rejection = Response;
-
- async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self> {
-
+impl<S> FromRequestParts<S> for AuthorizedUser
+where
+ S: Send + Sync,
+{
+ type Rejection = Response;
+
+ async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self> {
let Ok(Some(cookies)) = Option::<TypedHeader<Cookie>>::from_request_parts(parts, state).await else {
return Err(ResponseCode::Forbidden.text("No cookies provided"))
};
-
+
let Some(token) = cookies.get("auth") else {
return Err(ResponseCode::Forbidden.text("No auth token provided"))
};
-
+
let Ok(session) = Session::from_token(token) else {
return Err(ResponseCode::Unauthorized.text("Auth token invalid"))
};
-
+
let Ok(user) = User::from_user_id(session.user_id, true) else {
tracing::error!("Valid token but no valid user");
return Err(ResponseCode::InternalServerError.text("Valid token but no valid user"))
};
Ok(Self(user))
- }
+ }
}
pub struct Log;
#[async_trait]
-impl<S, B> FromRequest<S, B> for Log where
+impl<S, B> FromRequest<S, B> for Log
+where
B: HttpBody + Sync + Send + 'static,
B::Data: Send,
B::Error: Into<BoxError>,
@@ -45,26 +63,35 @@ impl<S, B> FromRequest<S, B> for Log where
{
type Rejection = Response;
- async fn from_request(mut req: Request<B>, state: &S) -> Result<Self> {
-
+ async fn from_request(mut req: Request<B>, state: &S) -> Result<Self> {
let Ok(ConnectInfo(info)) = req.extract_parts::<ConnectInfo<SocketAddr>>().await else {
return Ok(Self)
};
let method = req.method().clone();
- let path = req.extensions().get::<RouterURI>().map_or("", |path| path.0);
+ let path = req
+ .extensions()
+ .get::<RouterURI>()
+ .map_or("", |path| path.0);
let uri = req.uri().clone();
-
+
let Ok(bytes) = Bytes::from_request(req, state).await else {
console::log(info.ip(), method.clone(), uri.clone(), Some(path.to_string()), None).await;
return Ok(Self)
};
-
+
let Ok(body) = String::from_utf8(bytes.bytes().flatten().collect()) else {
console::log(info.ip(), method.clone(), uri.clone(), Some(path.to_string()), None).await;
return Ok(Self)
};
-
- console::log(info.ip(), method.clone(), uri.clone(), Some(path.to_string()), Some(body.to_string())).await;
+
+ console::log(
+ info.ip(),
+ method.clone(),
+ uri.clone(),
+ Some(path.to_string()),
+ Some(body.to_string()),
+ )
+ .await;
Ok(Self)
}
@@ -73,7 +100,8 @@ impl<S, B> FromRequest<S, B> for Log where
pub struct Json<T>(pub T);
#[async_trait]
-impl<T, S, B> FromRequest<S, B> for Json<T> where
+impl<T, S, B> FromRequest<S, B> for Json<T>
+where
T: DeserializeOwned + Check,
B: HttpBody + Sync + Send + 'static,
B::Data: Send,
@@ -82,26 +110,35 @@ impl<T, S, B> FromRequest<S, B> for Json<T> where
{
type Rejection = Response;
- async fn from_request(mut req: Request<B>, state: &S) -> Result<Self> {
-
+ async fn from_request(mut req: Request<B>, state: &S) -> Result<Self> {
let Ok(ConnectInfo(info)) = req.extract_parts::<ConnectInfo<SocketAddr>>().await else {
tracing::error!("Failed to read connection info");
return Err(ResponseCode::InternalServerError.text("Failed to read connection info"));
};
let method = req.method().clone();
- let path = req.extensions().get::<RouterURI>().map_or("", |path| path.0);
+ let path = req
+ .extensions()
+ .get::<RouterURI>()
+ .map_or("", |path| path.0);
let uri = req.uri().clone();
-
+
let Ok(bytes) = Bytes::from_request(req, state).await else {
tracing::error!("Failed to read request body");
return Err(ResponseCode::InternalServerError.text("Failed to read request body"));
};
-
+
let Ok(body) = String::from_utf8(bytes.bytes().flatten().collect()) else {
return Err(ResponseCode::BadRequest.text("Invalid utf8 body"))
};
-
- console::log(info.ip(), method.clone(), uri.clone(), Some(path.to_string()), Some(body.to_string())).await;
+
+ console::log(
+ info.ip(),
+ method.clone(),
+ uri.clone(),
+ Some(path.to_string()),
+ Some(body.to_string()),
+ )
+ .await;
let Ok(value) = serde_json::from_str::<T>(&body) else {
return Err(ResponseCode::BadRequest.text("Invalid request body"))
@@ -118,19 +155,18 @@ impl<T, S, B> FromRequest<S, B> for Json<T> where
pub type CheckResult = std::result::Result<(), String>;
pub trait Check {
-
fn check(&self) -> CheckResult;
fn assert_length(string: &str, min: usize, max: usize, message: &str) -> CheckResult {
if string.len() < min || string.len() > max {
- return Err(message.to_string())
+ return Err(message.to_string());
}
Ok(())
}
fn assert_range(number: u64, min: u64, max: u64, message: &str) -> CheckResult {
if number < min || number > max {
- return Err(message.to_string())
+ return Err(message.to_string());
}
Ok(())
}
@@ -138,4 +174,3 @@ pub trait Check {
#[derive(Clone)]
pub struct RouterURI(pub &'static str);
-