using System; using System.Collections.Generic; using System.Data; using System.Linq; using System.Linq.Expressions; using System.Reflection; using System.Text; using Dapper; using NzbDrone.Core.Datastore.Events; using NzbDrone.Core.Messaging.Events; namespace NzbDrone.Core.Datastore { public interface IBasicRepository where TModel : ModelBase, new() { IEnumerable All(); int Count(); TModel Find(int id); TModel Get(int id); TModel Insert(TModel model); TModel Update(TModel model); TModel Upsert(TModel model); void SetFields(TModel model, params Expression>[] properties); void Delete(TModel model); void Delete(int id); IEnumerable Get(IEnumerable ids); void InsertMany(IList model); void UpdateMany(IList model); void SetFields(IList models, params Expression>[] properties); void DeleteMany(List model); void DeleteMany(IEnumerable ids); void Purge(bool vacuum = false); bool HasItems(); TModel Single(); TModel SingleOrDefault(); PagingSpec GetPaged(PagingSpec pagingSpec); } public class BasicRepository : IBasicRepository where TModel : ModelBase, new() { private readonly IEventAggregator _eventAggregator; private readonly PropertyInfo _keyProperty; private readonly List _properties; private readonly string _updateSql; private readonly string _insertSql; protected readonly IDatabase _database; protected readonly string _table; public BasicRepository(IDatabase database, IEventAggregator eventAggregator) { _database = database; _eventAggregator = eventAggregator; var type = typeof(TModel); _table = TableMapping.Mapper.TableNameMapping(type); _keyProperty = type.GetProperty(nameof(ModelBase.Id)); var excluded = TableMapping.Mapper.ExcludeProperties(type).Select(x => x.Name).ToList(); excluded.Add(_keyProperty.Name); _properties = type.GetProperties().Where(x => x.IsMappableProperty() && !excluded.Contains(x.Name)).ToList(); _insertSql = GetInsertSql(); _updateSql = GetUpdateSql(_properties); } protected virtual SqlBuilder Builder() => new SqlBuilder(_database.DatabaseType); protected virtual List Query(SqlBuilder builder) => _database.Query(builder).ToList(); protected List Query(Expression> where) => Query(Builder().Where(where)); protected virtual List QueryDistinct(SqlBuilder builder) => _database.QueryDistinct(builder).ToList(); public int Count() { using (var conn = _database.OpenConnection()) { return conn.ExecuteScalar($"SELECT COUNT(*) FROM \"{_table}\""); } } public virtual IEnumerable All() { return Query(Builder()); } public TModel Find(int id) { var model = Query(x => x.Id == id).FirstOrDefault(); return model; } public TModel Get(int id) { var model = Find(id); if (model == null) { throw new ModelNotFoundException(typeof(TModel), id); } return model; } public IEnumerable Get(IEnumerable ids) { if (!ids.Any()) { return Array.Empty(); } var result = Query(x => ids.Contains(x.Id)); if (result.Count != ids.Count()) { throw new ApplicationException($"Expected query to return {ids.Count()} rows but returned {result.Count}"); } return result; } public TModel SingleOrDefault() { return All().SingleOrDefault(); } public TModel Single() { return All().Single(); } public TModel Insert(TModel model) { if (model.Id != 0) { throw new InvalidOperationException("Can't insert model with existing ID " + model.Id); } using (var conn = _database.OpenConnection()) { model = Insert(conn, null, model); } ModelCreated(model); return model; } private string GetInsertSql() { var sbColumnList = new StringBuilder(null); for (var i = 0; i < _properties.Count; i++) { var property = _properties[i]; sbColumnList.AppendFormat("\"{0}\"", property.Name); if (i < _properties.Count - 1) { sbColumnList.Append(", "); } } var sbParameterList = new StringBuilder(null); for (var i = 0; i < _properties.Count; i++) { var property = _properties[i]; sbParameterList.AppendFormat("@{0}", property.Name); if (i < _properties.Count - 1) { sbParameterList.Append(", "); } } if (_database.DatabaseType == DatabaseType.PostgreSQL) { return $"INSERT INTO \"{_table}\" ({sbColumnList.ToString()}) VALUES ({sbParameterList.ToString()}) RETURNING \"Id\""; } return $"INSERT INTO {_table} ({sbColumnList.ToString()}) VALUES ({sbParameterList.ToString()}); SELECT last_insert_rowid() id"; } private TModel Insert(IDbConnection connection, IDbTransaction transaction, TModel model) { SqlBuilderExtensions.LogQuery(_insertSql, model); var multi = connection.QueryMultiple(_insertSql, model, transaction); var multiRead = multi.Read(); var id = (int)(multiRead.First().id ?? multiRead.First().Id); _keyProperty.SetValue(model, id); return model; } public void InsertMany(IList models) { if (models.Any(x => x.Id != 0)) { throw new InvalidOperationException("Can't insert model with existing ID != 0"); } using (var conn = _database.OpenConnection()) { using (IDbTransaction tran = conn.BeginTransaction(IsolationLevel.ReadCommitted)) { foreach (var model in models) { Insert(conn, tran, model); } tran.Commit(); } } } public TModel Update(TModel model) { if (model.Id == 0) { throw new InvalidOperationException("Can't update model with ID 0"); } using (var conn = _database.OpenConnection()) { UpdateFields(conn, null, model, _properties); } ModelUpdated(model); return model; } public void UpdateMany(IList models) { if (models.Any(x => x.Id == 0)) { throw new InvalidOperationException("Can't update model with ID 0"); } using (var conn = _database.OpenConnection()) { UpdateFields(conn, null, models, _properties); } } protected void Delete(Expression> where) { Delete(Builder().Where(where)); } protected void Delete(SqlBuilder builder) { var sql = builder.AddDeleteTemplate(typeof(TModel)).LogQuery(); using (var conn = _database.OpenConnection()) { conn.Execute(sql.RawSql, sql.Parameters); } } public void Delete(TModel model) { Delete(model.Id); } public void Delete(int id) { Delete(x => x.Id == id); } public void DeleteMany(IEnumerable ids) { if (ids.Any()) { Delete(x => ids.Contains(x.Id)); } } public void DeleteMany(List models) { DeleteMany(models.Select(m => m.Id)); } public TModel Upsert(TModel model) { if (model.Id == 0) { Insert(model); return model; } Update(model); return model; } public void Purge(bool vacuum = false) { using (var conn = _database.OpenConnection()) { conn.Execute($"DELETE FROM \"{_table}\""); } if (vacuum) { Vacuum(); } } protected void Vacuum() { _database.Vacuum(); } public bool HasItems() { return Count() > 0; } public void SetFields(TModel model, params Expression>[] properties) { if (model.Id == 0) { throw new InvalidOperationException("Attempted to update model without ID"); } var propertiesToUpdate = properties.Select(x => x.GetMemberName()).ToList(); using (var conn = _database.OpenConnection()) { UpdateFields(conn, null, model, propertiesToUpdate); } ModelUpdated(model); } public void SetFields(IList models, params Expression>[] properties) { if (models.Any(x => x.Id == 0)) { throw new InvalidOperationException("Attempted to update model without ID"); } var propertiesToUpdate = properties.Select(x => x.GetMemberName()).ToList(); using (var conn = _database.OpenConnection()) { UpdateFields(conn, null, models, propertiesToUpdate); } foreach (var model in models) { ModelUpdated(model); } } private string GetUpdateSql(List propertiesToUpdate) { var sb = new StringBuilder(); sb.AppendFormat("UPDATE \"{0}\" SET ", _table); for (var i = 0; i < propertiesToUpdate.Count; i++) { var property = propertiesToUpdate[i]; sb.AppendFormat("\"{0}\" = @{1}", property.Name, property.Name); if (i < propertiesToUpdate.Count - 1) { sb.Append(", "); } } sb.Append($" WHERE \"{_keyProperty.Name}\" = @{_keyProperty.Name}"); return sb.ToString(); } private void UpdateFields(IDbConnection connection, IDbTransaction transaction, TModel model, List propertiesToUpdate) { var sql = propertiesToUpdate == _properties ? _updateSql : GetUpdateSql(propertiesToUpdate); SqlBuilderExtensions.LogQuery(sql, model); connection.Execute(sql, model, transaction: transaction); } private void UpdateFields(IDbConnection connection, IDbTransaction transaction, IList models, List propertiesToUpdate) { var sql = propertiesToUpdate == _properties ? _updateSql : GetUpdateSql(propertiesToUpdate); foreach (var model in models) { SqlBuilderExtensions.LogQuery(sql, model); } connection.Execute(sql, models, transaction: transaction); } protected virtual SqlBuilder PagedBuilder() => Builder(); protected virtual IEnumerable PagedQuery(SqlBuilder sql) => Query(sql); public virtual PagingSpec GetPaged(PagingSpec pagingSpec) { pagingSpec.Records = GetPagedRecords(PagedBuilder(), pagingSpec, PagedQuery); pagingSpec.TotalRecords = GetPagedRecordCount(PagedBuilder().SelectCount(), pagingSpec); return pagingSpec; } private void AddFilters(SqlBuilder builder, PagingSpec pagingSpec) { var filters = pagingSpec.FilterExpressions; foreach (var filter in filters) { builder.Where(filter); } } protected List GetPagedRecords(SqlBuilder builder, PagingSpec pagingSpec, Func> queryFunc) { AddFilters(builder, pagingSpec); if (pagingSpec.SortKey == null) { pagingSpec.SortKey = $"{_table}.{_keyProperty.Name}"; } var sortKey = TableMapping.Mapper.GetSortKey(pagingSpec.SortKey); var sortDirection = pagingSpec.SortDirection == SortDirection.Descending ? "DESC" : "ASC"; var pagingOffset = Math.Max(pagingSpec.Page - 1, 0) * pagingSpec.PageSize; builder.OrderBy($"\"{sortKey}\" {sortDirection} LIMIT {pagingSpec.PageSize} OFFSET {pagingOffset}"); return queryFunc(builder).ToList(); } protected int GetPagedRecordCount(SqlBuilder builder, PagingSpec pagingSpec, string template = null) { AddFilters(builder, pagingSpec); SqlBuilder.Template sql; if (template != null) { sql = builder.AddTemplate(template).LogQuery(); } else { sql = builder.AddPageCountTemplate(typeof(TModel)); } using (var conn = _database.OpenConnection()) { return conn.ExecuteScalar(sql.RawSql, sql.Parameters); } } protected void ModelCreated(TModel model) { PublishModelEvent(model, ModelAction.Created); } protected void ModelUpdated(TModel model) { PublishModelEvent(model, ModelAction.Updated); } protected void ModelDeleted(TModel model) { PublishModelEvent(model, ModelAction.Deleted); } private void PublishModelEvent(TModel model, ModelAction action) { if (PublishModelEvents) { _eventAggregator.PublishEvent(new ModelEvent(model, action)); } } protected virtual bool PublishModelEvents => false; } }