Skip to content

Commit

Permalink
Renamed and fixed typos in documentation of BatchNorm... operations.
Browse files Browse the repository at this point in the history
  • Loading branch information
janpfeifer committed Jul 16, 2024
1 parent c62d20e commit bd08d5f
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 18 deletions.
1 change: 1 addition & 0 deletions docs/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
* Added `xlabuilder.Shape.Memory` and `xlabuilder.NewArrayLiteralFromAny`.
* Added `xlabuilder.Op.Builder()`
* Added comments support to op_types.txt and added comments to several of the operations.
* Renamed `xlabuilder.BatchNorm{Inference,Training}` to `xlabuilder.BatchNormFor{Inference,Training}`

# v0.1.2 SuppressAbseilLoggingHack

Expand Down
36 changes: 18 additions & 18 deletions xlabuilder/special_ops.go
Original file line number Diff line number Diff line change
Expand Up @@ -1098,12 +1098,12 @@ func DecodeReverse(op *Op) (x *Op, axes []int) {
return
}

// BatchNormInference implements Batch Norm for inference. See details in
// https://www.tensorflow.org/xla/operation_semantics#batchnorminference.
// BatchNormForInference implements Batch Norm for inference. See details in
// https://www.tensorflow.org/xla/operation_semantics#batchnorminference
//
// Based on paper "Batch Normalization: Accelerating Deep Network Training by Reducing
// Internal Covariate Shift" (Sergey Ioffe, Christian Szegedy), https://arxiv.org/abs/1502.03167.
func BatchNormInference(operand, scale, offset, mean, variance *Op, epsilon float32, axis int) (*Op, error) {
func BatchNormForInference(operand, scale, offset, mean, variance *Op, epsilon float32, axis int) (*Op, error) {
builder := operand.builder
op := newOp(BatchNormInferenceOp, operand, scale, offset, mean, variance)
op.IntArg = axis
Expand All @@ -1115,8 +1115,8 @@ func BatchNormInference(operand, scale, offset, mean, variance *Op, epsilon floa
return op, nil
}

// DecodeBatchNormInference retrieves the arguments for the BatchNormInference op.
func DecodeBatchNormInference(op *Op) (operand, scale, offset, mean, variance *Op, epsilon float32, axis int) {
// DecodeBatchNormForInference retrieves the arguments for the BatchNormForInference op.
func DecodeBatchNormForInference(op *Op) (operand, scale, offset, mean, variance *Op, epsilon float32, axis int) {
operand = op.OpInputs[0]
scale = op.OpInputs[1]
offset = op.OpInputs[2]
Expand All @@ -1127,14 +1127,14 @@ func DecodeBatchNormInference(op *Op) (operand, scale, offset, mean, variance *O
return
}

// BatchNormTraining implements Batch Norm for training. See details in
// https://www.tensorflow.org/xla/operation_semantics#batchnormtraining.
// BatchNormForTraining implements Batch Norm for training. See details in
// https://www.tensorflow.org/xla/operation_semantics#batchnormtraining
//
// It returns the normalized tensor, the batchMean and the batchVariance.
//
// Based on paper "Batch Normalization: Accelerating Deep Network Training by Reducing
// Internal Covariate Shift" (Sergey Ioffe, Christian Szegedy), https://arxiv.org/abs/1502.03167.
func BatchNormTraining(operand, scale, offset *Op, epsilon float32, axis int) (normalized, batchMean, batchVariance *Op, err error) {
func BatchNormForTraining(operand, scale, offset *Op, epsilon float32, axis int) (normalized, batchMean, batchVariance *Op, err error) {
builder := operand.builder
op := newOp(BatchNormTrainingOp, operand, scale, offset)
op.IntArg = axis
Expand All @@ -1146,19 +1146,19 @@ func BatchNormTraining(operand, scale, offset *Op, epsilon float32, axis int) (n
var parts []*Op
parts, err = SplitTuple(op)
if err != nil {
err = errors.WithMessage(err, "failed to split results of BatchNormTraining")
err = errors.WithMessage(err, "failed to split results of BatchNormForTraining")
return
}
if len(parts) != 3 {
err = errors.Errorf("BatchNormTraining should have returned a tuple with 3 parts, got %s instead", op.Shape)
err = errors.Errorf("BatchNormForTraining should have returned a tuple with 3 parts, got %s instead", op.Shape)
return
}
normalized, batchMean, batchVariance = parts[0], parts[1], parts[2]
return
}

// DecodeBatchNormTraining retrieves the arguments for the BatchNormTraining op.
func DecodeBatchNormTraining(op *Op) (operand, scale, offset *Op, epsilon float32, axis int) {
// DecodeBatchNormForTraining retrieves the arguments for the BatchNormForTraining op.
func DecodeBatchNormForTraining(op *Op) (operand, scale, offset *Op, epsilon float32, axis int) {
operand = op.OpInputs[0]
scale = op.OpInputs[1]
offset = op.OpInputs[2]
Expand All @@ -1167,8 +1167,8 @@ func DecodeBatchNormTraining(op *Op) (operand, scale, offset *Op, epsilon float3
return
}

// BatchNormGrad implements Batch Norm for training. See details in
// https://www.tensorflow.org/xla/operation_semantics#batchnormtraining.
// BatchNormGradient calculates the BatchNorm gradient. See details in
// https://openxla.org/xla/operation_semantics#batchnormgrad
//
// The gradOutput is the adjoint gradient, that is, the gradient with respect to the output of the
// batch normalization.
Expand All @@ -1177,7 +1177,7 @@ func DecodeBatchNormTraining(op *Op) (operand, scale, offset *Op, epsilon float3
//
// Based on paper "Batch Normalization: Accelerating Deep Network Training by Reducing
// Internal Covariate Shift" (Sergey Ioffe, Christian Szegedy), https://arxiv.org/abs/1502.03167.
func BatchNormGrad(operand, scale, mean, variance, gradOutput *Op, epsilon float32, axis int) (gradOperand, gradScale, gradOffset *Op, err error) {
func BatchNormGradient(operand, scale, mean, variance, gradOutput *Op, epsilon float32, axis int) (gradOperand, gradScale, gradOffset *Op, err error) {
builder := operand.builder
op := newOp(BatchNormGradOp, operand, scale, mean, variance, gradOutput)
op.IntArg = axis
Expand All @@ -1189,18 +1189,18 @@ func BatchNormGrad(operand, scale, mean, variance, gradOutput *Op, epsilon float
var parts []*Op
parts, err = SplitTuple(op)
if err != nil {
err = errors.WithMessage(err, "failed to split results of BatchNormGrad")
err = errors.WithMessage(err, "failed to split results of BatchNormGradient")
return
}
if len(parts) != 3 {
err = errors.Errorf("BatchNormGrad should have returned a tuple with 3 parts, got %s instead", op.Shape)
err = errors.Errorf("BatchNormGradient should have returned a tuple with 3 parts, got %s instead", op.Shape)
return
}
gradOperand, gradScale, gradOffset = parts[0], parts[1], parts[2]
return
}

// DecodeBatchNormGrad retrieves the arguments for the BatchNormGrad op.
// DecodeBatchNormGrad retrieves the arguments for the BatchNormGradient op.
func DecodeBatchNormGrad(op *Op) (operand, scale, mean, variance, gradOutput *Op, epsilon float32, axis int) {
operand = op.OpInputs[0]
scale = op.OpInputs[1]
Expand Down

0 comments on commit bd08d5f

Please sign in to comment.