/*  Copyright (C) 2008 - 2011 Jordan Marr

This library is free software; you can redistribute it and/or
modify it under the terms of the GNU Lesser General Public
License as published by the Free Software Foundation; either
version 3 of the License, or (at your option) any later version.

This library is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
Lesser General Public License for more details.

You should have received a copy of the GNU Lesser General Public
License along with this library. If not, see <http://www.gnu.org/licenses/>. */

using System;
using System.Collections.Generic;
using System.Text;
using System.Data;
using System.Data.Common;
using System.Reflection;
using System.Collections;
using Marr.Data.Mapping;
using Marr.Data.Parameters;
using Marr.Data.QGen;
using System.Linq.Expressions;
using System.Diagnostics;

namespace Marr.Data
{
    /// <summary>
    /// This class is the main access point for making database related calls.
    /// </summary>
    public class DataMapper : IDataMapper
    {

        #region - Contructor, Members -

        private DbCommand _command;

        /// <summary>
        /// Initializes a DataMapper for the given provider type and connection string.
        /// </summary>
        /// <param name="providerName">Ex: </param>
        /// <param name="connectionString">The database connection string.</param>
        public DataMapper(string providerName, string connectionString)
            : this(DbProviderFactories.GetFactory(providerName), connectionString)
        { }

        /// <summary>
        /// A database provider agnostic initialization.
        /// </summary>
        /// <param name="connectionString">The database connection string.</param>
        public DataMapper(DbProviderFactory dbProviderFactory, string connectionString)
        {
            SqlMode = SqlModes.StoredProcedure;
            if (dbProviderFactory == null)
                throw new ArgumentNullException("dbProviderFactory");

            if (string.IsNullOrEmpty(connectionString))
                throw new ArgumentNullException("connectionString");

            ProviderFactory = dbProviderFactory;

            ConnectionString = connectionString;
        }

        public string ConnectionString { get; private set; }

        public DbProviderFactory ProviderFactory { get; private set; }

        /// <summary>
        /// Creates a new command utilizing the connection string.
        /// </summary>
        private DbCommand CreateNewCommand()
        {
            DbConnection conn = ProviderFactory.CreateConnection();
            conn.ConnectionString = ConnectionString;
            DbCommand cmd = conn.CreateCommand();
            SetSqlMode(cmd);
            return cmd;
        }

        /// <summary>
        /// Creates a new command utilizing the connection string with a given SQL command.
        /// </summary>
        private DbCommand CreateNewCommand(string sql)
        {
            DbCommand cmd = CreateNewCommand();
            cmd.CommandText = sql;
            return cmd;
        }

        /// <summary>
        /// Gets or creates a DbCommand object.
        /// </summary>
        public DbCommand Command
        {
            get
            {
                // Lazy load
                if (_command == null)
                    _command = CreateNewCommand();
                else
                    SetSqlMode(_command); // Set SqlMode every time.

                return _command;
            }
        }

        #endregion

        #region - Parameters -

        public DbParameterCollection Parameters
        {
            get
            {
                return Command.Parameters;
            }
        }

        public ParameterChainMethods AddParameter(string name, object value)
        {
            return new ParameterChainMethods(Command, name, value);
        }

        public IDbDataParameter AddParameter(IDbDataParameter parameter)
        {
            // Convert null values to DBNull.Value
            if (parameter.Value == null)
                parameter.Value = DBNull.Value;

            Parameters.Add(parameter);
            return parameter;
        }

        #endregion

        #region - SP / SQL Mode -

        /// <summary>
        /// Gets or sets a value that determines whether the DataMapper will 
        /// use a stored procedure or a sql text command to access 
        /// the database.  The default is stored procedure.
        /// </summary>
        public SqlModes SqlMode { get; set; }

        /// <summary>
        /// Sets the DbCommand objects CommandType to the current SqlMode.
        /// </summary>
        /// <param name="command">The DbCommand object we are modifying.</param>
        /// <returns>Returns the same DbCommand that was passed in.</returns>
        private DbCommand SetSqlMode(DbCommand command)
        {
            if (SqlMode == SqlModes.StoredProcedure)
                command.CommandType = CommandType.StoredProcedure;
            else
                command.CommandType = CommandType.Text;

            return command;
        }

