From 68e41f0860845ea01073624dc6a9b2ebc4f798a0 Mon Sep 17 00:00:00 2001 From: Qstick Date: Sun, 27 Feb 2022 16:14:41 -0600 Subject: [PATCH] Fixed: WhereBuilder for Postgres --- .../Datastore/BasicRepository.cs | 2 +- .../Datastore/Extensions/BuilderExtensions.cs | 20 +- src/NzbDrone.Core/Datastore/SqlBuilder.cs | 8 + src/NzbDrone.Core/Datastore/TableMapper.cs | 2 +- src/NzbDrone.Core/Datastore/WhereBuilder.cs | 384 +---------------- .../Datastore/WhereBuilderPostgres.cs | 374 +++++++++++++++++ .../Datastore/WhereBuilderSqlite.cs | 387 ++++++++++++++++++ .../History/HistoryRepository.cs | 2 +- 8 files changed, 790 insertions(+), 389 deletions(-) create mode 100644 src/NzbDrone.Core/Datastore/WhereBuilderPostgres.cs create mode 100644 src/NzbDrone.Core/Datastore/WhereBuilderSqlite.cs diff --git a/src/NzbDrone.Core/Datastore/BasicRepository.cs b/src/NzbDrone.Core/Datastore/BasicRepository.cs index c03ecae4b..33253efa9 100644 --- a/src/NzbDrone.Core/Datastore/BasicRepository.cs +++ b/src/NzbDrone.Core/Datastore/BasicRepository.cs @@ -66,7 +66,7 @@ namespace NzbDrone.Core.Datastore _updateSql = GetUpdateSql(_properties); } - protected virtual SqlBuilder Builder() => new SqlBuilder(); + protected virtual SqlBuilder Builder() => new SqlBuilder(_database.DatabaseType); protected virtual List Query(SqlBuilder builder) => _database.Query(builder).ToList(); diff --git a/src/NzbDrone.Core/Datastore/Extensions/BuilderExtensions.cs b/src/NzbDrone.Core/Datastore/Extensions/BuilderExtensions.cs index dc3e442f2..d9c68da05 100644 --- a/src/NzbDrone.Core/Datastore/Extensions/BuilderExtensions.cs +++ b/src/NzbDrone.Core/Datastore/Extensions/BuilderExtensions.cs @@ -42,21 +42,21 @@ namespace NzbDrone.Core.Datastore public static SqlBuilder Where(this SqlBuilder builder, Expression> filter) { - var wb = new WhereBuilder(filter, true, builder.Sequence); + var wb = GetWhereBuilder(builder.DatabaseType, filter, true, builder.Sequence); return builder.Where(wb.ToString(), wb.Parameters); } public static SqlBuilder OrWhere(this SqlBuilder builder, Expression> filter) { - var wb = new WhereBuilder(filter, true, builder.Sequence); + var wb = GetWhereBuilder(builder.DatabaseType, filter, true, builder.Sequence); return builder.OrWhere(wb.ToString(), wb.Parameters); } public static SqlBuilder Join(this SqlBuilder builder, Expression> filter) { - var wb = new WhereBuilder(filter, false, builder.Sequence); + var wb = GetWhereBuilder(builder.DatabaseType, filter, false, builder.Sequence); var rightTable = TableMapping.Mapper.TableNameMapping(typeof(TRight)); @@ -65,7 +65,7 @@ namespace NzbDrone.Core.Datastore public static SqlBuilder LeftJoin(this SqlBuilder builder, Expression> filter) { - var wb = new WhereBuilder(filter, false, builder.Sequence); + var wb = GetWhereBuilder(builder.DatabaseType, filter, false, builder.Sequence); var rightTable = TableMapping.Mapper.TableNameMapping(typeof(TRight)); @@ -138,6 +138,18 @@ namespace NzbDrone.Core.Datastore return sb.ToString(); } + private static WhereBuilder GetWhereBuilder(DatabaseType databaseType, Expression filter, bool requireConcrete, int seq) + { + if (databaseType == DatabaseType.PostgreSQL) + { + return new WhereBuilderPostgres(filter, requireConcrete, seq); + } + else + { + return new WhereBuilderSqlite(filter, requireConcrete, seq); + } + } + private static Dictionary ToDictionary(this DynamicParameters dynamicParams) { var argsDictionary = new Dictionary(); diff --git a/src/NzbDrone.Core/Datastore/SqlBuilder.cs b/src/NzbDrone.Core/Datastore/SqlBuilder.cs index e686d4852..8febd5b6a 100644 --- a/src/NzbDrone.Core/Datastore/SqlBuilder.cs +++ b/src/NzbDrone.Core/Datastore/SqlBuilder.cs @@ -8,9 +8,17 @@ namespace NzbDrone.Core.Datastore public class SqlBuilder { private readonly Dictionary _data = new Dictionary(); + private readonly DatabaseType _databaseType; + + public SqlBuilder(DatabaseType databaseType) + { + _databaseType = databaseType; + } public int Sequence { get; private set; } + public DatabaseType DatabaseType => _databaseType; + public Template AddTemplate(string sql, dynamic parameters = null) => new Template(this, sql, parameters); diff --git a/src/NzbDrone.Core/Datastore/TableMapper.cs b/src/NzbDrone.Core/Datastore/TableMapper.cs index bbb84b9cb..d53e782f2 100644 --- a/src/NzbDrone.Core/Datastore/TableMapper.cs +++ b/src/NzbDrone.Core/Datastore/TableMapper.cs @@ -183,7 +183,7 @@ namespace NzbDrone.Core.Datastore (db, parent) => { var id = childIdSelector(parent); - return db.Query(new SqlBuilder().Where(x => x.Id == id)).SingleOrDefault(); + return db.Query(new SqlBuilder(db.DatabaseType).Where(x => x.Id == id)).SingleOrDefault(); }, parent => childIdSelector(parent) > 0); } diff --git a/src/NzbDrone.Core/Datastore/WhereBuilder.cs b/src/NzbDrone.Core/Datastore/WhereBuilder.cs index da94e20e1..969364e4f 100644 --- a/src/NzbDrone.Core/Datastore/WhereBuilder.cs +++ b/src/NzbDrone.Core/Datastore/WhereBuilder.cs @@ -1,389 +1,9 @@ -using System; -using System.Collections.Generic; -using System.Data; -using System.Linq; -using System.Linq.Expressions; -using System.Reflection; -using System.Text; using Dapper; namespace NzbDrone.Core.Datastore { - public class WhereBuilder : ExpressionVisitor + public abstract class WhereBuilder : ExpressionVisitor { - protected StringBuilder _sb; - - private const DbType EnumerableMultiParameter = (DbType)(-1); - private readonly string _paramNamePrefix; - private readonly bool _requireConcreteValue = false; - private int _paramCount = 0; - private bool _gotConcreteValue = false; - - public WhereBuilder(Expression filter, bool requireConcreteValue, int seq) - { - _paramNamePrefix = string.Format("Clause{0}", seq + 1); - _requireConcreteValue = requireConcreteValue; - _sb = new StringBuilder(); - - Parameters = new DynamicParameters(); - - if (filter != null) - { - Visit(filter); - } - } - - public DynamicParameters Parameters { get; private set; } - - private string AddParameter(object value, DbType? dbType = null) - { - _gotConcreteValue = true; - _paramCount++; - var name = _paramNamePrefix + "_P" + _paramCount; - Parameters.Add(name, value, dbType); - return '@' + name; - } - - protected override Expression VisitBinary(BinaryExpression expression) - { - _sb.Append('('); - - Visit(expression.Left); - - _sb.AppendFormat(" {0} ", Decode(expression)); - - Visit(expression.Right); - - _sb.Append(')'); - - return expression; - } - - protected override Expression VisitMethodCall(MethodCallExpression expression) - { - var method = expression.Method.Name; - - switch (expression.Method.Name) - { - case "Contains": - ParseContainsExpression(expression); - break; - - case "StartsWith": - ParseStartsWith(expression); - break; - - case "EndsWith": - ParseEndsWith(expression); - break; - - default: - var msg = string.Format("'{0}' expressions are not yet implemented in the where clause expression tree parser.", method); - throw new NotImplementedException(msg); - } - - return expression; - } - - protected override Expression VisitMemberAccess(MemberExpression expression) - { - var tableName = expression?.Expression?.Type != null ? TableMapping.Mapper.TableNameMapping(expression.Expression.Type) : null; - var gotValue = TryGetRightValue(expression, out var value); - - // Only use the SQL condition if the expression didn't resolve to an actual value - if (tableName != null && !gotValue) - { - _sb.Append($"\"{tableName}\".\"{expression.Member.Name}\""); - } - else - { - if (value != null) - { - // string is IEnumerable but we don't want to pick up that case - var type = value.GetType(); - var typeInfo = type.GetTypeInfo(); - var isEnumerable = - type != typeof(string) && ( - typeInfo.ImplementedInterfaces.Any(ti => ti.IsGenericType && ti.GetGenericTypeDefinition() == typeof(IEnumerable<>)) || - (typeInfo.IsGenericType && typeInfo.GetGenericTypeDefinition() == typeof(IEnumerable<>))); - - var paramName = isEnumerable ? AddParameter(value, EnumerableMultiParameter) : AddParameter(value); - _sb.Append(paramName); - } - else - { - _gotConcreteValue = true; - _sb.Append("NULL"); - } - } - - return expression; - } - - protected override Expression VisitConstant(ConstantExpression expression) - { - if (expression.Value != null) - { - var paramName = AddParameter(expression.Value); - _sb.Append(paramName); - } - else - { - _gotConcreteValue = true; - _sb.Append("NULL"); - } - - return expression; - } - - private bool TryGetConstantValue(Expression expression, out object result) - { - result = null; - - if (expression is ConstantExpression constExp) - { - result = constExp.Value; - return true; - } - - return false; - } - - private bool TryGetPropertyValue(MemberExpression expression, out object result) - { - result = null; - - if (expression.Expression is MemberExpression nested) - { - // Value is passed in as a property on a parent entity - var container = (nested.Expression as ConstantExpression)?.Value; - - if (container == null) - { - return false; - } - - var entity = GetFieldValue(container, nested.Member); - result = GetFieldValue(entity, expression.Member); - return true; - } - - return false; - } - - private bool TryGetVariableValue(MemberExpression expression, out object result) - { - result = null; - - // Value is passed in as a variable - if (expression.Expression is ConstantExpression nested) - { - result = GetFieldValue(nested.Value, expression.Member); - return true; - } - - return false; - } - - private bool TryGetRightValue(Expression expression, out object value) - { - value = null; - - if (TryGetConstantValue(expression, out value)) - { - return true; - } - - var memberExp = expression as MemberExpression; - - if (TryGetPropertyValue(memberExp, out value)) - { - return true; - } - - if (TryGetVariableValue(memberExp, out value)) - { - return true; - } - - return false; - } - - private object GetFieldValue(object entity, MemberInfo member) - { - if (member.MemberType == MemberTypes.Field) - { - return (member as FieldInfo).GetValue(entity); - } - - if (member.MemberType == MemberTypes.Property) - { - return (member as PropertyInfo).GetValue(entity); - } - - throw new ArgumentException(string.Format("WhereBuilder could not get the value for {0}.{1}.", entity.GetType().Name, member.Name)); - } - - private bool IsNullVariable(Expression expression) - { - if (expression.NodeType == ExpressionType.Constant && - TryGetConstantValue(expression, out var constResult) && - constResult == null) - { - return true; - } - - if (expression.NodeType == ExpressionType.MemberAccess && - expression is MemberExpression member && - ((TryGetPropertyValue(member, out var result) && result == null) || - (TryGetVariableValue(member, out result) && result == null))) - { - return true; - } - - return false; - } - - private string Decode(BinaryExpression expression) - { - if (IsNullVariable(expression.Right)) - { - switch (expression.NodeType) - { - case ExpressionType.Equal: return "IS"; - case ExpressionType.NotEqual: return "IS NOT"; - } - } - - switch (expression.NodeType) - { - case ExpressionType.AndAlso: return "AND"; - case ExpressionType.And: return "AND"; - case ExpressionType.Equal: return "="; - case ExpressionType.GreaterThan: return ">"; - case ExpressionType.GreaterThanOrEqual: return ">="; - case ExpressionType.LessThan: return "<"; - case ExpressionType.LessThanOrEqual: return "<="; - case ExpressionType.NotEqual: return "<>"; - case ExpressionType.OrElse: return "OR"; - case ExpressionType.Or: return "OR"; - default: throw new NotSupportedException(string.Format("{0} statement is not supported", expression.NodeType.ToString())); - } - } - - private void ParseContainsExpression(MethodCallExpression expression) - { - var list = expression.Object; - - if (list != null && (list.Type == typeof(string))) - { - ParseStringContains(expression); - return; - } - - ParseEnumerableContains(expression); - } - - private void ParseEnumerableContains(MethodCallExpression body) - { - // Fish out the list and the item to compare - // It's in a different form for arrays and Lists - var list = body.Object; - Expression item; - - if (list != null) - { - // Generic collection - item = body.Arguments[0]; - } - else - { - // Static method - // Must be Enumerable.Contains(source, item) - if (body.Method.DeclaringType != typeof(Enumerable) || body.Arguments.Count != 2) - { - throw new NotSupportedException("Unexpected form of Enumerable.Contains"); - } - - list = body.Arguments[0]; - item = body.Arguments[1]; - } - - _sb.Append('('); - - Visit(item); - - _sb.Append(" IN "); - - // hardcode the integer list if it exists to bypass parameter limit - if (item.Type == typeof(int) && TryGetRightValue(list, out var value)) - { - var items = (IEnumerable)value; - _sb.Append('('); - _sb.Append(string.Join(", ", items)); - _sb.Append(')'); - - _gotConcreteValue = true; - } - else - { - Visit(list); - } - - _sb.Append(')'); - } - - private void ParseStringContains(MethodCallExpression body) - { - _sb.Append('('); - - Visit(body.Object); - - _sb.Append(" LIKE '%' || "); - - Visit(body.Arguments[0]); - - _sb.Append(" || '%')"); - } - - private void ParseStartsWith(MethodCallExpression body) - { - _sb.Append('('); - - Visit(body.Object); - - _sb.Append(" LIKE "); - - Visit(body.Arguments[0]); - - _sb.Append(" || '%')"); - } - - private void ParseEndsWith(MethodCallExpression body) - { - _sb.Append('('); - - Visit(body.Object); - - _sb.Append(" LIKE '%' || "); - - Visit(body.Arguments[0]); - - _sb.Append(')'); - } - - public override string ToString() - { - var sql = _sb.ToString(); - - if (_requireConcreteValue && !_gotConcreteValue) - { - var e = new InvalidOperationException("WhereBuilder requires a concrete condition"); - e.Data.Add("sql", sql); - throw e; - } - - return sql; - } + public DynamicParameters Parameters { get; protected set; } } } diff --git a/src/NzbDrone.Core/Datastore/WhereBuilderPostgres.cs b/src/NzbDrone.Core/Datastore/WhereBuilderPostgres.cs new file mode 100644 index 000000000..bc4c39490 --- /dev/null +++ b/src/NzbDrone.Core/Datastore/WhereBuilderPostgres.cs @@ -0,0 +1,374 @@ +using System; +using System.Collections.Generic; +using System.Data; +using System.Linq; +using System.Linq.Expressions; +using System.Reflection; +using System.Text; +using Dapper; + +namespace NzbDrone.Core.Datastore +{ + public class WhereBuilderPostgres : WhereBuilder + { + protected StringBuilder _sb; + + private const DbType EnumerableMultiParameter = (DbType)(-1); + private readonly string _paramNamePrefix; + private readonly bool _requireConcreteValue = false; + private int _paramCount = 0; + private bool _gotConcreteValue = false; + + public WhereBuilderPostgres(Expression filter, bool requireConcreteValue, int seq) + { + _paramNamePrefix = string.Format("Clause{0}", seq + 1); + _requireConcreteValue = requireConcreteValue; + _sb = new StringBuilder(); + + Parameters = new DynamicParameters(); + + if (filter != null) + { + Visit(filter); + } + } + + private string AddParameter(object value, DbType? dbType = null) + { + _gotConcreteValue = true; + _paramCount++; + var name = _paramNamePrefix + "_P" + _paramCount; + Parameters.Add(name, value, dbType); + return '@' + name; + } + + protected override Expression VisitBinary(BinaryExpression expression) + { + _sb.Append('('); + + Visit(expression.Left); + + _sb.AppendFormat(" {0} ", Decode(expression)); + + Visit(expression.Right); + + _sb.Append(')'); + + return expression; + } + + protected override Expression VisitMethodCall(MethodCallExpression expression) + { + var method = expression.Method.Name; + + switch (expression.Method.Name) + { + case "Contains": + ParseContainsExpression(expression); + break; + + case "StartsWith": + ParseStartsWith(expression); + break; + + case "EndsWith": + ParseEndsWith(expression); + break; + + default: + var msg = string.Format("'{0}' expressions are not yet implemented in the where clause expression tree parser.", method); + throw new NotImplementedException(msg); + } + + return expression; + } + + protected override Expression VisitMemberAccess(MemberExpression expression) + { + var tableName = expression?.Expression?.Type != null ? TableMapping.Mapper.TableNameMapping(expression.Expression.Type) : null; + var gotValue = TryGetRightValue(expression, out var value); + + // Only use the SQL condition if the expression didn't resolve to an actual value + if (tableName != null && !gotValue) + { + _sb.Append($"\"{tableName}\".\"{expression.Member.Name}\""); + } + else + { + if (value != null) + { + // string is IEnumerable but we don't want to pick up that case + var type = value.GetType(); + var typeInfo = type.GetTypeInfo(); + var isEnumerable = + type != typeof(string) && ( + typeInfo.ImplementedInterfaces.Any(ti => ti.IsGenericType && ti.GetGenericTypeDefinition() == typeof(IEnumerable<>)) || + (typeInfo.IsGenericType && typeInfo.GetGenericTypeDefinition() == typeof(IEnumerable<>))); + + var paramName = isEnumerable ? AddParameter(value, EnumerableMultiParameter) : AddParameter(value); + _sb.Append(paramName); + } + else + { + _gotConcreteValue = true; + _sb.Append("NULL"); + } + } + + return expression; + } + + protected override Expression VisitConstant(ConstantExpression expression) + { + if (expression.Value != null) + { + var paramName = AddParameter(expression.Value); + _sb.Append(paramName); + } + else + { + _gotConcreteValue = true; + _sb.Append("NULL"); + } + + return expression; + } + + private bool TryGetConstantValue(Expression expression, out object result) + { + result = null; + + if (expression is ConstantExpression constExp) + { + result = constExp.Value; + return true; + } + + return false; + } + + private bool TryGetPropertyValue(MemberExpression expression, out object result) + { + result = null; + + if (expression.Expression is MemberExpression nested) + { + // Value is passed in as a property on a parent entity + var container = (nested.Expression as ConstantExpression)?.Value; + + if (container == null) + { + return false; + } + + var entity = GetFieldValue(container, nested.Member); + result = GetFieldValue(entity, expression.Member); + return true; + } + + return false; + } + + private bool TryGetVariableValue(MemberExpression expression, out object result) + { + result = null; + + // Value is passed in as a variable + if (expression.Expression is ConstantExpression nested) + { + result = GetFieldValue(nested.Value, expression.Member); + return true; + } + + return false; + } + + private bool TryGetRightValue(Expression expression, out object value) + { + value = null; + + if (TryGetConstantValue(expression, out value)) + { + return true; + } + + var memberExp = expression as MemberExpression; + + if (TryGetPropertyValue(memberExp, out value)) + { + return true; + } + + if (TryGetVariableValue(memberExp, out value)) + { + return true; + } + + return false; + } + + private object GetFieldValue(object entity, MemberInfo member) + { + if (member.MemberType == MemberTypes.Field) + { + return (member as FieldInfo).GetValue(entity); + } + + if (member.MemberType == MemberTypes.Property) + { + return (member as PropertyInfo).GetValue(entity); + } + + throw new ArgumentException(string.Format("WhereBuilder could not get the value for {0}.{1}.", entity.GetType().Name, member.Name)); + } + + private bool IsNullVariable(Expression expression) + { + if (expression.NodeType == ExpressionType.Constant && + TryGetConstantValue(expression, out var constResult) && + constResult == null) + { + return true; + } + + if (expression.NodeType == ExpressionType.MemberAccess && + expression is MemberExpression member && + ((TryGetPropertyValue(member, out var result) && result == null) || + (TryGetVariableValue(member, out result) && result == null))) + { + return true; + } + + return false; + } + + private string Decode(BinaryExpression expression) + { + if (IsNullVariable(expression.Right)) + { + switch (expression.NodeType) + { + case ExpressionType.Equal: return "IS"; + case ExpressionType.NotEqual: return "IS NOT"; + } + } + + switch (expression.NodeType) + { + case ExpressionType.AndAlso: return "AND"; + case ExpressionType.And: return "AND"; + case ExpressionType.Equal: return "="; + case ExpressionType.GreaterThan: return ">"; + case ExpressionType.GreaterThanOrEqual: return ">="; + case ExpressionType.LessThan: return "<"; + case ExpressionType.LessThanOrEqual: return "<="; + case ExpressionType.NotEqual: return "<>"; + case ExpressionType.OrElse: return "OR"; + case ExpressionType.Or: return "OR"; + default: throw new NotSupportedException(string.Format("{0} statement is not supported", expression.NodeType.ToString())); + } + } + + private void ParseContainsExpression(MethodCallExpression expression) + { + var list = expression.Object; + + if (list != null && (list.Type == typeof(string))) + { + ParseStringContains(expression); + return; + } + + ParseEnumerableContains(expression); + } + + private void ParseEnumerableContains(MethodCallExpression body) + { + // Fish out the list and the item to compare + // It's in a different form for arrays and Lists + var list = body.Object; + Expression item; + + if (list != null) + { + // Generic collection + item = body.Arguments[0]; + } + else + { + // Static method + // Must be Enumerable.Contains(source, item) + if (body.Method.DeclaringType != typeof(Enumerable) || body.Arguments.Count != 2) + { + throw new NotSupportedException("Unexpected form of Enumerable.Contains"); + } + + list = body.Arguments[0]; + item = body.Arguments[1]; + } + + _sb.Append('('); + + Visit(item); + + _sb.Append(" = ANY ("); + + Visit(list); + + _sb.Append("))"); + } + + private void ParseStringContains(MethodCallExpression body) + { + _sb.Append('('); + + Visit(body.Object); + + _sb.Append(" LIKE '%' || "); + + Visit(body.Arguments[0]); + + _sb.Append(" || '%')"); + } + + private void ParseStartsWith(MethodCallExpression body) + { + _sb.Append('('); + + Visit(body.Object); + + _sb.Append(" LIKE "); + + Visit(body.Arguments[0]); + + _sb.Append(" || '%')"); + } + + private void ParseEndsWith(MethodCallExpression body) + { + _sb.Append('('); + + Visit(body.Object); + + _sb.Append(" LIKE '%' || "); + + Visit(body.Arguments[0]); + + _sb.Append(')'); + } + + public override string ToString() + { + var sql = _sb.ToString(); + + if (_requireConcreteValue && !_gotConcreteValue) + { + var e = new InvalidOperationException("WhereBuilder requires a concrete condition"); + e.Data.Add("sql", sql); + throw e; + } + + return sql; + } + } +} diff --git a/src/NzbDrone.Core/Datastore/WhereBuilderSqlite.cs b/src/NzbDrone.Core/Datastore/WhereBuilderSqlite.cs new file mode 100644 index 000000000..54b24d7e1 --- /dev/null +++ b/src/NzbDrone.Core/Datastore/WhereBuilderSqlite.cs @@ -0,0 +1,387 @@ +using System; +using System.Collections.Generic; +using System.Data; +using System.Linq; +using System.Linq.Expressions; +using System.Reflection; +using System.Text; +using Dapper; + +namespace NzbDrone.Core.Datastore +{ + public class WhereBuilderSqlite : WhereBuilder + { + protected StringBuilder _sb; + + private const DbType EnumerableMultiParameter = (DbType)(-1); + private readonly string _paramNamePrefix; + private readonly bool _requireConcreteValue = false; + private int _paramCount = 0; + private bool _gotConcreteValue = false; + + public WhereBuilderSqlite(Expression filter, bool requireConcreteValue, int seq) + { + _paramNamePrefix = string.Format("Clause{0}", seq + 1); + _requireConcreteValue = requireConcreteValue; + _sb = new StringBuilder(); + + Parameters = new DynamicParameters(); + + if (filter != null) + { + Visit(filter); + } + } + + private string AddParameter(object value, DbType? dbType = null) + { + _gotConcreteValue = true; + _paramCount++; + var name = _paramNamePrefix + "_P" + _paramCount; + Parameters.Add(name, value, dbType); + return '@' + name; + } + + protected override Expression VisitBinary(BinaryExpression expression) + { + _sb.Append('('); + + Visit(expression.Left); + + _sb.AppendFormat(" {0} ", Decode(expression)); + + Visit(expression.Right); + + _sb.Append(')'); + + return expression; + } + + protected override Expression VisitMethodCall(MethodCallExpression expression) + { + var method = expression.Method.Name; + + switch (expression.Method.Name) + { + case "Contains": + ParseContainsExpression(expression); + break; + + case "StartsWith": + ParseStartsWith(expression); + break; + + case "EndsWith": + ParseEndsWith(expression); + break; + + default: + var msg = string.Format("'{0}' expressions are not yet implemented in the where clause expression tree parser.", method); + throw new NotImplementedException(msg); + } + + return expression; + } + + protected override Expression VisitMemberAccess(MemberExpression expression) + { + var tableName = expression?.Expression?.Type != null ? TableMapping.Mapper.TableNameMapping(expression.Expression.Type) : null; + var gotValue = TryGetRightValue(expression, out var value); + + // Only use the SQL condition if the expression didn't resolve to an actual value + if (tableName != null && !gotValue) + { + _sb.Append($"\"{tableName}\".\"{expression.Member.Name}\""); + } + else + { + if (value != null) + { + // string is IEnumerable but we don't want to pick up that case + var type = value.GetType(); + var typeInfo = type.GetTypeInfo(); + var isEnumerable = + type != typeof(string) && ( + typeInfo.ImplementedInterfaces.Any(ti => ti.IsGenericType && ti.GetGenericTypeDefinition() == typeof(IEnumerable<>)) || + (typeInfo.IsGenericType && typeInfo.GetGenericTypeDefinition() == typeof(IEnumerable<>))); + + var paramName = isEnumerable ? AddParameter(value, EnumerableMultiParameter) : AddParameter(value); + _sb.Append(paramName); + } + else + { + _gotConcreteValue = true; + _sb.Append("NULL"); + } + } + + return expression; + } + + protected override Expression VisitConstant(ConstantExpression expression) + { + if (expression.Value != null) + { + var paramName = AddParameter(expression.Value); + _sb.Append(paramName); + } + else + { + _gotConcreteValue = true; + _sb.Append("NULL"); + } + + return expression; + } + + private bool TryGetConstantValue(Expression expression, out object result) + { + result = null; + + if (expression is ConstantExpression constExp) + { + result = constExp.Value; + return true; + } + + return false; + } + + private bool TryGetPropertyValue(MemberExpression expression, out object result) + { + result = null; + + if (expression.Expression is MemberExpression nested) + { + // Value is passed in as a property on a parent entity + var container = (nested.Expression as ConstantExpression)?.Value; + + if (container == null) + { + return false; + } + + var entity = GetFieldValue(container, nested.Member); + result = GetFieldValue(entity, expression.Member); + return true; + } + + return false; + } + + private bool TryGetVariableValue(MemberExpression expression, out object result) + { + result = null; + + // Value is passed in as a variable + if (expression.Expression is ConstantExpression nested) + { + result = GetFieldValue(nested.Value, expression.Member); + return true; + } + + return false; + } + + private bool TryGetRightValue(Expression expression, out object value) + { + value = null; + + if (TryGetConstantValue(expression, out value)) + { + return true; + } + + var memberExp = expression as MemberExpression; + + if (TryGetPropertyValue(memberExp, out value)) + { + return true; + } + + if (TryGetVariableValue(memberExp, out value)) + { + return true; + } + + return false; + } + + private object GetFieldValue(object entity, MemberInfo member) + { + if (member.MemberType == MemberTypes.Field) + { + return (member as FieldInfo).GetValue(entity); + } + + if (member.MemberType == MemberTypes.Property) + { + return (member as PropertyInfo).GetValue(entity); + } + + throw new ArgumentException(string.Format("WhereBuilder could not get the value for {0}.{1}.", entity.GetType().Name, member.Name)); + } + + private bool IsNullVariable(Expression expression) + { + if (expression.NodeType == ExpressionType.Constant && + TryGetConstantValue(expression, out var constResult) && + constResult == null) + { + return true; + } + + if (expression.NodeType == ExpressionType.MemberAccess && + expression is MemberExpression member && + ((TryGetPropertyValue(member, out var result) && result == null) || + (TryGetVariableValue(member, out result) && result == null))) + { + return true; + } + + return false; + } + + private string Decode(BinaryExpression expression) + { + if (IsNullVariable(expression.Right)) + { + switch (expression.NodeType) + { + case ExpressionType.Equal: return "IS"; + case ExpressionType.NotEqual: return "IS NOT"; + } + } + + switch (expression.NodeType) + { + case ExpressionType.AndAlso: return "AND"; + case ExpressionType.And: return "AND"; + case ExpressionType.Equal: return "="; + case ExpressionType.GreaterThan: return ">"; + case ExpressionType.GreaterThanOrEqual: return ">="; + case ExpressionType.LessThan: return "<"; + case ExpressionType.LessThanOrEqual: return "<="; + case ExpressionType.NotEqual: return "<>"; + case ExpressionType.OrElse: return "OR"; + case ExpressionType.Or: return "OR"; + default: throw new NotSupportedException(string.Format("{0} statement is not supported", expression.NodeType.ToString())); + } + } + + private void ParseContainsExpression(MethodCallExpression expression) + { + var list = expression.Object; + + if (list != null && (list.Type == typeof(string))) + { + ParseStringContains(expression); + return; + } + + ParseEnumerableContains(expression); + } + + private void ParseEnumerableContains(MethodCallExpression body) + { + // Fish out the list and the item to compare + // It's in a different form for arrays and Lists + var list = body.Object; + Expression item; + + if (list != null) + { + // Generic collection + item = body.Arguments[0]; + } + else + { + // Static method + // Must be Enumerable.Contains(source, item) + if (body.Method.DeclaringType != typeof(Enumerable) || body.Arguments.Count != 2) + { + throw new NotSupportedException("Unexpected form of Enumerable.Contains"); + } + + list = body.Arguments[0]; + item = body.Arguments[1]; + } + + _sb.Append('('); + + Visit(item); + + _sb.Append(" IN "); + + // hardcode the integer list if it exists to bypass parameter limit + if (item.Type == typeof(int) && TryGetRightValue(list, out var value)) + { + var items = (IEnumerable)value; + _sb.Append('('); + _sb.Append(string.Join(", ", items)); + _sb.Append(')'); + + _gotConcreteValue = true; + } + else + { + Visit(list); + } + + _sb.Append(')'); + } + + private void ParseStringContains(MethodCallExpression body) + { + _sb.Append('('); + + Visit(body.Object); + + _sb.Append(" LIKE '%' || "); + + Visit(body.Arguments[0]); + + _sb.Append(" || '%')"); + } + + private void ParseStartsWith(MethodCallExpression body) + { + _sb.Append('('); + + Visit(body.Object); + + _sb.Append(" LIKE "); + + Visit(body.Arguments[0]); + + _sb.Append(" || '%')"); + } + + private void ParseEndsWith(MethodCallExpression body) + { + _sb.Append('('); + + Visit(body.Object); + + _sb.Append(" LIKE '%' || "); + + Visit(body.Arguments[0]); + + _sb.Append(')'); + } + + public override string ToString() + { + var sql = _sb.ToString(); + + if (_requireConcreteValue && !_gotConcreteValue) + { + var e = new InvalidOperationException("WhereBuilder requires a concrete condition"); + e.Data.Add("sql", sql); + throw e; + } + + return sql; + } + } +} diff --git a/src/NzbDrone.Core/History/HistoryRepository.cs b/src/NzbDrone.Core/History/HistoryRepository.cs index e5e535fd8..1471d30f2 100644 --- a/src/NzbDrone.Core/History/HistoryRepository.cs +++ b/src/NzbDrone.Core/History/HistoryRepository.cs @@ -100,7 +100,7 @@ namespace NzbDrone.Core.History public int CountSince(int indexerId, DateTime date, List eventTypes) { - var builder = new SqlBuilder() + var builder = new SqlBuilder(_database.DatabaseType) .SelectCount() .Where(x => x.IndexerId == indexerId) .Where(x => x.Date >= date)