-
Notifications
You must be signed in to change notification settings - Fork 19.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add rand_augment processing layer (#20716)
* Add rand_augment init * Update rand_augment init * Add rand_augment * Add NotImplementedError * Add some test cases * Fix failed test case * Update rand_augment * Update rand_augment test * Fix random_rotation bug * Add build method to supress warning. * Add implementation for transform_bboxes
- Loading branch information
Showing
7 changed files
with
361 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
235 changes: 235 additions & 0 deletions
235
keras/src/layers/preprocessing/image_preprocessing/rand_augment.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,235 @@ | ||
import random | ||
|
||
import keras.src.layers as layers | ||
from keras.src.api_export import keras_export | ||
from keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import ( # noqa: E501 | ||
BaseImagePreprocessingLayer, | ||
) | ||
from keras.src.random import SeedGenerator | ||
from keras.src.utils import backend_utils | ||
|
||
|
||
@keras_export("keras.layers.RandAugment") | ||
class RandAugment(BaseImagePreprocessingLayer): | ||
"""RandAugment performs the Rand Augment operation on input images. | ||
This layer can be thought of as an all-in-one image augmentation layer. The | ||
policy implemented by this layer has been benchmarked extensively and is | ||
effective on a wide variety of datasets. | ||
References: | ||
- [RandAugment](https://arxiv.org/abs/1909.13719) | ||
Args: | ||
value_range: The range of values the input image can take. | ||
Default is `(0, 255)`. Typically, this would be `(0, 1)` | ||
for normalized images or `(0, 255)` for raw images. | ||
num_ops: The number of augmentation operations to apply sequentially | ||
to each image. Default is 2. | ||
factor: The strength of the augmentation as a normalized value | ||
between 0 and 1. Default is 0.5. | ||
interpolation: The interpolation method to use for resizing operations. | ||
Options include `nearest`, `bilinear`. Default is `bilinear`. | ||
seed: Integer. Used to create a random seed. | ||
""" | ||
|
||
_USE_BASE_FACTOR = False | ||
_FACTOR_BOUNDS = (0, 1) | ||
|
||
_AUGMENT_LAYERS = [ | ||
"random_shear", | ||
"random_translation", | ||
"random_rotation", | ||
"random_brightness", | ||
"random_color_degeneration", | ||
"random_contrast", | ||
"random_sharpness", | ||
"random_posterization", | ||
"solarization", | ||
"auto_contrast", | ||
"equalization", | ||
] | ||
|
||
def __init__( | ||
self, | ||
value_range=(0, 255), | ||
num_ops=2, | ||
factor=0.5, | ||
interpolation="bilinear", | ||
seed=None, | ||
data_format=None, | ||
**kwargs, | ||
): | ||
super().__init__(data_format=data_format, **kwargs) | ||
|
||
self.value_range = value_range | ||
self.num_ops = num_ops | ||
self._set_factor(factor) | ||
self.interpolation = interpolation | ||
self.seed = seed | ||
self.generator = SeedGenerator(seed) | ||
|
||
self.random_shear = layers.RandomShear( | ||
x_factor=self.factor, | ||
y_factor=self.factor, | ||
interpolation=interpolation, | ||
seed=self.seed, | ||
data_format=data_format, | ||
**kwargs, | ||
) | ||
|
||
self.random_translation = layers.RandomTranslation( | ||
height_factor=self.factor, | ||
width_factor=self.factor, | ||
interpolation=interpolation, | ||
seed=self.seed, | ||
data_format=data_format, | ||
**kwargs, | ||
) | ||
|
||
self.random_rotation = layers.RandomRotation( | ||
factor=self.factor, | ||
interpolation=interpolation, | ||
seed=self.seed, | ||
data_format=data_format, | ||
**kwargs, | ||
) | ||
|
||
self.random_brightness = layers.RandomBrightness( | ||
factor=self.factor, | ||
value_range=self.value_range, | ||
seed=self.seed, | ||
data_format=data_format, | ||
**kwargs, | ||
) | ||
|
||
self.random_color_degeneration = layers.RandomColorDegeneration( | ||
factor=self.factor, | ||
value_range=self.value_range, | ||
seed=self.seed, | ||
data_format=data_format, | ||
**kwargs, | ||
) | ||
|
||
self.random_contrast = layers.RandomContrast( | ||
factor=self.factor, | ||
value_range=self.value_range, | ||
seed=self.seed, | ||
data_format=data_format, | ||
**kwargs, | ||
) | ||
|
||
self.random_sharpness = layers.RandomSharpness( | ||
factor=self.factor, | ||
value_range=self.value_range, | ||
seed=self.seed, | ||
data_format=data_format, | ||
**kwargs, | ||
) | ||
|
||
self.solarization = layers.Solarization( | ||
addition_factor=self.factor, | ||
threshold_factor=self.factor, | ||
value_range=self.value_range, | ||
seed=self.seed, | ||
data_format=data_format, | ||
**kwargs, | ||
) | ||
|
||
self.random_posterization = layers.RandomPosterization( | ||
factor=max(1, int(8 * self.factor[1])), | ||
value_range=self.value_range, | ||
seed=self.seed, | ||
data_format=data_format, | ||
**kwargs, | ||
) | ||
|
||
self.auto_contrast = layers.AutoContrast( | ||
value_range=self.value_range, data_format=data_format, **kwargs | ||
) | ||
|
||
self.equalization = layers.Equalization( | ||
value_range=self.value_range, data_format=data_format, **kwargs | ||
) | ||
|
||
def build(self, input_shape): | ||
for layer_name in self._AUGMENT_LAYERS: | ||
augmentation_layer = getattr(self, layer_name) | ||
augmentation_layer.build(input_shape) | ||
|
||
def get_random_transformation(self, data, training=True, seed=None): | ||
if not training: | ||
return None | ||
|
||
if backend_utils.in_tf_graph(): | ||
self.backend.set_backend("tensorflow") | ||
|
||
for layer_name in self._AUGMENT_LAYERS: | ||
augmentation_layer = getattr(self, layer_name) | ||
augmentation_layer.backend.set_backend("tensorflow") | ||
|
||
transformation = {} | ||
random.shuffle(self._AUGMENT_LAYERS) | ||
for layer_name in self._AUGMENT_LAYERS[: self.num_ops]: | ||
augmentation_layer = getattr(self, layer_name) | ||
transformation[layer_name] = ( | ||
augmentation_layer.get_random_transformation( | ||
data, | ||
training=training, | ||
seed=self._get_seed_generator(self.backend._backend), | ||
) | ||
) | ||
|
||
return transformation | ||
|
||
def transform_images(self, images, transformation, training=True): | ||
if training: | ||
images = self.backend.cast(images, self.compute_dtype) | ||
|
||
for layer_name, transformation_value in transformation.items(): | ||
augmentation_layer = getattr(self, layer_name) | ||
images = augmentation_layer.transform_images( | ||
images, transformation_value | ||
) | ||
|
||
images = self.backend.cast(images, self.compute_dtype) | ||
return images | ||
|
||
def transform_labels(self, labels, transformation, training=True): | ||
return labels | ||
|
||
def transform_bounding_boxes( | ||
self, | ||
bounding_boxes, | ||
transformation, | ||
training=True, | ||
): | ||
if training: | ||
for layer_name, transformation_value in transformation.items(): | ||
augmentation_layer = getattr(self, layer_name) | ||
bounding_boxes = augmentation_layer.transform_bounding_boxes( | ||
bounding_boxes, transformation_value, training=training | ||
) | ||
return bounding_boxes | ||
|
||
def transform_segmentation_masks( | ||
self, segmentation_masks, transformation, training=True | ||
): | ||
return self.transform_images( | ||
segmentation_masks, transformation, training=training | ||
) | ||
|
||
def compute_output_shape(self, input_shape): | ||
return input_shape | ||
|
||
def get_config(self): | ||
config = { | ||
"value_range": self.value_range, | ||
"num_ops": self.num_ops, | ||
"factor": self.factor, | ||
"interpolation": self.interpolation, | ||
"seed": self.seed, | ||
} | ||
base_config = super().get_config() | ||
return {**base_config, **config} |
114 changes: 114 additions & 0 deletions
114
keras/src/layers/preprocessing/image_preprocessing/rand_augment_test.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
import numpy as np | ||
import pytest | ||
from tensorflow import data as tf_data | ||
|
||
from keras.src import backend | ||
from keras.src import layers | ||
from keras.src import testing | ||
|
||
|
||
class RandAugmentTest(testing.TestCase): | ||
@pytest.mark.requires_trainable_backend | ||
def test_layer(self): | ||
self.run_layer_test( | ||
layers.RandAugment, | ||
init_kwargs={ | ||
"value_range": (0, 255), | ||
"num_ops": 2, | ||
"factor": 1, | ||
"interpolation": "nearest", | ||
"seed": 1, | ||
"data_format": "channels_last", | ||
}, | ||
input_shape=(8, 3, 4, 3), | ||
supports_masking=False, | ||
expected_output_shape=(8, 3, 4, 3), | ||
) | ||
|
||
def test_rand_augment_inference(self): | ||
seed = 3481 | ||
layer = layers.RandAugment() | ||
|
||
np.random.seed(seed) | ||
inputs = np.random.randint(0, 255, size=(224, 224, 3)) | ||
output = layer(inputs, training=False) | ||
self.assertAllClose(inputs, output) | ||
|
||
def test_rand_augment_basic(self): | ||
data_format = backend.config.image_data_format() | ||
if data_format == "channels_last": | ||
input_data = np.random.random((2, 8, 8, 3)) | ||
else: | ||
input_data = np.random.random((2, 3, 8, 8)) | ||
layer = layers.RandAugment(data_format=data_format) | ||
|
||
augmented_image = layer(input_data) | ||
self.assertEqual(augmented_image.shape, input_data.shape) | ||
|
||
def test_rand_augment_no_operations(self): | ||
data_format = backend.config.image_data_format() | ||
if data_format == "channels_last": | ||
input_data = np.random.random((2, 8, 8, 3)) | ||
else: | ||
input_data = np.random.random((2, 3, 8, 8)) | ||
layer = layers.RandAugment(num_ops=0, data_format=data_format) | ||
|
||
augmented_image = layer(input_data) | ||
self.assertAllClose( | ||
backend.convert_to_numpy(augmented_image), input_data | ||
) | ||
|
||
def test_random_augment_randomness(self): | ||
data_format = backend.config.image_data_format() | ||
if data_format == "channels_last": | ||
input_data = np.random.random((2, 8, 8, 3)) | ||
else: | ||
input_data = np.random.random((2, 3, 8, 8)) | ||
|
||
layer = layers.RandAugment(num_ops=11, data_format=data_format) | ||
augmented_image = layer(input_data) | ||
|
||
self.assertNotAllClose( | ||
backend.convert_to_numpy(augmented_image), input_data | ||
) | ||
|
||
def test_tf_data_compatibility(self): | ||
data_format = backend.config.image_data_format() | ||
if data_format == "channels_last": | ||
input_data = np.random.random((2, 8, 8, 3)) | ||
else: | ||
input_data = np.random.random((2, 3, 8, 8)) | ||
layer = layers.RandAugment(data_format=data_format) | ||
|
||
ds = tf_data.Dataset.from_tensor_slices(input_data).batch(2).map(layer) | ||
for output in ds.take(1): | ||
output.numpy() | ||
|
||
def test_rand_augment_tf_data_bounding_boxes(self): | ||
data_format = backend.config.image_data_format() | ||
if data_format == "channels_last": | ||
image_shape = (1, 10, 8, 3) | ||
else: | ||
image_shape = (1, 3, 10, 8) | ||
input_image = np.random.random(image_shape) | ||
bounding_boxes = { | ||
"boxes": np.array( | ||
[ | ||
[ | ||
[2, 1, 4, 3], | ||
[6, 4, 8, 6], | ||
] | ||
] | ||
), | ||
"labels": np.array([[1, 2]]), | ||
} | ||
|
||
input_data = {"images": input_image, "bounding_boxes": bounding_boxes} | ||
|
||
ds = tf_data.Dataset.from_tensor_slices(input_data) | ||
layer = layers.RandAugment( | ||
data_format=data_format, | ||
seed=42, | ||
bounding_box_format="xyxy", | ||
) | ||
ds.map(layer) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.