Skip to content

Commit

Permalink
Add rand_augment processing layer (#20716)
Browse files Browse the repository at this point in the history
* Add rand_augment init

* Update rand_augment init

* Add rand_augment

* Add NotImplementedError

* Add some test cases

* Fix failed test case

* Update rand_augment

* Update rand_augment test

* Fix random_rotation bug

* Add build method to supress warning.

* Add implementation for transform_bboxes
  • Loading branch information
shashaka authored Jan 7, 2025
1 parent 8f04616 commit ab3c8f5
Show file tree
Hide file tree
Showing 7 changed files with 361 additions and 4 deletions.
3 changes: 3 additions & 0 deletions keras/api/_tf_keras/keras/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,9 @@
MaxNumBoundingBoxes,
)
from keras.src.layers.preprocessing.image_preprocessing.mix_up import MixUp
from keras.src.layers.preprocessing.image_preprocessing.rand_augment import (
RandAugment,
)
from keras.src.layers.preprocessing.image_preprocessing.random_brightness import (
RandomBrightness,
)
Expand Down
3 changes: 3 additions & 0 deletions keras/api/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,9 @@
MaxNumBoundingBoxes,
)
from keras.src.layers.preprocessing.image_preprocessing.mix_up import MixUp
from keras.src.layers.preprocessing.image_preprocessing.rand_augment import (
RandAugment,
)
from keras.src.layers.preprocessing.image_preprocessing.random_brightness import (
RandomBrightness,
)
Expand Down
3 changes: 3 additions & 0 deletions keras/src/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,9 @@
MaxNumBoundingBoxes,
)
from keras.src.layers.preprocessing.image_preprocessing.mix_up import MixUp
from keras.src.layers.preprocessing.image_preprocessing.rand_augment import (
RandAugment,
)
from keras.src.layers.preprocessing.image_preprocessing.random_brightness import (
RandomBrightness,
)
Expand Down
235 changes: 235 additions & 0 deletions keras/src/layers/preprocessing/image_preprocessing/rand_augment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,235 @@
import random

import keras.src.layers as layers
from keras.src.api_export import keras_export
from keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import ( # noqa: E501
BaseImagePreprocessingLayer,
)
from keras.src.random import SeedGenerator
from keras.src.utils import backend_utils


@keras_export("keras.layers.RandAugment")
class RandAugment(BaseImagePreprocessingLayer):
"""RandAugment performs the Rand Augment operation on input images.
This layer can be thought of as an all-in-one image augmentation layer. The
policy implemented by this layer has been benchmarked extensively and is
effective on a wide variety of datasets.
References:
- [RandAugment](https://arxiv.org/abs/1909.13719)
Args:
value_range: The range of values the input image can take.
Default is `(0, 255)`. Typically, this would be `(0, 1)`
for normalized images or `(0, 255)` for raw images.
num_ops: The number of augmentation operations to apply sequentially
to each image. Default is 2.
factor: The strength of the augmentation as a normalized value
between 0 and 1. Default is 0.5.
interpolation: The interpolation method to use for resizing operations.
Options include `nearest`, `bilinear`. Default is `bilinear`.
seed: Integer. Used to create a random seed.
"""

_USE_BASE_FACTOR = False
_FACTOR_BOUNDS = (0, 1)

_AUGMENT_LAYERS = [
"random_shear",
"random_translation",
"random_rotation",
"random_brightness",
"random_color_degeneration",
"random_contrast",
"random_sharpness",
"random_posterization",
"solarization",
"auto_contrast",
"equalization",
]

def __init__(
self,
value_range=(0, 255),
num_ops=2,
factor=0.5,
interpolation="bilinear",
seed=None,
data_format=None,
**kwargs,
):
super().__init__(data_format=data_format, **kwargs)

self.value_range = value_range
self.num_ops = num_ops
self._set_factor(factor)
self.interpolation = interpolation
self.seed = seed
self.generator = SeedGenerator(seed)

