1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
|
use std::io::Read;
use axum::{extract::{FromRequestParts, FromRequest}, async_trait, response::Response, http::{request::Parts, Request}, TypedHeader, headers::Cookie, body::HttpBody, BoxError};
use bytes::Bytes;
use serde::de::DeserializeOwned;
use crate::types::{user::User, response::{ResponseCode, Result}, session::Session};
pub struct AuthorizedUser(pub User);
#[async_trait]
impl<S> FromRequestParts<S> for AuthorizedUser where S: Send + Sync {
type Rejection = Response;
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self> {
let Ok(Some(cookies)) = Option::<TypedHeader<Cookie>>::from_request_parts(parts, state).await else {
return Err(ResponseCode::Forbidden.msg("No cookies provided"))
};
let Some(token) = cookies.get("auth") else {
return Err(ResponseCode::Forbidden.msg("No auth token provided"))
};
let Ok(session) = Session::from_token(&token) else {
return Err(ResponseCode::Unauthorized.msg("Auth token invalid"))
};
let Ok(user) = User::from_user_id(session.user_id, true) else {
return Err(ResponseCode::InternalServerError.msg("Valid token but no valid user"))
};
Ok(AuthorizedUser(user))
}
}
pub struct Json<T>(pub T);
#[async_trait]
impl<T, S, B> FromRequest<S, B> for Json<T> where
T: DeserializeOwned + Check,
B: HttpBody + Send + 'static,
B::Data: Send,
B::Error: Into<BoxError>,
S: Send + Sync,
{
type Rejection = Response;
async fn from_request(req: Request<B>, state: &S) -> Result<Self> {
let Ok(bytes) = Bytes::from_request(req, state).await else {
return Err(ResponseCode::InternalServerError.msg("Failed to read request body"));
};
let Ok(string) = String::from_utf8(bytes.bytes().flatten().collect()) else {
return Err(ResponseCode::BadRequest.msg("Invalid utf8 body"))
};
let Ok(value) = serde_json::from_str::<T>(&string) else {
return Err(ResponseCode::BadRequest.msg("Invalid request body"))
};
if let Err(msg) = value.check() {
return Err(ResponseCode::BadRequest.msg(&msg));
}
Ok(Json(value))
}
}
pub type CheckResult = std::result::Result<(), String>;
pub trait Check {
fn check(&self) -> CheckResult;
fn assert_length(string: &str, min: usize, max: usize, message: &str) -> CheckResult {
if string.len() < min || string.len() > max {
return Err(message.to_string())
}
Ok(())
}
fn assert_range(number: u64, min: u64, max: u64, message: &str) -> CheckResult {
if number < min || number > max {
return Err(message.to_string())
}
Ok(())
}
}
|