        #endregion

        #region - ExecuteScalar, ExecuteNonQuery, ExecuteReader -

        /// <summary>
        /// Executes a stored procedure that returns a scalar value.
        /// </summary>
        /// <param name="sql">The SQL command to execute.</param>
        /// <returns>A scalar value</returns>
        public object ExecuteScalar(string sql)
        {
            if (string.IsNullOrEmpty(sql))
                throw new ArgumentNullException("sql", "A SQL query or stored procedure name is required");
            Command.CommandText = sql;

            try
            {
                OpenConnection();
                return Command.ExecuteScalar();
            }
            finally
            {
                CloseConnection();
            }
        }

        /// <summary>
        /// Executes a non query that returns an integer.
        /// </summary>
        /// <param name="sql">The SQL command to execute.</param>
        /// <returns>An integer value</returns>
        public int ExecuteNonQuery(string sql)
        {
            if (string.IsNullOrEmpty(sql))
                throw new ArgumentNullException("sql", "A SQL query or stored procedure name is required");
            Command.CommandText = sql;

            try
            {
                OpenConnection();
                return Command.ExecuteNonQuery();
            }
            finally
            {
                CloseConnection();
            }
        }

        /// <summary>
        /// Executes a DataReader that can be controlled using a Func delegate.
        /// (Note that reader.Read() will be called automatically).
        /// </summary>
        /// <typeparam name="TResult">The type that will be return in the result set.</typeparam>
        /// <param name="sql">The sql statement that will be executed.</param>
        /// <param name="func">The function that will build the the TResult set.</param>
        /// <returns>An IEnumerable of TResult.</returns>
        public IEnumerable<TResult> ExecuteReader<TResult>(string sql, Func<DbDataReader, TResult> func)
        {
            if (string.IsNullOrEmpty(sql))
                throw new ArgumentNullException("sql", "A SQL query or stored procedure name is required");
            Command.CommandText = sql;

            try
            {
                OpenConnection();

                var list = new List<TResult>();
                DbDataReader reader = null;
                try
                {
                    reader = Command.ExecuteReader();

                    while (reader.Read())
                    {
                        list.Add(func(reader));
                    }

                    return list;
                }
                finally
                {
                    if (reader != null) reader.Close();
                }
            }
            finally
            {
                CloseConnection();
            }
        }

        /// <summary>
        /// Executes a DataReader that can be controlled using an Action delegate.
        /// </summary>
        /// <param name="sql">The sql statement that will be executed.</param>
        /// <param name="action">The delegate that will work with the result set.</param>
        public void ExecuteReader(string sql, Action<DbDataReader> action)
        {
            if (string.IsNullOrEmpty(sql))
                throw new ArgumentNullException("sql", "A SQL query or stored procedure name is required");

            Command.CommandText = sql;

            try
            {
                OpenConnection();

                DbDataReader reader = null;
                try
                {
                    reader = Command.ExecuteReader();

                    while (reader.Read())
                    {
                        action(reader);
                    }
                }
                finally
                {
                    if (reader != null) reader.Close();
                }
            }
            finally
            {
                CloseConnection();
            }
        }

        #endregion

        #region - DataSets -

        public DataSet GetDataSet(string sql)
        {
            return GetDataSet(sql, new DataSet(), null);
        }

        public DataSet GetDataSet(string sql, DataSet ds, string tableName)
        {
            if (string.IsNullOrEmpty(sql))
                throw new ArgumentNullException("sql", "A SQL query or stored procedure name is required");

            try
            {
                using (DbDataAdapter adapter = ProviderFactory.CreateDataAdapter())
                {
                    Command.CommandText = sql;
                    adapter.SelectCommand = Command;

                    if (ds == null)
                        ds = new DataSet();

                    OpenConnection();

                    if (string.IsNullOrEmpty(tableName))
                        adapter.Fill(ds);
                    else
                        adapter.Fill(ds, tableName);

                    return ds;
                }
            }
            finally
            {
                CloseConnection();  // Clears parameters
            }
        }

        public DataTable GetDataTable(string sql)
        {
            return GetDataTable(sql, null, null);
        }

