Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add log1p elementwise op #993

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions fx2ait/fx2ait/acc_tracer/acc_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1782,6 +1782,12 @@ def log(*, input):
return torch.log(input=input)


@register_acc_op_properties(AccOpProperty.pointwise, AccOpProperty.unary)
@register_acc_op
def log1p(*, input):
return torch.log1p(input=input)


@register_acc_op_properties(AccOpProperty.pointwise, AccOpProperty.unary)
@register_acc_op_mapping(op_and_target=("call_function", torch.sqrt))
@register_acc_op_mapping(op_and_target=("call_method", "sqrt"))
Expand Down
14 changes: 14 additions & 0 deletions fx2ait/fx2ait/converters/ait_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,6 +479,20 @@ def acc_ops_log(
return elementwise(FuncEnum.LOGE)(input_val)


@ait_converter(acc_ops.log1p)
def acc_ops_log1p(
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> ConverterOutput:
input_val = kwargs["input"]
if not isinstance(input_val, AITTensor):
raise RuntimeError(f"Unexpected input for {name}: {input_val}")

return elementwise(FuncEnum.LOG1P)(input_val)


@ait_converter(acc_ops.var)
def acc_ops_var(
target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Argument], name: str
Expand Down
7 changes: 7 additions & 0 deletions python/aitemplate/backend/backend_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,13 @@ class GPUBackendSpec(BackendSpec):
"bfloat16": "hlog",
"float": "logf",
},
FuncEnum.LOG1P: {
"half2": "h2log1p",
"bfloat16_2": "h2log1p",
"half": "hlog1p",
"bfloat16": "hlog1p",
"float": "log1pf",
},
FuncEnum.EXP: {
"half2": "h2exp",
"bfloat16_2": "h2exp",
Expand Down
24 changes: 24 additions & 0 deletions python/aitemplate/backend/cuda/elementwise/custom_math.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -1075,4 +1075,28 @@ __device__ bfloat16_2 h2celu(const bfloat16_2 a, const bfloat16_2 alpha) {
#endif
}

__device__ half hlog1p(const half a) {
return half(log1pf(float(a)));
}

__device__ bfloat16 hlog1p(const bfloat16 a) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
return bfloat16(log1pf(float(a)));
#else
NOT_IMPLEMENTED();
#endif
}

__device__ half2 h2log1p(const half2 a) {
return half2(log1pf(float(a.x)), log1pf(float(a.y)));
}

__device__ bfloat16_2 h2log1p(const bfloat16_2 a) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
return bfloat16_2(log1pf(float(a.x)), log1pf(float(a.y)));
#else
NOT_IMPLEMENTED();
#endif
}

#endif
1 change: 1 addition & 0 deletions python/aitemplate/compiler/ops/common/epilogue.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,3 +66,4 @@ class FuncEnum(Enum):
FLOOR_DIV = 28
CELU = 29
FLOOR = 30
LOG1P = 31
4 changes: 4 additions & 0 deletions python/aitemplate/compiler/ops/common/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ def log(tensor: Any) -> Tensor:
return OP_REGISTRY.get("LOGE")(tensor)


def log1p(tensor: Any) -> Tensor:
return OP_REGISTRY.get("LOG1P")(tensor)


def exp(tensor: Any) -> Tensor:
return OP_REGISTRY.get("EXP")(tensor)

Expand Down
Loading