diff --git a/libs/subliminal_patch/core.py b/libs/subliminal_patch/core.py index 151eb5df1..5c7f34313 100644 --- a/libs/subliminal_patch/core.py +++ b/libs/subliminal_patch/core.py @@ -111,6 +111,36 @@ class _ProviderConfigs(dict): return super().update(items) +class _Banlist: + def __init__(self, must_not_contain, must_contain): + self.must_not_contain = must_not_contain + self.must_contain = must_contain + + def is_valid(self, subtitle): + if subtitle.release_info is None: + return True + + if any([x for x in self.must_not_contain + if re.search(x, subtitle.release_info, flags=re.IGNORECASE) is not None]): + logger.info("Skipping subtitle because release name contains prohibited string: %s", subtitle) + return False + if any([x for x in self.must_contain + if re.search(x, subtitle.release_info, flags=re.IGNORECASE) is None]): + logger.info("Skipping subtitle because release name does not contains required string: %s", subtitle) + return False + + return True + + +class _Blacklist(list): + def is_valid(self, provider, subtitle): + blacklisted = not (str(provider), str(subtitle.id)) in self + if blacklisted: + logger.debug("Skipping blacklisted subtitle: %s", subtitle) + + return blacklisted + + class SZProviderPool(ProviderPool): def __init__(self, providers=None, provider_configs=None, blacklist=None, ban_list=None, throttle_callback=None, pre_download_hook=None, post_download_hook=None, language_hook=None): @@ -123,10 +153,10 @@ class SZProviderPool(ProviderPool): #: Discarded providers self.discarded_providers = set() - self.blacklist = blacklist or [] + self.blacklist = _Blacklist(blacklist or []) #: Should be a dict of 2 lists of strings - self.ban_list = ban_list or {'must_contain': [], 'must_not_contain': []} + self.ban_list = _Banlist(**(ban_list or {'must_contain': [], 'must_not_contain': []})) self.throttle_callback = throttle_callback @@ -175,8 +205,8 @@ class SZProviderPool(ProviderPool): # self.provider_configs = provider_configs self.provider_configs.update(provider_configs) - self.blacklist = blacklist or [] - self.ban_list = ban_list or {'must_contain': [], 'must_not_contain': []} + self.blacklist = _Blacklist(blacklist or []) + self.ban_list = _Banlist(**ban_list or {'must_contain': [], 'must_not_contain': []}) return updated @@ -267,18 +297,12 @@ class SZProviderPool(ProviderPool): seen = [] out = [] for s in results: - if (str(provider), str(s.id)) in self.blacklist: - logger.info("Skipping blacklisted subtitle: %s", s) + if not self.blacklist.is_valid(provider, s): continue - if s.release_info is not None: - if any([x for x in self.ban_list["must_not_contain"] - if re.search(x, s.release_info, flags=re.IGNORECASE) is not None]): - logger.info("Skipping subtitle because release name contains prohibited string: %s", s) - continue - if any([x for x in self.ban_list["must_contain"] - if re.search(x, s.release_info, flags=re.IGNORECASE) is None]): - logger.info("Skipping subtitle because release name does not contains required string: %s", s) - continue + + if not self.ban_list.is_valid(s): + continue + if s.id in seen: continue