summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authortylerm <tylerm@tylerm.dev>2023-08-22 04:16:31 +0000
committertylerm <tylerm@tylerm.dev>2023-08-22 04:16:31 +0000
commitedbbdf72c78536c48357a86181bbf6897fc52074 (patch)
tree91d91e9dfb77ae3b7d75f4348c01bba59d0f13dc /src
parentallow port env (diff)
parentfinish dms (diff)
downloadxssbook-edbbdf72c78536c48357a86181bbf6897fc52074.tar.gz
xssbook-edbbdf72c78536c48357a86181bbf6897fc52074.tar.bz2
xssbook-edbbdf72c78536c48357a86181bbf6897fc52074.zip
Merge pull request 'dms are cool' (#1) from dev into main
Reviewed-on: https://g.tylerm.dev/tylerm/xssbook/pulls/1
Diffstat (limited to '')
-rw-r--r--src/api/chat.rs512
-rw-r--r--src/api/mod.rs33
-rw-r--r--src/database/chat.rs211
-rw-r--r--src/database/mod.rs2
-rw-r--r--src/main.rs2
-rw-r--r--src/public/docs.rs14
-rw-r--r--src/public/mod.rs29
-rw-r--r--src/public/pages.rs5
-rw-r--r--src/types/chat.rs129
-rw-r--r--src/types/mod.rs1
10 files changed, 883 insertions, 55 deletions
diff --git a/src/api/chat.rs b/src/api/chat.rs
new file mode 100644
index 0000000..02bdfbd
--- /dev/null
+++ b/src/api/chat.rs
@@ -0,0 +1,512 @@
+use std::collections::HashMap;
+
+use axum::{response::Response, Router, routing::{post, patch, delete, get}, extract::{ws::Message, WebSocketUpgrade}};
+use serde::Deserialize;
+use tokio::sync::{Mutex, mpsc::{Sender, self}};
+use crate::{
+ public::docs::{EndpointDocumentation, EndpointMethod},
+ types::{
+ extract::{AuthorizedUser, Check, CheckResult, Database, Json, Log},
+ http::ResponseCode,
+ chat::{ChatRoom, ChatEvent}, user::User,
+ },
+};
+use std::collections::hash_map::Values;
+use lazy_static::lazy_static;
+
+lazy_static!(
+ static ref CONNECTIONS: Mutex<HashMap<u64, ConnectionPool>> = Mutex::new(HashMap::new());
+);
+
+struct ConnectionPool {
+ inner: HashMap<usize, Sender<ChatEvent>>,
+ index: usize
+}
+
+impl ConnectionPool {
+ fn new() -> Self {
+ Self {
+ inner: HashMap::new(),
+ index: 0
+ }
+ }
+
+ fn add(&mut self, send: Sender<ChatEvent>) -> usize {
+ let idx = self.index;
+ self.index += 1;
+ self.inner.insert(idx, send);
+ idx
+ }
+
+ fn del(&mut self, idx: &usize) {
+ self.inner.remove(idx);
+ }
+
+ fn values(&self) -> Values<'_, usize, Sender<ChatEvent>> {
+ self.inner.values()
+ }
+}
+
+async fn send_event(event: ChatEvent, room: &ChatRoom) {
+ for user in &room.users {
+ let lock = CONNECTIONS.lock().await;
+ let Some(connection) = lock.get(&user) else {
+ continue
+ };
+ for channel in connection.values() {
+ channel.send(event.clone()).await.ok();
+ }
+ }
+}
+
+pub const CHAT_LIST: EndpointDocumentation = EndpointDocumentation {
+ uri: "/api/chat/list",
+ method: EndpointMethod::Post,
+ description: "Returns the rooms you are in",
+ body: None,
+ responses: &[
+ (201, "Returns rooms in a list"),
+ (400, "Body does not match parameters"),
+ (401, "Unauthorized"),
+ (500, "Failed to retrieve rooms"),
+ ],
+ cookie: Some("auth"),
+};
+
+async fn list (
+ AuthorizedUser(user): AuthorizedUser,
+ Database(db): Database,
+ _: Log
+) -> Response {
+ let Ok(rooms) = ChatRoom::from_user_id(&db, user.user_id) else {
+ return ResponseCode::InternalServerError.text("Failed to retrieve rooms")
+ };
+
+ let Ok(json) = serde_json::to_string(&rooms) else {
+ return ResponseCode::InternalServerError.text("Failed to retrieve rooms")
+ };
+
+ ResponseCode::Success.json(&json)
+}
+
+pub const CHAT_CREATE: EndpointDocumentation = EndpointDocumentation {
+ uri: "/api/chat/create",
+ method: EndpointMethod::Post,
+ description: "Creates a new room",
+ body: Some(
+ r#"
+ {
+ "name" : "Funny memes"
+ }
+ "#,
+ ),
+ responses: &[
+ (201, "Successfully created room"),
+ (400, "Body does not match parameters"),
+ (401, "Unauthorized"),
+ (500, "Failed to create room"),
+ ],
+ cookie: Some("auth"),
+};
+
+#[derive(Deserialize)]
+struct RoomCreateRequest {
+ name: String,
+}
+
+impl Check for RoomCreateRequest {
+ fn check(&self) -> CheckResult {
+ Self::assert_length(
+ &self.name,
+ 1,
+ 255,
+ "Room names must be between 1-255 characters long",
+ )?;
+ Ok(())
+ }
+}
+
+async fn create (
+ AuthorizedUser(user): AuthorizedUser,
+ Database(db): Database,
+ Json(body): Json<RoomCreateRequest>,
+) -> Response {
+ let Ok(room) = ChatRoom::new(&db, vec![user.user_id], body.name) else {
+ return ResponseCode::InternalServerError.text("Failed to create room")
+ };
+
+ for user in &room.users {
+ send_event(ChatEvent::Add {
+ user_id: *user,
+ room: room.clone()
+ }, &room).await;
+ }
+
+ let Ok(json) = serde_json::to_string(&room) else {
+ return ResponseCode::InternalServerError.text("Failed to create room")
+ };
+
+ ResponseCode::Created.json(&json)
+}
+
+pub const CHAT_ADD: EndpointDocumentation = EndpointDocumentation {
+ uri: "/api/chat/add",
+ method: EndpointMethod::Patch,
+ description: "Adds a user to a room",
+ body: Some(
+ r#"
+ {
+ "room_id": 69,
+ "email" : "joebide@house.gov"
+ }
+ "#,
+ ),
+ responses: &[
+ (201, "Successfully added user"),
+ (400, "Body does not match parameters"),
+ (401, "Unauthorized"),
+ (500, "Failed to add user"),
+ ],
+ cookie: Some("auth"),
+};
+
+#[derive(Deserialize)]
+struct AddUserRequest {
+ room_id: u64,
+ email: String,
+}
+
+impl Check for AddUserRequest {
+ fn check(&self) -> CheckResult {
+ Ok(())
+ }
+}
+
+async fn add (
+ AuthorizedUser(user): AuthorizedUser,
+ Database(db): Database,
+ Json(body): Json<AddUserRequest>,
+) -> Response {
+
+ let Ok(to_add) = User::from_email(&db, &body.email) else {
+ return ResponseCode::BadRequest.text("User does not exist")
+ };
+
+ let Ok(mut room) = ChatRoom::from_user_and_room_id(&db, user.user_id, body.room_id) else {
+ return ResponseCode::BadRequest.text("Room doesnt exist or you are not in it")
+ };
+
+ if room.users.contains(&to_add.user_id) {
+ return ResponseCode::BadRequest.text("User is already in the room")
+ }
+
+ let Ok(success) = room.add_user(&db, to_add.user_id) else {
+ return ResponseCode::InternalServerError.text("Failed to add user")
+ };
+
+ if !success {
+ return ResponseCode::BadRequest.text("User is already in the room")
+ }
+
+ room.users.push(to_add.user_id);
+
+ send_event(ChatEvent::Add {
+ user_id: to_add.user_id,
+ room: room.clone()
+ }, &room).await;
+
+ ResponseCode::Success.text("Successfully added user")
+}
+
+pub const CHAT_LEAVE: EndpointDocumentation = EndpointDocumentation {
+ uri: "/api/chat/leave",
+ method: EndpointMethod::Delete,
+ description: "Leaves a room",
+ body: Some(
+ r#"
+ {
+ "room_id": 69
+ }
+ "#,
+ ),
+ responses: &[
+ (201, "Successfully left room"),
+ (400, "Body does not match parameters"),
+ (401, "Unauthorized"),
+ (500, "Failed to leave a room"),
+ ],
+ cookie: Some("auth"),
+};
+
+#[derive(Deserialize)]
+struct LeaveRoomRequest {
+ room_id: u64,
+}
+
+impl Check for LeaveRoomRequest {
+ fn check(&self) -> CheckResult {
+ Ok(())
+ }
+}
+
+async fn leave (
+ AuthorizedUser(user): AuthorizedUser,
+ Database(db): Database,
+ Json(body): Json<LeaveRoomRequest>,
+) -> Response {
+
+ let Ok(room) = ChatRoom::from_user_and_room_id(&db, user.user_id, body.room_id) else {
+ return ResponseCode::BadRequest.text("Room doesnt exist or you are not in it")
+ };
+
+ let Ok(success) = room.remove_user(&db, user.user_id) else {
+ return ResponseCode::InternalServerError.text("Failed to leave room")
+ };
+
+ if !success {
+ return ResponseCode::BadRequest.text("You are currently not in this room (how did this happen?)")
+ }
+
+ send_event(ChatEvent::Leave {
+ user_id: user.user_id,
+ room_id: room.room_id
+ }, &room).await;
+
+ ResponseCode::Success.text("Successfully left room")
+}
+
+pub const CHAT_SEND: EndpointDocumentation = EndpointDocumentation {
+ uri: "/api/chat/send",
+ method: EndpointMethod::Post,
+ description: "Send a message to a room",
+ body: Some(
+ r#"
+ {
+ "room_id": 420,
+ "content" : "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
+ }
+ "#,
+ ),
+ responses: &[
+ (201, "Successfully sent message"),
+ (400, "Body does not match parameters"),
+ (401, "Unauthorized"),
+ (500, "Failed to send message"),
+ ],
+ cookie: Some("auth"),
+};
+
+#[derive(Deserialize)]
+struct SendMessageRequest {
+ room_id: u64,
+ content: String
+}
+
+impl Check for SendMessageRequest {
+ fn check(&self) -> CheckResult {
+ Self::assert_length(
+ &self.content,
+ 1,
+ 500,
+ "Messages must be between 1-500 length"
+ )?;
+ Ok(())
+ }
+}
+
+async fn send (
+ AuthorizedUser(user): AuthorizedUser,
+ Database(db): Database,
+ Json(body): Json<SendMessageRequest>,
+) -> Response {
+
+ let Ok(room) = ChatRoom::from_user_and_room_id(&db, user.user_id, body.room_id) else {
+ return ResponseCode::BadRequest.text("Room doesnt exist or you are not in it")
+ };
+
+ let Ok(msg) = room.send_message(&db, user.user_id, body.content) else {
+ return ResponseCode::InternalServerError.text("Failed to send message")
+ };
+
+ send_event(ChatEvent::Message {
+ user_id: msg.user_id,
+ room_id: msg.room_id,
+ message_id: msg.message_id,
+ content: msg.content,
+ date: msg.date
+ }, &room).await;
+
+ ResponseCode::Created.text("Successfully sent message")
+}
+
+pub const CHAT_LOAD: EndpointDocumentation = EndpointDocumentation {
+ uri: "/api/chat/load",
+ method: EndpointMethod::Post,
+ description: "Get a page of historic room messages starting before given message id",
+ body: Some(
+ r#"
+ {
+ "room_id": 69,
+ "newest_msg": 400,
+ "page": 3
+ }
+ "#,
+ ),
+ responses: &[
+ (201, "Successfully sent message"),
+ (400, "Body does not match parameters"),
+ (401, "Unauthorized"),
+ (500, "Failed to send message"),
+ ],
+ cookie: Some("auth"),
+};
+
+#[derive(Deserialize)]
+struct LoadMessagesRequest {
+ room_id: u64,
+ newest_msg: u64,
+ page: u64
+}
+
+impl Check for LoadMessagesRequest {
+ fn check(&self) -> CheckResult {
+ Ok(())
+ }
+}
+
+async fn load (
+ AuthorizedUser(user): AuthorizedUser,
+ Database(db): Database,
+ Json(body): Json<LoadMessagesRequest>,
+) -> Response {
+
+ let Ok(room) = ChatRoom::from_user_and_room_id(&db, user.user_id, body.room_id) else {
+ return ResponseCode::BadRequest.text("Room doesnt exist or you are not in it")
+ };
+
+ let Ok(msgs) = room.load_old_chat_messages(&db, body.newest_msg, body.page) else {
+ return ResponseCode::InternalServerError.text("Failed to load messages")
+ };
+
+ let Ok(json) = serde_json::to_string(&msgs) else {
+ return ResponseCode::InternalServerError.text("Failed to load messages")
+ };
+
+ ResponseCode::Created.json(&json)
+}
+
+pub const CHAT_TYPING: EndpointDocumentation = EndpointDocumentation {
+ uri: "/api/chat/typing",
+ method: EndpointMethod::Post,
+ description: "Set if your typing in a given room",
+ body: Some(
+ r#"
+ {
+ "room_id": 69,
+ }
+ "#,
+ ),
+ responses: &[
+ (201, "Successfully sent typing indicator"),
+ (400, "Body does not match parameters"),
+ (401, "Unauthorized"),
+ (500, "Failed to send typing indicator"),
+ ],
+ cookie: Some("auth"),
+};
+
+#[derive(Deserialize)]
+struct TypingRequest {
+ room_id: u64,
+}
+
+impl Check for TypingRequest {
+ fn check(&self) -> CheckResult {
+ Ok(())
+ }
+}
+
+async fn typing (
+ AuthorizedUser(user): AuthorizedUser,
+ Database(db): Database,
+ Json(body): Json<TypingRequest>,
+) -> Response {
+
+ let Ok(room) = ChatRoom::from_user_and_room_id(&db, user.user_id, body.room_id) else {
+ return ResponseCode::BadRequest.text("Room doesnt exist or you are not in it")
+ };
+
+ send_event(ChatEvent::Typing {
+ user_id: user.user_id,
+ room_id: room.room_id,
+ }, &room).await;
+
+ ResponseCode::Success.text("Successfully sent typing indicator")
+}
+
+pub const CHAT_CONNECT: EndpointDocumentation = EndpointDocumentation {
+ uri: "/api/chat/connect",
+ method: EndpointMethod::Get,
+ description: "Start a websocket connection for chat events",
+ body: None,
+ responses: &[],
+ cookie: Some("auth"),
+};
+
+async fn connect (
+ AuthorizedUser(user): AuthorizedUser,
+ ws: WebSocketUpgrade
+) -> Response {
+ ws.on_upgrade(|mut ws| async move {
+ let user = user;
+ let (send, mut recv) = mpsc::channel::<ChatEvent>(20);
+ let id: usize;
+ {
+ let mut lock = CONNECTIONS.lock().await;
+ match lock.get_mut(&user.user_id) {
+ Some(pool) => {
+ id = pool.add(send);
+ },
+ None => {
+ let mut pool = ConnectionPool::new();
+ id = pool.add(send);
+ lock.insert(user.user_id, pool);
+ }
+ };
+ }
+ loop {
+ tokio::select! {
+ m = ws.recv() => {
+ let Some(Ok(_)) = m else {
+ break;
+ };
+ }
+ s = recv.recv() => {
+ let Some(msg) = s else {
+ break;
+ };
+ if let Ok(string) = serde_json::to_string(&msg) {
+ ws.send(Message::Text(string)).await.ok();
+ }
+ }
+ }
+ }
+
+ let mut lock = CONNECTIONS.lock().await;
+ if let Some(conn) = lock.get_mut(&user.user_id) {
+ conn.del(&id);
+ };
+ })
+}
+
+pub fn router() -> Router {
+ Router::new()
+ .route("/create", post(create))
+ .route("/list", post(list))
+ .route("/add", patch(add))
+ .route("/leave", delete(leave))
+ .route("/send", post(send))
+ .route("/load", post(load))
+ .route("/typing", post(typing))
+ .route("/connect", get(connect))
+}
diff --git a/src/api/mod.rs b/src/api/mod.rs
index 8b631c8..0c01ea0 100644
--- a/src/api/mod.rs
+++ b/src/api/mod.rs
@@ -1,33 +1,21 @@
use crate::types::extract::{RouterURI, self};
-use axum::{
- error_handling::HandleErrorLayer,
- BoxError, Extension, Router, middleware,
-};
-use tower::ServiceBuilder;
-use tower_governor::{
- errors::display_error, governor::GovernorConfigBuilder, key_extractor::SmartIpKeyExtractor,
- GovernorLayer,
-};
+pub mod chat;
pub mod admin;
pub mod auth;
pub mod posts;
pub mod users;
pub use auth::RegistrationRequet;
+use axum::{Extension, Router, middleware};
pub fn router() -> Router {
- let governor_conf = Box::new(
- GovernorConfigBuilder::default()
- .burst_size(15)
- .per_second(1)
- .key_extractor(SmartIpKeyExtractor)
- .finish()
- .expect("Failed to create rate limiter"),
- );
-
Router::new()
.nest(
+ "/chat",
+ chat::router().layer(Extension(RouterURI("/api/chat"))),
+ )
+ .nest(
"/admin",
admin::router().layer(Extension(RouterURI("/api/admin"))),
)
@@ -43,14 +31,5 @@ pub fn router() -> Router {
"/posts",
posts::router().layer(Extension(RouterURI("/api/posts"))),
)
- .layer(
- ServiceBuilder::new()
- .layer(HandleErrorLayer::new(|e: BoxError| async move {
- display_error(e)
- }))
- .layer(GovernorLayer {
- config: Box::leak(governor_conf),
- }),
- )
.layer(middleware::from_fn(extract::connect))
}
diff --git a/src/database/chat.rs b/src/database/chat.rs
new file mode 100644
index 0000000..99ec86c
--- /dev/null
+++ b/src/database/chat.rs
@@ -0,0 +1,211 @@
+use std::time::{SystemTime, UNIX_EPOCH, Duration};
+
+use tracing::instrument;
+
+use crate::types::chat::{ChatRoom, ChatMessage};
+
+use super::Database;
+
+impl Database {
+
+ #[instrument(skip(self))]
+ pub fn init_chat(&self) -> Result<(), rusqlite::Error> {
+ let sql = "
+ CREATE TABLE IF NOT EXISTS chat_rooms (
+ room_id INTEGER PRIMARY KEY AUTOINCREMENT,
+ name VARCHAR(255) NOT NULL
+ );
+ ";
+ self.0.execute(sql, ())?;
+
+ let sql2 = "
+ CREATE TABLE IF NOT EXISTS chat_users (
+ room_id INTEGER NOT NULL,
+ user_id INTEGER NOT NULL,
+ FOREIGN KEY(room_id) REFERENCES chat_rooms(room_id),
+ FOREIGN KEY(user_id) REFERENCES users(user_id),
+ PRIMARY KEY (room_id, user_id)
+ );
+ ";
+ self.0.execute(sql2, ())?;
+
+ let sql3 = "
+ CREATE TABLE IF NOT EXISTS chat_messages (
+ message_id INTEGER PRIMARY KEY AUTOINCREMENT,
+ user_id INTEGER NOT NULL,
+ room_id INTEGER NOT NULL,
+ date INTEGER NOT NULL,
+ content VARCHAR(500) NOT NULL,
+ FOREIGN KEY(user_id) REFERENCES users(user_id),
+ FOREIGN KEY(room_id) REFERENCES chat_rooms(room_id)
+ );
+ ";
+ self.0.execute(sql3, ())?;
+
+ let sql4 = "CREATE INDEX IF NOT EXISTS chat_message_ids ON chat_messages(room_id);";
+ self.0.execute(sql4, ())?;
+
+ Ok(())
+ }
+
+ #[instrument(skip(self))]
+ pub fn get_rooms(&self, user_id: u64) -> Result<Vec<ChatRoom>, rusqlite::Error> {
+ tracing::trace!("Retrieving rooms");
+ let mut stmt = self.0.prepare(
+ "
+ SELECT * FROM chat_rooms
+ WHERE room_id IN (
+ SELECT room_id
+ FROM chat_users
+ WHERE user_id = ?
+ );
+ ",
+ )?;
+
+ let row = stmt.query_map([user_id], |row| {
+ let room_id: u64 = row.get(0)?;
+ let name: String = row.get(1)?;
+
+ let mut stmt2 = self.0.prepare(
+ "
+ SELECT user_id FROM chat_users
+ WHERE room_id = ?;
+ "
+ )?;
+
+ let users = stmt2.query_map([room_id], |row2| {
+ Ok(row2.get(0)?)
+ })?.into_iter().flatten().collect();
+
+ let room = ChatRoom {
+ room_id,
+ users,
+ name
+ };
+
+ Ok(room)
+ })?;
+
+ Ok(row.into_iter().flatten().collect())
+ }
+
+ #[instrument(skip(self))]
+ pub fn create_room(&self, users: Vec<u64>, name: String) -> Result<ChatRoom, rusqlite::Error> {
+ tracing::trace!("Creating new room");
+ let mut stmt = self.0.prepare(
+ "INSERT INTO chat_rooms (name) VALUES (?) RETURNING *;"
+ )?;
+ let mut room = stmt.query_row([name], |row| {
+ let room_id = row.get(0)?;
+ let name = row.get(1)?;
+ Ok(ChatRoom {
+ room_id,
+ users: Vec::new(),
+ name
+ })
+ })?;
+
+ let mut stmt2 = self.0.prepare(
+ "INSERT INTO chat_users (room_id, user_id) VALUES (?, ?);"
+ )?;
+
+ for user_id in users {
+ stmt2.execute([room.room_id, user_id])?;
+ room.users.push(user_id);
+ }
+
+ Ok(room)
+ }
+
+ #[instrument(skip(self))]
+ pub fn add_user_to_room(&self, room_id: u64, user_id: u64) -> Result<bool, rusqlite::Error> {
+ tracing::trace!("Adding user to room");
+ let mut stmt = self.0.prepare(
+ "INSERT OR REPLACE INTO chat_users (room_id, user_id) VALUES(?,?);"
+ )?;
+
+ let changes = stmt.execute([room_id, user_id])?;
+
+ Ok(changes == 1)
+ }
+
+ #[instrument(skip(self))]
+ pub fn remove_user_from_room(&self, room_id: u64, user_id: u64) -> Result<bool, rusqlite::Error> {
+ tracing::trace!("Removing user from room");
+ let mut stmt = self.0.prepare(
+ "DELETE FROM chat_users WHERE room_id = ? AND user_id = ?;"
+ )?;
+
+ let changes = stmt.execute([room_id, user_id])?;
+
+ Ok(changes == 1)
+ }
+
+ #[instrument(skip(self))]
+ pub fn create_message(&self, room_id: u64, user_id: u64, content: String) -> Result<ChatMessage, rusqlite::Error> {
+ tracing::trace!("Creating new chat message");
+ let date = u64::try_from(
+ SystemTime::now()
+ .duration_since(UNIX_EPOCH)
+ .unwrap_or(Duration::ZERO)
+ .as_millis(),
+ )
+ .unwrap_or(0);
+
+ let mut stmt = self.0.prepare(
+ "INSERT INTO chat_messages (user_id, room_id, date, content) VALUES (?,?,?,?) RETURNING *;"
+ )?;
+
+ let msg = stmt.query_row((user_id, room_id, date, content), |row| {
+ let message_id = row.get(0)?;
+ let user_id = row.get(1)?;
+ let room_id = row.get(2)?;
+ let date = row.get(3)?;
+ let content = row.get(4)?;
+
+ Ok(ChatMessage {
+ message_id,
+ room_id,
+ user_id,
+ date,
+ content
+ })
+ })?;
+
+ Ok(msg)
+ }
+
+ #[instrument(skip(self))]
+ pub fn load_old_chat_messages(&self, room_id: u64, newest_message: u64, page: u64) -> Result<Vec<ChatMessage>, rusqlite::Error> {
+ tracing::trace!("Loading old chat messages");
+ let mut stmt = self.0.prepare(
+ "
+ SELECT * FROM chat_messages
+ WHERE room_id = ?
+ AND message_id < ?
+ ORDER BY message_id DESC
+ LIMIT ?
+ OFFSET ?
+ "
+ )?;
+
+ let messages = stmt.query_map((room_id, newest_message, 20, 20 * page), |row| {
+ let message_id = row.get(0)?;
+ let user_id = row.get(1)?;
+ let room_id = row.get(2)?;
+ let date = row.get(3)?;
+ let content = row.get(4)?;
+
+ Ok(ChatMessage {
+ message_id,
+ room_id,
+ user_id,
+ date,
+ content
+ })
+ })?;
+
+ Ok(messages.into_iter().flatten().collect())
+ }
+
+}
diff --git a/src/database/mod.rs b/src/database/mod.rs
index 67e05c6..7d0928f 100644
--- a/src/database/mod.rs
+++ b/src/database/mod.rs
@@ -1,6 +1,7 @@
use rusqlite::Connection;
use tracing::instrument;
+pub mod chat;
pub mod comments;
pub mod friends;
pub mod likes;
@@ -32,5 +33,6 @@ pub fn init() -> Result<(), rusqlite::Error> {
db.init_likes()?;
db.init_comments()?;
db.init_friends()?;
+ db.init_chat()?;
Ok(())
}
diff --git a/src/main.rs b/src/main.rs
index 817f8ac..cc8a61e 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -46,7 +46,7 @@ async fn main() {
tracing_subscriber::registry()
.with(
fmt_layer
- .with_filter(LevelFilter::TRACE)
+ .with_filter(LevelFilter::INFO)
.with_filter(filter_fn(|metadata| {
metadata.target().starts_with("xssbook")
})),
diff --git a/src/public/docs.rs b/src/public/docs.rs
index 397e696..976638b 100644
--- a/src/public/docs.rs
+++ b/src/public/docs.rs
@@ -3,24 +3,28 @@ use lazy_static::lazy_static;
use tokio::sync::Mutex;
use crate::{
- api::{admin, auth, posts, users},
+ api::{admin, auth, posts, users, chat},
types::http::ResponseCode,
};
use super::console::beautify;
pub enum EndpointMethod {
+ Get,
Post,
Put,
Patch,
+ Delete
}
impl ToString for EndpointMethod {
fn to_string(&self) -> String {
match self {
+ Self::Get => "GET".to_owned(),
Self::Post => "POST".to_owned(),
Self::Put => "PUT".to_owned(),
Self::Patch => "PATCH".to_owned(),
+ Self::Delete => "DELETE".to_owned(),
}
}
}
@@ -139,6 +143,14 @@ pub async fn init() {
users::USERS_FOLLOW,
users::USERS_FOLLOW_STATUS,
users::USERS_FRIENDS,
+ chat::CHAT_LIST,
+ chat::CHAT_CREATE,
+ chat::CHAT_ADD,
+ chat::CHAT_LEAVE,
+ chat::CHAT_SEND,
+ chat::CHAT_LOAD,
+ chat::CHAT_TYPING,
+ chat::CHAT_CONNECT,
admin::ADMIN_AUTH,
admin::ADMIN_QUERY,
admin::ADMIN_POSTS,
diff --git a/src/public/mod.rs b/src/public/mod.rs
index bb75ef0..bd40fda 100644
--- a/src/public/mod.rs
+++ b/src/public/mod.rs
@@ -1,17 +1,12 @@
use axum::{
body::Body,
- error_handling::HandleErrorLayer,
headers::HeaderName,
http::{HeaderValue, Request, StatusCode},
- response::{IntoResponse, Response},
+ response::{Response, IntoResponse},
routing::get,
- BoxError, Router,
-};
-use tower::{ServiceBuilder, ServiceExt};
-use tower_governor::{
- errors::display_error, governor::GovernorConfigBuilder, key_extractor::SmartIpKeyExtractor,
- GovernorLayer,
+ Router,
};
+use tower::ServiceExt;
use tower_http::services::ServeFile;
use crate::types::http::ResponseCode;
@@ -23,15 +18,6 @@ pub mod file;
pub mod pages;
pub fn router() -> Router {
- let governor_conf = Box::new(
- GovernorConfigBuilder::default()
- .burst_size(30)
- .per_second(1)
- .key_extractor(SmartIpKeyExtractor)
- .finish()
- .expect("Failed to create rate limiter"),
- );
-
Router::new()
.nest("/", pages::router())
.route("/favicon.ico", get(file::favicon))
@@ -42,15 +28,6 @@ pub fn router() -> Router {
.route("/image/*path", get(file::image))
.route("/image/avatar", get(file::avatar))
.route("/image/banner", get(file::banner))
- .layer(
- ServiceBuilder::new()
- .layer(HandleErrorLayer::new(|e: BoxError| async move {
- display_error(e)
- }))
- .layer(GovernorLayer {
- config: Box::leak(governor_conf),
- }),
- )
}
pub async fn serve(path: &str) -> Response {
diff --git a/src/public/pages.rs b/src/public/pages.rs
index a7789b2..0eef51b 100644
--- a/src/public/pages.rs
+++ b/src/public/pages.rs
@@ -66,6 +66,10 @@ async fn forgot(UserAgent(agent): UserAgent, _: Log) -> Response {
Redirect::to("https://www.youtube.com/watch?v=dQw4w9WgXcQ").into_response()
}
+async fn chat() -> Response {
+ super::serve("/chat.html").await
+}
+
pub fn router() -> Router {
Router::new()
.route("/", get(root))
@@ -79,4 +83,5 @@ pub fn router() -> Router {
.route("/admin", get(admin))
.route("/docs", get(api))
.route("/forgot", get(forgot))
+ .route("/chat", get(chat))
}
diff --git a/src/types/chat.rs b/src/types/chat.rs
new file mode 100644
index 0000000..8413f77
--- /dev/null
+++ b/src/types/chat.rs
@@ -0,0 +1,129 @@
+use serde::{Serialize, Deserialize};
+use tracing::instrument;
+use crate::{types::http::{ResponseCode, Result}, database::Database};
+
+#[derive(Deserialize, Serialize, Clone, Debug)]
+pub struct ChatRoom {
+ pub room_id: u64,
+ pub users: Vec<u64>,
+ pub name: String
+}
+
+#[derive(Serialize, Clone, Debug)]
+pub struct ChatMessage {
+ pub message_id: u64,
+ pub user_id: u64,
+ pub room_id: u64,
+ pub date: u64,
+ pub content: String
+}
+
+#[derive(Serialize, Deserialize, Clone, Debug)]
+#[serde(tag = "type")]
+pub enum ChatEvent {
+ #[serde(rename = "message")]
+ Message {
+ user_id: u64,
+ message_id: u64,
+ room_id: u64,
+ content: String,
+ date: u64
+ },
+
+ #[serde(rename = "add")]
+ Add {
+ user_id: u64,
+ room: ChatRoom
+ },
+
+ #[serde(rename = "leave")]
+ Leave {
+ user_id: u64,
+ room_id: u64
+ },
+
+ #[serde(rename = "typing")]
+ Typing {
+ user_id: u64,
+ room_id: u64
+ }
+}
+
+impl ChatRoom {
+
+ #[instrument(skip(db))]
+ pub fn new(db: &Database, users: Vec<u64>, name: String) -> Result<Self> {
+ let Ok(room) = db.create_room(users, name) else {
+ tracing::error!("Failed to create room");
+ return Err(ResponseCode::InternalServerError.text("Failed to create room"))
+ };
+
+ Ok(room)
+ }
+
+ #[instrument(skip(db))]
+ pub fn from_user_id(db: &Database, user_id: u64) -> Result<Vec<Self>> {
+ let Ok(rooms) = db.get_rooms(user_id) else {
+ tracing::error!("Failed to get rooms");
+ return Err(ResponseCode::InternalServerError.text("Failed to get rooms"))
+ };
+
+ Ok(rooms)
+ }
+
+ #[instrument(skip(db))]
+ pub fn from_user_and_room_id(db: &Database, user_id: u64, room_id: u64) -> Result<Self> {
+ let Ok(rooms) = db.get_rooms(user_id) else {
+ tracing::error!("Failed to get room");
+ return Err(ResponseCode::InternalServerError.text("Failed to get room"))
+ };
+
+ for room in rooms {
+ if room.room_id == room_id {
+ return Ok(room);
+ }
+ }
+
+ return Err(ResponseCode::BadRequest.text("Room doesnt exist or you are not in it"))
+ }
+
+ #[instrument(skip(db))]
+ pub fn add_user(&self, db: &Database, user_id: u64) -> Result<bool> {
+ let Ok(success) = db.add_user_to_room(self.room_id, user_id) else {
+ tracing::error!("Failed to add user to room");
+ return Err(ResponseCode::InternalServerError.text("Failed to add user to room"))
+ };
+
+ Ok(success)
+ }
+
+ #[instrument(skip(db))]
+ pub fn remove_user(&self, db: &Database, user_id: u64) -> Result<bool> {
+ let Ok(success) = db.remove_user_from_room(self.room_id, user_id) else {
+ tracing::error!("Failed to remove user from room");
+ return Err(ResponseCode::InternalServerError.text("Failed to remove user from room"))
+ };
+
+ Ok(success)
+ }
+
+ #[instrument(skip(db))]
+ pub fn send_message(&self, db: &Database, user_id: u64, content: String) -> Result<ChatMessage> {
+ let Ok(msg) = db.create_message(self.room_id, user_id, content) else {
+ tracing::error!("Failed to create messgae");
+ return Err(ResponseCode::InternalServerError.text("Failed to create message"))
+ };
+
+ Ok(msg)
+ }
+
+ #[instrument(skip(db))]
+ pub fn load_old_chat_messages(&self, db: &Database, newest_message: u64, page: u64) -> Result<Vec<ChatMessage>> {
+ let Ok(msgs) = db.load_old_chat_messages(self.room_id, newest_message, page) else {
+ tracing::error!("Failed to load messgaes");
+ return Err(ResponseCode::InternalServerError.text("Failed to load messages"))
+ };
+
+ Ok(msgs)
+ }
+}
diff --git a/src/types/mod.rs b/src/types/mod.rs
index 1ee2d08..a325ff9 100644
--- a/src/types/mod.rs
+++ b/src/types/mod.rs
@@ -1,3 +1,4 @@
+pub mod chat;
pub mod comment;
pub mod extract;
pub mod http;