self.random_shear = layers.RandomShear(
x_factor=self.factor,
y_factor=self.factor,
interpolation=interpolation,
seed=self.seed,
data_format=data_format,
**kwargs,
)

self.random_translation = layers.RandomTranslation(
height_factor=self.factor,
width_factor=self.factor,
interpolation=interpolation,
seed=self.seed,
data_format=data_format,
**kwargs,
)

self.random_rotation = layers.RandomRotation(
factor=self.factor,
interpolation=interpolation,
seed=self.seed,
data_format=data_format,
**kwargs,
)

self.random_brightness = layers.RandomBrightness(
factor=self.factor,
value_range=self.value_range,
seed=self.seed,
data_format=data_format,
**kwargs,
)

self.random_color_degeneration = layers.RandomColorDegeneration(
factor=self.factor,
value_range=self.value_range,
seed=self.seed,
data_format=data_format,
**kwargs,
)

self.random_contrast = layers.RandomContrast(
factor=self.factor,
value_range=self.value_range,
seed=self.seed,
data_format=data_format,
**kwargs,
)

self.random_sharpness = layers.RandomSharpness(
factor=self.factor,
value_range=self.value_range,
seed=self.seed,
data_format=data_format,
**kwargs,
)

self.solarization = layers.Solarization(
addition_factor=self.factor,
threshold_factor=self.factor,
value_range=self.value_range,
seed=self.seed,
data_format=data_format,
**kwargs,
)

self.random_posterization = layers.RandomPosterization(
factor=max(1, int(8 * self.factor[1])),
value_range=self.value_range,
seed=self.seed,
data_format=data_format,
**kwargs,
)

self.auto_contrast = layers.AutoContrast(
value_range=self.value_range, data_format=data_format, **kwargs
)

self.equalization = layers.Equalization(
value_range=self.value_range, data_format=data_format, **kwargs
)

def build(self, input_shape):
for layer_name in self._AUGMENT_LAYERS:
augmentation_layer = getattr(self, layer_name)
augmentation_layer.build(input_shape)

def get_random_transformation(self, data, training=True, seed=None):
if not training:
return None

if backend_utils.in_tf_graph():
self.backend.set_backend("tensorflow")

for layer_name in self._AUGMENT_LAYERS:
augmentation_layer = getattr(self, layer_name)
augmentation_layer.backend.set_backend("tensorflow")

transformation = {}
random.shuffle(self._AUGMENT_LAYERS)
for layer_name in self._AUGMENT_LAYERS[: self.num_ops]:
augmentation_layer = getattr(self, layer_name)
transformation[layer_name] = (
augmentation_layer.get_random_transformation(
data,
training=training,
seed=self._get_seed_generator(self.backend._backend),
)
)

return transformation

def transform_images(self, images, transformation, training=True):
if training:
images = self.backend.cast(images, self.compute_dtype)

for layer_name, transformation_value in transformation.items():
augmentation_layer = getattr(self, layer_name)
images = augmentation_layer.transform_images(
images, transformation_value
)

images = self.backend.cast(images, self.compute_dtype)
return images

def transform_labels(self, labels, transformation, training=True):
return labels

def transform_bounding_boxes(
self,
bounding_boxes,
transformation,
training=True,
):
if training:
for layer_name, transformation_value in transformation.items():
augmentation_layer = getattr(self, layer_name)
bounding_boxes = augmentation_layer.transform_bounding_boxes(
bounding_boxes, transformation_value, training=training
)
return bounding_boxes

def transform_segmentation_masks(
self, segmentation_masks, transformation, training=True
):
return self.transform_images(
segmentation_masks, transformation, training=training
)

def compute_output_shape(self, input_shape):
return input_shape

def get_config(self):
config = {
"value_range": self.value_range,
"num_ops": self.num_ops,
"factor": self.factor,
"interpolation": self.interpolation,
"seed": self.seed,
}
base_config = super().get_config()
return {**base_config, **config}
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
import numpy as np
import pytest
from tensorflow import data as tf_data

from keras.src import backend
from keras.src import layers
from keras.src import testing


