summaryrefslogtreecommitdiff
path: root/src/room
diff options
context:
space:
mode:
Diffstat (limited to 'src/room')
-rw-r--r--src/room/handle.rs106
-rw-r--r--src/room/messages.rs69
-rw-r--r--src/room/mod.rs128
-rw-r--r--src/room/websocket.rs66
4 files changed, 369 insertions, 0 deletions
diff --git a/src/room/handle.rs b/src/room/handle.rs
new file mode 100644
index 0000000..d397c70
--- /dev/null
+++ b/src/room/handle.rs
@@ -0,0 +1,106 @@
+use std::collections::HashSet;
+
+use super::messages::{ClientMessage, ServerMessage};
+
+// send a ServerMessage::Connections to all sockets
+pub async fn send_connections(v: &mut super::Clients, added: Option<usize>, removed: Option<usize>, frame: u64) {
+ // get the list of connection IDs
+ let connections: Vec<usize> = v.iter()
+ .enumerate()
+ .filter(|(_, n)| n.is_some())
+ .map(|(id, _)| id)
+ .collect();
+
+ super::send(v, |id, _c| {
+ Some(ServerMessage::Connections {
+ connections: connections.clone(),
+ added,
+ removed,
+ id,
+ frame,
+ })
+ }).await;
+}
+
+// handle incoming websocket messages
+pub async fn handle(
+ v: &mut super::Clients,
+ requests: &mut HashSet<(u64, Option<usize>, usize)>, // frame, connection, client id
+ pending: &mut Vec<(Option<usize>, Option<usize>)>,
+ id: usize,
+ msg: ClientMessage,
+) {
+ match msg {
+ // broadcast inputs to every other connection
+ ClientMessage::Input { data, frame } => {
+ super::broadcast(v, ServerMessage::Input {
+ data,
+ frame,
+ connection: id
+ }, Some(id)).await;
+ },
+ // a client needs the current game state, grab it from another client
+ ClientMessage::RequestState { frame, connection } => {
+ let count = super::conn_count(v);
+
+ if count < 2 { // nobody to request state *from*
+ if let Some(Some(client)) = v.get(id) {
+ client.send(ServerMessage::State {
+ state: serde_json::Value::Null,
+ frame: 0,
+ connection: None,
+ }).await.ok();
+ }
+ return;
+ }
+
+ // request state from other clients
+ requests.insert((frame, connection, id));
+
+ match connection {
+ None => {
+ super::broadcast(v, ServerMessage::RequestState { frame }, Some(id)).await;
+ },
+ Some(id) => { // it's to a specific connection
+ let Some(Some(client)) = v.get(id) else {
+ return;
+ };
+ client.send(ServerMessage::RequestState { frame }).await.ok();
+ },
+ }
+ },
+ // a client responded to a request for game state, tell all the requestees
+ ClientMessage::State { state, frame } => {
+ let mut new_requests = HashSet::new();
+ for (fr, conn, cid) in requests.drain() {
+ if
+ fr != frame || // this isn't the requested frame
+ (conn.is_some() && Some(id) != conn) // this isn't the requested connection
+ {
+ new_requests.insert((fr, conn, cid));
+ continue;
+ }
+ if let Some(Some(client)) = v.get(cid) {
+ client.send(ServerMessage::State {
+ state: state.clone(),
+ frame,
+ connection: Some(id),
+ }).await.ok();
+ }
+ }
+ *requests = new_requests;
+ },
+ // a client said what frame they're on, actually send the connections message
+ ClientMessage::Frame { frame } => {
+ for (added, removed) in pending.into_iter() {
+ send_connections(v, *added, *removed, frame).await;
+ }
+ *pending = Vec::new();
+ },
+ ClientMessage::Ping { frame } => {
+ if let Some(Some(client)) = v.get(id) {
+ client.send(ServerMessage::Pong { frame }).await.ok();
+ }
+ }
+ }
+}
diff --git a/src/room/messages.rs b/src/room/messages.rs
new file mode 100644
index 0000000..72958a6
--- /dev/null
+++ b/src/room/messages.rs
@@ -0,0 +1,69 @@
+use serde::{Serialize, Deserialize};
+use serde_json::Value;
+
+#[derive(Deserialize, Clone, Debug)]
+#[serde(tag = "type")]
+pub enum ClientMessage {
+ #[serde(rename = "frame")]
+ Frame {
+ frame: u64,
+ },
+ #[serde(rename = "input")]
+ Input {
+ data: Value,
+ frame: u64,
+ },
+ #[serde(rename = "requeststate")]
+ RequestState {
+ connection: Option<usize>,
+ frame: u64,
+ },
+ #[serde(rename = "state")]
+ State {
+ state: Value,
+ frame: u64,
+ },
+ #[serde(rename = "ping")]
+ Ping {
+ frame: u64,
+ },
+}
+
+#[derive(Serialize, Clone, Debug)]
+#[serde(tag = "type")]
+pub enum ServerMessage {
+ #[serde(rename = "framerequest")]
+ FrameRequest,
+ #[serde(rename = "connections")]
+ Connections {
+ connections: Vec<usize>,
+ added: Option<usize>,
+ removed: Option<usize>,
+ id: usize,
+ frame: u64,
+ },
+ #[serde(rename = "input")]
+ Input {
+ data: Value,
+ frame: u64,
+ connection: usize,
+ },
+ #[serde(rename = "requeststate")]
+ RequestState {
+ frame: u64,
+ },
+ #[serde(rename = "state")]
+ State {
+ state: Value,
+ frame: u64,
+ connection: Option<usize>,
+ },
+ #[serde(rename = "pong")]
+ Pong {
+ frame: u64,
+ },
+ #[serde(rename = "error")]
+ Error {
+ error: String,
+ },
+}
diff --git a/src/room/mod.rs b/src/room/mod.rs
new file mode 100644
index 0000000..8b3d8c2
--- /dev/null
+++ b/src/room/mod.rs
@@ -0,0 +1,128 @@
+use std::{time::Duration, collections::HashSet};
+
+use axum::extract::ws::WebSocket;
+use tokio::sync::mpsc;
+
+mod websocket;
+mod messages;
+mod handle;
+
+use messages::{ClientMessage, ServerMessage};
+
+pub enum RoomMessage {
+ Add(WebSocket),
+ Remove(usize),
+ WsMessage(usize, ClientMessage),
+}
+
+pub type Client = mpsc::Sender<ServerMessage>;
+pub type Clients = Vec<Option<Client>>;
+
+pub type Room = mpsc::Sender<RoomMessage>;
+
+// spawns a task for the room that listens for incoming messages from websockets as well as connections and disconnections
+pub fn start_room(room_id: String, room_service: super::rooms::RoomService) -> Room {
+ let (tx, rx) = mpsc::channel::<RoomMessage>(20);
+
+ let txret = tx.clone();
+
+ tokio::spawn(room_task(tx, rx, room_id, room_service));
+
+ txret
+}
+
+async fn room_task(tx: mpsc::Sender<RoomMessage>, mut rx: mpsc::Receiver<RoomMessage>, room_id: String, room_service: super::rooms::RoomService) {
+ let mut ws = Vec::new();
+ let mut state_requests = HashSet::new();
+ let mut pending: Vec<(Option<usize>, Option<usize>)> = Vec::new();
+
+ while let Some(message) = rx.recv().await {
+ match message {
+ RoomMessage::Add(w) => { // a new connection is added
+ // create channels for the websocket and start a task to send and receive from it
+ let (wstx, wsrx) = mpsc::channel(5);
+ let id = ws.len();
+ ws.push(Some(wstx));
+ tokio::spawn(websocket::start_ws(w, id, tx.clone(), wsrx));
+
+ if conn_count(&ws) < 2 { // the first connection is on frame 0
+ handle::send_connections(&mut ws, Some(id), None, 0).await;
+ } else {
+ // connections need to be added on a specific frame
+ // so ask the clients for a frame to put this event on
+ broadcast(&mut ws, ServerMessage::FrameRequest, Some(id)).await;
+ pending.push((Some(id), None));
+ }
+ },
+ RoomMessage::Remove(id) => { // a connection is closed (sent by the websocket task on exiting)
+ // only remove it if it exists
+ if let Some(item) = ws.get_mut(id) {
+ *item = None;
+ };
+ let count = conn_count(&ws);
+ if count == 0 { // remove rooms once they become empty
+ room_service.send(super::rooms::RoomServiceRequest::Remove(room_id.clone())).await.ok();
+ break;
+ }
+
+ // disconnections happen on a specific frame, ask the clients for a frame
+ broadcast(&mut ws, ServerMessage::FrameRequest, None).await;
+ pending.push((None, Some(id)));
+ },
+ RoomMessage::WsMessage(id, msg) => { // new data from a websocket
+ handle::handle(&mut ws, &mut state_requests, &mut pending, id, msg).await;
+ }
+ }
+ }
+}
+
+// send the websocket to the room task
+pub async fn add_connection(tx: &Room, ws: WebSocket) {
+ tx.send_timeout(RoomMessage::Add(ws), Duration::from_secs(1)).await.ok();
+}
+
+pub fn conn_count(v: &Clients) -> usize {
+ v.iter().filter(|i| i.is_some()).count()
+}
+
+// send a message to all or some of the clients, in parallel rather than series,
+// based on a callback
+pub async fn send(v: &mut Clients, create_message: impl Fn(usize, &Client) -> Option<ServerMessage>) -> usize {
+ let tasks = v.iter()
+ .enumerate()
+ .map(|(id, c)| {
+ // send to existing clients
+ let Some(client) = c.clone() else {
+ return None;
+ };
+
+ let Some(msg) = create_message(id, &client) else {
+ return None;
+ };
+
+ Some(tokio::spawn(async move {
+ client.send(msg).await.ok();
+ }))
+ });
+
+ let count = tasks.len();
+ // make sure all the tasks complete
+ for task in tasks {
+ if let Some(t) = task {
+ t.await.ok();
+ }
+ }
+
+ count
+}
+
+// send a message to all the websockets in the room (optionally excluding one)
+pub async fn broadcast(v: &mut Clients, msg: ServerMessage, except: Option<usize>) -> usize {
+ send(v, |id, _client| {
+ if Some(id) == except {
+ return None;
+ }
+
+ Some(msg.clone())
+ }).await
+}
diff --git a/src/room/websocket.rs b/src/room/websocket.rs
new file mode 100644
index 0000000..50a4537
--- /dev/null
+++ b/src/room/websocket.rs
@@ -0,0 +1,66 @@
+use std::time::Duration;
+
+use axum::extract::ws::{WebSocket, Message};
+use tokio::sync::mpsc;
+
+use super::RoomMessage;
+use super::messages::ServerMessage;
+
+// set up some senders and receivers so that the websocket can receive messages from the task, send messages to the task, and notify the task when it closes
+pub async fn start_ws(mut ws: WebSocket, id: usize, tx: mpsc::Sender<RoomMessage>, mut rx: mpsc::Receiver<ServerMessage>) {
+ loop {
+ tokio::select! {
+ m = ws.recv() => { // receive from the websocket and send it to `tx`
+ if let Some(Ok(msg)) = m {
+ // get the string contents
+ let optionstring = match msg {
+ Message::Text(s) => {
+ Some(s)
+ },
+ Message::Binary(bin) => {
+ String::from_utf8(bin).ok()
+ },
+ Message::Close(_) => { // quit the whole loop on disconnect
+ break;
+ },
+ _ => None
+ };
+
+ // ignore things that aren't strings
+ let Some(s) = optionstring else {
+ continue;
+ };
+
+ // decode and send to the room
+ match serde_json::from_str(&s) {
+ Ok(message) => {
+ tx.send_timeout(RoomMessage::WsMessage(id, message), Duration::from_secs(1)).await.ok();
+ },
+ Err(e) => { // let the client know if they sent a bad message
+ if let Ok(text) = serde_json::to_string(&ServerMessage::Error{
+ error: format!("Failed to decode JSON message: {}: {}", e, s),
+ }) {
+ ws.send(Message::Text(text)).await.ok();
+ }
+ }
+ }
+ } else { // websocket error
+ break;
+ }
+ }
+ s = rx.recv() => { // receive from `rx` and send it to the websocket
+ if let Some(msg) = s {
+ if let Ok(string) = serde_json::to_string(&msg) {
+ ws.send(Message::Text(string)).await.ok();
+ }
+ } else { // shouldn't happen but this is if the room drops the sender, it should close the websocket anyways
+ break;
+ }
+ }
+ }
+ }
+
+ // websocket disconnect due to either error or normal disconnect
+ // notify the room that the socket should be removed
+ tx.send_timeout(RoomMessage::Remove(id), Duration::from_secs(1)).await.ok();
+}