        public DataTable GetDataTable(string sql, DataTable dt, string tableName)
        {
            if (string.IsNullOrEmpty(sql))
                throw new ArgumentNullException("sql", "A SQL query or stored procedure name is required");

            try
            {
                using (DbDataAdapter adapter = ProviderFactory.CreateDataAdapter())
                {
                    Command.CommandText = sql;
                    adapter.SelectCommand = Command;

                    if (dt == null)
                        dt = new DataTable();

                    adapter.Fill(dt);

                    if (!string.IsNullOrEmpty(tableName))
                        dt.TableName = tableName;

                    return dt;
                }
            }
            finally
            {
                CloseConnection();  // Clears parameters
            }
        }

        public int UpdateDataSet(DataSet ds, string sql)
        {
            if (string.IsNullOrEmpty(sql))
                throw new ArgumentNullException("sql", "A SQL query or stored procedure name is required");

            if (ds == null)
                throw new ArgumentNullException("ds", "DataSet cannot be null.");

            DbDataAdapter adapter = null;

            try
            {
                adapter = ProviderFactory.CreateDataAdapter();

                adapter.UpdateCommand = Command;
                adapter.UpdateCommand.CommandText = sql;

                return adapter.Update(ds);
            }
            finally
            {
                if (adapter.UpdateCommand != null)
                    adapter.UpdateCommand.Dispose();

                adapter.Dispose();
            }
        }

        public int InsertDataTable(DataTable table, string insertSP)
        {
            return InsertDataTable(table, insertSP, UpdateRowSource.None);
        }

        public int InsertDataTable(DataTable dt, string sql, UpdateRowSource updateRowSource)
        {
            if (string.IsNullOrEmpty(sql))
                throw new ArgumentNullException("sql", "A SQL query or stored procedure name is required");

            if (dt == null)
                throw new ArgumentNullException("dt", "DataTable cannot be null.");

            DbDataAdapter adapter = null;

            try
            {
                adapter = ProviderFactory.CreateDataAdapter();

                adapter.InsertCommand = Command;
                adapter.InsertCommand.CommandText = sql;

                adapter.InsertCommand.UpdatedRowSource = updateRowSource;

                return adapter.Update(dt);
            }
            finally
            {
                if (adapter.InsertCommand != null)
                    adapter.InsertCommand.Dispose();

                adapter.Dispose();
            }
        }

        public int DeleteDataTable(DataTable dt, string sql)
        {
            if (string.IsNullOrEmpty(sql))
                throw new ArgumentNullException("sql", "A SQL query or stored procedure name is required");

            if (dt == null)
                throw new ArgumentNullException("dt", "DataSet cannot be null.");

            DbDataAdapter adapter = null;

            try
            {
                adapter = ProviderFactory.CreateDataAdapter();

                adapter.DeleteCommand = Command;
                adapter.DeleteCommand.CommandText = sql;

                return adapter.Update(dt);
            }
            finally
            {
                if (adapter.DeleteCommand != null)
                    adapter.DeleteCommand.Dispose();

                adapter.Dispose();
            }
        }

        #endregion

        #region - Find -

        public T Find<T>(string sql)
        {
            return Find<T>(sql, default(T));
        }

        /// <summary>
        /// Returns an entity of type T.
        /// </summary>
        /// <typeparam name="T">The type of entity that is to be instantiated and loaded with values.</typeparam>
        /// <param name="sql">The SQL command to execute.</param>
        /// <returns>An instantiated and loaded entity of type T.</returns>
        public T Find<T>(string sql, T ent)
        {
            if (string.IsNullOrEmpty(sql))
                throw new ArgumentNullException("sql", "A stored procedure name has not been specified for 'Find'.");

            Type entityType = typeof(T);
            Command.CommandText = sql;

            MapRepository repository = MapRepository.Instance;
            ColumnMapCollection mappings = repository.GetColumns(entityType);

            bool isSimpleType = DataHelper.IsSimpleType(typeof(T));

            try
            {
                OpenConnection();
                var mappingHelper = new MappingHelper(this);

                using (DbDataReader reader = Command.ExecuteReader())
                {
                    if (reader.Read())
                    {
                        if (isSimpleType)
                        {
                            return mappingHelper.LoadSimpleValueFromFirstColumn<T>(reader);
                        }
                        else
                        {
                            if (ent == null)
                                ent = (T)mappingHelper.CreateAndLoadEntity<T>(mappings, reader, false);
                            else
                                mappingHelper.LoadExistingEntity(mappings, reader, ent, false);
                        }
                    }
                }
            }
            finally
            {
                CloseConnection();
            }

            return ent;
        }

