import os
import sys
from tempfile import NamedTemporaryFile
from abc import ABCMeta, abstractmethod
from threading import Thread
from datetime import datetime, timedelta
from collections import namedtuple
import wave
import subprocess
from queue import Queue, Empty
from .io import _guess_audio_format
from .util import AudioDataSource, make_duration_formatter
from .core import split
from .exceptions import (
    EndOfProcessing,
    AudioEncodingError,
    AudioEncodingWarning,
)


_STOP_PROCESSING = "STOP_PROCESSING"
_Detection = namedtuple("_Detection", "id start end duration")


def _run_subprocess(command):
    try:
        with subprocess.Popen(
            command,
            stdin=open(os.devnull, "rb"),
            stdout=subprocess.PIPE,
            stderr=subprocess.PIPE,
        ) as proc:
            stdout, stderr = proc.communicate()
            return proc.returncode, stdout, stderr
    except Exception:
        err_msg = "Couldn't export audio using command: '{}'".format(command)
        raise AudioEncodingError(err_msg)


class Worker(Thread, metaclass=ABCMeta):
    def __init__(self, timeout=0.5, logger=None):
        self._timeout = timeout
        self._logger = logger
        self._inbox = Queue()
        Thread.__init__(self)

    def run(self):
        while True:
            message = self._get_message()
            if message == _STOP_PROCESSING:
                break
            if message is not None:
                self._process_message(message)
        self._post_process()

    @abstractmethod
    def _process_message(self, message):
        """Process incoming messages"""

    def _post_process(self):
        pass

    def _log(self, message):
        self._logger.info(message)

    def _stop_requested(self):
        try:
            message = self._inbox.get_nowait()
            if message == _STOP_PROCESSING:
                return True
        except Empty:
            return False

    def stop(self):
        self.send(_STOP_PROCESSING)
        self.join()

    def send(self, message):
        self._inbox.put(message)

    def _get_message(self):
        try:
            message = self._inbox.get(timeout=self._timeout)
            return message
        except Empty:
            return None


class TokenizerWorker(Worker, AudioDataSource):
    def __init__(self, reader, observers=None, logger=None, **kwargs):
        self._observers = observers if observers is not None else []
        self._reader = reader
        self._audio_region_gen = split(self, **kwargs)
        self._detections = []
        self._log_format = "[DET]: Detection {0.id} (start: {0.start:.3f}, "
        self._log_format += "end: {0.end:.3f}, duration: {0.duration:.3f})"
        Worker.__init__(self, timeout=0.2, logger=logger)

    def _process_message(self):
        pass

    @property
    def detections(self):
        return self._detections

    def _notify_observers(self, message):
        for observer in self._observers:
            observer.send(message)

    def run(self):
        self._reader.open()
        start_processing_timestamp = datetime.now()
        for _id, audio_region in enumerate(self._audio_region_gen, start=1):
            timestamp = start_processing_timestamp + timedelta(
                seconds=audio_region.meta.start
            )
            audio_region.meta.timestamp = timestamp
            detection = _Detection(
                _id,
                audio_region.meta.start,
                audio_region.meta.end,
                audio_region.duration,
            )
            self._detections.append(detection)
            if self._logger is not None:
                message = self._log_format.format(detection)
                self._log(message)
            self._notify_observers((_id, audio_region))
        self._notify_observers(_STOP_PROCESSING)
        self._reader.close()

    def start_all(self):
        for observer in self._observers:
            observer.start()
        self.start()

    def stop_all(self):
        self.stop()
        for observer in self._observers:
            observer.stop()
        self._reader.close()

    def read(self):
        if self._stop_requested():
            return None
        else:
            return self._reader.read()

    def __getattr__(self, name):
        return getattr(self._reader, name)


