diff --git a/imagecorruptions/__init__.py b/imagecorruptions/__init__.py index 7c9ac1f..ba6303b 100644 --- a/imagecorruptions/__init__.py +++ b/imagecorruptions/__init__.py @@ -12,7 +12,7 @@ corruption_tuple} -def corrupt(image, severity=1, corruption_name=None, corruption_number=-1): +def corrupt(image, severity=1, corruption_name=None, corruption_number=-1, **kwargs): """This function returns a corrupted version of the given image. Args: @@ -59,10 +59,10 @@ def corrupt(image, severity=1, corruption_name=None, corruption_number=-1): if not (corruption_name is None): image_corrupted = corruption_dict[corruption_name](Image.fromarray(image), - severity) + severity, **kwargs) elif corruption_number != -1: image_corrupted = corruption_tuple[corruption_number](Image.fromarray(image), - severity) + severity, **kwargs) else: raise ValueError("Either corruption_name or corruption_number must be passed") diff --git a/imagecorruptions/corruptions.py b/imagecorruptions/corruptions.py index c9e8d38..0efe925 100644 --- a/imagecorruptions/corruptions.py +++ b/imagecorruptions/corruptions.py @@ -15,6 +15,7 @@ from pkg_resources import resource_filename from numba import njit +SK_VERSION = {k:int(v) for k,v in zip(['major', 'minor'], sk.__version__.split('.')[:2])} def disk(radius, alias_blur=0.1, dtype=np.float32): if radius <= 8: @@ -154,7 +155,9 @@ def _motion_blur(x, radius, sigma, angle): @njit() -def _shuffle_pixels_njit_glass_blur(d0, d1, x, c): +def _shuffle_pixels_njit_glass_blur(d0, d1, x, c, seed=None): + + np.random.seed(seed) # locally shuffle pixels for i in range(c[2]): @@ -185,10 +188,10 @@ def shot_noise(x, severity=1): return np.clip(np.random.poisson(x * c) / float(c), 0, 1) * 255 -def impulse_noise(x, severity=1): +def impulse_noise(x, severity=1, seed=None): c = [.03, .06, .09, 0.17, 0.27][severity - 1] - - x = sk.util.random_noise(np.array(x) / 255., mode='s&p', amount=c) + mode = 's&p' + x = sk.util.random_noise(np.array(x) / 255., 's&p', seed, amount=c) return np.clip(x, 0, 1) * 255 @@ -202,21 +205,30 @@ def speckle_noise(x, severity=1): def gaussian_blur(x, severity=1): c = [1, 2, 3, 4, 6][severity - 1] - x = gaussian(np.array(x) / 255., sigma=c, channel_axis=-1) - return np.clip(x, 0, 1) * 255 + if SK_VERSION['major'] >= 0 and SK_VERSION['minor'] >= 19: + kwargs = {'channel_axis': -1} + else: # pre scikit-image 0.19 + kwargs = {'multichannel': True} + x = gaussian(np.array(x) / 255., sigma=c, **kwargs) + return np.clip(x, 0, 1) * 255 -def glass_blur(x, severity=1): +def glass_blur(x, severity=1, seed=None): # sigma, max_delta, iterations c = [(0.7, 1, 2), (0.9, 2, 1), (1, 2, 3), (1.1, 3, 2), (1.5, 4, 2)][ severity - 1] + if SK_VERSION['major'] >= 0 and SK_VERSION['minor'] >= 19: + kwargs = {'channel_axis': -1} + else: # pre scikit-image 0.19 + kwargs = {'multichannel': True} + x = np.uint8( - gaussian(np.array(x) / 255., sigma=c[0], channel_axis=-1) * 255) + gaussian(np.array(x) / 255., sigma=c[0], **kwargs) * 255) - x = _shuffle_pixels_njit_glass_blur(np.array(x).shape[0], np.array(x).shape[1], x, c) + x = _shuffle_pixels_njit_glass_blur(np.array(x).shape[0], np.array(x).shape[1], x, c, seed) - return np.clip(gaussian(x / 255., sigma=c[0], channel_axis=-1), 0, + return np.clip(gaussian(x / 255., sigma=c[0], **kwargs), 0, 1) * 255