        #endregion

        #region - Query -

        /// <summary>
        /// Creates a QueryBuilder that allows you to build a query.
        /// </summary>
        /// <typeparam name="T">The type of object that will be queried.</typeparam>
        /// <returns>Returns a QueryBuilder of T.</returns>
        public QueryBuilder<T> Query<T>()
        {
            var dialect = QueryFactory.CreateDialect(this);
            return new QueryBuilder<T>(this, dialect);
        }

        /// <summary>
        /// Returns the results of a query.
        /// Uses a List of type T to return the data.
        /// </summary>
        /// <returns>Returns a list of the specified type.</returns>
        public List<T> Query<T>(string sql)
        {
            return (List<T>)Query<T>(sql, new List<T>());
        }

        /// <summary>
        /// Returns the results of a SP query.
        /// </summary>
        /// <returns>Returns a list of the specified type.</returns>
        public ICollection<T> Query<T>(string sql, ICollection<T> entityList)
        {
            return Query<T>(sql, entityList, false);
        }

        internal ICollection<T> Query<T>(string sql, ICollection<T> entityList, bool useAltName)
        {
            if (entityList == null)
                throw new ArgumentNullException("entityList", "ICollection instance cannot be null.");

            if (string.IsNullOrEmpty(sql))
                throw new ArgumentNullException("sql", "A query or stored procedure has not been specified for 'Query'.");

            var mappingHelper = new MappingHelper(this);
            Type entityType = typeof(T);
            Command.CommandText = sql;
            ColumnMapCollection mappings = MapRepository.Instance.GetColumns(entityType);

            bool isSimpleType = DataHelper.IsSimpleType(typeof(T));

            try
            {
                OpenConnection();
                using (DbDataReader reader = Command.ExecuteReader())
                {
                    while (reader.Read())
                    {
                        if (isSimpleType)
                        {
                            entityList.Add(mappingHelper.LoadSimpleValueFromFirstColumn<T>(reader));
                        }
                        else
                        {
                            entityList.Add((T)mappingHelper.CreateAndLoadEntity<T>(mappings, reader, useAltName));
                        }
                    }
                }
            }
            finally
            {
                CloseConnection();
            }

            return entityList;
        }

        #endregion

        #region - Query to Graph -

        public List<T> QueryToGraph<T>(string sql)
        {
            return (List<T>)QueryToGraph<T>(sql, new List<T>());
        }

        public ICollection<T> QueryToGraph<T>(string sql, ICollection<T> entityList)
        {
            EntityGraph graph = new EntityGraph(typeof(T), (IList)entityList);
            return QueryToGraph<T>(sql, graph, new List<MemberInfo>());
        }

        /// <summary>
        /// Queries a view that joins multiple tables and returns an object graph.
        /// </summary>
        /// <typeparam name="T"></typeparam>
        /// <param name="sql"></param>
        /// <param name="entityList"></param>
        /// <param name="entityGraph">Coordinates loading all objects in the graph..</param>
        /// <returns></returns>
        internal ICollection<T> QueryToGraph<T>(string sql, EntityGraph graph, List<MemberInfo> childrenToLoad)
        {
            if (string.IsNullOrEmpty(sql))
                throw new ArgumentNullException("sql", "sql");

            var mappingHelper = new MappingHelper(this);
            Type parentType = typeof(T);
            Command.CommandText = sql;

            try
            {
                OpenConnection();
                using (DbDataReader reader = Command.ExecuteReader())
                {
                    while (reader.Read())
                    {
                        // The entire EntityGraph is traversed for each record, 
                        // and multiple entities are created from each view record.
                        foreach (EntityGraph lvl in graph)
                        {
                            if (lvl.IsParentReference)
                            {
                                // A child specified a circular reference to its previously loaded parent
                                lvl.AddParentReference();
                            }
                            else if (childrenToLoad.Count > 0 && !lvl.IsRoot && !childrenToLoad.ContainsMember(lvl.Member))
                            {
                                // A list of relationships-to-load was specified and this relationship was not included
                                continue;
                            }
                            else if (lvl.IsNewGroup(reader))
                            {
                                // Create a new entity with the data reader
                                var newEntity = mappingHelper.CreateAndLoadEntity(lvl.EntityType, lvl.Columns, reader, true);

                                // Add entity to the appropriate place in the object graph
                                lvl.AddEntity(newEntity);
                            }
                        }
                    }
                }
            }
            finally
            {
                CloseConnection();
            }

            return (ICollection<T>)graph.RootList;
        }

