summaryrefslogtreecommitdiff
path: root/src/types/extract.rs
blob: bb50aa729db4b8bb88833c45be1fb9b212543bfd (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
use std::io::Read;

use axum::{extract::{FromRequestParts, FromRequest}, async_trait, response::Response, http::{request::Parts, Request}, TypedHeader, headers::Cookie, body::HttpBody, BoxError};
use bytes::Bytes;
use serde::de::DeserializeOwned;

use crate::types::{user::User, response::{ResponseCode, Result}, session::Session};

pub struct AuthorizedUser(pub User);

#[async_trait]
impl<S> FromRequestParts<S> for AuthorizedUser where S: Send + Sync {
	type Rejection = Response;
	
	async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self> {
        
        let Ok(Some(cookies)) = Option::<TypedHeader<Cookie>>::from_request_parts(parts, state).await else {
            return Err(ResponseCode::Forbidden.msg("No cookies provided"))
        };
    
        let Some(token) = cookies.get("auth") else {
            return Err(ResponseCode::Forbidden.msg("No auth token provided"))
        };
    
        let Ok(session) = Session::from_token(&token) else {
            return Err(ResponseCode::Unauthorized.msg("Auth token invalid"))
        };
    
        let Ok(user) = User::from_user_id(session.user_id, true)  else {
            return Err(ResponseCode::InternalServerError.msg("Valid token but no valid user"))
        };

        Ok(AuthorizedUser(user))
	}
}

pub struct Json<T>(pub T);

#[async_trait]
impl<T, S, B> FromRequest<S, B> for Json<T> where 
    T: DeserializeOwned + Check,
    B: HttpBody + Send + 'static,
    B::Data: Send,
    B::Error: Into<BoxError>,
    S: Send + Sync,
{
    type Rejection = Response;

    async fn from_request(req: Request<B>, state: &S) ->  Result<Self> {
        
        let Ok(bytes) = Bytes::from_request(req, state).await else {
            return Err(ResponseCode::InternalServerError.msg("Failed to read request body"));
        };
        
        let Ok(string) = String::from_utf8(bytes.bytes().flatten().collect()) else {
            return Err(ResponseCode::BadRequest.msg("Invalid utf8 body"))
        };

        let Ok(value) = serde_json::from_str::<T>(&string) else {
            return Err(ResponseCode::BadRequest.msg("Invalid request body"))
        };

        if let Err(msg) = value.check() {
            return Err(ResponseCode::BadRequest.msg(&msg));
        }

        Ok(Json(value))
    }
}

pub type CheckResult = std::result::Result<(), String>;

pub trait Check {
    
    fn check(&self) -> CheckResult;

    fn assert_length(string: &str, min: usize, max: usize, message: &str) -> CheckResult {
        if string.len() < min || string.len() > max {
            return Err(message.to_string())
        }
        Ok(())
    }

    fn assert_range(number: u64, min: u64, max: u64, message: &str) -> CheckResult {
        if number < min || number > max {
            return Err(message.to_string())
        }
        Ok(())
    }
}