@ -6,6 +6,7 @@ using System.Diagnostics;
using System.Diagnostics.CodeAnalysis ;
using System.Diagnostics.CodeAnalysis ;
using System.Globalization ;
using System.Globalization ;
using System.Linq ;
using System.Linq ;
using System.Text ;
using System.Threading.Tasks ;
using System.Threading.Tasks ;
using Microsoft.AspNet.SignalR.Configuration ;
using Microsoft.AspNet.SignalR.Configuration ;
using Microsoft.AspNet.SignalR.Hosting ;
using Microsoft.AspNet.SignalR.Hosting ;
@ -165,7 +166,7 @@ namespace Microsoft.AspNet.SignalR
if ( Transport = = null )
if ( Transport = = null )
{
{
throw new InvalidOperationException ( String . Format ( CultureInfo . CurrentCulture , Resources . Error_ProtocolErrorUnknownTransport ) ) ;
return FailResponse ( context . Response , String . Format ( CultureInfo . CurrentCulture , Resources . Error_ProtocolErrorUnknownTransport ) ) ;
}
}
string connectionToken = context . Request . QueryString [ "connectionToken" ] ;
string connectionToken = context . Request . QueryString [ "connectionToken" ] ;
@ -173,10 +174,17 @@ namespace Microsoft.AspNet.SignalR
// If there's no connection id then this is a bad request
// If there's no connection id then this is a bad request
if ( String . IsNullOrEmpty ( connectionToken ) )
if ( String . IsNullOrEmpty ( connectionToken ) )
{
{
throw new InvalidOperationException ( String . Format ( CultureInfo . CurrentCulture , Resources . Error_ProtocolErrorMissingConnectionToken ) ) ;
return FailResponse ( context . Response , String . Format ( CultureInfo . CurrentCulture , Resources . Error_ProtocolErrorMissingConnectionToken ) ) ;
}
}
string connectionId = GetConnectionId ( context , connectionToken ) ;
string connectionId ;
string message ;
int statusCode ;
if ( ! TryGetConnectionId ( context , connectionToken , out connectionId , out message , out statusCode ) )
{
return FailResponse ( context . Response , message , statusCode ) ;
}
// Set the transport's connection id to the unprotected one
// Set the transport's connection id to the unprotected one
Transport . ConnectionId = connectionId ;
Transport . ConnectionId = connectionId ;
@ -227,10 +235,21 @@ namespace Microsoft.AspNet.SignalR
}
}
[SuppressMessage("Microsoft.Design", "CA1031:DoNotCatchGeneralExceptionTypes", Justification = "We want to catch any exception when unprotecting data.")]
[SuppressMessage("Microsoft.Design", "CA1031:DoNotCatchGeneralExceptionTypes", Justification = "We want to catch any exception when unprotecting data.")]
internal string GetConnectionId ( HostContext context , string connectionToken )
internal bool TryGetConnectionId ( HostContext context ,
string connectionToken ,
out string connectionId ,
out string message ,
out int statusCode )
{
{
string unprotectedConnectionToken = null ;
string unprotectedConnectionToken = null ;
// connectionId is only valid when this method returns true
connectionId = null ;
// message and statusCode are only valid when this method returns false
message = null ;
statusCode = 400 ;
try
try
{
{
unprotectedConnectionToken = ProtectedData . Unprotect ( connectionToken , Purposes . ConnectionToken ) ;
unprotectedConnectionToken = ProtectedData . Unprotect ( connectionToken , Purposes . ConnectionToken ) ;
@ -242,21 +261,24 @@ namespace Microsoft.AspNet.SignalR
if ( String . IsNullOrEmpty ( unprotectedConnectionToken ) )
if ( String . IsNullOrEmpty ( unprotectedConnectionToken ) )
{
{
throw new InvalidOperationException ( String . Format ( CultureInfo . CurrentCulture , Resources . Error_ConnectionIdIncorrectFormat ) ) ;
message = String . Format ( CultureInfo . CurrentCulture , Resources . Error_ConnectionIdIncorrectFormat ) ;
return false ;
}
}
var tokens = unprotectedConnectionToken . Split ( SplitChars , 2 ) ;
var tokens = unprotectedConnectionToken . Split ( SplitChars , 2 ) ;
string connectionId = tokens [ 0 ] ;
connectionId = tokens [ 0 ] ;
string tokenUserName = tokens . Length > 1 ? tokens [ 1 ] : String . Empty ;
string tokenUserName = tokens . Length > 1 ? tokens [ 1 ] : String . Empty ;
string userName = GetUserIdentity ( context ) ;
string userName = GetUserIdentity ( context ) ;
if ( ! String . Equals ( tokenUserName , userName , StringComparison . OrdinalIgnoreCase ) )
if ( ! String . Equals ( tokenUserName , userName , StringComparison . OrdinalIgnoreCase ) )
{
{
throw new InvalidOperationException ( String . Format ( CultureInfo . CurrentCulture , Resources . Error_UnrecognizedUserIdentity ) ) ;
message = String . Format ( CultureInfo . CurrentCulture , Resources . Error_UnrecognizedUserIdentity ) ;
statusCode = 403 ;
return false ;
}
}
return connectionId ;
return true ;
}
}
[SuppressMessage("Microsoft.Design", "CA1031:DoNotCatchGeneralExceptionTypes", Justification = "We want to prevent any failures in unprotecting")]
[SuppressMessage("Microsoft.Design", "CA1031:DoNotCatchGeneralExceptionTypes", Justification = "We want to prevent any failures in unprotecting")]
@ -477,6 +499,12 @@ namespace Microsoft.AspNet.SignalR
return context . Response . End ( data ) ;
return context . Response . End ( data ) ;
}
}
private static Task FailResponse ( IResponse response , string message , int statusCode = 400 )
{
response . StatusCode = statusCode ;
return response . End ( message ) ;
}
private static bool IsNegotiationRequest ( IRequest request )
private static bool IsNegotiationRequest ( IRequest request )
{
{
return request . Url . LocalPath . EndsWith ( "/negotiate" , StringComparison . OrdinalIgnoreCase ) ;
return request . Url . LocalPath . EndsWith ( "/negotiate" , StringComparison . OrdinalIgnoreCase ) ;