# -*- coding: utf-8 -*- import os from contextlib import contextmanager import logging import io import subprocess import sys from datetime import timedelta from typing import cast, Callable, Dict, List, Optional, Union import ffmpeg import numpy as np import tqdm from ffsubsync.constants import ( DEFAULT_ENCODING, DEFAULT_MAX_SUBTITLE_SECONDS, DEFAULT_SCALE_FACTOR, DEFAULT_START_SECONDS, SAMPLE_RATE, ) from ffsubsync.ffmpeg_utils import ffmpeg_bin_path, subprocess_args from ffsubsync.generic_subtitles import GenericSubtitle from ffsubsync.sklearn_shim import TransformerMixin from ffsubsync.sklearn_shim import Pipeline from ffsubsync.subtitle_parser import make_subtitle_parser from ffsubsync.subtitle_transformers import SubtitleScaler logging.basicConfig(level=logging.INFO) logger: logging.Logger = logging.getLogger(__name__) def make_subtitle_speech_pipeline( fmt: str = "srt", encoding: str = DEFAULT_ENCODING, caching: bool = False, max_subtitle_seconds: int = DEFAULT_MAX_SUBTITLE_SECONDS, start_seconds: int = DEFAULT_START_SECONDS, scale_factor: float = DEFAULT_SCALE_FACTOR, parser=None, **kwargs, ) -> Union[Pipeline, Callable[[float], Pipeline]]: if parser is None: parser = make_subtitle_parser( fmt, encoding=encoding, caching=caching, max_subtitle_seconds=max_subtitle_seconds, start_seconds=start_seconds, **kwargs, ) assert parser.encoding == encoding assert parser.max_subtitle_seconds == max_subtitle_seconds assert parser.start_seconds == start_seconds def subpipe_maker(framerate_ratio): return Pipeline( [ ("parse", parser), ("scale", SubtitleScaler(framerate_ratio)), ( "speech_extract", SubtitleSpeechTransformer( sample_rate=SAMPLE_RATE, start_seconds=start_seconds, framerate_ratio=framerate_ratio, ), ), ] ) if scale_factor is None: return subpipe_maker else: return subpipe_maker(scale_factor) def _make_auditok_detector( sample_rate: int, frame_rate: int, non_speech_label: float ) -> Callable[[bytes], np.ndarray]: try: from auditok import ( BufferAudioSource, ADSFactory, AudioEnergyValidator, StreamTokenizer, ) except ImportError as e: logger.error( """Error: auditok not installed! Consider installing it with `pip install auditok`. Note that auditok is GPLv3 licensed, which means that successfully importing it at runtime creates a derivative work that is GPLv3 licensed. For personal use this is fine, but note that any commercial use that relies on auditok must be open source as per the GPLv3!* *Not legal advice. Consult with a lawyer. """ ) raise e bytes_per_frame = 2 frames_per_window = frame_rate // sample_rate validator = AudioEnergyValidator(sample_width=bytes_per_frame, energy_threshold=50) tokenizer = StreamTokenizer( validator=validator, min_length=0.2 * sample_rate, max_length=int(5 * sample_rate), max_continuous_silence=0.25 * sample_rate, ) def _detect(asegment: bytes) -> np.ndarray: asource = BufferAudioSource( data_buffer=asegment, sampling_rate=frame_rate, sample_width=bytes_per_frame, channels=1, ) ads = ADSFactory.ads(audio_source=asource, block_dur=1.0 / sample_rate) ads.open() tokens = tokenizer.tokenize(ads) length = ( len(asegment) // bytes_per_frame + frames_per_window - 1 ) // frames_per_window media_bstring = np.zeros(length + 1) for token in tokens: media_bstring[token[1]] = 1.0 media_bstring[token[2] + 1] = non_speech_label - 1.0 return np.clip(np.cumsum(media_bstring)[:-1], 0.0, 1.0) return _detect def _make_webrtcvad_detector( sample_rate: int, frame_rate: int, non_speech_label: float ) -> Callable[[bytes], np.ndarray]: import webrtcvad vad = webrtcvad.Vad() vad.set_mode(3) # set non-speech pruning aggressiveness from 0 to 3 window_duration = 1.0 / sample_rate # duration in seconds frames_per_window = int(window_duration * frame_rate + 0.5) bytes_per_frame = 2 def _detect(asegment: bytes) -> np.ndarray: media_bstring = [] failures = 0 for start in range(0, len(asegment) // bytes_per_frame, frames_per_window): stop = min(start + frames_per_window, len(asegment) // bytes_per_frame) try: is_speech = vad.is_speech( asegment[start * bytes_per_frame : stop * bytes_per_frame], sample_rate=frame_rate, ) except Exception: is_speech = False failures += 1 # webrtcvad has low recall on mode 3, so treat non-speech as "not sure" media_bstring.append(1.0 if is_speech else non_speech_label) return np.array(media_bstring) return _detect def _make_silero_detector( sample_rate: int, frame_rate: int, non_speech_label: float ) -> Callable[[bytes], np.ndarray]: import torch window_duration = 1.0 / sample_rate # duration in seconds frames_per_window = int(window_duration * frame_rate + 0.5) bytes_per_frame = 1 model, _ = torch.hub.load( repo_or_dir="snakers4/silero-vad", model="silero_vad", force_reload=False, onnx=False, ) exception_logged = False def _detect(asegment) -> np.ndarray: asegment = np.frombuffer(asegment, np.int16).astype(np.float32) / (1 << 15) asegment = torch.FloatTensor(asegment) media_bstring = [] failures = 0 for start in range(0, len(asegment) // bytes_per_frame, frames_per_window): stop = min(start + frames_per_window, len(asegment)) try: speech_prob = model( asegment[start * bytes_per_frame : stop * bytes_per_frame], frame_rate, ).item() except Exception: nonlocal exception_logged if not exception_logged: exception_logged = True logger.exception("exception occurred during speech detection") speech_prob = 0.0 failures += 1 media_bstring.append(1.0 - (1.0 - speech_prob) * (1.0 - non_speech_label)) return np.array(media_bstring) return _detect class ComputeSpeechFrameBoundariesMixin: def __init__(self) -> None: self.start_frame_: Optional[int] = None self.end_frame_: Optional[int] = None @property def num_frames(self) -> Optional[int]: if self.start_frame_ is None or self.end_frame_ is None: return None return self.end_frame_ - self.start_frame_ def fit_boundaries( self, speech_frames: np.ndarray ) -> "ComputeSpeechFrameBoundariesMixin": nz = np.nonzero(speech_frames > 0.5)[0] if len(nz) > 0: self.start_frame_ = int(np.min(nz)) self.end_frame_ = int(np.max(nz)) return self class VideoSpeechTransformer(TransformerMixin): def __init__( self, vad: str, sample_rate: int, frame_rate: int, non_speech_label: float, start_seconds: int = 0, ffmpeg_path: Optional[str] = None, ref_stream: Optional[str] = None, vlc_mode: bool = False, gui_mode: bool = False, ) -> None: super(VideoSpeechTransformer, self).__init__() self.vad: str = vad self.sample_rate: int = sample_rate self.frame_rate: int = frame_rate self._non_speech_label: float = non_speech_label self.start_seconds: int = start_seconds self.ffmpeg_path: Optional[str] = ffmpeg_path self.ref_stream: Optional[str] = ref_stream self.vlc_mode: bool = vlc_mode self.gui_mode: bool = gui_mode self.video_speech_results_: Optional[np.ndarray] = None def try_fit_using_embedded_subs(self, fname: str) -> None: embedded_subs = [] embedded_subs_times = [] if self.ref_stream is None: # check first 5; should cover 99% of movies streams_to_try: List[str] = list(map("0:s:{}".format, range(5))) else: streams_to_try = [self.ref_stream] for stream in streams_to_try: ffmpeg_args = [ ffmpeg_bin_path( "ffmpeg", self.gui_mode, ffmpeg_resources_path=self.ffmpeg_path ) ] ffmpeg_args.extend( [ "-loglevel", "fatal", "-nostdin", "-i", fname, "-map", "{}".format(stream), "-f", "srt", "-", ] ) process = subprocess.Popen( ffmpeg_args, **subprocess_args(include_stdout=True) ) output = io.BytesIO(process.communicate()[0]) if process.returncode != 0: break pipe = cast( Pipeline, make_subtitle_speech_pipeline(start_seconds=self.start_seconds), ).fit(output) speech_step = pipe.steps[-1][1] embedded_subs.append(speech_step) embedded_subs_times.append(speech_step.max_time_) if len(embedded_subs) == 0: if self.ref_stream is None: error_msg = "Video file appears to lack subtitle stream" else: error_msg = "Stream {} not found".format(self.ref_stream) raise ValueError(error_msg) # use longest set of embedded subs subs_to_use = embedded_subs[int(np.argmax(embedded_subs_times))] self.video_speech_results_ = subs_to_use.subtitle_speech_results_ def fit(self, fname: str, *_) -> "VideoSpeechTransformer": if "subs" in self.vad and ( self.ref_stream is None or self.ref_stream.startswith("0:s:") ): try: logger.info("Checking video for subtitles stream...") self.try_fit_using_embedded_subs(fname) logger.info("...success!") return self except Exception as e: logger.info(e) try: total_duration = ( float( ffmpeg.probe( fname, cmd=ffmpeg_bin_path( "ffprobe", self.gui_mode, ffmpeg_resources_path=self.ffmpeg_path, ), )["format"]["duration"] ) - self.start_seconds ) except Exception as e: logger.warning(e) total_duration = None if "webrtc" in self.vad: detector = _make_webrtcvad_detector( self.sample_rate, self.frame_rate, self._non_speech_label ) elif "auditok" in self.vad: detector = _make_auditok_detector( self.sample_rate, self.frame_rate, self._non_speech_label ) elif "silero" in self.vad: detector = _make_silero_detector( self.sample_rate, self.frame_rate, self._non_speech_label ) else: raise ValueError("unknown vad: %s" % self.vad) media_bstring: List[np.ndarray] = [] ffmpeg_args = [ ffmpeg_bin_path( "ffmpeg", self.gui_mode, ffmpeg_resources_path=self.ffmpeg_path ) ] if self.start_seconds > 0: ffmpeg_args.extend( [ "-ss", str(timedelta(seconds=self.start_seconds)), ] ) ffmpeg_args.extend(["-loglevel", "fatal", "-nostdin", "-i", fname]) if self.ref_stream is not None and self.ref_stream.startswith("0:a:"): ffmpeg_args.extend(["-map", self.ref_stream]) ffmpeg_args.extend( [ "-f", "s16le", "-ac", "1", "-acodec", "pcm_s16le", "-af", "aresample=async=1", "-ar", str(self.frame_rate), "-", ] ) process = subprocess.Popen(ffmpeg_args, **subprocess_args(include_stdout=True)) bytes_per_frame = 2 frames_per_window = bytes_per_frame * self.frame_rate // self.sample_rate windows_per_buffer = 10000 simple_progress = 0.0 redirect_stderr = None tqdm_extra_args = {} should_print_redirected_stderr = self.gui_mode if self.gui_mode: try: from contextlib import redirect_stderr # type: ignore tqdm_extra_args["file"] = sys.stdout except ImportError: should_print_redirected_stderr = False if redirect_stderr is None: @contextmanager def redirect_stderr(enter_result=None): yield enter_result assert redirect_stderr is not None pbar_output = io.StringIO() with redirect_stderr(pbar_output): with tqdm.tqdm( total=total_duration, disable=self.vlc_mode, **tqdm_extra_args ) as pbar: while True: in_bytes = process.stdout.read( frames_per_window * windows_per_buffer ) if not in_bytes: break newstuff = len(in_bytes) / float(bytes_per_frame) / self.frame_rate if ( total_duration is not None and simple_progress + newstuff > total_duration ): newstuff = total_duration - simple_progress simple_progress += newstuff pbar.update(newstuff) if self.vlc_mode and total_duration is not None: print("%d" % int(simple_progress * 100.0 / total_duration)) sys.stdout.flush() if should_print_redirected_stderr: assert self.gui_mode # no need to flush since we pass -u to do unbuffered output for gui mode print(pbar_output.read()) if "silero" not in self.vad: in_bytes = np.frombuffer(in_bytes, np.uint8) media_bstring.append(detector(in_bytes)) process.wait() if len(media_bstring) == 0: raise ValueError( "Unable to detect speech. " "Perhaps try specifying a different stream / track, or a different vad." ) self.video_speech_results_ = np.concatenate(media_bstring) logger.info("total of speech segments: %s", np.sum(self.video_speech_results_)) return self def transform(self, *_) -> np.ndarray: return self.video_speech_results_ _PAIRED_NESTER: Dict[str, str] = { "(": ")", "{": "}", "[": "]", # FIXME: False positive sometimes when there are html tags, e.g. Hello? # '<': '>', } # TODO: need way better metadata detector def _is_metadata(content: str, is_beginning_or_end: bool) -> bool: content = content.strip() if len(content) == 0: return True if ( content[0] in _PAIRED_NESTER.keys() and content[-1] == _PAIRED_NESTER[content[0]] ): return True if is_beginning_or_end: if "english" in content.lower(): return True if " - " in content: return True return False class SubtitleSpeechTransformer(TransformerMixin, ComputeSpeechFrameBoundariesMixin): def __init__( self, sample_rate: int, start_seconds: int = 0, framerate_ratio: float = 1.0 ) -> None: super(SubtitleSpeechTransformer, self).__init__() self.sample_rate: int = sample_rate self.start_seconds: int = start_seconds self.framerate_ratio: float = framerate_ratio self.subtitle_speech_results_: Optional[np.ndarray] = None self.max_time_: Optional[int] = None def fit(self, subs: List[GenericSubtitle], *_) -> "SubtitleSpeechTransformer": max_time = 0 for sub in subs: max_time = max(max_time, sub.end.total_seconds()) self.max_time_ = max_time - self.start_seconds samples = np.zeros(int(max_time * self.sample_rate) + 2, dtype=float) start_frame = float("inf") end_frame = 0 for i, sub in enumerate(subs): if _is_metadata(sub.content, i == 0 or i + 1 == len(subs)): continue start = int( round( (sub.start.total_seconds() - self.start_seconds) * self.sample_rate ) ) start_frame = min(start_frame, start) duration = sub.end.total_seconds() - sub.start.total_seconds() end = start + int(round(duration * self.sample_rate)) end_frame = max(end_frame, end) samples[start:end] = min(1.0 / self.framerate_ratio, 1.0) self.subtitle_speech_results_ = samples self.fit_boundaries(self.subtitle_speech_results_) return self def transform(self, *_) -> np.ndarray: assert self.subtitle_speech_results_ is not None return self.subtitle_speech_results_ class DeserializeSpeechTransformer(TransformerMixin): def __init__(self, non_speech_label: float) -> None: super(DeserializeSpeechTransformer, self).__init__() self._non_speech_label: float = non_speech_label self.deserialized_speech_results_: Optional[np.ndarray] = None def fit(self, fname, *_) -> "DeserializeSpeechTransformer": speech = np.load(fname) if hasattr(speech, "files"): if "speech" in speech.files: speech = speech["speech"] else: raise ValueError( 'could not find "speech" array in ' "serialized file; only contains: %s" % speech.files ) speech[speech < 1.0] = self._non_speech_label self.deserialized_speech_results_ = speech return self def transform(self, *_) -> np.ndarray: assert self.deserialized_speech_results_ is not None return self.deserialized_speech_results_