        #endregion

        #region - Update -

        public UpdateQueryBuilder<T> Update<T>()
        {
            return new UpdateQueryBuilder<T>(this);
        }

        public int Update<T>(T entity, Expression<Func<T, bool>> filter)
        {
            return Update<T>()
                .Entity(entity)
                .Where(filter)
                .Execute();
        }

        public int Update<T>(string tableName, T entity, Expression<Func<T, bool>> filter)
        {
            return Update<T>()
                .TableName(tableName)
                .Entity(entity)
                .Where(filter)
                .Execute();
        }

        public int Update<T>(T entity, string sql)
        {
            return Update<T>()
                .Entity(entity)
                .QueryText(sql)
                .Execute();
        }

        #endregion

        #region - Insert -

        /// <summary>
        /// Creates an InsertQueryBuilder that allows you to build an insert statement.
        /// This method gives you the flexibility to manually configure all options of your insert statement.
        /// Note: You must manually call the Execute() chaining method to run the query.
        /// </summary>
        public InsertQueryBuilder<T> Insert<T>()
        {
            return new InsertQueryBuilder<T>(this);
        }

        /// <summary>
        /// Generates and executes an insert query for the given entity.
        /// This overload will automatically run an identity query if you have mapped an auto-incrementing column,
        /// and if an identity query has been implemented for your current database dialect.
        /// </summary>
        public object Insert<T>(T entity)
        {
            var columns = MapRepository.Instance.GetColumns(typeof(T));
            var dialect = QueryFactory.CreateDialect(this);
            var builder = Insert<T>().Entity(entity);

            // If an auto-increment column exists and this dialect has an identity query...
            if (columns.Exists(c => c.ColumnInfo.IsAutoIncrement) && dialect.HasIdentityQuery)
            {
                builder.GetIdentity();
            }

            return builder.Execute();
        }

        /// <summary>
        /// Generates and executes an insert query for the given entity.
        /// This overload will automatically run an identity query if you have mapped an auto-incrementing column,
        /// and if an identity query has been implemented for your current database dialect.
        /// </summary>
        public object Insert<T>(string tableName, T entity)
        {
            var columns = MapRepository.Instance.GetColumns(typeof(T));
            var dialect = QueryFactory.CreateDialect(this);
            var builder = Insert<T>().Entity(entity).TableName(tableName);

            // If an auto-increment column exists and this dialect has an identity query...
            if (columns.Exists(c => c.ColumnInfo.IsAutoIncrement) && dialect.HasIdentityQuery)
            {
                builder.GetIdentity();
            }

            return builder.Execute();
        }

        /// <summary>
        /// Executes an insert query for the given entity using the given sql insert statement.
        /// This overload will automatically run an identity query if you have mapped an auto-incrementing column,
        /// and if an identity query has been implemented for your current database dialect.
        /// </summary>
        public object Insert<T>(T entity, string sql)
        {
            var columns = MapRepository.Instance.GetColumns(typeof(T));
            var dialect = QueryFactory.CreateDialect(this);
            var builder = Insert<T>().Entity(entity).QueryText(sql);

            // If an auto-increment column exists and this dialect has an identity query...
            if (columns.Exists(c => c.ColumnInfo.IsAutoIncrement) && dialect.HasIdentityQuery)
            {
                builder.GetIdentity();
            }

            return builder.Execute();
        }

        #endregion

        #region - Delete -

        public int Delete<T>(Expression<Func<T, bool>> filter)
        {
            return Delete<T>(null, filter);
        }

