You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
bazarr/libs/ffsubsync/aligners.py

121 lines
4.7 KiB

# -*- coding: utf-8 -*-
import logging
import math
import numpy as np
from .constants import FRAMERATE_RATIOS
from .golden_section_search import gss
from .sklearn_shim import TransformerMixin
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class FailedToFindAlignmentException(Exception):
pass
class FFTAligner(TransformerMixin):
def __init__(self, max_offset_samples=None):
self.max_offset_samples = max_offset_samples
self.best_offset_ = None
self.best_score_ = None
self.get_score_ = False
def _zero_out_extreme_offsets(self, convolve, substring):
convolve = np.copy(convolve)
if self.max_offset_samples is None:
return convolve
offset_to_index = lambda offset: len(convolve) - 1 + offset - len(substring)
convolve[:offset_to_index(-self.max_offset_samples)] = convolve[offset_to_index(self.max_offset_samples):] = 0
return convolve
def _compute_argmax(self, convolve, substring):
best_idx = np.argmax(convolve)
self.best_offset_ = len(convolve) - 1 - best_idx - len(substring)
self.best_score_ = convolve[best_idx]
def fit(self, refstring, substring, get_score=False):
refstring, substring = [
list(map(int, s))
if isinstance(s, str) else s
for s in [refstring, substring]
]
refstring, substring = map(
lambda s: 2 * np.array(s).astype(float) - 1, [refstring, substring])
total_bits = math.log(len(substring) + len(refstring), 2)
total_length = int(2 ** math.ceil(total_bits))
extra_zeros = total_length - len(substring) - len(refstring)
subft = np.fft.fft(np.append(np.zeros(extra_zeros + len(refstring)), substring))
refft = np.fft.fft(np.flip(np.append(refstring, np.zeros(len(substring) + extra_zeros)), 0))
convolve = np.real(np.fft.ifft(subft * refft))
self._compute_argmax(self._zero_out_extreme_offsets(convolve, substring), substring)
if self.best_score_ == 0.:
self._compute_argmax(convolve, substring)
self.get_score_ = get_score
return self
def transform(self, *_):
if self.get_score_:
return self.best_score_, self.best_offset_
else:
return self.best_offset_
class MaxScoreAligner(TransformerMixin):
def __init__(self, base_aligner, srtin=None, sample_rate=None, max_offset_seconds=None):
self.srtin = srtin
if sample_rate is None or max_offset_seconds is None:
self.max_offset_samples = None
else:
self.max_offset_samples = abs(int(max_offset_seconds * sample_rate))
if isinstance(base_aligner, type):
self.base_aligner = base_aligner(max_offset_samples=self.max_offset_samples)
else:
self.base_aligner = base_aligner
self.max_offset_seconds = max_offset_seconds
self._scores = []
def fit_gss(self, refstring, subpipe_maker):
def opt_func(framerate_ratio, is_last_iter):
subpipe = subpipe_maker(framerate_ratio)
substring = subpipe.fit_transform(self.srtin)
score = self.base_aligner.fit_transform(refstring, substring, get_score=True)
logger.info('got score %.0f (offset %d) for ratio %.3f', score[0], score[1], framerate_ratio)
if is_last_iter:
self._scores.append((score, subpipe))
return -score[0]
gss(opt_func, 0.9, 1.1)
return self
def fit(self, refstring, subpipes):
if not isinstance(subpipes, list):
subpipes = [subpipes]
for subpipe in subpipes:
if callable(subpipe):
self.fit_gss(refstring, subpipe)
continue
elif hasattr(subpipe, 'transform'):
substring = subpipe.transform(self.srtin)
else:
substring = subpipe
self._scores.append((
self.base_aligner.fit_transform(
refstring, substring, get_score=True
),
subpipe
))
return self
def transform(self, *_):
scores = self._scores
if self.max_offset_samples is not None:
scores = list(filter(lambda s: abs(s[0][1]) <= self.max_offset_samples, scores))
if len(scores) == 0:
raise FailedToFindAlignmentException('Synchronization failed; consider passing '
'--max-offset-seconds with a number larger than '
'{}'.format(self.max_offset_seconds))
(score, offset), subpipe = max(scores, key=lambda x: x[0][0])
return (score, offset), subpipe