diff --git a/fx2ait/fx2ait/acc_tracer/acc_ops.py b/fx2ait/fx2ait/acc_tracer/acc_ops.py index f3a787cb0..6a15ccf6a 100644 --- a/fx2ait/fx2ait/acc_tracer/acc_ops.py +++ b/fx2ait/fx2ait/acc_tracer/acc_ops.py @@ -1251,23 +1251,6 @@ def softsign(*, input): return nn.functional.softsign(input=input) -@register_custom_acc_mapper_fn( - op_and_target=("call_function", torch.log1p), - arg_replacement_tuples=[ - ("input", "input"), - ], -) -def torch_log1p_mapper(node: torch.fx.Node, _: torch.nn.Module) -> torch.fx.Node: - with node.graph.inserting_before(node): - add_kwargs = {"input": node.kwargs["input"], "other": 1.0} - add_node = node.graph.call_function(add, kwargs=add_kwargs) - add_node.meta = node.meta.copy() - log_kwargs = {"input": add_node} - log_node = node.graph.call_function(log, kwargs=log_kwargs) - log_node.meta = node.meta.copy() - return log_node - - def reduce_op_mapper( node: torch.fx.Node, mod: torch.fx.GraphModule, func ) -> torch.fx.Node: @@ -1782,6 +1765,13 @@ def log(*, input): return torch.log(input=input) +@register_acc_op_properties(AccOpProperty.pointwise, AccOpProperty.unary) +@register_acc_op_mapping(op_and_target=("call_function", torch.log1p)) +@register_acc_op +def log1p(*, input): + return torch.log1p(input=input) + + @register_acc_op_properties(AccOpProperty.pointwise, AccOpProperty.unary) @register_acc_op_mapping(op_and_target=("call_function", torch.sqrt)) @register_acc_op_mapping(op_and_target=("call_method", "sqrt")) diff --git a/fx2ait/fx2ait/converters/ait_converters.py b/fx2ait/fx2ait/converters/ait_converters.py index b04448c34..00bd71e9b 100644 --- a/fx2ait/fx2ait/converters/ait_converters.py +++ b/fx2ait/fx2ait/converters/ait_converters.py @@ -468,6 +468,20 @@ def acc_ops_log( return elementwise(FuncEnum.LOGE)(input_val) +@ait_converter(acc_ops.log1p) +def acc_ops_log1p( + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> ConverterOutput: + input_val = kwargs["input"] + if not isinstance(input_val, AITTensor): + raise RuntimeError(f"Unexpected input for {name}: {input_val}") + + return elementwise(FuncEnum.LOG1P)(input_val) + + @ait_converter(acc_ops.var) def acc_ops_var( target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Argument], name: str diff --git a/python/aitemplate/backend/backend_spec.py b/python/aitemplate/backend/backend_spec.py index 6e9f700ae..f7b4eea1d 100644 --- a/python/aitemplate/backend/backend_spec.py +++ b/python/aitemplate/backend/backend_spec.py @@ -183,6 +183,13 @@ class GPUBackendSpec(BackendSpec): "bfloat16": "hlog", "float": "logf", }, + FuncEnum.LOG1P: { + "half2": "h2log1p", + "bfloat16_2": "h2log1p", + "half": "hlog1p", + "bfloat16": "hlog1p", + "float": "log1pf", + }, FuncEnum.EXP: { "half2": "h2exp", "bfloat16_2": "h2exp", diff --git a/python/aitemplate/backend/cuda/elementwise/custom_math.cuh b/python/aitemplate/backend/cuda/elementwise/custom_math.cuh index 255e420b5..397644284 100644 --- a/python/aitemplate/backend/cuda/elementwise/custom_math.cuh +++ b/python/aitemplate/backend/cuda/elementwise/custom_math.cuh @@ -1048,4 +1048,28 @@ __device__ bfloat16_2 h2celu(const bfloat16_2 a, const bfloat16_2 alpha) { #endif } +__device__ half hlog1p(const half a) { + return half(log1pf(float(a))); +} + +__device__ bfloat16 hlog1p(const bfloat16 a) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + return bfloat16(log1pf(float(a))); +#else + NOT_IMPLEMENTED(); +#endif +} + +__device__ half2 h2log1p(const half2 a) { + return half2(log1pf(float(a.x)), log1pf(float(a.y))); +} + +__device__ bfloat16_2 h2log1p(const bfloat16_2 a) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + return bfloat16_2(log1pf(float(a.x)), log1pf(float(a.y))); +#else + NOT_IMPLEMENTED(); +#endif +} + #endif diff --git a/python/aitemplate/compiler/ops/common/epilogue.py b/python/aitemplate/compiler/ops/common/epilogue.py index fd684bf6e..d94f20c72 100644 --- a/python/aitemplate/compiler/ops/common/epilogue.py +++ b/python/aitemplate/compiler/ops/common/epilogue.py @@ -65,3 +65,4 @@ class FuncEnum(Enum): SOFTSIGN = 27 FLOOR_DIV = 28 CELU = 29 + LOG1P = 30 diff --git a/python/aitemplate/compiler/ops/common/math.py b/python/aitemplate/compiler/ops/common/math.py index 4534628e6..d782288de 100644 --- a/python/aitemplate/compiler/ops/common/math.py +++ b/python/aitemplate/compiler/ops/common/math.py @@ -47,6 +47,10 @@ def log(tensor: Any) -> Tensor: return OP_REGISTRY.get("LOGE")(tensor) +def log1p(tensor: Any) -> Tensor: + return OP_REGISTRY.get("LOG1P")(tensor) + + def exp(tensor: Any) -> Tensor: return OP_REGISTRY.get("EXP")(tensor)