summaryrefslogtreecommitdiff
path: root/src/types
diff options
context:
space:
mode:
Diffstat (limited to 'src/types')
-rw-r--r--src/types/extract.rs97
-rw-r--r--src/types/http.rs22
-rw-r--r--src/types/mod.rs6
-rw-r--r--src/types/post.rs25
-rw-r--r--src/types/session.rs16
-rw-r--r--src/types/user.rs28
6 files changed, 119 insertions, 75 deletions
diff --git a/src/types/extract.rs b/src/types/extract.rs
index b4a6cfc..f21c352 100644
--- a/src/types/extract.rs
+++ b/src/types/extract.rs
@@ -1,43 +1,61 @@
use std::{io::Read, net::SocketAddr};
-use axum::{extract::{FromRequestParts, FromRequest, ConnectInfo}, async_trait, response::Response, http::{request::Parts, Request}, TypedHeader, headers::Cookie, body::HttpBody, BoxError, RequestExt};
+use axum::{
+ async_trait,
+ body::HttpBody,
+ extract::{ConnectInfo, FromRequest, FromRequestParts},
+ headers::Cookie,
+ http::{request::Parts, Request},
+ response::Response,
+ BoxError, RequestExt, TypedHeader,
+};
use bytes::Bytes;
use serde::de::DeserializeOwned;
-use crate::{types::{user::User, http::{ResponseCode, Result}, session::Session}, console};
+use crate::{
+ console,
+ types::{
+ http::{ResponseCode, Result},
+ session::Session,
+ user::User,
+ },
+};
pub struct AuthorizedUser(pub User);
#[async_trait]
-impl<S> FromRequestParts<S> for AuthorizedUser where S: Send + Sync {
- type Rejection = Response;
-
- async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self> {
-
+impl<S> FromRequestParts<S> for AuthorizedUser
+where
+ S: Send + Sync,
+{
+ type Rejection = Response;
+
+ async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self> {
let Ok(Some(cookies)) = Option::<TypedHeader<Cookie>>::from_request_parts(parts, state).await else {
return Err(ResponseCode::Forbidden.text("No cookies provided"))
};
-
+
let Some(token) = cookies.get("auth") else {
return Err(ResponseCode::Forbidden.text("No auth token provided"))
};
-
+
let Ok(session) = Session::from_token(token) else {
return Err(ResponseCode::Unauthorized.text("Auth token invalid"))
};
-
+
let Ok(user) = User::from_user_id(session.user_id, true) else {
tracing::error!("Valid token but no valid user");
return Err(ResponseCode::InternalServerError.text("Valid token but no valid user"))
};
Ok(Self(user))
- }
+ }
}
pub struct Log;
#[async_trait]
-impl<S, B> FromRequest<S, B> for Log where
+impl<S, B> FromRequest<S, B> for Log
+where
B: HttpBody + Sync + Send + 'static,
B::Data: Send,
B::Error: Into<BoxError>,
@@ -45,26 +63,35 @@ impl<S, B> FromRequest<S, B> for Log where
{
type Rejection = Response;
- async fn from_request(mut req: Request<B>, state: &S) -> Result<Self> {
-
+ async fn from_request(mut req: Request<B>, state: &S) -> Result<Self> {
let Ok(ConnectInfo(info)) = req.extract_parts::<ConnectInfo<SocketAddr>>().await else {
return Ok(Self)
};
let method = req.method().clone();
- let path = req.extensions().get::<RouterURI>().map_or("", |path| path.0);
+ let path = req
+ .extensions()
+ .get::<RouterURI>()
+ .map_or("", |path| path.0);
let uri = req.uri().clone();
-
+
let Ok(bytes) = Bytes::from_request(req, state).await else {
console::log(info.ip(), method.clone(), uri.clone(), Some(path.to_string()), None).await;
return Ok(Self)
};
-
+
let Ok(body) = String::from_utf8(bytes.bytes().flatten().collect()) else {
console::log(info.ip(), method.clone(), uri.clone(), Some(path.to_string()), None).await;
return Ok(Self)
};
-
- console::log(info.ip(), method.clone(), uri.clone(), Some(path.to_string()), Some(body.to_string())).await;
+
+ console::log(
+ info.ip(),
+ method.clone(),
+ uri.clone(),
+ Some(path.to_string()),
+ Some(body.to_string()),
+ )
+ .await;
Ok(Self)
}
@@ -73,7 +100,8 @@ impl<S, B> FromRequest<S, B> for Log where
pub struct Json<T>(pub T);
#[async_trait]
-impl<T, S, B> FromRequest<S, B> for Json<T> where
+impl<T, S, B> FromRequest<S, B> for Json<T>
+where
T: DeserializeOwned + Check,
B: HttpBody + Sync + Send + 'static,
B::Data: Send,
@@ -82,26 +110,35 @@ impl<T, S, B> FromRequest<S, B> for Json<T> where
{
type Rejection = Response;
- async fn from_request(mut req: Request<B>, state: &S) -> Result<Self> {
-
+ async fn from_request(mut req: Request<B>, state: &S) -> Result<Self> {
let Ok(ConnectInfo(info)) = req.extract_parts::<ConnectInfo<SocketAddr>>().await else {
tracing::error!("Failed to read connection info");
return Err(ResponseCode::InternalServerError.text("Failed to read connection info"));
};
let method = req.method().clone();
- let path = req.extensions().get::<RouterURI>().map_or("", |path| path.0);
+ let path = req
+ .extensions()
+ .get::<RouterURI>()
+ .map_or("", |path| path.0);
let uri = req.uri().clone();
-
+
let Ok(bytes) = Bytes::from_request(req, state).await else {
tracing::error!("Failed to read request body");
return Err(ResponseCode::InternalServerError.text("Failed to read request body"));
};
-
+
let Ok(body) = String::from_utf8(bytes.bytes().flatten().collect()) else {
return Err(ResponseCode::BadRequest.text("Invalid utf8 body"))
};
-
- console::log(info.ip(), method.clone(), uri.clone(), Some(path.to_string()), Some(body.to_string())).await;
+
+ console::log(
+ info.ip(),
+ method.clone(),
+ uri.clone(),
+ Some(path.to_string()),
+ Some(body.to_string()),
+ )
+ .await;
let Ok(value) = serde_json::from_str::<T>(&body) else {
return Err(ResponseCode::BadRequest.text("Invalid request body"))
@@ -118,19 +155,18 @@ impl<T, S, B> FromRequest<S, B> for Json<T> where
pub type CheckResult = std::result::Result<(), String>;
pub trait Check {
-
fn check(&self) -> CheckResult;
fn assert_length(string: &str, min: usize, max: usize, message: &str) -> CheckResult {
if string.len() < min || string.len() > max {
- return Err(message.to_string())
+ return Err(message.to_string());
}
Ok(())
}
fn assert_range(number: u64, min: u64, max: u64, message: &str) -> CheckResult {
if number < min || number > max {
- return Err(message.to_string())
+ return Err(message.to_string());
}
Ok(())
}
@@ -138,4 +174,3 @@ pub trait Check {
#[derive(Clone)]
pub struct RouterURI(pub &'static str);
-
diff --git a/src/types/http.rs b/src/types/http.rs
index 0e7b703..8524b15 100644
--- a/src/types/http.rs
+++ b/src/types/http.rs
@@ -1,4 +1,9 @@
-use axum::{response::{IntoResponse, Response}, http::{StatusCode, Request, HeaderValue}, body::Body, headers::HeaderName};
+use axum::{
+ body::Body,
+ headers::HeaderName,
+ http::{HeaderValue, Request, StatusCode},
+ response::{IntoResponse, Response},
+};
use tower::ServiceExt;
use tower_http::services::ServeFile;
use tracing::instrument;
@@ -12,11 +17,10 @@ pub enum ResponseCode {
Forbidden,
NotFound,
ImATeapot,
- InternalServerError
+ InternalServerError,
}
impl ResponseCode {
-
const fn code(self) -> StatusCode {
match self {
Self::Success => StatusCode::OK,
@@ -26,7 +30,7 @@ impl ResponseCode {
Self::Forbidden => StatusCode::FORBIDDEN,
Self::NotFound => StatusCode::NOT_FOUND,
Self::ImATeapot => StatusCode::IM_A_TEAPOT,
- Self::InternalServerError => StatusCode::INTERNAL_SERVER_ERROR
+ Self::InternalServerError => StatusCode::INTERNAL_SERVER_ERROR,
}
}
@@ -39,7 +43,8 @@ impl ResponseCode {
pub fn json(self, json: &str) -> Response {
let mut res = (self.code(), json.to_owned()).into_response();
res.headers_mut().insert(
- HeaderName::from_static("content-type"), HeaderValue::from_static("application/json"),
+ HeaderName::from_static("content-type"),
+ HeaderValue::from_static("application/json"),
);
res
}
@@ -48,14 +53,15 @@ impl ResponseCode {
pub fn html(self, json: &str) -> Response {
let mut res = (self.code(), json.to_owned()).into_response();
res.headers_mut().insert(
- HeaderName::from_static("content-type"), HeaderValue::from_static("text/html"),
+ HeaderName::from_static("content-type"),
+ HeaderValue::from_static("text/html"),
);
res
}
#[instrument()]
pub async fn file(self, path: &str) -> Response {
- if !path.chars().any(|c| c == '.' ) {
+ if !path.chars().any(|c| c == '.') {
return Self::BadRequest.text("Folders cannot be served");
}
let path = format!("public{path}");
@@ -72,4 +78,4 @@ impl ResponseCode {
}
}
-pub type Result<T> = std::result::Result<T, Response>; \ No newline at end of file
+pub type Result<T> = std::result::Result<T, Response>;
diff --git a/src/types/mod.rs b/src/types/mod.rs
index 0ab104c..3449d5c 100644
--- a/src/types/mod.rs
+++ b/src/types/mod.rs
@@ -1,5 +1,5 @@
-pub mod user;
+pub mod extract;
+pub mod http;
pub mod post;
pub mod session;
-pub mod extract;
-pub mod http; \ No newline at end of file
+pub mod user;
diff --git a/src/types/post.rs b/src/types/post.rs
index 95aed0e..90eada2 100644
--- a/src/types/post.rs
+++ b/src/types/post.rs
@@ -1,10 +1,10 @@
use core::fmt;
-use std::collections::HashSet;
use serde::Serialize;
+use std::collections::HashSet;
use tracing::instrument;
use crate::database;
-use crate::types::http::{Result, ResponseCode};
+use crate::types::http::{ResponseCode, Result};
#[derive(Serialize)]
pub struct Post {
@@ -13,19 +13,18 @@ pub struct Post {
pub content: String,
pub likes: HashSet<u64>,
pub comments: Vec<(u64, String)>,
- pub date: u64
+ pub date: u64,
}
impl fmt::Debug for Post {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Post")
- .field("post_id", &self.post_id)
- .finish()
+ .field("post_id", &self.post_id)
+ .finish()
}
}
impl Post {
-
#[instrument()]
pub fn from_post_id(post_id: u64) -> Result<Self> {
let Ok(Some(post)) = database::posts::get_post(post_id) else {
@@ -64,10 +63,10 @@ impl Post {
#[instrument()]
pub fn comment(&mut self, user_id: u64, content: String) -> Result<()> {
self.comments.push((user_id, content));
-
+
if database::posts::update_post(self.post_id, &self.likes, &self.comments).is_err() {
tracing::error!("Failed to comment on post");
- return Err(ResponseCode::InternalServerError.text("Failed to comment on post"))
+ return Err(ResponseCode::InternalServerError.text("Failed to comment on post"));
}
Ok(())
@@ -75,19 +74,19 @@ impl Post {
#[instrument()]
pub fn like(&mut self, user_id: u64, state: bool) -> Result<()> {
-
if state {
self.likes.insert(user_id);
} else {
self.likes.remove(&user_id);
}
-
+
if database::posts::update_post(self.post_id, &self.likes, &self.comments).is_err() {
tracing::error!("Failed to change like state on post");
- return Err(ResponseCode::InternalServerError.text("Failed to change like state on post"))
+ return Err(
+ ResponseCode::InternalServerError.text("Failed to change like state on post")
+ );
}
Ok(())
}
-
-} \ No newline at end of file
+}
diff --git a/src/types/session.rs b/src/types/session.rs
index 176e389..e704ac7 100644
--- a/src/types/session.rs
+++ b/src/types/session.rs
@@ -3,16 +3,15 @@ use serde::Serialize;
use tracing::instrument;
use crate::database;
-use crate::types::http::{Result, ResponseCode};
+use crate::types::http::{ResponseCode, Result};
#[derive(Serialize)]
pub struct Session {
pub user_id: u64,
- pub token: String
+ pub token: String,
}
impl Session {
-
#[instrument()]
pub fn from_token(token: &str) -> Result<Self> {
let Ok(Some(session)) = database::sessions::get_session(token) else {
@@ -24,10 +23,14 @@ impl Session {
#[instrument()]
pub fn new(user_id: u64) -> Result<Self> {
- let token: String = rand::thread_rng().sample_iter(&Alphanumeric).take(32).map(char::from).collect();
+ let token: String = rand::thread_rng()
+ .sample_iter(&Alphanumeric)
+ .take(32)
+ .map(char::from)
+ .collect();
match database::sessions::set_session(user_id, &token) {
Err(_) => Err(ResponseCode::BadRequest.text("Failed to create session")),
- Ok(_) => Ok(Self {user_id, token})
+ Ok(_) => Ok(Self { user_id, token }),
}
}
@@ -39,5 +42,4 @@ impl Session {
};
Ok(())
}
-
-} \ No newline at end of file
+}
diff --git a/src/types/user.rs b/src/types/user.rs
index 0013d7d..fcfbe91 100644
--- a/src/types/user.rs
+++ b/src/types/user.rs
@@ -1,10 +1,9 @@
-use serde::{Serialize, Deserialize};
+use serde::{Deserialize, Serialize};
use tracing::instrument;
use crate::api::auth::RegistrationRequet;
use crate::database;
-use crate::types::http::{Result, ResponseCode};
-
+use crate::types::http::{ResponseCode, Result};
#[derive(Serialize, Deserialize, Debug)]
pub struct User {
@@ -21,7 +20,6 @@ pub struct User {
}
impl User {
-
#[instrument()]
pub fn from_user_id(user_id: u64, hide_password: bool) -> Result<Self> {
let Ok(Some(user)) = database::users::get_user_by_id(user_id, hide_password) else {
@@ -33,12 +31,15 @@ impl User {
#[instrument()]
pub fn from_user_ids(user_ids: Vec<u64>) -> Vec<Self> {
- user_ids.iter().filter_map(|user_id| {
- let Ok(Some(user)) = database::users::get_user_by_id(*user_id, true) else {
+ user_ids
+ .iter()
+ .filter_map(|user_id| {
+ let Ok(Some(user)) = database::users::get_user_by_id(*user_id, true) else {
return None;
};
- Some(user)
- }).collect()
+ Some(user)
+ })
+ .collect()
}
#[instrument()]
@@ -70,13 +71,15 @@ impl User {
#[instrument()]
pub fn new(request: RegistrationRequet) -> Result<Self> {
if Self::from_email(&request.email).is_ok() {
- return Err(ResponseCode::BadRequest.text(&format!("Email is already in use by {}", &request.email)))
+ return Err(ResponseCode::BadRequest
+ .text(&format!("Email is already in use by {}", &request.email)));
}
if let Ok(user) = Self::from_password(&request.password) {
- return Err(ResponseCode::BadRequest.text(&format!("Password is already in use by {}", user.email)))
+ return Err(ResponseCode::BadRequest
+ .text(&format!("Password is already in use by {}", user.email)));
}
-
+
let Ok(user) = database::users::add_user(request) else {
tracing::error!("Failed to create new user");
return Err(ResponseCode::InternalServerError.text("Failed to create new uesr"))
@@ -84,5 +87,4 @@ impl User {
Ok(user)
}
-
-} \ No newline at end of file
+}