diff --git a/README.md b/README.md index ce8967f..75bf649 100644 --- a/README.md +++ b/README.md @@ -11,6 +11,7 @@ thesis `Testing robustness of DeepFake recognition methods against non-malicious * [How it works?](#how-it-works-) * [TL;DR:](#tl-dr-) * [Tell me more!](#tell-me-more-) +* [Favourite code snippet](#favourite-code-snippet) * [Design](#design) * [Installation](#installation) * [GPU configuration](#gpu-configuration) @@ -23,10 +24,9 @@ recognition method used to properly discriminate DeepFake from authentic video w ## What is this? * CLI program allowing to: - * Preprocess dataset containing real and fake videos. Chosen part of negatives is modified by selected non-malicious - image modifications based on settings assigned by user. + * Preprocess dataset containing real and fake videos. Chosen part of negatives is modified by selected modifications based on settings assigned by user. * Preprocessing can be done via single command or gradually in steps where single step represents activity such - as `extract faces from fake videos.` The letter method allows preprocessing large datasets[^2] on without the + as `extract faces from fake videos`. The letter method allows preprocessing large datasets[^2] without the need of keeping program on for 24+ hours. * Train detection model and evaluate it in assigned settings. @@ -59,28 +59,57 @@ alter negatives are configured via YAML files. Here is a sample: --- # Example modifications settings. # In total half of negatives are altered. -modifications: - - name: RedEyesEffectModification - share: 0.125 - options: - brightness_threshold: 50 - - name: CLAHEModification - share: 0.125 - options: - clip_limit: 2.0 - grid_width: 8 - grid_height: 8 - - name: HistogramEqualizationModification - share: 0.125 - - name: GaussianBlurModification - share: 0.125 - options: - kernel_width: 9 - kernel_height: 9 +modifications_chains: + - share: 0.25 + modifications: + - name: GammaCorrectionModification + options: + gamma_value: 0.75 + - name: CLAHEModification + options: + clip_limit: 2.0 + grid_width: 8 + grid_height: 8 + - share: 0.25 + modifications: + - name: GammaCorrectionModification + options: + gamma_value: 0.75 + - name: HistogramEqualizationModification ``` The names of supported modifications can be found in [this file](src/dfd/datasets/modifications/register.py). +## Favourite code snippet + +```python +@given( + given_specifications=st.lists( + st.builds( + ModificationStub, + name=st.text(min_size=1), + no_repeats=st.integers(min_value=1), + ), + min_size=1, + max_size=5, + ) +) +def test_combine_multiple_specifications(given_specifications): + # GIVEN + image_mock = Mock(spec_set=np.ndarray) + # WHEN specifications are combined + combined_specification = functools.reduce(operator.and_, given_specifications) + # THEN specification names are combined + expected_name = "__".join([spec.name for spec in given_specifications]) + assert combined_specification.name == expected_name + # And specifications are performed in order + combined_specification.perform(image_mock) + expected_calls_in_order = [call.repeat(spec.no_repeats) for spec in given_specifications] + image_mock.assert_has_calls(expected_calls_in_order) +``` + +It's uses two cool concepts: property-based testing and specification pattern[^4]. + ## Design The application design is loosely inspired @@ -121,3 +150,4 @@ pip install git+https://github.com/cicheck/dfd.git [^1]: Currently the only supported detection method is [Meso-4](https://arxiv.org/abs/1809.00888). [^2]: Such as [Celeb-DF](https://github.com/yuezunli/celeb-deepfakeforensics). [^3]: Half of negatives modified, 4 modifications used with naive parameters. +[^4]: Or rather design loosely inspired by specification pattern :stuck_out_tongue: diff --git a/docs/diagrams/app_architecture.png b/docs/diagrams/app_architecture.png index 7ca0251..0338c95 100644 Binary files a/docs/diagrams/app_architecture.png and b/docs/diagrams/app_architecture.png differ diff --git a/docs/diagrams/app_architecture.puml b/docs/diagrams/app_architecture.puml index e9f99e6..9617534 100644 --- a/docs/diagrams/app_architecture.puml +++ b/docs/diagrams/app_architecture.puml @@ -24,7 +24,7 @@ package "Business Logic (Core)" #skyblue { } package Modifications { - interface ModificationInterface { + abstract ModificationSpecification { + name -- + perform() @@ -100,8 +100,8 @@ MesoNet -up-- ModelRegistry : register > MesoNet -up--|> ModelInterface -GammaCorrectionModification -up--|> ModificationInterface -CLAHEModification -up--|> ModificationInterface +GammaCorrectionModification -up--|> ModificationSpecification +CLAHEModification -up--|> ModificationSpecification GammaCorrectionModification -up-- ModificationRegistry : register > CLAHEModification -up-- ModificationRegistry : register > diff --git a/example_settings.yaml b/example_settings.yaml index 8fb166b..98b90df 100644 --- a/example_settings.yaml +++ b/example_settings.yaml @@ -1,25 +1,19 @@ --- # Example ModificationGenerator settings file -modifications: - - name: RedEyesEffectModification - share: 0.125 - options: - brightness_threshold: 50 - face_landmarks_detector_path: "/media/cicheck/Extreme Pro/\ - models/shape_predictor_68_face_landmarks.dat" - - name: CLAHEModification - share: 0.125 - options: - clip_limit: 2.0 - grid_width: 8 - grid_height: 8 - - name: HistogramEqualizationModification - share: 0.125 - - name: GammaCorrectionModification - share: 0.0625 - options: - gamma_value: 0.75 - - name: GammaCorrectionModification - share: 0.0625 - options: - gamma_value: 1.25 +modifications_chains: + - share: 0.25 + modifications: + - name: GammaCorrectionModification + options: + gamma_value: 0.75 + - name: CLAHEModification + options: + clip_limit: 2.0 + grid_width: 8 + grid_height: 8 + - share: 0.25 + modifications: + - name: GammaCorrectionModification + options: + gamma_value: 0.75 + - name: HistogramEqualizationModification diff --git a/setup.cfg b/setup.cfg index 5b00244..5ab3398 100644 --- a/setup.cfg +++ b/setup.cfg @@ -47,6 +47,7 @@ tests = pytest-cov==2.12.1 pytest>=6.2.4 coverage[toml] + hypothesis==6.34.1 style = darglint>=1.8.0 flake8>=3.9.2 diff --git a/src/dfd/datasets/frames_generators/modification.py b/src/dfd/datasets/frames_generators/modification.py index 82e5c3e..6eee013 100644 --- a/src/dfd/datasets/frames_generators/modification.py +++ b/src/dfd/datasets/frames_generators/modification.py @@ -1,23 +1,24 @@ """Generate new frames after performing set non malicious modifications on original frames.""" import functools import itertools +import operator import pathlib -from typing import Generator, List, NamedTuple, Optional +from typing import Generator, List, NamedTuple, Optional, Sequence import cv2 as cv import numpy as np -from dfd.exceptions import DfdError from dfd.datasets.modifications.definitions import IdentityModification -from dfd.datasets.modifications.interfaces import ModificationInterface from dfd.datasets.modifications.register import ModificationRegister -from dfd.datasets.settings import GeneratorSettings +from dfd.datasets.modifications.specification import ModificationSpecification +from dfd.datasets.settings import GeneratorSettings, ModificationSettings +from dfd.exceptions import DfdError class ModificationShare(NamedTuple): """Share of frames on which modification will be performed.""" - modification: ModificationInterface + modification: ModificationSpecification share: float @@ -31,7 +32,7 @@ class ModificationRange(NamedTuple): """ - modification: ModificationInterface + modification: ModificationSpecification lower_bound: int upper_bound: int @@ -93,7 +94,7 @@ def from_directory( input_frame = cv.imread(str(input_frame_path)) modified_frame = modification.perform(input_frame) yield ModifiedFrame( - modification_used=str(modification), + modification_used=modification.name, frame=modified_frame, original_path=input_frame_path, ) @@ -101,14 +102,9 @@ def from_directory( @functools.lru_cache(maxsize=1) def _get_modifications_share(self) -> List[ModificationShare]: modifications_share: List[ModificationShare] = [] - for modification_settings in self._setting.modifications: - mame = modification_settings.name - options = modification_settings.options - share = modification_settings.share - - modification_class = self._register.get_modification_class(mame) - # TODO: fix typing - modification = modification_class(**options) # type: ignore + for modification_chain_settings in self._setting.modifications_chains: + share = modification_chain_settings.share + modification = self._chain_modifications(modification_chain_settings.modifications) modifications_share.append(ModificationShare(modification, share)) self._check_modifications_are_unique( @@ -117,7 +113,7 @@ def _get_modifications_share(self) -> List[ModificationShare]: return modifications_share @staticmethod - def _check_modifications_are_unique(modifications: List[ModificationInterface]): + def _check_modifications_are_unique(modifications: List[ModificationSpecification]): """Check if modifications are unique. Raises: @@ -160,7 +156,7 @@ def _get_frames_permutation(self, no_frames: int) -> np.ndarray: def _choose_modification( self, frame_index: int, input_frame_path: pathlib.Path, no_frames: int - ) -> ModificationInterface: + ) -> ModificationSpecification: frames_permutation = self._get_frames_permutation(no_frames) modifications_range = self._get_modifications_range(no_frames) permuted_index = frames_permutation[frame_index] @@ -170,3 +166,14 @@ def _choose_modification( # TODO: log error # This should never happen raise DfdError("Could not select modification.") + + def _chain_modifications( + self, modifications_settings: Sequence[ModificationSettings] + ) -> ModificationSpecification: + modifications: List[ModificationSpecification] = [] + for modification_settings in modifications_settings: + modification_class = self._register.get_modification_class(modification_settings.name) + # TODO: fix typing + modification = modification_class(**modification_settings.options) # type: ignore + modifications.append(modification) + return functools.reduce(operator.and_, modifications) diff --git a/src/dfd/datasets/modifications/definitions/clahe.py b/src/dfd/datasets/modifications/definitions/clahe.py index 3ea8787..3409b90 100644 --- a/src/dfd/datasets/modifications/definitions/clahe.py +++ b/src/dfd/datasets/modifications/definitions/clahe.py @@ -3,10 +3,10 @@ import cv2 as cv import numpy as np -from dfd.datasets.modifications.interfaces import ModificationInterface +from dfd.datasets.modifications.specification import ModificationSpecification -class CLAHEModification(ModificationInterface): +class CLAHEModification(ModificationSpecification): """Modification CLAHE (Contrast Limited Adaptive Histogram Equalization)""" def __init__(self, clip_limit: float, grid_width: int, grid_height: int) -> None: @@ -21,6 +21,17 @@ def __init__(self, clip_limit: float, grid_width: int, grid_height: int) -> None self._clip_limit = clip_limit self._title_grid_size = (grid_width, grid_height) + @property + def name(self) -> str: + """Get specification name. + + Returns: + The name of specification. + + """ + width, height = self._title_grid_size + return f"clahe_{width}_{height}_{self._clip_limit}" + def perform(self, image: np.ndarray) -> np.ndarray: """Perform CLAHE on image. @@ -44,5 +55,4 @@ def perform(self, image: np.ndarray) -> np.ndarray: return cv.cvtColor(ycrcb_image, cv.COLOR_YCrCb2BGR) def __str__(self) -> str: - width, height = self._title_grid_size - return f"clahe_{width}_{height}_{self._clip_limit}" + return self.name diff --git a/src/dfd/datasets/modifications/definitions/gamma_correction.py b/src/dfd/datasets/modifications/definitions/gamma_correction.py index cc4d81b..9a8e0ca 100644 --- a/src/dfd/datasets/modifications/definitions/gamma_correction.py +++ b/src/dfd/datasets/modifications/definitions/gamma_correction.py @@ -2,10 +2,10 @@ import cv2 as cv import numpy as np -from dfd.datasets.modifications.interfaces import ModificationInterface +from dfd.datasets.modifications.specification import ModificationSpecification -class GammaCorrectionModification(ModificationInterface): +class GammaCorrectionModification(ModificationSpecification): """Modification Gamma Correction.""" def __init__(self, gamma_value: float) -> None: @@ -17,6 +17,16 @@ def __init__(self, gamma_value: float) -> None: """ self._gamma_value = gamma_value + @property + def name(self) -> str: + """Get specification name. + + Returns: + The name of specification. + + """ + return f"gamma_correction_{self._gamma_value}" + def perform(self, image: np.ndarray) -> np.ndarray: """Perform gamma correction on provided image. @@ -38,4 +48,4 @@ def perform(self, image: np.ndarray) -> np.ndarray: return cv.LUT(image, look_up_table) def __str__(self) -> str: - return f"gamma_correction_{self._gamma_value}" + return self.name diff --git a/src/dfd/datasets/modifications/definitions/gaussian_blur.py b/src/dfd/datasets/modifications/definitions/gaussian_blur.py index 12e7183..ab14b03 100644 --- a/src/dfd/datasets/modifications/definitions/gaussian_blur.py +++ b/src/dfd/datasets/modifications/definitions/gaussian_blur.py @@ -3,10 +3,10 @@ import cv2 as cv import numpy as np -from dfd.datasets.modifications.interfaces import ModificationInterface +from dfd.datasets.modifications.specification import ModificationSpecification -class GaussianBlurModification(ModificationInterface): +class GaussianBlurModification(ModificationSpecification): """Modification Gaussian blur (AKA Gaussian smoothing).""" def __init__( @@ -30,6 +30,16 @@ def __init__( self._sigma_y = sigma_y def __str__(self) -> str: + return self.name + + @property + def name(self) -> str: + """Get specification name. + + Returns: + The name of specification. + + """ width, height = self._kernel_size return f"gaussian_blur{width}_{height}_{self._sigma_x}_{self._sigma_y}" diff --git a/src/dfd/datasets/modifications/definitions/gaussian_noise.py b/src/dfd/datasets/modifications/definitions/gaussian_noise.py index 34878f0..cbccef0 100644 --- a/src/dfd/datasets/modifications/definitions/gaussian_noise.py +++ b/src/dfd/datasets/modifications/definitions/gaussian_noise.py @@ -2,10 +2,10 @@ import numpy as np -from dfd.datasets.modifications.interfaces import ModificationInterface +from dfd.datasets.modifications.specification import ModificationSpecification -class GaussianNoiseModification(ModificationInterface): +class GaussianNoiseModification(ModificationSpecification): """Modification Gaussian noise.""" def __init__( @@ -24,6 +24,16 @@ def __init__( self._mean = mean self._standard_deviation = standard_deviation + @property + def name(self) -> str: + """Get specification name. + + Returns: + The name of specification. + + """ + return f"gaussian_noise{self._mean}_{self._standard_deviation}" + def perform(self, image: np.ndarray) -> np.ndarray: """Add Gaussian noise to provided image. diff --git a/src/dfd/datasets/modifications/definitions/histogram_equalization.py b/src/dfd/datasets/modifications/definitions/histogram_equalization.py index 49d3b25..513aff5 100644 --- a/src/dfd/datasets/modifications/definitions/histogram_equalization.py +++ b/src/dfd/datasets/modifications/definitions/histogram_equalization.py @@ -3,10 +3,10 @@ import cv2 as cv import numpy as np -from dfd.datasets.modifications.interfaces import ModificationInterface +from dfd.datasets.modifications.specification import ModificationSpecification -class HistogramEqualizationModification(ModificationInterface): +class HistogramEqualizationModification(ModificationSpecification): """Modification Histogram Equalization.""" def perform(self, image: np.ndarray) -> np.ndarray: @@ -29,5 +29,15 @@ def perform(self, image: np.ndarray) -> np.ndarray: # Convert back to BGR return cv.cvtColor(ycrcb_image, cv.COLOR_YCrCb2BGR) - def __str__(self) -> str: + @property + def name(self) -> str: + """Get specification name. + + Returns: + The name of specification. + + """ return "histogram_equalization" + + def __str__(self) -> str: + return self.name diff --git a/src/dfd/datasets/modifications/definitions/identity.py b/src/dfd/datasets/modifications/definitions/identity.py index a2302f5..58841c4 100644 --- a/src/dfd/datasets/modifications/definitions/identity.py +++ b/src/dfd/datasets/modifications/definitions/identity.py @@ -1,10 +1,10 @@ """Modification Identity.""" import numpy as np -from dfd.datasets.modifications.interfaces import ModificationInterface +from dfd.datasets.modifications.specification import ModificationSpecification -class IdentityModification(ModificationInterface): +class IdentityModification(ModificationSpecification): """Modification Identity.""" def perform(self, image: np.ndarray) -> np.ndarray: @@ -16,5 +16,15 @@ def perform(self, image: np.ndarray) -> np.ndarray: """ return image - def __str__(self) -> str: + @property + def name(self) -> str: + """Get specification name. + + Returns: + The name of specification. + + """ return "identity" + + def __str__(self) -> str: + return self.name diff --git a/src/dfd/datasets/modifications/definitions/median_filter.py b/src/dfd/datasets/modifications/definitions/median_filter.py index 8a84b66..5af1e45 100644 --- a/src/dfd/datasets/modifications/definitions/median_filter.py +++ b/src/dfd/datasets/modifications/definitions/median_filter.py @@ -3,10 +3,10 @@ import cv2 as cv import numpy as np -from dfd.datasets.modifications.interfaces import ModificationInterface +from dfd.datasets.modifications.specification import ModificationSpecification -class MedianFilterModification(ModificationInterface): +class MedianFilterModification(ModificationSpecification): """Modification Median filter.""" def __init__(self, aperture_size: int) -> None: @@ -19,6 +19,16 @@ def __init__(self, aperture_size: int) -> None: self._aperture_size = aperture_size def __str__(self) -> str: + return self.name + + @property + def name(self) -> str: + """Get specification name. + + Returns: + The name of specification. + + """ return "median_filter_{aperture_size}".format(aperture_size=self._aperture_size) def perform(self, image: np.ndarray) -> np.ndarray: diff --git a/src/dfd/datasets/modifications/definitions/red_eyes_effect.py b/src/dfd/datasets/modifications/definitions/red_eyes_effect.py index e3cf2ea..877ddd5 100644 --- a/src/dfd/datasets/modifications/definitions/red_eyes_effect.py +++ b/src/dfd/datasets/modifications/definitions/red_eyes_effect.py @@ -5,7 +5,7 @@ import dlib import numpy as np -from dfd.datasets.modifications.interfaces import ModificationInterface +from dfd.datasets.modifications.specification import ModificationSpecification def _convert_dlib_shape_to_np_array(dlib_shape) -> np.array: @@ -21,7 +21,7 @@ def _convert_dlib_shape_to_np_array(dlib_shape) -> np.array: return np.array([[point.x, point.y] for point in dlib_shape.parts()], dtype="int") -class RedEyesEffectModification(ModificationInterface): +class RedEyesEffectModification(ModificationSpecification): """Modification red-eyes effect.""" def __init__(self, face_landmarks_detector_path: str, brightness_threshold: int = 50) -> None: @@ -37,6 +37,16 @@ def __init__(self, face_landmarks_detector_path: str, brightness_threshold: int self._brightness_threshold = brightness_threshold def __str__(self) -> str: + return self.name + + @property + def name(self) -> str: + """Get specification name. + + Returns: + The name of specification. + + """ return f"red_eyes_effect_{self._brightness_threshold}" def perform(self, image: np.ndarray) -> np.ndarray: diff --git a/src/dfd/datasets/modifications/interfaces.py b/src/dfd/datasets/modifications/interfaces.py deleted file mode 100644 index bbd79a6..0000000 --- a/src/dfd/datasets/modifications/interfaces.py +++ /dev/null @@ -1,25 +0,0 @@ -"""Interfaces used in modifications package.""" -import abc - -import numpy as np - - -class ModificationInterface(abc.ABC): - """Modification interface.""" - - @classmethod - def name(cls) -> str: - """Returns modification name.""" - - return cls.__name__ - - @abc.abstractmethod - def perform(self, image: np.ndarray) -> np.ndarray: - """Perform modification on image. - - Args: - image: OpenCV image. - - Returns: - Modified image. - """ diff --git a/src/dfd/datasets/modifications/register.py b/src/dfd/datasets/modifications/register.py index 5b513f6..758b26f 100644 --- a/src/dfd/datasets/modifications/register.py +++ b/src/dfd/datasets/modifications/register.py @@ -11,9 +11,9 @@ MedianFilterModification, RedEyesEffectModification, ) -from .interfaces import ModificationInterface +from .specification import ModificationSpecification -NameToModificationTypeMap = Dict[str, Type[ModificationInterface]] +NameToModificationTypeMap = Dict[str, Type[ModificationSpecification]] class ModificationRegister: @@ -41,10 +41,10 @@ def default(cls) -> "ModificationRegister": } return cls( - {modification.name(): modification for modification in default_modifications}, + {modification.class_name(): modification for modification in default_modifications}, ) - def get_modification_class(self, modification_name: str) -> Type[ModificationInterface]: + def get_modification_class(self, modification_name: str) -> Type[ModificationSpecification]: """Get registered modification via name. Raises: diff --git a/src/dfd/datasets/modifications/specification.py b/src/dfd/datasets/modifications/specification.py new file mode 100644 index 0000000..50d37b9 --- /dev/null +++ b/src/dfd/datasets/modifications/specification.py @@ -0,0 +1,91 @@ +"""Interfaces used in modifications package.""" +from __future__ import annotations + +import abc + +import numpy as np + + +class ModificationSpecification(abc.ABC): + """Base class for modification used to alter frames. + + Design is loosely inspired by specification pattern. + + """ + + def __call__(self, image: np.ndarray): + """Perform modifications, convenience overwrite of dunder method. + + Args: + image: OpenCV image that will be modified. + + Returns: + Modified image. + + """ + return self.perform(image) + + def __and__(self, other: ModificationSpecification) -> _AndSpecification: + return _AndSpecification(self, other) + + @classmethod + def class_name(cls) -> str: + """Get class name. + + Returns: + Specification class name. + + """ + return cls.__name__ + + @property + @abc.abstractmethod + def name(self) -> str: + """Get specification name. + + Name should include each modification used in specification. + + Returns: + Specification name. + """ + + @abc.abstractmethod + def perform(self, image: np.ndarray) -> np.ndarray: + """Perform modification defined in specification. + + Args: + image: OpenCV image. + + Returns: + Modified image. + """ + + +class _AndSpecification(ModificationSpecification): + """Combine two specifications.""" + + def __init__( + self, first_spec: ModificationSpecification, sec_spec: ModificationSpecification + ) -> None: + self._first_spec = first_spec + self._sec_spec = sec_spec + + @property + def name(self) -> str: + """Get specification name. + + Returns: + Name, combination of names of both specification used to create this one. + """ + return f"{self._first_spec.name}__{self._sec_spec.name}" + + def perform(self, image: np.ndarray) -> np.ndarray: + """Apply modifications defined in both specifications. + + Modifications are applied in order left -> right. + + Returns: + Modified image. + + """ + return self._sec_spec(self._first_spec(image)) diff --git a/src/dfd/datasets/preprocessor.py b/src/dfd/datasets/preprocessor.py index b869b1b..ffeb877 100644 --- a/src/dfd/datasets/preprocessor.py +++ b/src/dfd/datasets/preprocessor.py @@ -56,7 +56,7 @@ def _generate_frame_and_filename_batches( frame_and_name_pair = (frame, frame_path.name) # Frame has different shape than previous ones (i.e. is from different video) # TODO: ugly use named tuple instead of [0][0] - if len(batch) > 0 and batch[0][0].shape != frame_and_name_pair[0].shape: + if len(batch) > 0 and batch[0][0].shape != frame_and_name_pair[0].no_repeats: yield batch batch = [frame_and_name_pair] continue diff --git a/src/dfd/datasets/settings.py b/src/dfd/datasets/settings.py index eaf9704..cf26689 100644 --- a/src/dfd/datasets/settings.py +++ b/src/dfd/datasets/settings.py @@ -14,20 +14,32 @@ class ModificationSettings(pydantic.BaseModel): Args: modification_name: name, used to retrieve modification from ModificationRegistry - share: share of frames on which modification should be applied - options: modification options, i.e. parameters provided to modification __init__ + options: modification options, i.e. parameters provided to modification `__init__` """ name: str - share: float options: dict = {} +class ModificationsChainSettings(pydantic.BaseModel): + """Settings for single chain of modification applied in order. + + Args: + share: Share of frames on which modifications chain will be applied. + modifications: List of modifications that will be applied on sectioned frames. + Modification on top will be applied first, modification on bootom last. + + """ + + share: float + modifications: List[ModificationSettings] = [] + + class GeneratorSettings(pydantic.BaseModel): """Generator settings.""" - modifications: List[ModificationSettings] + modifications_chains: List[ModificationsChainSettings] @classmethod def from_yaml(cls, yaml_filepath: pathlib.Path) -> "GeneratorSettings": @@ -52,38 +64,60 @@ def default(cls) -> "GeneratorSettings": """ return cls( - modifications=[ - ModificationSettings( - name="RedEyesEffectModification", + modifications_chains=[ + ModificationsChainSettings( share="0.125", - options={ - "brightness_threshold": 50, - "face_landmarks_detector_path": str(assets.FACE_LANDMARKS_MODEL_PATH), - }, + modifications=[ + ModificationSettings( + name="RedEyesEffectModification", + options={ + "brightness_threshold": 50, + "face_landmarks_detector_path": str( + assets.FACE_LANDMARKS_MODEL_PATH + ), + }, + ), + ], ), - ModificationSettings( - name="CLAHEModification", + ModificationsChainSettings( share="0.125", - options={ - "clip_limit": 2.0, - "grid_width": 8, - "grid_height": 8, - }, + modifications=[ + ModificationSettings( + name="CLAHEModification", + options={ + "clip_limit": 2.0, + "grid_width": 8, + "grid_height": 8, + }, + ), + ], ), - ModificationSettings( - name="HistogramEqualizationModification", + ModificationsChainSettings( share="0.125", - options={}, + modifications=[ + ModificationSettings( + name="HistogramEqualizationModification", + options={}, + ), + ], ), - ModificationSettings( - name="GammaCorrectionModification", + ModificationsChainSettings( share="0.0625", - options={"gamma_value": 0.75}, + modifications=[ + ModificationSettings( + name="GammaCorrectionModification", + options={"gamma_value": 0.75}, + ), + ], ), - ModificationSettings( - name="GammaCorrectionModification", + ModificationsChainSettings( share="0.0625", - options={"gamma_value": 1.25}, + modifications=[ + ModificationSettings( + name="GammaCorrectionModification", + options={"gamma_value": 1.25}, + ), + ], ), ] ) diff --git a/tests/ut/datasets/modifications/test_specificaion.py b/tests/ut/datasets/modifications/test_specificaion.py new file mode 100644 index 0000000..e420eeb --- /dev/null +++ b/tests/ut/datasets/modifications/test_specificaion.py @@ -0,0 +1,49 @@ +import functools +import operator +import typing as t +from unittest.mock import Mock, call + +import numpy as np +from hypothesis import given +from hypothesis import strategies as st + +from dfd.datasets.modifications.specification import ModificationSpecification + + +class ModificationStub(ModificationSpecification): + def __init__(self, name: str, no_repeats: int) -> None: + self._name = name + self.no_repeats = no_repeats + + @property + def name(self) -> str: + return self._name + + def perform(self, image: np.ndarray) -> np.ndarray: + image.repeat(self.no_repeats) + return image + + +@given( + given_specifications=st.lists( + st.builds( + ModificationStub, + name=st.text(min_size=1), + no_repeats=st.integers(min_value=1), + ), + min_size=1, + max_size=5, + ) +) +def test_combine_multiple_specifications(given_specifications): + # GIVEN + image_mock = Mock(spec_set=np.ndarray) + # WHEN specifications are combined + combined_specification = functools.reduce(operator.and_, given_specifications) + # THEN specification names are combined + expected_name = "__".join([spec.name for spec in given_specifications]) + assert combined_specification.name == expected_name + # And specifications are performed in order + combined_specification.perform(image_mock) + expected_calls_in_order = [call.repeat(spec.no_repeats) for spec in given_specifications] + image_mock.assert_has_calls(expected_calls_in_order)