From 18e6ac2247adb247ed4d71408fdd41507f078a58 Mon Sep 17 00:00:00 2001 From: pseudo-rnd-thoughts Date: Mon, 12 Aug 2024 18:23:33 +0100 Subject: [PATCH] run yapf --- envpool/python/xla_template.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/envpool/python/xla_template.py b/envpool/python/xla_template.py index 6aad7e48..5894d2c6 100644 --- a/envpool/python/xla_template.py +++ b/envpool/python/xla_template.py @@ -91,12 +91,8 @@ def translation(c: Any, *args: Any, platform: str = "cpu") -> Any: prim.multiple_results = (len(out_specs) > 1) prim.def_impl(partial(xla.apply_primitive, prim)) prim.def_abstract_eval(abstract) - interpreters.mlir["cpu"][prim] = partial( - translation, platform="cpu" - ) - interpreters.mlir["gpu"][prim] = partial( - translation, platform="gpu" - ) + interpreters.mlir["cpu"][prim] = partial(translation, platform="cpu") + interpreters.mlir["gpu"][prim] = partial(translation, platform="gpu") def call(*args: Any) -> Any: return prim.bind(*args)