diff --git a/SocketHttpListener/Ext.cs b/SocketHttpListener/Ext.cs index a02b48061d..3b500ecd89 100644 --- a/SocketHttpListener/Ext.cs +++ b/SocketHttpListener/Ext.cs @@ -4,6 +4,7 @@ using System.IO; using System.IO.Compression; using System.Net; using System.Text; +using System.Threading; using System.Threading.Tasks; using MediaBrowser.Model.Services; using HttpStatusCode = SocketHttpListener.Net.HttpStatusCode; @@ -95,8 +96,30 @@ namespace SocketHttpListener : buffer; } - private static bool readBytes( - this Stream stream, byte[] buffer, int offset, int length, Stream dest) + private static async Task ReadBytesAsync(this Stream stream, byte[] buffer, int offset, int length) + { + var len = await stream.ReadAsync(buffer, offset, length).ConfigureAwait(false); + if (len < 1) + return buffer.SubArray(0, offset); + + var tmp = 0; + while (len < length) + { + tmp = await stream.ReadAsync(buffer, offset + len, length - len).ConfigureAwait(false); + if (tmp < 1) + { + break; + } + + len += tmp; + } + + return len < length + ? buffer.SubArray(0, offset + len) + : buffer; + } + + private static bool readBytes(this Stream stream, byte[] buffer, int offset, int length, Stream dest) { var bytes = stream.readBytes(buffer, offset, length); var len = bytes.Length; @@ -105,6 +128,15 @@ namespace SocketHttpListener return len == offset + length; } + private static async Task ReadBytesAsync(this Stream stream, byte[] buffer, int offset, int length, Stream dest) + { + var bytes = await stream.ReadBytesAsync(buffer, offset, length).ConfigureAwait(false); + var len = bytes.Length; + dest.Write(bytes, 0, len); + + return len == offset + length; + } + #endregion #region Internal Methods @@ -331,12 +363,10 @@ namespace SocketHttpListener : string.Format("\"{0}\"", value.Replace("\"", "\\\"")); } - internal static byte[] ReadBytes(this Stream stream, int length) - { - return stream.readBytes(new byte[length], 0, length); - } + internal static Task ReadBytesAsync(this Stream stream, int length) + => stream.ReadBytesAsync(new byte[length], 0, length); - internal static byte[] ReadBytes(this Stream stream, long length, int bufferLength) + internal static async Task ReadBytesAsync(this Stream stream, long length, int bufferLength) { using (var result = new MemoryStream()) { @@ -347,7 +377,7 @@ namespace SocketHttpListener var end = false; for (long i = 0; i < count; i++) { - if (!stream.readBytes(buffer, 0, bufferLength, result)) + if (!await stream.ReadBytesAsync(buffer, 0, bufferLength, result).ConfigureAwait(false)) { end = true; break; @@ -355,26 +385,14 @@ namespace SocketHttpListener } if (!end && rem > 0) - stream.readBytes(new byte[rem], 0, rem, result); + { + await stream.ReadBytesAsync(new byte[rem], 0, rem, result).ConfigureAwait(false); + } return result.ToArray(); } } - internal static async Task ReadBytesAsync(this Stream stream, int length) - { - var buffer = new byte[length]; - - var len = await stream.ReadAsync(buffer, 0, length).ConfigureAwait(false); - var bytes = len < 1 - ? new byte[0] - : len < length - ? stream.readBytes(buffer, len, length - len) - : buffer; - - return bytes; - } - internal static string RemovePrefix(this string value, params string[] prefixes) { var i = 0; @@ -493,19 +511,16 @@ namespace SocketHttpListener return string.Format("{0}; {1}", m, parameters.ToString("; ")); } - internal static List ToList(this IEnumerable source) - { - return new List(source); - } - internal static ushort ToUInt16(this byte[] src, ByteOrder srcOrder) { - return BitConverter.ToUInt16(src.ToHostOrder(srcOrder), 0); + src.ToHostOrder(srcOrder); + return BitConverter.ToUInt16(src, 0); } internal static ulong ToUInt64(this byte[] src, ByteOrder srcOrder) { - return BitConverter.ToUInt64(src.ToHostOrder(srcOrder), 0); + src.ToHostOrder(srcOrder); + return BitConverter.ToUInt64(src, 0); } internal static string TrimEndSlash(this string value) @@ -852,14 +867,17 @@ namespace SocketHttpListener /// /// is . /// - public static byte[] ToHostOrder(this byte[] src, ByteOrder srcOrder) + public static void ToHostOrder(this byte[] src, ByteOrder srcOrder) { if (src == null) + { throw new ArgumentNullException(nameof(src)); + } - return src.Length > 1 && !srcOrder.IsHostOrder() - ? src.Reverse() - : src; + if (src.Length > 1 && !srcOrder.IsHostOrder()) + { + Array.Reverse(src); + } } /// diff --git a/SocketHttpListener/WebSocket.cs b/SocketHttpListener/WebSocket.cs index 128bc8b971..b71dc0f28d 100644 --- a/SocketHttpListener/WebSocket.cs +++ b/SocketHttpListener/WebSocket.cs @@ -189,11 +189,11 @@ namespace SocketHttpListener _context = null; } - private bool concatenateFragmentsInto(Stream dest) + private async Task ConcatenateFragmentsIntoAsync(Stream dest) { while (true) { - var frame = WebSocketFrame.Read(_stream, true); + var frame = await WebSocketFrame.ReadAsync(_stream, true).ConfigureAwait(false); if (frame.IsFinal) { /* FINAL */ @@ -370,20 +370,22 @@ namespace SocketHttpListener close(code, reason ?? code.GetMessage(), false); } - private bool processFragmentedFrame(WebSocketFrame frame) + private Task ProcessFragmentedFrameAsync(WebSocketFrame frame) { return frame.IsContinuation // Not first fragment - ? true - : processFragments(frame); + ? Task.FromResult(true) + : ProcessFragmentsAsync(frame); } - private bool processFragments(WebSocketFrame first) + private async Task ProcessFragmentsAsync(WebSocketFrame first) { using (var buff = new MemoryStream()) { buff.WriteBytes(first.PayloadData.ApplicationData); - if (!concatenateFragmentsInto(buff)) + if (!await ConcatenateFragmentsIntoAsync(buff).ConfigureAwait(false)) + { return false; + } byte[] data; if (_compression != CompressionMethod.None) @@ -419,7 +421,7 @@ namespace SocketHttpListener return false; } - private bool processWebSocketFrame(WebSocketFrame frame) + private async Task ProcessWebSocketFrameAsync(WebSocketFrame frame) { return frame.IsCompressed && _compression == CompressionMethod.None ? processUnsupportedFrame( @@ -427,7 +429,7 @@ namespace SocketHttpListener CloseStatusCode.IncorrectData, "A compressed data has been received without available decompression method.") : frame.IsFragmented - ? processFragmentedFrame(frame) + ? await ProcessFragmentedFrameAsync(frame).ConfigureAwait(false) : frame.IsData ? processDataFrame(frame) : frame.IsPing @@ -563,44 +565,46 @@ namespace SocketHttpListener private void startReceiving() { if (_messageEventQueue.Count > 0) + { _messageEventQueue.Clear(); + } _exitReceiving = new AutoResetEvent(false); _receivePong = new AutoResetEvent(false); Action receive = null; - receive = () => WebSocketFrame.ReadAsync( - _stream, - true, - frame => - { - if (processWebSocketFrame(frame) && _readyState != WebSocketState.Closed) - { - receive(); - - if (!frame.IsData) - return; - - lock (_forEvent) - { - try - { - var e = dequeueFromMessageEventQueue(); - if (e != null && _readyState == WebSocketState.Open) - OnMessage.Emit(this, e); - } - catch (Exception ex) - { - processException(ex, "An exception has occurred while OnMessage."); - } - } - } - else if (_exitReceiving != null) - { - _exitReceiving.Set(); - } - }, - ex => processException(ex, "An exception has occurred while receiving a message.")); + receive = async () => await WebSocketFrame.ReadAsync( + _stream, + true, + async frame => + { + if (await ProcessWebSocketFrameAsync(frame).ConfigureAwait(false) && _readyState != WebSocketState.Closed) + { + receive(); + + if (!frame.IsData) + return; + + lock (_forEvent) + { + try + { + var e = dequeueFromMessageEventQueue(); + if (e != null && _readyState == WebSocketState.Open) + OnMessage.Emit(this, e); + } + catch (Exception ex) + { + processException(ex, "An exception has occurred while OnMessage."); + } + } + } + else if (_exitReceiving != null) + { + _exitReceiving.Set(); + } + }, + ex => processException(ex, "An exception has occurred while receiving a message.")); receive(); } diff --git a/SocketHttpListener/WebSocketFrame.cs b/SocketHttpListener/WebSocketFrame.cs index 74ed23c457..2e4774b3d8 100644 --- a/SocketHttpListener/WebSocketFrame.cs +++ b/SocketHttpListener/WebSocketFrame.cs @@ -2,6 +2,7 @@ using System; using System.Collections; using System.Collections.Generic; using System.IO; +using System.Threading.Tasks; namespace SocketHttpListener { @@ -177,7 +178,7 @@ namespace SocketHttpListener return opcode == Opcode.Text || opcode == Opcode.Binary; } - private static WebSocketFrame read(byte[] header, Stream stream, bool unmask) + private static async Task ReadAsync(byte[] header, Stream stream, bool unmask) { /* Header */ @@ -229,7 +230,7 @@ namespace SocketHttpListener ? 2 : 8; - var extPayloadLen = size > 0 ? stream.ReadBytes(size) : new byte[0]; + var extPayloadLen = size > 0 ? await stream.ReadBytesAsync(size).ConfigureAwait(false) : Array.Empty(); if (size > 0 && extPayloadLen.Length != size) throw new WebSocketException( "The 'Extended Payload Length' of a frame cannot be read from the data source."); @@ -239,7 +240,7 @@ namespace SocketHttpListener /* Masking Key */ var masked = mask == Mask.Mask; - var maskingKey = masked ? stream.ReadBytes(4) : new byte[0]; + var maskingKey = masked ? await stream.ReadBytesAsync(4).ConfigureAwait(false) : Array.Empty(); if (masked && maskingKey.Length != 4) throw new WebSocketException( "The 'Masking Key' of a frame cannot be read from the data source."); @@ -264,8 +265,8 @@ namespace SocketHttpListener "The length of 'Payload Data' of a frame is greater than the allowable length."); data = payloadLen > 126 - ? stream.ReadBytes((long)len, 1024) - : stream.ReadBytes((int)len); + ? await stream.ReadBytesAsync((long)len, 1024).ConfigureAwait(false) + : await stream.ReadBytesAsync((int)len).ConfigureAwait(false); //if (data.LongLength != (long)len) // throw new WebSocketException( @@ -273,7 +274,7 @@ namespace SocketHttpListener } else { - data = new byte[0]; + data = Array.Empty(); } var payload = new PayloadData(data, masked); @@ -281,7 +282,7 @@ namespace SocketHttpListener { payload.Mask(maskingKey); frame._mask = Mask.Unmask; - frame._maskingKey = new byte[0]; + frame._maskingKey = Array.Empty(); } frame._payloadData = payload; @@ -329,41 +330,39 @@ namespace SocketHttpListener return new WebSocketFrame(fin, opcode, mask, new PayloadData(data), compressed); } - internal static WebSocketFrame Read(Stream stream) - { - return Read(stream, true); - } + internal static Task ReadAsync(Stream stream) + => ReadAsync(stream, true); - internal static WebSocketFrame Read(Stream stream, bool unmask) + internal static async Task ReadAsync(Stream stream, bool unmask) { - var header = stream.ReadBytes(2); + var header = await stream.ReadBytesAsync(2).ConfigureAwait(false); if (header.Length != 2) + { throw new WebSocketException( "The header part of a frame cannot be read from the data source."); + } - return read(header, stream, unmask); + return await ReadAsync(header, stream, unmask).ConfigureAwait(false); } - internal static async void ReadAsync( + internal static async Task ReadAsync( Stream stream, bool unmask, Action completed, Action error) { try { var header = await stream.ReadBytesAsync(2).ConfigureAwait(false); if (header.Length != 2) + { throw new WebSocketException( "The header part of a frame cannot be read from the data source."); + } - var frame = read(header, stream, unmask); - if (completed != null) - completed(frame); + var frame = await ReadAsync(header, stream, unmask).ConfigureAwait(false); + completed?.Invoke(frame); } catch (Exception ex) { - if (error != null) - { - error(ex); - } + error.Invoke(ex); } }