This commit is contained in:
Tyler Murphy 2023-01-28 18:04:00 -05:00
parent 0fbecaba3d
commit b58654fd70
17 changed files with 520 additions and 258 deletions

View file

@ -1,9 +1,14 @@
use axum::{Router, routing::post, response::Response}; use axum::{response::Response, routing::post, Router};
use serde::Deserialize; use serde::Deserialize;
use time::{OffsetDateTime, Duration}; use time::{Duration, OffsetDateTime};
use tower_cookies::{Cookies, Cookie}; use tower_cookies::{Cookie, Cookies};
use crate::types::{user::User, http::ResponseCode, session::Session, extract::{Json, AuthorizedUser, Check, CheckResult, Log}}; use crate::types::{
extract::{AuthorizedUser, Check, CheckResult, Json, Log},
http::ResponseCode,
session::Session,
user::User,
};
#[derive(Deserialize, Debug)] #[derive(Deserialize, Debug)]
pub struct RegistrationRequet { pub struct RegistrationRequet {
@ -14,34 +19,67 @@ pub struct RegistrationRequet {
pub gender: String, pub gender: String,
pub day: u8, pub day: u8,
pub month: u8, pub month: u8,
pub year: u32 pub year: u32,
} }
impl Check for RegistrationRequet { impl Check for RegistrationRequet {
fn check(&self) -> CheckResult { fn check(&self) -> CheckResult {
Self::assert_length(&self.firstname, 1, 20, "First name can only by 1-20 characters long")?; Self::assert_length(
Self::assert_length(&self.lastname, 1, 20, "Last name can only by 1-20 characters long")?; &self.firstname,
1,
20,
"First name can only by 1-20 characters long",
)?;
Self::assert_length(
&self.lastname,
1,
20,
"Last name can only by 1-20 characters long",
)?;
Self::assert_length(&self.email, 1, 50, "Email can only by 1-50 characters long")?; Self::assert_length(&self.email, 1, 50, "Email can only by 1-50 characters long")?;
Self::assert_length(&self.password, 1, 50, "Password can only by 1-50 characters long")?; Self::assert_length(
Self::assert_length(&self.gender, 1, 100, "Gender can only by 1-100 characters long")?; &self.password,
Self::assert_range(u64::from(self.day), 1, 255, "Birthday day can only be between 1-255")?; 1,
Self::assert_range(u64::from(self.month), 1, 255, "Birthday month can only be between 1-255")?; 50,
Self::assert_range(u64::from(self.year), 1, 4_294_967_295, "Birthday year can only be between 1-4294967295")?; "Password can only by 1-50 characters long",
)?;
Self::assert_length(
&self.gender,
1,
100,
"Gender can only by 1-100 characters long",
)?;
Self::assert_range(
u64::from(self.day),
1,
255,
"Birthday day can only be between 1-255",
)?;
Self::assert_range(
u64::from(self.month),
1,
255,
"Birthday month can only be between 1-255",
)?;
Self::assert_range(
u64::from(self.year),
1,
4_294_967_295,
"Birthday year can only be between 1-4294967295",
)?;
Ok(()) Ok(())
} }
} }
async fn register(cookies: Cookies, Json(body): Json<RegistrationRequet>) -> Response { async fn register(cookies: Cookies, Json(body): Json<RegistrationRequet>) -> Response {
let user = match User::new(body) { let user = match User::new(body) {
Ok(user) => user, Ok(user) => user,
Err(err) => return err Err(err) => return err,
}; };
let session = match Session::new(user.user_id) { let session = match Session::new(user.user_id) {
Ok(session) => session, Ok(session) => session,
Err(err) => return err Err(err) => return err,
}; };
let mut now = OffsetDateTime::now_utc(); let mut now = OffsetDateTime::now_utc();
@ -71,18 +109,17 @@ impl Check for LoginRequest {
} }
async fn login(cookies: Cookies, Json(body): Json<LoginRequest>) -> Response { async fn login(cookies: Cookies, Json(body): Json<LoginRequest>) -> Response {
let Ok(user) = User::from_email(&body.email) else { let Ok(user) = User::from_email(&body.email) else {
return ResponseCode::BadRequest.text("Email is not registered") return ResponseCode::BadRequest.text("Email is not registered")
}; };
if user.password != body.password { if user.password != body.password {
return ResponseCode::BadRequest.text("Password is not correct") return ResponseCode::BadRequest.text("Password is not correct");
} }
let session = match Session::new(user.user_id) { let session = match Session::new(user.user_id) {
Ok(session) => session, Ok(session) => session,
Err(err) => return err Err(err) => return err,
}; };
let mut now = OffsetDateTime::now_utc(); let mut now = OffsetDateTime::now_utc();
@ -100,11 +137,10 @@ async fn login(cookies: Cookies, Json(body): Json<LoginRequest>) -> Response {
} }
async fn logout(cookies: Cookies, AuthorizedUser(user): AuthorizedUser, _: Log) -> Response { async fn logout(cookies: Cookies, AuthorizedUser(user): AuthorizedUser, _: Log) -> Response {
cookies.remove(Cookie::new("auth", "")); cookies.remove(Cookie::new("auth", ""));
if let Err(err) = Session::delete(user.user_id) { if let Err(err) = Session::delete(user.user_id) {
return err return err;
} }
ResponseCode::Success.text("Successfully logged out") ResponseCode::Success.text("Successfully logged out")

View file

@ -1,6 +1,13 @@
use axum::{Router, response::{Response, Redirect, IntoResponse}, routing::get}; use axum::{
response::{IntoResponse, Redirect, Response},
routing::get,
Router,
};
use crate::{types::{extract::AuthorizedUser, http::ResponseCode}, console}; use crate::{
console,
types::{extract::AuthorizedUser, http::ResponseCode},
};
async fn root(user: Option<AuthorizedUser>) -> Response { async fn root(user: Option<AuthorizedUser>) -> Response {
if user.is_some() { if user.is_some() {

View file

@ -1,23 +1,37 @@
use axum::{response::Response, Router, routing::{post, patch}}; use axum::{
response::Response,
routing::{patch, post},
Router,
};
use serde::Deserialize; use serde::Deserialize;
use crate::types::{extract::{AuthorizedUser, Json, Check, CheckResult}, post::Post, http::ResponseCode}; use crate::types::{
extract::{AuthorizedUser, Check, CheckResult, Json},
http::ResponseCode,
post::Post,
};
#[derive(Deserialize)] #[derive(Deserialize)]
struct PostCreateRequest { struct PostCreateRequest {
content: String content: String,
} }
impl Check for PostCreateRequest { impl Check for PostCreateRequest {
fn check(&self) -> CheckResult { fn check(&self) -> CheckResult {
Self::assert_length(&self.content, 1, 500, "Comments must be between 1-500 characters long")?; Self::assert_length(
&self.content,
1,
500,
"Comments must be between 1-500 characters long",
)?;
Ok(()) Ok(())
} }
} }
async fn create(AuthorizedUser(user): AuthorizedUser, Json(body): Json<PostCreateRequest>) -> Response { async fn create(
AuthorizedUser(user): AuthorizedUser,
Json(body): Json<PostCreateRequest>,
) -> Response {
let Ok(post) = Post::new(user.user_id, body.content) else { let Ok(post) = Post::new(user.user_id, body.content) else {
return ResponseCode::InternalServerError.text("Failed to create post") return ResponseCode::InternalServerError.text("Failed to create post")
}; };
@ -31,7 +45,7 @@ async fn create(AuthorizedUser(user): AuthorizedUser, Json(body): Json<PostCreat
#[derive(Deserialize)] #[derive(Deserialize)]
struct PostPageRequest { struct PostPageRequest {
page: u64 page: u64,
} }
impl Check for PostPageRequest { impl Check for PostPageRequest {
@ -40,8 +54,10 @@ impl Check for PostPageRequest {
} }
} }
async fn page(AuthorizedUser(_user): AuthorizedUser, Json(body): Json<PostPageRequest>) -> Response { async fn page(
AuthorizedUser(_user): AuthorizedUser,
Json(body): Json<PostPageRequest>,
) -> Response {
let Ok(posts) = Post::from_post_page(body.page) else { let Ok(posts) = Post::from_post_page(body.page) else {
return ResponseCode::InternalServerError.text("Failed to fetch posts") return ResponseCode::InternalServerError.text("Failed to fetch posts")
}; };
@ -55,7 +71,7 @@ async fn page(AuthorizedUser(_user): AuthorizedUser, Json(body): Json<PostPageRe
#[derive(Deserialize)] #[derive(Deserialize)]
struct UsersPostsRequest { struct UsersPostsRequest {
user_id: u64 user_id: u64,
} }
impl Check for UsersPostsRequest { impl Check for UsersPostsRequest {
@ -64,8 +80,10 @@ impl Check for UsersPostsRequest {
} }
} }
async fn user(AuthorizedUser(_user): AuthorizedUser, Json(body): Json<UsersPostsRequest>) -> Response { async fn user(
AuthorizedUser(_user): AuthorizedUser,
Json(body): Json<UsersPostsRequest>,
) -> Response {
let Ok(posts) = Post::from_user_id(body.user_id) else { let Ok(posts) = Post::from_user_id(body.user_id) else {
return ResponseCode::InternalServerError.text("Failed to fetch posts") return ResponseCode::InternalServerError.text("Failed to fetch posts")
}; };
@ -80,18 +98,25 @@ async fn user(AuthorizedUser(_user): AuthorizedUser, Json(body): Json<UsersPosts
#[derive(Deserialize)] #[derive(Deserialize)]
struct PostCommentRequest { struct PostCommentRequest {
content: String, content: String,
post_id: u64 post_id: u64,
} }
impl Check for PostCommentRequest { impl Check for PostCommentRequest {
fn check(&self) -> CheckResult { fn check(&self) -> CheckResult {
Self::assert_length(&self.content, 1, 255, "Comments must be between 1-255 characters long")?; Self::assert_length(
&self.content,
1,
255,
"Comments must be between 1-255 characters long",
)?;
Ok(()) Ok(())
} }
} }
async fn comment(AuthorizedUser(user): AuthorizedUser, Json(body): Json<PostCommentRequest>) -> Response { async fn comment(
AuthorizedUser(user): AuthorizedUser,
Json(body): Json<PostCommentRequest>,
) -> Response {
let Ok(mut post) = Post::from_post_id(body.post_id) else { let Ok(mut post) = Post::from_post_id(body.post_id) else {
return ResponseCode::InternalServerError.text("Failed to fetch posts") return ResponseCode::InternalServerError.text("Failed to fetch posts")
}; };
@ -106,7 +131,7 @@ async fn comment(AuthorizedUser(user): AuthorizedUser, Json(body): Json<PostComm
#[derive(Deserialize)] #[derive(Deserialize)]
struct PostLikeRequest { struct PostLikeRequest {
state: bool, state: bool,
post_id: u64 post_id: u64,
} }
impl Check for PostLikeRequest { impl Check for PostLikeRequest {
@ -116,7 +141,6 @@ impl Check for PostLikeRequest {
} }
async fn like(AuthorizedUser(user): AuthorizedUser, Json(body): Json<PostLikeRequest>) -> Response { async fn like(AuthorizedUser(user): AuthorizedUser, Json(body): Json<PostLikeRequest>) -> Response {
let Ok(mut post) = Post::from_post_id(body.post_id) else { let Ok(mut post) = Post::from_post_id(body.post_id) else {
return ResponseCode::InternalServerError.text("Failed to fetch posts") return ResponseCode::InternalServerError.text("Failed to fetch posts")
}; };

View file

@ -1,10 +1,14 @@
use axum::{Router, response::Response, routing::post}; use crate::types::{
extract::{AuthorizedUser, Check, CheckResult, Json},
http::ResponseCode,
user::User,
};
use axum::{response::Response, routing::post, Router};
use serde::Deserialize; use serde::Deserialize;
use crate::types::{extract::{AuthorizedUser, Json, Check, CheckResult}, http::ResponseCode, user::User};
#[derive(Deserialize)] #[derive(Deserialize)]
struct UserLoadRequest { struct UserLoadRequest {
ids: Vec<u64> ids: Vec<u64>,
} }
impl Check for UserLoadRequest { impl Check for UserLoadRequest {
@ -13,8 +17,10 @@ impl Check for UserLoadRequest {
} }
} }
async fn load_batch(AuthorizedUser(_user): AuthorizedUser, Json(body): Json<UserLoadRequest>) -> Response { async fn load_batch(
AuthorizedUser(_user): AuthorizedUser,
Json(body): Json<UserLoadRequest>,
) -> Response {
let users = User::from_user_ids(body.ids); let users = User::from_user_ids(body.ids);
let Ok(json) = serde_json::to_string(&users) else { let Ok(json) = serde_json::to_string(&users) else {
return ResponseCode::InternalServerError.text("Failed to fetch users") return ResponseCode::InternalServerError.text("Failed to fetch users")
@ -25,7 +31,7 @@ async fn load_batch(AuthorizedUser(_user): AuthorizedUser, Json(body): Json<User
#[derive(Deserialize)] #[derive(Deserialize)]
struct UserPageReqiest { struct UserPageReqiest {
page: u64 page: u64,
} }
impl Check for UserPageReqiest { impl Check for UserPageReqiest {
@ -34,8 +40,10 @@ impl Check for UserPageReqiest {
} }
} }
async fn load_page(AuthorizedUser(_user): AuthorizedUser, Json(body): Json<UserPageReqiest>) -> Response { async fn load_page(
AuthorizedUser(_user): AuthorizedUser,
Json(body): Json<UserPageReqiest>,
) -> Response {
let Ok(users) = User::from_user_page(body.page) else { let Ok(users) = User::from_user_page(body.page) else {
return ResponseCode::InternalServerError.text("Failed to fetch users") return ResponseCode::InternalServerError.text("Failed to fetch users")
}; };
@ -48,7 +56,6 @@ async fn load_page(AuthorizedUser(_user): AuthorizedUser, Json(body): Json<UserP
} }
async fn load_self(AuthorizedUser(user): AuthorizedUser) -> Response { async fn load_self(AuthorizedUser(user): AuthorizedUser) -> Response {
let Ok(json) = serde_json::to_string(&user) else { let Ok(json) = serde_json::to_string(&user) else {
return ResponseCode::InternalServerError.text("Failed to fetch user") return ResponseCode::InternalServerError.text("Failed to fetch user")
}; };

View file

@ -1,8 +1,11 @@
use std::{net::IpAddr, collections::VecDeque, io, }; use axum::{
use axum::{http::{Method, Uri}, response::Response}; http::{Method, Uri},
response::Response,
};
use lazy_static::lazy_static; use lazy_static::lazy_static;
use serde::Serialize; use serde::Serialize;
use serde_json::{ser::Formatter, Value}; use serde_json::{ser::Formatter, Value};
use std::{collections::VecDeque, io, net::IpAddr};
use tokio::sync::Mutex; use tokio::sync::Mutex;
use crate::types::http::ResponseCode; use crate::types::http::ResponseCode;
@ -12,7 +15,7 @@ struct LogMessage {
method: Method, method: Method,
uri: Uri, uri: Uri,
path: String, path: String,
body: String body: String,
} }
impl ToString for LogMessage { impl ToString for LogMessage {
@ -31,7 +34,7 @@ impl ToString for LogMessage {
Method::CONNECT => "#3fe0ad", Method::CONNECT => "#3fe0ad",
Method::TRACE => "#e03fc5", Method::TRACE => "#e03fc5",
Method::OPTIONS => "#423fe0", Method::OPTIONS => "#423fe0",
_ => "white" _ => "white",
}; };
format!("<div><span class='ip'>{}</span> <span class='method' style='color: {};'>{}</span> <span class='path'>{}{}</span> <span class='body'>{}</span></div>", ip, color, self.method, self.path, self.uri, self.body) format!("<div><span class='ip'>{}</span> <span class='method' style='color: {};'>{}</span> <span class='path'>{}{}</span> <span class='body'>{}</span></div>", ip, color, self.method, self.path, self.uri, self.body)
} }
@ -42,8 +45,9 @@ lazy_static! {
} }
pub async fn log(ip: IpAddr, method: Method, uri: Uri, path: Option<String>, body: Option<String>) { pub async fn log(ip: IpAddr, method: Method, uri: Uri, path: Option<String>, body: Option<String>) {
if uri.to_string().starts_with("/console") {
if uri.to_string().starts_with("/console") { return; } return;
}
let path = path.unwrap_or_default(); let path = path.unwrap_or_default();
let body = body.unwrap_or_default(); let body = body.unwrap_or_default();
@ -55,7 +59,7 @@ pub async fn log(ip: IpAddr, method: Method, uri: Uri, path: Option<String>, bod
method, method,
uri, uri,
path, path,
body: beautify(body) body: beautify(body),
}; };
let mut lock = LOG.lock().await; let mut lock = LOG.lock().await;
@ -67,11 +71,17 @@ pub async fn log(ip: IpAddr, method: Method, uri: Uri, path: Option<String>, bod
struct HtmlFormatter; struct HtmlFormatter;
impl Formatter for HtmlFormatter { impl Formatter for HtmlFormatter {
fn write_null<W>(&mut self, writer: &mut W) -> io::Result<()> where W: ?Sized + io::Write { fn write_null<W>(&mut self, writer: &mut W) -> io::Result<()>
where
W: ?Sized + io::Write,
{
writer.write_all(b"<span class='null'>null</span>") writer.write_all(b"<span class='null'>null</span>")
} }
fn write_bool<W>(&mut self, writer: &mut W, value: bool) -> io::Result<()> where W: ?Sized + io::Write { fn write_bool<W>(&mut self, writer: &mut W, value: bool) -> io::Result<()>
where
W: ?Sized + io::Write,
{
let s = if value { let s = if value {
b"<span class='bool'>true</span>" as &[u8] b"<span class='bool'>true</span>" as &[u8]
} else { } else {
@ -80,65 +90,104 @@ impl Formatter for HtmlFormatter {
writer.write_all(s) writer.write_all(s)
} }
fn write_i8<W>(&mut self, writer: &mut W, value: i8) -> io::Result<()> where W: ?Sized + io::Write { fn write_i8<W>(&mut self, writer: &mut W, value: i8) -> io::Result<()>
where
W: ?Sized + io::Write,
{
let buff = format!("<span class='number'>{value}</span>"); let buff = format!("<span class='number'>{value}</span>");
writer.write_all(buff.as_bytes()) writer.write_all(buff.as_bytes())
} }
fn write_i16<W>(&mut self, writer: &mut W, value: i16) -> io::Result<()> where W: ?Sized + io::Write { fn write_i16<W>(&mut self, writer: &mut W, value: i16) -> io::Result<()>
where
W: ?Sized + io::Write,
{
let buff = format!("<span class='number'>{value}</span>"); let buff = format!("<span class='number'>{value}</span>");
writer.write_all(buff.as_bytes()) writer.write_all(buff.as_bytes())
} }
fn write_i32<W>(&mut self, writer: &mut W, value: i32) -> io::Result<()> where W: ?Sized + io::Write { fn write_i32<W>(&mut self, writer: &mut W, value: i32) -> io::Result<()>
where
W: ?Sized + io::Write,
{
let buff = format!("<span class='number'>{value}</span>"); let buff = format!("<span class='number'>{value}</span>");
writer.write_all(buff.as_bytes()) writer.write_all(buff.as_bytes())
} }
fn write_i64<W>(&mut self, writer: &mut W, value: i64) -> io::Result<()> where W: ?Sized + io::Write { fn write_i64<W>(&mut self, writer: &mut W, value: i64) -> io::Result<()>
where
W: ?Sized + io::Write,
{
let buff = format!("<span class='number'>{value}</span>"); let buff = format!("<span class='number'>{value}</span>");
writer.write_all(buff.as_bytes()) writer.write_all(buff.as_bytes())
} }
fn write_u8<W>(&mut self, writer: &mut W, value: u8) -> io::Result<()> where W: ?Sized + io::Write { fn write_u8<W>(&mut self, writer: &mut W, value: u8) -> io::Result<()>
where
W: ?Sized + io::Write,
{
let buff = format!("<span class='number'>{value}</span>"); let buff = format!("<span class='number'>{value}</span>");
writer.write_all(buff.as_bytes()) writer.write_all(buff.as_bytes())
} }
fn write_u16<W>(&mut self, writer: &mut W, value: u16) -> io::Result<()> where W: ?Sized + io::Write { fn write_u16<W>(&mut self, writer: &mut W, value: u16) -> io::Result<()>
where
W: ?Sized + io::Write,
{
let buff = format!("<span class='number'>{value}</span>"); let buff = format!("<span class='number'>{value}</span>");
writer.write_all(buff.as_bytes()) writer.write_all(buff.as_bytes())
} }
fn write_u32<W>(&mut self, writer: &mut W, value: u32) -> io::Result<()> where W: ?Sized + io::Write { fn write_u32<W>(&mut self, writer: &mut W, value: u32) -> io::Result<()>
where
W: ?Sized + io::Write,
{
let buff = format!("<span class='number'>{value}</span>"); let buff = format!("<span class='number'>{value}</span>");
writer.write_all(buff.as_bytes()) writer.write_all(buff.as_bytes())
} }
fn write_u64<W>(&mut self, writer: &mut W, value: u64) -> io::Result<()> where W: ?Sized + io::Write { fn write_u64<W>(&mut self, writer: &mut W, value: u64) -> io::Result<()>
where
W: ?Sized + io::Write,
{
let buff = format!("<span class='number'>{value}</span>"); let buff = format!("<span class='number'>{value}</span>");
writer.write_all(buff.as_bytes()) writer.write_all(buff.as_bytes())
} }
fn write_f32<W>(&mut self, writer: &mut W, value: f32) -> io::Result<()> where W: ?Sized + io::Write { fn write_f32<W>(&mut self, writer: &mut W, value: f32) -> io::Result<()>
where
W: ?Sized + io::Write,
{
let buff = format!("<span class='number'>{value}</span>"); let buff = format!("<span class='number'>{value}</span>");
writer.write_all(buff.as_bytes()) writer.write_all(buff.as_bytes())
} }
fn write_f64<W>(&mut self, writer: &mut W, value: f64) -> io::Result<()> where W: ?Sized + io::Write { fn write_f64<W>(&mut self, writer: &mut W, value: f64) -> io::Result<()>
where
W: ?Sized + io::Write,
{
let buff = format!("<span class='number'>{value}</span>"); let buff = format!("<span class='number'>{value}</span>");
writer.write_all(buff.as_bytes()) writer.write_all(buff.as_bytes())
} }
fn begin_string<W>(&mut self, writer: &mut W) -> io::Result<()> where W: ?Sized + io::Write { fn begin_string<W>(&mut self, writer: &mut W) -> io::Result<()>
where
W: ?Sized + io::Write,
{
writer.write_all(b"<span class='string'>\"") writer.write_all(b"<span class='string'>\"")
} }
fn end_string<W>(&mut self, writer: &mut W) -> io::Result<()> where W: ?Sized + io::Write { fn end_string<W>(&mut self, writer: &mut W) -> io::Result<()>
where
W: ?Sized + io::Write,
{
writer.write_all(b"\"</span>") writer.write_all(b"\"</span>")
} }
fn begin_object_key<W>(&mut self, writer: &mut W, first: bool) -> io::Result<()> where W: ?Sized + io::Write { fn begin_object_key<W>(&mut self, writer: &mut W, first: bool) -> io::Result<()>
where
W: ?Sized + io::Write,
{
if first { if first {
writer.write_all(b"<span class='key'>") writer.write_all(b"<span class='key'>")
} else { } else {
@ -146,15 +195,17 @@ impl Formatter for HtmlFormatter {
} }
} }
fn end_object_key<W>(&mut self, writer: &mut W) -> io::Result<()> where W: ?Sized + io::Write { fn end_object_key<W>(&mut self, writer: &mut W) -> io::Result<()>
where
W: ?Sized + io::Write,
{
writer.write_all(b"</span>") writer.write_all(b"</span>")
} }
} }
fn beautify(body: String) -> String { fn beautify(body: String) -> String {
if body.is_empty() { if body.is_empty() {
return String::new() return String::new();
} }
let Ok(mut json) = serde_json::from_str::<Value>(&body) else { let Ok(mut json) = serde_json::from_str::<Value>(&body) else {
return body return body
@ -165,13 +216,12 @@ fn beautify(body: String) -> String {
let mut writer: Vec<u8> = Vec::with_capacity(128); let mut writer: Vec<u8> = Vec::with_capacity(128);
let mut serializer = serde_json::Serializer::with_formatter(&mut writer, HtmlFormatter); let mut serializer = serde_json::Serializer::with_formatter(&mut writer, HtmlFormatter);
if json.serialize(&mut serializer).is_err() { if json.serialize(&mut serializer).is_err() {
return body return body;
} }
String::from_utf8_lossy(&writer).to_string() String::from_utf8_lossy(&writer).to_string()
} }
pub async fn generate() -> Response { pub async fn generate() -> Response {
let lock = LOG.lock().await; let lock = LOG.lock().await;
let mut html = r#"<!DOCTYPE html> let mut html = r#"<!DOCTYPE html>
@ -183,7 +233,8 @@ pub async fn generate() -> Response {
<title>XSSBook - Console</title> <title>XSSBook - Console</title>
</head> </head>
<body> <body>
"#.to_string(); "#
.to_string();
for message in lock.iter() { for message in lock.iter() {
html.push_str(&message.to_string()); html.push_str(&message.to_string());

View file

@ -1,6 +1,6 @@
pub mod posts; pub mod posts;
pub mod users;
pub mod sessions; pub mod sessions;
pub mod users;
pub fn connect() -> Result<rusqlite::Connection, rusqlite::Error> { pub fn connect() -> Result<rusqlite::Connection, rusqlite::Error> {
rusqlite::Connection::open("xssbook.db") rusqlite::Connection::open("xssbook.db")

View file

@ -1,11 +1,11 @@
use std::collections::HashSet; use std::collections::HashSet;
use std::time::{SystemTime, UNIX_EPOCH, Duration}; use std::time::{Duration, SystemTime, UNIX_EPOCH};
use rusqlite::{OptionalExtension, Row}; use rusqlite::{OptionalExtension, Row};
use tracing::instrument; use tracing::instrument;
use crate::types::post::Post;
use crate::database; use crate::database;
use crate::types::post::Post;
pub fn init() -> Result<(), rusqlite::Error> { pub fn init() -> Result<(), rusqlite::Error> {
let sql = " let sql = "
@ -40,7 +40,14 @@ fn post_from_row(row: &Row) -> Result<Post, rusqlite::Error> {
return Err(rusqlite::Error::InvalidQuery) return Err(rusqlite::Error::InvalidQuery)
}; };
Ok(Post{post_id, user_id, content, likes, comments, date}) Ok(Post {
post_id,
user_id,
content,
likes,
comments,
date,
})
} }
#[instrument()] #[instrument()]
@ -48,10 +55,12 @@ pub fn get_post(post_id: u64) -> Result<Option<Post>, rusqlite::Error> {
tracing::trace!("Retrieving post"); tracing::trace!("Retrieving post");
let conn = database::connect()?; let conn = database::connect()?;
let mut stmt = conn.prepare("SELECT * FROM posts WHERE post_id = ?")?; let mut stmt = conn.prepare("SELECT * FROM posts WHERE post_id = ?")?;
let row = stmt.query_row([post_id], |row| { let row = stmt
.query_row([post_id], |row| {
let row = post_from_row(row)?; let row = post_from_row(row)?;
Ok(row) Ok(row)
}).optional()?; })
.optional()?;
Ok(row) Ok(row)
} }
@ -91,7 +100,13 @@ pub fn add_post(user_id: u64, content: &str) -> Result<Post, rusqlite::Error> {
let Ok(comments_json) = serde_json::to_string(&comments) else { let Ok(comments_json) = serde_json::to_string(&comments) else {
return Err(rusqlite::Error::InvalidQuery) return Err(rusqlite::Error::InvalidQuery)
}; };
let date = u64::try_from(SystemTime::now().duration_since(UNIX_EPOCH).unwrap_or(Duration::ZERO).as_millis()).unwrap_or(0); let date = u64::try_from(
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or(Duration::ZERO)
.as_millis(),
)
.unwrap_or(0);
let conn = database::connect()?; let conn = database::connect()?;
let mut stmt = conn.prepare("INSERT INTO posts (user_id, content, likes, comments, date) VALUES(?,?,?,?,?) RETURNING *;")?; let mut stmt = conn.prepare("INSERT INTO posts (user_id, content, likes, comments, date) VALUES(?,?,?,?,?) RETURNING *;")?;
let post = stmt.query_row((user_id, content, likes_json, comments_json, date), |row| { let post = stmt.query_row((user_id, content, likes_json, comments_json, date), |row| {
@ -102,7 +117,11 @@ pub fn add_post(user_id: u64, content: &str) -> Result<Post, rusqlite::Error> {
} }
#[instrument()] #[instrument()]
pub fn update_post(post_id: u64, likes: &HashSet<u64>, comments: &Vec<(u64, String)>) -> Result<(), rusqlite::Error> { pub fn update_post(
post_id: u64,
likes: &HashSet<u64>,
comments: &Vec<(u64, String)>,
) -> Result<(), rusqlite::Error> {
tracing::trace!("Updating post"); tracing::trace!("Updating post");
let Ok(likes_json) = serde_json::to_string(&likes) else { let Ok(likes_json) = serde_json::to_string(&likes) else {
return Err(rusqlite::Error::InvalidQuery) return Err(rusqlite::Error::InvalidQuery)

View file

@ -21,12 +21,14 @@ pub fn get_session(token: &str) -> Result<Option<Session>, rusqlite::Error> {
tracing::trace!("Retrieving session"); tracing::trace!("Retrieving session");
let conn = database::connect()?; let conn = database::connect()?;
let mut stmt = conn.prepare("SELECT * FROM sessions WHERE token = ?")?; let mut stmt = conn.prepare("SELECT * FROM sessions WHERE token = ?")?;
let row = stmt.query_row([token], |row| { let row = stmt
.query_row([token], |row| {
Ok(Session { Ok(Session {
user_id: row.get(0)?, user_id: row.get(0)?,
token: row.get(1)?, token: row.get(1)?,
}) })
}).optional()?; })
.optional()?;
Ok(row) Ok(row)
} }

View file

@ -1,8 +1,8 @@
use std::time::{SystemTime, UNIX_EPOCH, Duration};
use rusqlite::{OptionalExtension, Row}; use rusqlite::{OptionalExtension, Row};
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use tracing::instrument; use tracing::instrument;
use crate::{database, types::user::User, api::auth::RegistrationRequet}; use crate::{api::auth::RegistrationRequet, database, types::user::User};
pub fn init() -> Result<(), rusqlite::Error> { pub fn init() -> Result<(), rusqlite::Error> {
let sql = " let sql = "
@ -36,9 +36,24 @@ fn user_from_row(row: &Row, hide_password: bool) -> Result<User, rusqlite::Error
let month = row.get(8)?; let month = row.get(8)?;
let year = row.get(9)?; let year = row.get(9)?;
let password = if hide_password { String::new() } else { password }; let password = if hide_password {
String::new()
} else {
password
};
Ok(User{user_id, firstname, lastname, email, password, gender,date, day, month, year}) Ok(User {
user_id,
firstname,
lastname,
email,
password,
gender,
date,
day,
month,
year,
})
} }
#[instrument()] #[instrument()]
@ -46,34 +61,46 @@ pub fn get_user_by_id(user_id: u64, hide_password: bool) -> Result<Option<User>,
tracing::trace!("Retrieving user by id"); tracing::trace!("Retrieving user by id");
let conn = database::connect()?; let conn = database::connect()?;
let mut stmt = conn.prepare("SELECT * FROM users WHERE user_id = ?")?; let mut stmt = conn.prepare("SELECT * FROM users WHERE user_id = ?")?;
let row = stmt.query_row([user_id], |row| { let row = stmt
.query_row([user_id], |row| {
let row = user_from_row(row, hide_password)?; let row = user_from_row(row, hide_password)?;
Ok(row) Ok(row)
}).optional()?; })
.optional()?;
Ok(row) Ok(row)
} }
#[instrument()] #[instrument()]
pub fn get_user_by_email(email: &str, hide_password: bool) -> Result<Option<User>, rusqlite::Error> { pub fn get_user_by_email(
email: &str,
hide_password: bool,
) -> Result<Option<User>, rusqlite::Error> {
tracing::trace!("Retrieving user by email"); tracing::trace!("Retrieving user by email");
let conn = database::connect()?; let conn = database::connect()?;
let mut stmt = conn.prepare("SELECT * FROM users WHERE email = ?")?; let mut stmt = conn.prepare("SELECT * FROM users WHERE email = ?")?;
let row = stmt.query_row([email], |row| { let row = stmt
.query_row([email], |row| {
let row = user_from_row(row, hide_password)?; let row = user_from_row(row, hide_password)?;
Ok(row) Ok(row)
}).optional()?; })
.optional()?;
Ok(row) Ok(row)
} }
#[instrument()] #[instrument()]
pub fn get_user_by_password(password: &str, hide_password: bool) -> Result<Option<User>, rusqlite::Error> { pub fn get_user_by_password(
password: &str,
hide_password: bool,
) -> Result<Option<User>, rusqlite::Error> {
tracing::trace!("Retrieving user by password"); tracing::trace!("Retrieving user by password");
let conn = database::connect()?; let conn = database::connect()?;
let mut stmt = conn.prepare("SELECT * FROM users WHERE password = ?")?; let mut stmt = conn.prepare("SELECT * FROM users WHERE password = ?")?;
let row = stmt.query_row([password], |row| { let row = stmt
.query_row([password], |row| {
let row = user_from_row(row, hide_password)?; let row = user_from_row(row, hide_password)?;
Ok(row) Ok(row)
}).optional()?; })
.optional()?;
Ok(row) Ok(row)
} }
@ -93,13 +120,32 @@ pub fn get_user_page(page: u64, hide_password: bool) -> Result<Vec<User>, rusqli
#[instrument()] #[instrument()]
pub fn add_user(request: RegistrationRequet) -> Result<User, rusqlite::Error> { pub fn add_user(request: RegistrationRequet) -> Result<User, rusqlite::Error> {
tracing::trace!("Adding new user"); tracing::trace!("Adding new user");
let date = u64::try_from(SystemTime::now().duration_since(UNIX_EPOCH).unwrap_or(Duration::ZERO).as_millis()).unwrap_or(0); let date = u64::try_from(
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or(Duration::ZERO)
.as_millis(),
)
.unwrap_or(0);
let conn = database::connect()?; let conn = database::connect()?;
let mut stmt = conn.prepare("INSERT INTO users (firstname, lastname, email, password, gender, date, day, month, year) VALUES(?,?,?,?,?,?,?,?,?) RETURNING *;")?; let mut stmt = conn.prepare("INSERT INTO users (firstname, lastname, email, password, gender, date, day, month, year) VALUES(?,?,?,?,?,?,?,?,?) RETURNING *;")?;
let user = stmt.query_row((request.firstname, request.lastname, request.email, request.password, request.gender, date, request.day, request.month, request.year), |row| { let user = stmt.query_row(
(
request.firstname,
request.lastname,
request.email,
request.password,
request.gender,
date,
request.day,
request.month,
request.year,
),
|row| {
let row = user_from_row(row, false)?; let row = user_from_row(row, false)?;
Ok(row) Ok(row)
})?; },
)?;
Ok(user) Ok(user)
} }

View file

@ -1,18 +1,31 @@
use axum::{
body::HttpBody,
extract::ConnectInfo,
http::{Request, StatusCode},
middleware::{self, Next},
response::Response,
Extension, RequestExt, Router,
};
use std::{net::SocketAddr, process::exit}; use std::{net::SocketAddr, process::exit};
use axum::{Router, response::Response, http::{Request, StatusCode}, middleware::{Next, self}, extract::ConnectInfo, RequestExt, body::HttpBody, Extension};
use tower_cookies::CookieManagerLayer; use tower_cookies::CookieManagerLayer;
use tracing::{metadata::LevelFilter, error, info}; use tracing::{error, info, metadata::LevelFilter};
use tracing_subscriber::{prelude::__tracing_subscriber_SubscriberExt, util::SubscriberInitExt, Layer, filter::filter_fn}; use tracing_subscriber::{
filter::filter_fn, prelude::__tracing_subscriber_SubscriberExt, util::SubscriberInitExt, Layer,
};
use types::http::ResponseCode; use types::http::ResponseCode;
use crate::{api::{pages, auth, users, posts}, types::extract::RouterURI}; use crate::{
api::{auth, pages, posts, users},
types::extract::RouterURI,
};
mod api; mod api;
mod console;
mod database; mod database;
mod types; mod types;
mod console;
async fn serve<B>(req: Request<B>, next: Next<B>) -> Response where async fn serve<B>(req: Request<B>, next: Next<B>) -> Response
where
B: Send + Sync + 'static + HttpBody, B: Send + Sync + 'static + HttpBody,
{ {
let uri = req.uri(); let uri = req.uri();
@ -23,15 +36,22 @@ async fn serve<B>(req: Request<B>, next: Next<B>) -> Response where
file file
} }
async fn log<B>(mut req: Request<B>, next: Next<B>) -> Response where async fn log<B>(mut req: Request<B>, next: Next<B>) -> Response
where
B: Send + Sync + 'static + HttpBody, B: Send + Sync + 'static + HttpBody,
{ {
let Ok(ConnectInfo(info)) = req.extract_parts::<ConnectInfo<SocketAddr>>().await else { let Ok(ConnectInfo(info)) = req.extract_parts::<ConnectInfo<SocketAddr>>().await else {
return next.run(req).await return next.run(req).await
}; };
console::log(info.ip(), req.method().clone(), req.uri().clone(), None, None).await; console::log(
info.ip(),
req.method().clone(),
req.uri().clone(),
None,
None,
)
.await;
next.run(req).await next.run(req).await
} }
@ -42,13 +62,14 @@ async fn not_found() -> Response {
#[tokio::main] #[tokio::main]
async fn main() { async fn main() {
let fmt_layer = tracing_subscriber::fmt::layer(); let fmt_layer = tracing_subscriber::fmt::layer();
tracing_subscriber::registry() tracing_subscriber::registry()
.with( .with(
fmt_layer.with_filter(LevelFilter::TRACE).with_filter(filter_fn(|metadata| { fmt_layer
.with_filter(LevelFilter::TRACE)
.with_filter(filter_fn(|metadata| {
metadata.target().starts_with("xssbook") metadata.target().starts_with("xssbook")
})) })),
) )
.init(); .init();
@ -62,13 +83,19 @@ async fn main() {
.nest("/", pages::router()) .nest("/", pages::router())
.layer(middleware::from_fn(log)) .layer(middleware::from_fn(log))
.layer(middleware::from_fn(serve)) .layer(middleware::from_fn(serve))
.nest("/api/auth", auth::router() .nest(
.layer(Extension(RouterURI("/api/auth"))) "/api/auth",
).nest("/api/users", users::router() auth::router().layer(Extension(RouterURI("/api/auth"))),
.layer(Extension(RouterURI("/api/users"))) )
).nest("/api/posts", posts::router() .nest(
.layer(Extension(RouterURI("/api/posts"))) "/api/users",
).layer(CookieManagerLayer::new()); users::router().layer(Extension(RouterURI("/api/users"))),
)
.nest(
"/api/posts",
posts::router().layer(Extension(RouterURI("/api/posts"))),
)
.layer(CookieManagerLayer::new());
let Ok(addr) = "[::]:8080".parse::<std::net::SocketAddr>() else { let Ok(addr) = "[::]:8080".parse::<std::net::SocketAddr>() else {
error!("Failed to parse port binding"); error!("Failed to parse port binding");
@ -81,5 +108,4 @@ async fn main() {
.serve(app.into_make_service_with_connect_info::<SocketAddr>()) .serve(app.into_make_service_with_connect_info::<SocketAddr>())
.await .await
.unwrap_or(()); .unwrap_or(());
} }

View file

@ -1,19 +1,36 @@
use std::{io::Read, net::SocketAddr}; use std::{io::Read, net::SocketAddr};
use axum::{extract::{FromRequestParts, FromRequest, ConnectInfo}, async_trait, response::Response, http::{request::Parts, Request}, TypedHeader, headers::Cookie, body::HttpBody, BoxError, RequestExt}; use axum::{
async_trait,
body::HttpBody,
extract::{ConnectInfo, FromRequest, FromRequestParts},
headers::Cookie,
http::{request::Parts, Request},
response::Response,
BoxError, RequestExt, TypedHeader,
};
use bytes::Bytes; use bytes::Bytes;
use serde::de::DeserializeOwned; use serde::de::DeserializeOwned;
use crate::{types::{user::User, http::{ResponseCode, Result}, session::Session}, console}; use crate::{
console,
types::{
http::{ResponseCode, Result},
session::Session,
user::User,
},
};
pub struct AuthorizedUser(pub User); pub struct AuthorizedUser(pub User);
#[async_trait] #[async_trait]
impl<S> FromRequestParts<S> for AuthorizedUser where S: Send + Sync { impl<S> FromRequestParts<S> for AuthorizedUser
where
S: Send + Sync,
{
type Rejection = Response; type Rejection = Response;
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self> { async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self> {
let Ok(Some(cookies)) = Option::<TypedHeader<Cookie>>::from_request_parts(parts, state).await else { let Ok(Some(cookies)) = Option::<TypedHeader<Cookie>>::from_request_parts(parts, state).await else {
return Err(ResponseCode::Forbidden.text("No cookies provided")) return Err(ResponseCode::Forbidden.text("No cookies provided"))
}; };
@ -37,7 +54,8 @@ impl<S> FromRequestParts<S> for AuthorizedUser where S: Send + Sync {
pub struct Log; pub struct Log;
#[async_trait] #[async_trait]
impl<S, B> FromRequest<S, B> for Log where impl<S, B> FromRequest<S, B> for Log
where
B: HttpBody + Sync + Send + 'static, B: HttpBody + Sync + Send + 'static,
B::Data: Send, B::Data: Send,
B::Error: Into<BoxError>, B::Error: Into<BoxError>,
@ -46,12 +64,14 @@ impl<S, B> FromRequest<S, B> for Log where
type Rejection = Response; type Rejection = Response;
async fn from_request(mut req: Request<B>, state: &S) -> Result<Self> { async fn from_request(mut req: Request<B>, state: &S) -> Result<Self> {
let Ok(ConnectInfo(info)) = req.extract_parts::<ConnectInfo<SocketAddr>>().await else { let Ok(ConnectInfo(info)) = req.extract_parts::<ConnectInfo<SocketAddr>>().await else {
return Ok(Self) return Ok(Self)
}; };
let method = req.method().clone(); let method = req.method().clone();
let path = req.extensions().get::<RouterURI>().map_or("", |path| path.0); let path = req
.extensions()
.get::<RouterURI>()
.map_or("", |path| path.0);
let uri = req.uri().clone(); let uri = req.uri().clone();
let Ok(bytes) = Bytes::from_request(req, state).await else { let Ok(bytes) = Bytes::from_request(req, state).await else {
@ -64,7 +84,14 @@ impl<S, B> FromRequest<S, B> for Log where
return Ok(Self) return Ok(Self)
}; };
console::log(info.ip(), method.clone(), uri.clone(), Some(path.to_string()), Some(body.to_string())).await; console::log(
info.ip(),
method.clone(),
uri.clone(),
Some(path.to_string()),
Some(body.to_string()),
)
.await;
Ok(Self) Ok(Self)
} }
@ -73,7 +100,8 @@ impl<S, B> FromRequest<S, B> for Log where
pub struct Json<T>(pub T); pub struct Json<T>(pub T);
#[async_trait] #[async_trait]
impl<T, S, B> FromRequest<S, B> for Json<T> where impl<T, S, B> FromRequest<S, B> for Json<T>
where
T: DeserializeOwned + Check, T: DeserializeOwned + Check,
B: HttpBody + Sync + Send + 'static, B: HttpBody + Sync + Send + 'static,
B::Data: Send, B::Data: Send,
@ -83,13 +111,15 @@ impl<T, S, B> FromRequest<S, B> for Json<T> where
type Rejection = Response; type Rejection = Response;
async fn from_request(mut req: Request<B>, state: &S) -> Result<Self> { async fn from_request(mut req: Request<B>, state: &S) -> Result<Self> {
let Ok(ConnectInfo(info)) = req.extract_parts::<ConnectInfo<SocketAddr>>().await else { let Ok(ConnectInfo(info)) = req.extract_parts::<ConnectInfo<SocketAddr>>().await else {
tracing::error!("Failed to read connection info"); tracing::error!("Failed to read connection info");
return Err(ResponseCode::InternalServerError.text("Failed to read connection info")); return Err(ResponseCode::InternalServerError.text("Failed to read connection info"));
}; };
let method = req.method().clone(); let method = req.method().clone();
let path = req.extensions().get::<RouterURI>().map_or("", |path| path.0); let path = req
.extensions()
.get::<RouterURI>()
.map_or("", |path| path.0);
let uri = req.uri().clone(); let uri = req.uri().clone();
let Ok(bytes) = Bytes::from_request(req, state).await else { let Ok(bytes) = Bytes::from_request(req, state).await else {
@ -101,7 +131,14 @@ impl<T, S, B> FromRequest<S, B> for Json<T> where
return Err(ResponseCode::BadRequest.text("Invalid utf8 body")) return Err(ResponseCode::BadRequest.text("Invalid utf8 body"))
}; };
console::log(info.ip(), method.clone(), uri.clone(), Some(path.to_string()), Some(body.to_string())).await; console::log(
info.ip(),
method.clone(),
uri.clone(),
Some(path.to_string()),
Some(body.to_string()),
)
.await;
let Ok(value) = serde_json::from_str::<T>(&body) else { let Ok(value) = serde_json::from_str::<T>(&body) else {
return Err(ResponseCode::BadRequest.text("Invalid request body")) return Err(ResponseCode::BadRequest.text("Invalid request body"))
@ -118,19 +155,18 @@ impl<T, S, B> FromRequest<S, B> for Json<T> where
pub type CheckResult = std::result::Result<(), String>; pub type CheckResult = std::result::Result<(), String>;
pub trait Check { pub trait Check {
fn check(&self) -> CheckResult; fn check(&self) -> CheckResult;
fn assert_length(string: &str, min: usize, max: usize, message: &str) -> CheckResult { fn assert_length(string: &str, min: usize, max: usize, message: &str) -> CheckResult {
if string.len() < min || string.len() > max { if string.len() < min || string.len() > max {
return Err(message.to_string()) return Err(message.to_string());
} }
Ok(()) Ok(())
} }
fn assert_range(number: u64, min: u64, max: u64, message: &str) -> CheckResult { fn assert_range(number: u64, min: u64, max: u64, message: &str) -> CheckResult {
if number < min || number > max { if number < min || number > max {
return Err(message.to_string()) return Err(message.to_string());
} }
Ok(()) Ok(())
} }
@ -138,4 +174,3 @@ pub trait Check {
#[derive(Clone)] #[derive(Clone)]
pub struct RouterURI(pub &'static str); pub struct RouterURI(pub &'static str);

View file

@ -1,4 +1,9 @@
use axum::{response::{IntoResponse, Response}, http::{StatusCode, Request, HeaderValue}, body::Body, headers::HeaderName}; use axum::{
body::Body,
headers::HeaderName,
http::{HeaderValue, Request, StatusCode},
response::{IntoResponse, Response},
};
use tower::ServiceExt; use tower::ServiceExt;
use tower_http::services::ServeFile; use tower_http::services::ServeFile;
use tracing::instrument; use tracing::instrument;
@ -12,11 +17,10 @@ pub enum ResponseCode {
Forbidden, Forbidden,
NotFound, NotFound,
ImATeapot, ImATeapot,
InternalServerError InternalServerError,
} }
impl ResponseCode { impl ResponseCode {
const fn code(self) -> StatusCode { const fn code(self) -> StatusCode {
match self { match self {
Self::Success => StatusCode::OK, Self::Success => StatusCode::OK,
@ -26,7 +30,7 @@ impl ResponseCode {
Self::Forbidden => StatusCode::FORBIDDEN, Self::Forbidden => StatusCode::FORBIDDEN,
Self::NotFound => StatusCode::NOT_FOUND, Self::NotFound => StatusCode::NOT_FOUND,
Self::ImATeapot => StatusCode::IM_A_TEAPOT, Self::ImATeapot => StatusCode::IM_A_TEAPOT,
Self::InternalServerError => StatusCode::INTERNAL_SERVER_ERROR Self::InternalServerError => StatusCode::INTERNAL_SERVER_ERROR,
} }
} }
@ -39,7 +43,8 @@ impl ResponseCode {
pub fn json(self, json: &str) -> Response { pub fn json(self, json: &str) -> Response {
let mut res = (self.code(), json.to_owned()).into_response(); let mut res = (self.code(), json.to_owned()).into_response();
res.headers_mut().insert( res.headers_mut().insert(
HeaderName::from_static("content-type"), HeaderValue::from_static("application/json"), HeaderName::from_static("content-type"),
HeaderValue::from_static("application/json"),
); );
res res
} }
@ -48,14 +53,15 @@ impl ResponseCode {
pub fn html(self, json: &str) -> Response { pub fn html(self, json: &str) -> Response {
let mut res = (self.code(), json.to_owned()).into_response(); let mut res = (self.code(), json.to_owned()).into_response();
res.headers_mut().insert( res.headers_mut().insert(
HeaderName::from_static("content-type"), HeaderValue::from_static("text/html"), HeaderName::from_static("content-type"),
HeaderValue::from_static("text/html"),
); );
res res
} }
#[instrument()] #[instrument()]
pub async fn file(self, path: &str) -> Response { pub async fn file(self, path: &str) -> Response {
if !path.chars().any(|c| c == '.' ) { if !path.chars().any(|c| c == '.') {
return Self::BadRequest.text("Folders cannot be served"); return Self::BadRequest.text("Folders cannot be served");
} }
let path = format!("public{path}"); let path = format!("public{path}");

View file

@ -1,5 +1,5 @@
pub mod user;
pub mod post;
pub mod session;
pub mod extract; pub mod extract;
pub mod http; pub mod http;
pub mod post;
pub mod session;
pub mod user;

View file

@ -1,10 +1,10 @@
use core::fmt; use core::fmt;
use std::collections::HashSet;
use serde::Serialize; use serde::Serialize;
use std::collections::HashSet;
use tracing::instrument; use tracing::instrument;
use crate::database; use crate::database;
use crate::types::http::{Result, ResponseCode}; use crate::types::http::{ResponseCode, Result};
#[derive(Serialize)] #[derive(Serialize)]
pub struct Post { pub struct Post {
@ -13,7 +13,7 @@ pub struct Post {
pub content: String, pub content: String,
pub likes: HashSet<u64>, pub likes: HashSet<u64>,
pub comments: Vec<(u64, String)>, pub comments: Vec<(u64, String)>,
pub date: u64 pub date: u64,
} }
impl fmt::Debug for Post { impl fmt::Debug for Post {
@ -25,7 +25,6 @@ impl fmt::Debug for Post {
} }
impl Post { impl Post {
#[instrument()] #[instrument()]
pub fn from_post_id(post_id: u64) -> Result<Self> { pub fn from_post_id(post_id: u64) -> Result<Self> {
let Ok(Some(post)) = database::posts::get_post(post_id) else { let Ok(Some(post)) = database::posts::get_post(post_id) else {
@ -67,7 +66,7 @@ impl Post {
if database::posts::update_post(self.post_id, &self.likes, &self.comments).is_err() { if database::posts::update_post(self.post_id, &self.likes, &self.comments).is_err() {
tracing::error!("Failed to comment on post"); tracing::error!("Failed to comment on post");
return Err(ResponseCode::InternalServerError.text("Failed to comment on post")) return Err(ResponseCode::InternalServerError.text("Failed to comment on post"));
} }
Ok(()) Ok(())
@ -75,7 +74,6 @@ impl Post {
#[instrument()] #[instrument()]
pub fn like(&mut self, user_id: u64, state: bool) -> Result<()> { pub fn like(&mut self, user_id: u64, state: bool) -> Result<()> {
if state { if state {
self.likes.insert(user_id); self.likes.insert(user_id);
} else { } else {
@ -84,10 +82,11 @@ impl Post {
if database::posts::update_post(self.post_id, &self.likes, &self.comments).is_err() { if database::posts::update_post(self.post_id, &self.likes, &self.comments).is_err() {
tracing::error!("Failed to change like state on post"); tracing::error!("Failed to change like state on post");
return Err(ResponseCode::InternalServerError.text("Failed to change like state on post")) return Err(
ResponseCode::InternalServerError.text("Failed to change like state on post")
);
} }
Ok(()) Ok(())
} }
} }

View file

@ -3,16 +3,15 @@ use serde::Serialize;
use tracing::instrument; use tracing::instrument;
use crate::database; use crate::database;
use crate::types::http::{Result, ResponseCode}; use crate::types::http::{ResponseCode, Result};
#[derive(Serialize)] #[derive(Serialize)]
pub struct Session { pub struct Session {
pub user_id: u64, pub user_id: u64,
pub token: String pub token: String,
} }
impl Session { impl Session {
#[instrument()] #[instrument()]
pub fn from_token(token: &str) -> Result<Self> { pub fn from_token(token: &str) -> Result<Self> {
let Ok(Some(session)) = database::sessions::get_session(token) else { let Ok(Some(session)) = database::sessions::get_session(token) else {
@ -24,10 +23,14 @@ impl Session {
#[instrument()] #[instrument()]
pub fn new(user_id: u64) -> Result<Self> { pub fn new(user_id: u64) -> Result<Self> {
let token: String = rand::thread_rng().sample_iter(&Alphanumeric).take(32).map(char::from).collect(); let token: String = rand::thread_rng()
.sample_iter(&Alphanumeric)
.take(32)
.map(char::from)
.collect();
match database::sessions::set_session(user_id, &token) { match database::sessions::set_session(user_id, &token) {
Err(_) => Err(ResponseCode::BadRequest.text("Failed to create session")), Err(_) => Err(ResponseCode::BadRequest.text("Failed to create session")),
Ok(_) => Ok(Self {user_id, token}) Ok(_) => Ok(Self { user_id, token }),
} }
} }
@ -39,5 +42,4 @@ impl Session {
}; };
Ok(()) Ok(())
} }
} }

View file

@ -1,10 +1,9 @@
use serde::{Serialize, Deserialize}; use serde::{Deserialize, Serialize};
use tracing::instrument; use tracing::instrument;
use crate::api::auth::RegistrationRequet; use crate::api::auth::RegistrationRequet;
use crate::database; use crate::database;
use crate::types::http::{Result, ResponseCode}; use crate::types::http::{ResponseCode, Result};
#[derive(Serialize, Deserialize, Debug)] #[derive(Serialize, Deserialize, Debug)]
pub struct User { pub struct User {
@ -21,7 +20,6 @@ pub struct User {
} }
impl User { impl User {
#[instrument()] #[instrument()]
pub fn from_user_id(user_id: u64, hide_password: bool) -> Result<Self> { pub fn from_user_id(user_id: u64, hide_password: bool) -> Result<Self> {
let Ok(Some(user)) = database::users::get_user_by_id(user_id, hide_password) else { let Ok(Some(user)) = database::users::get_user_by_id(user_id, hide_password) else {
@ -33,12 +31,15 @@ impl User {
#[instrument()] #[instrument()]
pub fn from_user_ids(user_ids: Vec<u64>) -> Vec<Self> { pub fn from_user_ids(user_ids: Vec<u64>) -> Vec<Self> {
user_ids.iter().filter_map(|user_id| { user_ids
.iter()
.filter_map(|user_id| {
let Ok(Some(user)) = database::users::get_user_by_id(*user_id, true) else { let Ok(Some(user)) = database::users::get_user_by_id(*user_id, true) else {
return None; return None;
}; };
Some(user) Some(user)
}).collect() })
.collect()
} }
#[instrument()] #[instrument()]
@ -70,11 +71,13 @@ impl User {
#[instrument()] #[instrument()]
pub fn new(request: RegistrationRequet) -> Result<Self> { pub fn new(request: RegistrationRequet) -> Result<Self> {
if Self::from_email(&request.email).is_ok() { if Self::from_email(&request.email).is_ok() {
return Err(ResponseCode::BadRequest.text(&format!("Email is already in use by {}", &request.email))) return Err(ResponseCode::BadRequest
.text(&format!("Email is already in use by {}", &request.email)));
} }
if let Ok(user) = Self::from_password(&request.password) { if let Ok(user) = Self::from_password(&request.password) {
return Err(ResponseCode::BadRequest.text(&format!("Password is already in use by {}", user.email))) return Err(ResponseCode::BadRequest
.text(&format!("Password is already in use by {}", user.email)));
} }
let Ok(user) = database::users::add_user(request) else { let Ok(user) = database::users::add_user(request) else {
@ -84,5 +87,4 @@ impl User {
Ok(user) Ok(user)
} }
} }