summaryrefslogtreecommitdiff
path: root/src/api/chat.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/api/chat.rs')
-rw-r--r--src/api/chat.rs193
1 files changed, 187 insertions, 6 deletions
diff --git a/src/api/chat.rs b/src/api/chat.rs
index 0f3b92e..1e56c3e 100644
--- a/src/api/chat.rs
+++ b/src/api/chat.rs
@@ -1,13 +1,63 @@
-use axum::{response::Response, Router, routing::{post, patch, delete}};
+use std::collections::HashMap;
+
+use axum::{response::Response, Router, routing::{post, patch, delete, get}, extract::{ws::Message, WebSocketUpgrade}};
use serde::Deserialize;
+use tokio::sync::{Mutex, mpsc::{Sender, self}};
use crate::{
public::docs::{EndpointDocumentation, EndpointMethod},
types::{
extract::{AuthorizedUser, Check, CheckResult, Database, Json},
http::ResponseCode,
- chat::ChatRoom, user::User,
+ chat::{ChatRoom, ChatEvent}, user::User,
},
};
+use std::collections::hash_map::Values;
+use lazy_static::lazy_static;
+
+lazy_static!(
+ static ref CONNECTIONS: Mutex<HashMap<u64, ConnectionPool>> = Mutex::new(HashMap::new());
+);
+
+struct ConnectionPool {
+ inner: HashMap<usize, Sender<ChatEvent>>,
+ index: usize
+}
+
+impl ConnectionPool {
+ fn new() -> Self {
+ Self {
+ inner: HashMap::new(),
+ index: 0
+ }
+ }
+
+ fn add(&mut self, send: Sender<ChatEvent>) -> usize {
+ let idx = self.index;
+ self.index += 1;
+ self.inner.insert(idx, send);
+ idx
+ }
+
+ fn del(&mut self, idx: &usize) {
+ self.inner.remove(idx);
+ }
+
+ fn values(&self) -> Values<'_, usize, Sender<ChatEvent>> {
+ self.inner.values()
+ }
+}
+
+async fn send_event(event: ChatEvent, room: &ChatRoom) {
+ for user in &room.users {
+ let lock = CONNECTIONS.lock().await;
+ let Some(connection) = lock.get(&user) else {
+ continue
+ };
+ for channel in connection.values() {
+ channel.send(event.clone()).await.ok();
+ }
+ }
+}
pub const CHAT_LIST: EndpointDocumentation = EndpointDocumentation {
uri: "/api/chat/list",
@@ -80,11 +130,18 @@ async fn create (
Database(db): Database,
Json(body): Json<RoomCreateRequest>,
) -> Response {
- let Ok(post) = ChatRoom::new(&db, vec![user.user_id], body.name) else {
+ let Ok(room) = ChatRoom::new(&db, vec![user.user_id], body.name) else {
return ResponseCode::InternalServerError.text("Failed to create room")
};
- let Ok(json) = serde_json::to_string(&post) else {
+ for user in &room.users {
+ send_event(ChatEvent::Add {
+ user_id: *user,
+ room_id: room.room_id
+ }, &room).await;
+ }
+
+ let Ok(json) = serde_json::to_string(&room) else {
return ResponseCode::InternalServerError.text("Failed to create room")
};
@@ -145,6 +202,11 @@ async fn add (
if !success {
return ResponseCode::BadRequest.text("User is already in the room")
}
+
+ send_event(ChatEvent::Add {
+ user_id: to_add.user_id,
+ room_id: room.room_id
+ }, &room).await;
ResponseCode::Success.text("Successfully added user")
}
@@ -197,6 +259,11 @@ async fn leave (
if !success {
return ResponseCode::BadRequest.text("You are currently not in this room (how did this happen?)")
}
+
+ send_event(ChatEvent::Leave {
+ user_id: user.user_id,
+ room_id: room.room_id
+ }, &room).await;
ResponseCode::Success.text("Successfully left room")
}
@@ -250,9 +317,17 @@ async fn send (
return ResponseCode::BadRequest.text("Room doesnt exist or you are not in it")
};
- let Ok(_msg) = room.send_message(&db, user.user_id, body.content) else {
+ let Ok(msg) = room.send_message(&db, user.user_id, body.content) else {
return ResponseCode::InternalServerError.text("Failed to send message")
};
+
+ send_event(ChatEvent::Message {
+ user_id: msg.user_id,
+ room_id: msg.room_id,
+ message_id: msg.message_id,
+ content: msg.content,
+ date: msg.date
+ }, &room).await;
ResponseCode::Created.text("Successfully sent message")
}
@@ -302,7 +377,7 @@ async fn load (
return ResponseCode::BadRequest.text("Room doesnt exist or you are not in it")
};
- let Ok(msgs) = room.load_old_chat_messagegs(&db, body.newest_msg, body.page) else {
+ let Ok(msgs) = room.load_old_chat_messages(&db, body.newest_msg, body.page) else {
return ResponseCode::InternalServerError.text("Failed to load messages")
};
@@ -313,6 +388,110 @@ async fn load (
ResponseCode::Created.json(&json)
}
+pub const CHAT_TYPING: EndpointDocumentation = EndpointDocumentation {
+ uri: "/api/chat/typing",
+ method: EndpointMethod::Post,
+ description: "Set if your typing in a given room",
+ body: Some(
+ r#"
+ {
+ "room_id": 69,
+ }
+ "#,
+ ),
+ responses: &[
+ (201, "Successfully sent typing indicator"),
+ (400, "Body does not match parameters"),
+ (401, "Unauthorized"),
+ (500, "Failed to send typing indicator"),
+ ],
+ cookie: Some("auth"),
+};
+
+#[derive(Deserialize)]
+struct TypingRequest {
+ room_id: u64,
+}
+
+impl Check for TypingRequest {
+ fn check(&self) -> CheckResult {
+ Ok(())
+ }
+}
+
+async fn typing (
+ AuthorizedUser(user): AuthorizedUser,
+ Database(db): Database,
+ Json(body): Json<TypingRequest>,
+) -> Response {
+
+ let Ok(room) = ChatRoom::from_user_and_room_id(&db, user.user_id, body.room_id) else {
+ return ResponseCode::BadRequest.text("Room doesnt exist or you are not in it")
+ };
+
+ send_event(ChatEvent::Typing {
+ user_id: user.user_id,
+ room_id: room.room_id,
+ }, &room).await;
+
+ ResponseCode::Success.text("Successfully sent typing indicator")
+}
+
+pub const CHAT_CONNECT: EndpointDocumentation = EndpointDocumentation {
+ uri: "/api/chat/connect",
+ method: EndpointMethod::Get,
+ description: "Start a websocket connection for chat events",
+ body: None,
+ responses: &[],
+ cookie: Some("auth"),
+};
+
+async fn connect (
+ AuthorizedUser(user): AuthorizedUser,
+ ws: WebSocketUpgrade
+) -> Response {
+ ws.on_upgrade(|mut ws| async move {
+ let user = user;
+ let (send, mut recv) = mpsc::channel::<ChatEvent>(20);
+ let id: usize;
+ {
+ let mut lock = CONNECTIONS.lock().await;
+ match lock.get_mut(&user.user_id) {
+ Some(pool) => {
+ id = pool.add(send);
+ },
+ None => {
+ let mut pool = ConnectionPool::new();
+ id = pool.add(send);
+ lock.insert(user.user_id, pool);
+ }
+ };
+ }
+ loop {
+ tokio::select! {
+ m = ws.recv() => {
+ let Some(Ok(_)) = m else {
+ break;
+ };
+ }
+ s = recv.recv() => {
+ let Some(msg) = s else {
+ break;
+ };
+ if let Ok(string) = serde_json::to_string(&msg) {
+ ws.send(Message::Text(string)).await.ok();
+ }
+ }
+ }
+ }
+
+ let mut lock = CONNECTIONS.lock().await;
+ if let Some(conn) = lock.get_mut(&user.user_id) {
+ conn.del(&id);
+ };
+ })
+}
+
pub fn router() -> Router {
Router::new()
.route("/create", post(create))
@@ -321,4 +500,6 @@ pub fn router() -> Router {
.route("/leave", delete(leave))
.route("/send", post(send))
.route("/load", post(load))
+ .route("/typing", post(typing))
+ .route("/connect", get(connect))
}