class StreamSaverWorker(Worker):
    def __init__(
        self,
        audio_reader,
        filename,
        export_format=None,
        cache_size_sec=0.5,
        timeout=0.2,
    ):
        self._reader = audio_reader
        sample_size_bytes = self._reader.sw * self._reader.ch
        self._cache_size = cache_size_sec * self._reader.sr * sample_size_bytes
        self._output_filename = filename
        self._export_format = _guess_audio_format(export_format, filename)
        if self._export_format is None:
            self._export_format = "wav"
        self._init_output_stream()
        self._exported = False
        self._cache = []
        self._total_cached = 0
        Worker.__init__(self, timeout=timeout)

    def _get_non_existent_filename(self):
        filename = self._output_filename + ".wav"
        i = 0
        while os.path.exists(filename):
            i += 1
            filename = self._output_filename + "({}).wav".format(i)
        return filename

    def _init_output_stream(self):
        if self._export_format != "wav":
            self._tmp_output_filename = self._get_non_existent_filename()
        else:
            self._tmp_output_filename = self._output_filename
        self._wfp = wave.open(self._tmp_output_filename, "wb")
        self._wfp.setframerate(self._reader.sr)
        self._wfp.setsampwidth(self._reader.sw)
        self._wfp.setnchannels(self._reader.ch)

    @property
    def sr(self):
        return self._reader.sampling_rate

    @property
    def sw(self):
        return self._reader.sample_width

    @property
    def ch(self):
        return self._reader.channels

    def __del__(self):
        self._post_process()

        if (
            (self._tmp_output_filename != self._output_filename)
            and self._exported
            and os.path.exists(self._tmp_output_filename)
        ):
            os.remove(self._tmp_output_filename)

    def _process_message(self, data):
        self._cache.append(data)
        self._total_cached += len(data)
        if self._total_cached >= self._cache_size:
            self._write_cached_data()

    def _post_process(self):
        while True:
            try:
                data = self._inbox.get_nowait()
                if data != _STOP_PROCESSING:
                    self._cache.append(data)
                    self._total_cached += len(data)
            except Empty:
                break
        self._write_cached_data()
        self._wfp.close()

    def _write_cached_data(self):
        if self._cache:
            data = b"".join(self._cache)
            self._wfp.writeframes(data)
            self._cache = []
            self._total_cached = 0

    def open(self):
        self._reader.open()

    def close(self):
        self._reader.close()
        self.stop()

    def rewind(self):
        # ensure compatibility with AudioDataSource with record=True
        pass

    @property
    def data(self):
        with wave.open(self._tmp_output_filename, "rb") as wfp:
            return wfp.readframes(-1)

    def save_stream(self):
        if self._exported:
            return self._output_filename

        if self._export_format in ("raw", "wav"):
            if self._export_format == "raw":
                self._export_raw()
            self._exported = True
            return self._output_filename
        try:
            self._export_with_ffmpeg_or_avconv()
        except AudioEncodingError:
            try:
                self._export_with_sox()
            except AudioEncodingError:
                warn_msg = "Couldn't save audio data in the desired format "
                warn_msg += "'{}'. Either none of 'ffmpeg', 'avconv' or 'sox' "
                warn_msg += "is installed or this format is not recognized.\n"
                warn_msg += "Audio file was saved as '{}'"
                raise AudioEncodingWarning(
                    warn_msg.format(
                        self._export_format, self._tmp_output_filename
                    )
                )
        finally:
            self._exported = True
        return self._output_filename

    def _export_raw(self):
        with open(self._output_filename, "wb") as wfp:
            wfp.write(self.data)

    def _export_with_ffmpeg_or_avconv(self):
        command = [
            "-y",
            "-f",
            "wav",
            "-i",
            self._tmp_output_filename,
            "-f",
            self._export_format,
            self._output_filename,
        ]
        returncode, stdout, stderr = _run_subprocess(["ffmpeg"] + command)
        if returncode != 0:
            returncode, stdout, stderr = _run_subprocess(["avconv"] + command)
            if returncode != 0:
                raise AudioEncodingError(stderr)
        return stdout, stderr

    def _export_with_sox(self):
        command = [
            "sox",
            "-t",
            "wav",
            self._tmp_output_filename,
            self._output_filename,
        ]
        returncode, stdout, stderr = _run_subprocess(command)
        if returncode != 0:
            raise AudioEncodingError(stderr)
        return stdout, stderr

    def close_output(self):
        self._wfp.close()

    def read(self):
        data = self._reader.read()
        if data is not None:
            self.send(data)
        else:
            self.send(_STOP_PROCESSING)
        return data

    def __getattr__(self, name):
        if name == "data":
            return self.data
        return getattr(self._reader, name)


class PlayerWorker(Worker):
    def __init__(self, player, progress_bar=False, timeout=0.2, logger=None):
        self._player = player
        self._progress_bar = progress_bar
        self._log_format = "[PLAY]: Detection {id} played"
        Worker.__init__(self, timeout=timeout, logger=logger)

    def _process_message(self, message):
        _id, audio_region = message
        if self._logger is not None:
            message = self._log_format.format(id=_id)
            self._log(message)
        audio_region.play(
            player=self._player, progress_bar=self._progress_bar, leave=False
        )


class RegionSaverWorker(Worker):
    def __init__(
        self,
        filename_format,
        audio_format=None,
        timeout=0.2,
        logger=None,
        **audio_parameters
    ):
        self._filename_format = filename_format
        self._audio_format = audio_format
        self._audio_parameters = audio_parameters
        self._debug_format = "[SAVE]: Detection {id} saved as '{filename}'"
        Worker.__init__(self, timeout=timeout, logger=logger)

    def _process_message(self, message):
        _id, audio_region = message
        filename = self._filename_format.format(
            id=_id,
            start=audio_region.meta.start,
            end=audio_region.meta.end,
            duration=audio_region.duration,
        )
        filename = audio_region.save(
            filename, self._audio_format, **self._audio_parameters
        )
        if self._logger:
            message = self._debug_format.format(id=_id, filename=filename)
            self._log(message)


class CommandLineWorker(Worker):
    def __init__(self, command, timeout=0.2, logger=None):
        self._command = command
        Worker.__init__(self, timeout=timeout, logger=logger)
        self._debug_format = "[COMMAND]: Detection {id} command: '{command}'"

    def _process_message(self, message):
        _id, audio_region = message
        with NamedTemporaryFile(delete=False) as file:
            filename = audio_region.save(file.name, audio_format="wav")
            command = self._command.format(file=filename)
            os.system(command)
            if self._logger is not None:
                message = self._debug_format.format(id=_id, command=command)
                self._log(message)


class PrintWorker(Worker):
    def __init__(
        self,
        print_format="{start} {end}",
        time_format="%S",
        timestamp_format="%Y/%m/%d %H:%M:%S.%f",
        timeout=0.2,
    ):

        self._print_format = print_format
        self._format_time = make_duration_formatter(time_format)
        self._timestamp_format = timestamp_format
        self.detections = []
        Worker.__init__(self, timeout=timeout)

    def _process_message(self, message):
        _id, audio_region = message
        timestamp = audio_region.meta.timestamp
        timestamp = timestamp.strftime(self._timestamp_format)
        text = self._print_format.format(
            id=_id,
            start=self._format_time(audio_region.meta.start),
            end=self._format_time(audio_region.meta.end),
            duration=self._format_time(audio_region.duration),
            timestamp=timestamp,
        )
        print(text)