From a8b6798dfe4939dc8c36cec6b36a4261477fb087 Mon Sep 17 00:00:00 2001 From: Tyler Murphy Date: Wed, 15 Feb 2023 00:47:55 -0500 Subject: [PATCH] fix root db call --- src/api/mod.rs | 25 +++---------------------- src/public/pages.rs | 5 +++-- src/types/extract.rs | 16 ++++++++++++++-- 3 files changed, 20 insertions(+), 26 deletions(-) diff --git a/src/api/mod.rs b/src/api/mod.rs index eeaaa0a..8b631c8 100644 --- a/src/api/mod.rs +++ b/src/api/mod.rs @@ -1,15 +1,8 @@ -use crate::{ - database, - types::extract::{DatabaseExtention, RouterURI}, -}; +use crate::types::extract::{RouterURI, self}; use axum::{ error_handling::HandleErrorLayer, - http::Request, - middleware::{self, Next}, - response::Response, - BoxError, Extension, Router, + BoxError, Extension, Router, middleware, }; -use tokio::sync::Mutex; use tower::ServiceBuilder; use tower_governor::{ errors::display_error, governor::GovernorConfigBuilder, key_extractor::SmartIpKeyExtractor, @@ -23,18 +16,6 @@ pub mod users; pub use auth::RegistrationRequet; -async fn connect(mut req: Request, next: Next) -> Response -where - B: Send, -{ - if let Ok(db) = database::Database::connect() { - let ex = DatabaseExtention(Mutex::new(db)); - req.extensions_mut().insert(ex); - } - - next.run(req).await -} - pub fn router() -> Router { let governor_conf = Box::new( GovernorConfigBuilder::default() @@ -71,5 +52,5 @@ pub fn router() -> Router { config: Box::leak(governor_conf), }), ) - .layer(middleware::from_fn(connect)) + .layer(middleware::from_fn(extract::connect)) } diff --git a/src/public/pages.rs b/src/public/pages.rs index 426727e..a7789b2 100644 --- a/src/public/pages.rs +++ b/src/public/pages.rs @@ -1,13 +1,13 @@ use axum::{ response::{IntoResponse, Redirect, Response}, routing::get, - Router, + Router, middleware, }; use crate::{ public::console, types::{ - extract::{AuthorizedUser, Log, UserAgent}, + extract::{AuthorizedUser, Log, UserAgent, self}, http::ResponseCode, }, }; @@ -70,6 +70,7 @@ pub fn router() -> Router { Router::new() .route("/", get(root)) .route("/login", get(login)) + .layer(middleware::from_fn(extract::connect)) .route("/home", get(home)) .route("/people", get(people)) .route("/profile", get(profile)) diff --git a/src/types/extract.rs b/src/types/extract.rs index f05215f..a76eac4 100644 --- a/src/types/extract.rs +++ b/src/types/extract.rs @@ -9,7 +9,7 @@ use axum::{ extract::{ConnectInfo, FromRequest, FromRequestParts}, http::{header::USER_AGENT, request::Parts, Request}, response::Response, - BoxError, RequestExt, + BoxError, RequestExt, middleware::Next, }; use bytes::Bytes; use image::{io::Reader, DynamicImage, ImageFormat}; @@ -100,7 +100,7 @@ where }; let Some(db) = parts.extensions.get::() else { - return Err(ResponseCode::Forbidden.text("Could not connect to database")) + return Err(ResponseCode::InternalServerError.text("Could not connect to database")) }; let db = db.0.lock().await; @@ -288,6 +288,18 @@ where } } +pub async fn connect(mut req: Request, next: Next) -> Response +where + B: Send, +{ + if let Ok(db) = database::Database::connect() { + let ex = DatabaseExtention(Mutex::new(db)); + req.extensions_mut().insert(ex); + } + + next.run(req).await +} + async fn read_body(mut req: Request, state: &S) -> Result> where B: HttpBody + Sync + Send + 'static,