        public int Delete<T>(string tableName, Expression<Func<T, bool>> filter)
        {
            // Remember sql mode
            var previousSqlMode = SqlMode;
            SqlMode = SqlModes.Text;

            var mappingHelper = new MappingHelper(this);
            if (tableName == null)
            {
                tableName = MapRepository.Instance.GetTableName(typeof(T));
            }
            var dialect = QueryFactory.CreateDialect(this);
            TableCollection tables = new TableCollection();
            tables.Add(new Table(typeof(T)));
            var where = new WhereBuilder<T>(Command, dialect, filter, tables, false, false);
            IQuery query = QueryFactory.CreateDeleteQuery(dialect, tables[0], where.ToString());
            Command.CommandText = query.Generate();

            int rowsAffected = 0;

            try
            {
                OpenConnection();
                rowsAffected = Command.ExecuteNonQuery();
            }
            finally
            {
                CloseConnection();
            }

            // Return to previous sql mode
            SqlMode = previousSqlMode;

            return rowsAffected;
        }

        #endregion

        #region - Events -

        public event EventHandler OpeningConnection;

        public event EventHandler ClosingConnection;

        #endregion

        #region - Connections / Transactions -

        protected virtual void OnOpeningConnection()
        {
            if (OpeningConnection != null)
                OpeningConnection(this, EventArgs.Empty);
        }

        protected virtual void OnClosingConnection()
        {
            WriteToTraceLog();

            if (ClosingConnection != null)
                ClosingConnection(this, EventArgs.Empty);
        }

        protected internal void OpenConnection()
        {
            OnOpeningConnection();

            if (Command.Connection.State != ConnectionState.Open)
                Command.Connection.Open();
        }

        protected internal void CloseConnection()
        {
            OnClosingConnection();

            Command.Parameters.Clear();
            Command.CommandText = string.Empty;

            if (Command.Transaction == null)
                Command.Connection.Close(); // Only close if no transaction is present

            UnbindEvents();
        }

        private void WriteToTraceLog()
        {
            if (MapRepository.Instance.EnableTraceLogging)
            {
                var sb = new StringBuilder();
                sb.AppendLine();
                sb.AppendLine("==== Begin Query Trace ====");
                sb.AppendLine();
                sb.AppendLine("QUERY TYPE:");
                sb.AppendLine(Command.CommandType.ToString());
                sb.AppendLine();
                sb.AppendLine("QUERY TEXT:");
                sb.AppendLine(Command.CommandText);
                sb.AppendLine();
                sb.AppendLine("PARAMETERS:");
                foreach (IDbDataParameter p in Parameters)
                {
                    object val = (p.Value != null && p.Value is string) ? string.Format("\"{0}\"", p.Value) : p.Value;
                    sb.AppendFormat("{0} = [{1}]", p.ParameterName, val ?? "NULL").AppendLine();
                }
                sb.AppendLine();
                sb.AppendLine("==== End Query Trace ====");
                sb.AppendLine();

                Trace.Write(sb.ToString());
            }
        }

        private void UnbindEvents()
        {
            OpeningConnection = null;
            ClosingConnection = null;
        }

        public void BeginTransaction()
        {
            OpenConnection();
            DbTransaction trans = Command.Connection.BeginTransaction();
            Command.Transaction = trans;
        }

        public void RollBack()
        {
            try
            {
                if (Command.Transaction != null)
                    Command.Transaction.Rollback();
            }
            finally
            {
                Command.Connection.Close();
            }
        }

        public void Commit()
        {
            try
            {
                if (Command.Transaction != null)
                    Command.Transaction.Commit();
            }
            finally
            {
                Command.Connection.Close();
            }
        }

        #endregion

        #region - IDisposable Members -

        public void Dispose()
        {
            Dispose(true);
            GC.SuppressFinalize(this); // In case a derived class implements a finalizer
        }

        protected virtual void Dispose(bool disposing)
        {
            if (disposing)
            {
                if (Command.Transaction != null)
                {
                    Command.Transaction.Dispose();
                    Command.Transaction = null;
                }

                if (Command.Connection != null)
                {
                    Command.Connection.Dispose();
                    Command.Connection = null;
                }

                if (Command != null)
                {
                    Command.Dispose();
                    _command = null;
                }
            }
        }

        #endregion

    }
}