import json
import logging
import threading
import warnings
from typing import Any, Iterator, List, Optional, Type, Union

logger = logging.getLogger("pycountry.db")


class Data:
    def __init__(self, **fields: str):
        self._fields = fields

    def __getattr__(self, key):
        if key in self._fields:
            return self._fields[key]
        raise AttributeError()

    def __setattr__(self, key: str, value: str) -> None:
        if key != "_fields":
            self._fields[key] = value
        super().__setattr__(key, value)

    def __repr__(self) -> str:
        cls_name = self.__class__.__name__
        fields = ", ".join("%s=%r" % i for i in sorted(self._fields.items()))
        return f"{cls_name}({fields})"

    def __dir__(self) -> List[str]:
        return dir(self.__class__) + list(self._fields)

    def __iter__(self):
        # allow casting into a dict
        for field in self._fields:
            yield field, getattr(self, field)


class Country(Data):
    def __getattr__(self, key):
        if key in ("common_name", "official_name"):
            # First try to get the common_name or official_name
            value = self._fields.get(key)
            if value is not None:
                return value
            # Fall back to name if common_name or official_name is not found
            name = self._fields.get("name")
            if name is not None:
                warning_message = (
                    f"Country's {key} not found. Country name provided instead."
                )
                warnings.warn(warning_message, UserWarning)
                return name
            raise AttributeError()
        else:
            # For other keys, simply return the value or raise an error
            if key in self._fields:
                return self._fields[key]
            raise AttributeError()


class Subdivision(Data):
    pass


def lazy_load(f):
    def load_if_needed(self, *args, **kw):
        if not self._is_loaded:
            with self._load_lock:
                self._load()
        return f(self, *args, **kw)

    return load_if_needed


class Database:
    data_class: Union[Type, str]
    root_key: Optional[str] = None
    no_index: List[str] = []

    def __init__(self, filename: str) -> None:
        self.filename = filename
        self._is_loaded = False
        self._load_lock = threading.Lock()

        if isinstance(self.data_class, str):
            self.factory = type(self.data_class, (Data,), {})
        else:
            self.factory = self.data_class

    def _clear(self):
        self._is_loaded = False
        self.objects = []
        self.index_names = set()
        self.indices = {}

    def _load(self) -> None:
        if self._is_loaded:
            # Help keeping the _load_if_needed code easier
            # to read.
            return
        self._clear()

        with open(self.filename, encoding="utf-8") as f:
            tree = json.load(f)

        for entry in tree[self.root_key]:
            obj = self.factory(**entry)
            self.objects.append(obj)
            # Inject into index.
            for key, value in entry.items():
                if key in self.no_index:
                    continue
                # Lookups and searches are case insensitive. Normalize
                # here.
                index = self.indices.setdefault(key, {})
                value = value.lower()
                if value in index:
                    logger.debug(
                        "%s %r already taken in index %r and will be "
                        "ignored. This is an error in the databases."
                        % (self.factory.__name__, value, key)
                    )
                index[value] = obj

        self._is_loaded = True

    # Public API

    @lazy_load
    def add_entry(self, **kw):
        # create the object with the correct dynamic type
        obj = self.factory(**kw)

        # append object
        self.objects.append(obj)

        # update indices
        for key, value in kw.items():
            if key in self.no_index:
                continue
            value = value.lower()
            index = self.indices.setdefault(key, {})
            index[value] = obj

    @lazy_load
    def remove_entry(self, **kw):
        # make sure that we receive None if no entry found
        if "default" in kw:
            del kw["default"]
        obj = self.get(**kw)
        if not obj:
            raise KeyError(
                f"{self.factory.__name__} not found and cannot be removed: {kw}"
            )

        # remove object
        self.objects.remove(obj)

        # update indices
        for key, value in obj:
            if key in self.no_index:
                continue
            value = value.lower()
            index = self.indices.setdefault(key, {})
            if value in index:
                del index[value]

    @lazy_load
    def __iter__(self) -> Iterator["Database"]:
        return iter(self.objects)

    @lazy_load
    def __len__(self) -> int:
        return len(self.objects)

    @lazy_load
    def get(self, **kw: Optional[str]) -> Optional[Any]:
        kw.setdefault("default", None)
        default = kw.pop("default")
        if len(kw) != 1:
            raise TypeError("Only one criteria may be given")
        field, value = kw.popitem()
        if not isinstance(value, str):
            raise LookupError()
        # Normalize for case-insensitivity
        value = value.lower()
        index = self.indices[field]
        try:
            return index[value]
        except KeyError:
            # Pythonic APIs implementing     get() shouldn't raise KeyErrors.
            # Those are a bit unexpected and they should rather support
            # returning `None` by default and allow customization.
            return default

    @lazy_load
    def lookup(self, value: str) -> Type:
        if not isinstance(value, str):
            raise LookupError()

        # Normalize for case-insensitivity
        value = value.lower()

        # Use indexes first
        for key in self.indices:
            try:
                return self.indices[key][value]
            except LookupError:
                pass

        # Use non-indexed values now. Avoid going through indexed values.
        for candidate in self:
            for k in self.no_index:
                v = candidate._fields.get(k)
                if v is None:
                    continue
                if v.lower() == value:
                    return candidate

        raise LookupError("Could not find a record for %r" % value)