class RandAugmentTest(testing.TestCase):
@pytest.mark.requires_trainable_backend
def test_layer(self):
self.run_layer_test(
layers.RandAugment,
init_kwargs={
"value_range": (0, 255),
"num_ops": 2,
"factor": 1,
"interpolation": "nearest",
"seed": 1,
"data_format": "channels_last",
},
input_shape=(8, 3, 4, 3),
supports_masking=False,
expected_output_shape=(8, 3, 4, 3),
)

def test_rand_augment_inference(self):
seed = 3481
layer = layers.RandAugment()

np.random.seed(seed)
inputs = np.random.randint(0, 255, size=(224, 224, 3))
output = layer(inputs, training=False)
self.assertAllClose(inputs, output)

def test_rand_augment_basic(self):
data_format = backend.config.image_data_format()
if data_format == "channels_last":
input_data = np.random.random((2, 8, 8, 3))
else:
input_data = np.random.random((2, 3, 8, 8))
layer = layers.RandAugment(data_format=data_format)

augmented_image = layer(input_data)
self.assertEqual(augmented_image.shape, input_data.shape)

def test_rand_augment_no_operations(self):
data_format = backend.config.image_data_format()
if data_format == "channels_last":
input_data = np.random.random((2, 8, 8, 3))
else:
input_data = np.random.random((2, 3, 8, 8))
layer = layers.RandAugment(num_ops=0, data_format=data_format)

augmented_image = layer(input_data)
self.assertAllClose(
backend.convert_to_numpy(augmented_image), input_data
)

def test_random_augment_randomness(self):
data_format = backend.config.image_data_format()
if data_format == "channels_last":
input_data = np.random.random((2, 8, 8, 3))
else:
input_data = np.random.random((2, 3, 8, 8))

layer = layers.RandAugment(num_ops=11, data_format=data_format)
augmented_image = layer(input_data)

self.assertNotAllClose(
backend.convert_to_numpy(augmented_image), input_data
)

def test_tf_data_compatibility(self):
data_format = backend.config.image_data_format()
if data_format == "channels_last":
input_data = np.random.random((2, 8, 8, 3))
else:
input_data = np.random.random((2, 3, 8, 8))
layer = layers.RandAugment(data_format=data_format)

ds = tf_data.Dataset.from_tensor_slices(input_data).batch(2).map(layer)
for output in ds.take(1):
output.numpy()

def test_rand_augment_tf_data_bounding_boxes(self):
data_format = backend.config.image_data_format()
if data_format == "channels_last":
image_shape = (1, 10, 8, 3)
else:
image_shape = (1, 3, 10, 8)
input_image = np.random.random(image_shape)
bounding_boxes = {
"boxes": np.array(
[
[
[2, 1, 4, 3],
[6, 4, 8, 6],
]
]
),
"labels": np.array([[1, 2]]),
}

input_data = {"images": input_image, "bounding_boxes": bounding_boxes}

ds = tf_data.Dataset.from_tensor_slices(input_data)
layer = layers.RandAugment(
data_format=data_format,
seed=42,
bounding_box_format="xyxy",
)
ds.map(layer)
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def test_correctness(self):
seed = 2390

# Always scale up, but randomly between 0 ~ 255
layer = layers.RandomBrightness([0, 1.0])
layer = layers.RandomBrightness([0.1, 1.0])
np.random.seed(seed)
inputs = np.random.randint(0, 255, size=(224, 224, 3))
output = backend.convert_to_numpy(layer(inputs))
Expand All @@ -44,7 +44,7 @@ def test_correctness(self):
self.assertTrue(np.mean(diff) > 0)

# Always scale down, but randomly between 0 ~ 255
layer = layers.RandomBrightness([-1.0, 0.0])
layer = layers.RandomBrightness([-1.0, -0.1])
np.random.seed(seed)
inputs = np.random.randint(0, 255, size=(224, 224, 3))
output = backend.convert_to_numpy(layer(inputs))
Expand Down
Loading

0 comments on commit ab3c8f5

Please sign in to comment.