Fixed: WhereBuilder exception when string variable null

pull/4103/head
ta264 5 years ago
parent df101258c5
commit 01a03e9baf

@ -66,6 +66,23 @@ namespace NzbDrone.Core.Test.Datastore
_subject.ToString().Should().Be($"(\"Movies\".\"Id\" = \"Movies\".\"Id\")"); _subject.ToString().Should().Be($"(\"Movies\".\"Id\" = \"Movies\".\"Id\")");
} }
[Test]
public void where_string_is_null()
{
_subject = Where(x => x.ImdbId == null);
_subject.ToString().Should().Be($"(\"Movies\".\"ImdbId\" IS NULL)");
}
[Test]
public void where_string_is_null_value()
{
string imdb = null;
_subject = Where(x => x.ImdbId == imdb);
_subject.ToString().Should().Be($"(\"Movies\".\"ImdbId\" IS NULL)");
}
[Test] [Test]
public void where_column_contains_string() public void where_column_contains_string()
{ {

@ -11,15 +11,13 @@ namespace NzbDrone.Core.Datastore
{ {
public class WhereBuilder : ExpressionVisitor public class WhereBuilder : ExpressionVisitor
{ {
private const DbType EnumerableMultiParameter = (DbType)(-1); protected StringBuilder _sb;
private const DbType EnumerableMultiParameter = (DbType)(-1);
private readonly string _paramNamePrefix; private readonly string _paramNamePrefix;
private readonly bool _requireConcreteValue = false;
private int _paramCount = 0; private int _paramCount = 0;
private bool _requireConcreteValue = false;
private bool _gotConcreteValue = false; private bool _gotConcreteValue = false;
protected StringBuilder _sb;
public DynamicParameters Parameters { get; private set; }
public WhereBuilder(Expression filter, bool requireConcreteValue) public WhereBuilder(Expression filter, bool requireConcreteValue)
{ {
@ -35,6 +33,8 @@ namespace NzbDrone.Core.Datastore
} }
} }
public DynamicParameters Parameters { get; private set; }
private string AddParameter(object value, DbType? dbType = null) private string AddParameter(object value, DbType? dbType = null)
{ {
_gotConcreteValue = true; _gotConcreteValue = true;
@ -61,7 +61,7 @@ namespace NzbDrone.Core.Datastore
protected override Expression VisitMethodCall(MethodCallExpression expression) protected override Expression VisitMethodCall(MethodCallExpression expression)
{ {
string method = expression.Method.Name; var method = expression.Method.Name;
switch (expression.Method.Name) switch (expression.Method.Name)
{ {
@ -78,7 +78,7 @@ namespace NzbDrone.Core.Datastore
break; break;
default: default:
string msg = string.Format("'{0}' expressions are not yet implemented in the where clause expression tree parser.", method); var msg = string.Format("'{0}' expressions are not yet implemented in the where clause expression tree parser.", method);
throw new NotImplementedException(msg); throw new NotImplementedException(msg);
} }
@ -87,7 +87,7 @@ namespace NzbDrone.Core.Datastore
protected override Expression VisitMemberAccess(MemberExpression expression) protected override Expression VisitMemberAccess(MemberExpression expression)
{ {
string tableName = TableMapping.Mapper.TableNameMapping(expression.Expression.Type); var tableName = expression != null ? TableMapping.Mapper.TableNameMapping(expression.Expression.Type) : null;
if (tableName != null) if (tableName != null)
{ {
@ -95,27 +95,26 @@ namespace NzbDrone.Core.Datastore
} }
else else
{ {
object value = GetRightValue(expression); var value = GetRightValue(expression);
// string is IEnumerable<Char> but we don't want to pick up that case if (value != null)
var type = value.GetType();
var typeInfo = type.GetTypeInfo();
bool isEnumerable =
type != typeof(string) && (
typeInfo.ImplementedInterfaces.Any(ti => ti.IsGenericType && ti.GetGenericTypeDefinition() == typeof(IEnumerable<>)) ||
(typeInfo.IsGenericType && typeInfo.GetGenericTypeDefinition() == typeof(IEnumerable<>)));
string paramName;
if (isEnumerable)
{ {
paramName = AddParameter(value, EnumerableMultiParameter); // string is IEnumerable<Char> but we don't want to pick up that case
var type = value.GetType();
var typeInfo = type.GetTypeInfo();
var isEnumerable =
type != typeof(string) && (
typeInfo.ImplementedInterfaces.Any(ti => ti.IsGenericType && ti.GetGenericTypeDefinition() == typeof(IEnumerable<>)) ||
(typeInfo.IsGenericType && typeInfo.GetGenericTypeDefinition() == typeof(IEnumerable<>)));
var paramName = isEnumerable ? AddParameter(value, EnumerableMultiParameter) : AddParameter(value);
_sb.Append(paramName);
} }
else else
{ {
paramName = AddParameter(value); _gotConcreteValue = true;
_sb.Append("NULL");
} }
_sb.Append(paramName);
} }
return expression; return expression;
@ -125,51 +124,78 @@ namespace NzbDrone.Core.Datastore
{ {
if (expression.Value != null) if (expression.Value != null)
{ {
string paramName = AddParameter(expression.Value); var paramName = AddParameter(expression.Value);
_sb.Append(paramName); _sb.Append(paramName);
} }
else else
{ {
_gotConcreteValue = true;
_sb.Append("NULL"); _sb.Append("NULL");
} }
return expression; return expression;
} }
private object GetRightValue(Expression rightExpression) private bool TryGetConstantValue(Expression expression, out object result)
{ {
object rightValue = null; if (expression is ConstantExpression constExp)
{
result = constExp.Value;
return true;
}
var right = rightExpression as ConstantExpression; result = null;
return false;
}
// Value is not directly passed in as a constant private bool TryGetPropertyValue(MemberExpression expression, out object result)
if (right == null) {
if (expression.Expression is MemberExpression nested)
{ {
var rightMemberExp = rightExpression as MemberExpression;
var parentMemberExpression = rightMemberExp.Expression as MemberExpression;
// Value is passed in as a property on a parent entity // Value is passed in as a property on a parent entity
if (parentMemberExpression != null) var container = (nested.Expression as ConstantExpression).Value;
{ var entity = GetFieldValue(container, nested.Member);
var memberInfo = (rightMemberExp.Expression as MemberExpression).Member; result = GetFieldValue(entity, expression.Member);
var container = ((rightMemberExp.Expression as MemberExpression).Expression as ConstantExpression).Value; return true;
var entity = GetFieldValue(container, memberInfo);
rightValue = GetFieldValue(entity, rightMemberExp.Member);
}
else
{
// Value is passed in as a variable
var parent = (rightMemberExp.Expression as ConstantExpression).Value;
rightValue = GetFieldValue(parent, rightMemberExp.Member);
}
} }
else
result = null;
return false;
}
private bool TryGetVariableValue(MemberExpression expression, out object result)
{
// Value is passed in as a variable
if (expression.Expression is ConstantExpression nested)
{
result = GetFieldValue(nested.Value, expression.Member);
return true;
}
result = null;
return false;
}
private object GetRightValue(Expression expression)
{
if (TryGetConstantValue(expression, out var constValue))
{
return constValue;
}
var memberExp = expression as MemberExpression;
if (TryGetPropertyValue(memberExp, out var propValue))
{
return propValue;
}
if (TryGetVariableValue(memberExp, out var variableValue))
{ {
// Value is passed in directly as a constant return variableValue;
rightValue = right.Value;
} }
return rightValue; return null;
} }
private object GetFieldValue(object entity, MemberInfo member) private object GetFieldValue(object entity, MemberInfo member)
@ -187,13 +213,29 @@ namespace NzbDrone.Core.Datastore
throw new ArgumentException(string.Format("WhereBuilder could not get the value for {0}.{1}.", entity.GetType().Name, member.Name)); throw new ArgumentException(string.Format("WhereBuilder could not get the value for {0}.{1}.", entity.GetType().Name, member.Name));
} }
private string Decode(BinaryExpression expression) private bool IsNullVariable(Expression expression)
{ {
bool isRightSideNullConstant = expression.Right.NodeType == if (expression.NodeType == ExpressionType.Constant &&
ExpressionType.Constant && TryGetConstantValue(expression, out var constResult) &&
((ConstantExpression)expression.Right).Value == null; constResult == null)
{
return true;
}
if (expression.NodeType == ExpressionType.MemberAccess &&
expression is MemberExpression member &&
TryGetVariableValue(member, out var variableResult) &&
variableResult == null)
{
return true;
}
return false;
}
if (isRightSideNullConstant) private string Decode(BinaryExpression expression)
{
if (IsNullVariable(expression.Right))
{ {
switch (expression.NodeType) switch (expression.NodeType)
{ {

@ -120,7 +120,7 @@ namespace NzbDrone.Core.Movies
public Movie FindByImdbId(string imdbid) public Movie FindByImdbId(string imdbid)
{ {
var imdbIdWithPrefix = Parser.Parser.NormalizeImdbId(imdbid); var imdbIdWithPrefix = Parser.Parser.NormalizeImdbId(imdbid);
return Query(x => x.ImdbId == imdbIdWithPrefix).FirstOrDefault(); return imdbIdWithPrefix == null ? null : Query(x => x.ImdbId == imdbIdWithPrefix).FirstOrDefault();
} }
public Movie FindByTmdbId(int tmdbid) public Movie FindByTmdbId(int tmdbid)

Loading…
Cancel
Save