diff --git a/Emby.Server.Implementations/Middleware/WebSocketMiddleware.cs b/Emby.Server.Implementations/Middleware/WebSocketMiddleware.cs index a1d0e77d65..268bf4042d 100644 --- a/Emby.Server.Implementations/Middleware/WebSocketMiddleware.cs +++ b/Emby.Server.Implementations/Middleware/WebSocketMiddleware.cs @@ -25,7 +25,10 @@ namespace Emby.Server.Implementations.Middleware if (httpContext.WebSockets.IsWebSocketRequest) { var webSocketContext = await httpContext.WebSockets.AcceptWebSocketAsync(null).ConfigureAwait(false); - _webSocketManager.AddSocket(webSocketContext); + if (webSocketContext != null) + { + await _webSocketManager.OnWebSocketConnected(webSocketContext); + } } else { diff --git a/Emby.Server.Implementations/WebSockets/WebSocketHandler.cs b/Emby.Server.Implementations/WebSockets/WebSocketHandler.cs new file mode 100644 index 0000000000..70b9e85aa9 --- /dev/null +++ b/Emby.Server.Implementations/WebSockets/WebSocketHandler.cs @@ -0,0 +1,10 @@ +using System.Threading.Tasks; +using MediaBrowser.Model.Net; + +namespace Emby.Server.Implementations.WebSockets +{ + public interface IWebSocketHandler + { + Task ProcessMessage(WebSocketMessage message); + } +} diff --git a/Emby.Server.Implementations/WebSockets/WebSocketManager.cs b/Emby.Server.Implementations/WebSockets/WebSocketManager.cs index 7e74a45277..888f2f0fce 100644 --- a/Emby.Server.Implementations/WebSockets/WebSocketManager.cs +++ b/Emby.Server.Implementations/WebSockets/WebSocketManager.cs @@ -1,22 +1,100 @@ using System; using System.Collections.Concurrent; +using System.Collections.Generic; +using System.Linq; using System.Net.WebSockets; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using MediaBrowser.Controller.Net; +using MediaBrowser.Model.Net; +using MediaBrowser.Model.Serialization; +using Microsoft.Extensions.Logging; +using UtfUnknown; namespace Emby.Server.Implementations.WebSockets { public class WebSocketManager { - private readonly ConcurrentDictionary _activeWebSockets; + private readonly IWebSocketHandler[] _webSocketHandlers; + private readonly IJsonSerializer _jsonSerializer; + private readonly ILogger _logger; + private const int BufferSize = 4096; - public WebSocketManager() + public WebSocketManager(IWebSocketHandler[] webSocketHandlers, IJsonSerializer jsonSerializer, ILogger logger) { - _activeWebSockets = new ConcurrentDictionary(); + _webSocketHandlers = webSocketHandlers; + _jsonSerializer = jsonSerializer; + _logger = logger; } - public void AddSocket(WebSocket webSocket) + public async Task OnWebSocketConnected(WebSocket webSocket) { - var guid = Guid.NewGuid(); - _activeWebSockets.TryAdd(guid, webSocket); + var taskCompletionSource = new TaskCompletionSource(); + var cancellationToken = new CancellationTokenSource().Token; + WebSocketReceiveResult result; + var message = new List(); + + do + { + var buffer = WebSocket.CreateServerBuffer(BufferSize); + result = await webSocket.ReceiveAsync(buffer, cancellationToken); + message.AddRange(buffer.Array.Take(result.Count)); + + if (result.EndOfMessage) + { + await ProcessMessage(message.ToArray(), taskCompletionSource); + message.Clear(); + } + } while (!taskCompletionSource.Task.IsCompleted && + webSocket.State == WebSocketState.Open && + result.MessageType != WebSocketMessageType.Close); + + if (webSocket.State == WebSocketState.Open) + { + await webSocket.CloseAsync(result.CloseStatus ?? WebSocketCloseStatus.NormalClosure, + result.CloseStatusDescription, cancellationToken); + } + } + + public async Task ProcessMessage(byte[] messageBytes, TaskCompletionSource taskCompletionSource) + { + var charset = CharsetDetector.DetectFromBytes(messageBytes).Detected?.EncodingName; + var message = string.Equals(charset, "utf-8", StringComparison.OrdinalIgnoreCase) + ? Encoding.UTF8.GetString(messageBytes, 0, messageBytes.Length) + : Encoding.ASCII.GetString(messageBytes, 0, messageBytes.Length); + + // All messages are expected to be json + if (!message.StartsWith("{", StringComparison.OrdinalIgnoreCase)) + { + _logger.LogDebug("Received web socket message that is not a json structure: {Message}", message); + return; + } + + try + { + var info = _jsonSerializer.DeserializeFromString>(message); + + _logger.LogDebug("Websocket message received: {0}", info.MessageType); + + var tasks = _webSocketHandlers.Select(handler => Task.Run(() => + { + try + { + handler.ProcessMessage(info).ConfigureAwait(false); + } + catch (Exception ex) + { + _logger.LogError(ex, "{0} failed processing WebSocket message {1}", handler.GetType().Name, info.MessageType ?? string.Empty); + } + })); + + await Task.WhenAll(tasks); + } + catch (Exception ex) + { + _logger.LogError(ex, "Error processing web socket message"); + } } } }