Skip to content

Commit

Permalink
feat: initializer
Browse files Browse the repository at this point in the history
  • Loading branch information
phi-friday committed Sep 13, 2024
1 parent 3ef42db commit 3b21ce8
Show file tree
Hide file tree
Showing 5 changed files with 104 additions and 16 deletions.
1 change: 1 addition & 0 deletions src/timeout_executor/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

__all__ = ["TIMEOUT_EXECUTOR_INPUT_FILE", "SUBPROCESS_COMMAND"]
TIMEOUT_EXECUTOR_INPUT_FILE = "_TIMEOUT_EXECUTOR_INPUT_FILE"
TIMEOUT_EXECUTOR_INIT_FILE = "_TIMEOUT_EXECUTOR_INIT_FILE"
SUBPROCESS_COMMAND = (
"from timeout_executor.subprocess import run_in_subprocess;run_in_subprocess()"
)
89 changes: 75 additions & 14 deletions src/timeout_executor/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,21 @@
import cloudpickle
from typing_extensions import ParamSpec, Self, TypeVar, override

from timeout_executor.const import SUBPROCESS_COMMAND, TIMEOUT_EXECUTOR_INPUT_FILE
from timeout_executor.const import (
SUBPROCESS_COMMAND,
TIMEOUT_EXECUTOR_INIT_FILE,
TIMEOUT_EXECUTOR_INPUT_FILE,
)
from timeout_executor.logging import logger
from timeout_executor.result import AsyncResult
from timeout_executor.terminate import Terminator
from timeout_executor.types import Callback, CallbackArgs, ExecutorArgs, ProcessCallback
from timeout_executor.types import (
Callback,
CallbackArgs,
ExecutorArgs,
InitializerArgs,
ProcessCallback,
)

if TYPE_CHECKING:
from collections.abc import Awaitable, Iterable
Expand All @@ -42,20 +52,22 @@ def __init__(
timeout: float,
func: Callable[P, T],
callbacks: Callable[[], Iterable[ProcessCallback[P, T]]] | None = None,
initializer: InitializerArgs[..., Any] | None = None,
) -> None:
self._timeout = timeout
self._func = func
self._func_name = func_name(func)
self._unique_id = uuid4()
self._init_callbacks = callbacks
self._callbacks: deque[ProcessCallback[P, T]] = deque()
self._initializer = initializer

@property
def unique_id(self) -> UUID:
return self._unique_id

def _create_temp_files(self) -> tuple[Path, Path]:
"""create temp files for input and output"""
def _create_temp_files(self) -> tuple[Path, Path, Path]:
"""create temp files for input, output and init"""
temp_dir = Path(tempfile.gettempdir()) / "timeout_executor"
temp_dir.mkdir(exist_ok=True)

Expand All @@ -64,8 +76,9 @@ def _create_temp_files(self) -> tuple[Path, Path]:

input_file = unique_dir / "input.b"
output_file = unique_dir / "output.b"
init_file = unique_dir / "init.b"

return input_file, output_file
return input_file, output_file, init_file

def _command(self, stacklevel: int = 2) -> list[str]:
"""create subprocess command"""
Expand All @@ -85,15 +98,37 @@ def _dump_args(
)
return input_args_as_bytes

def _dump_initializer(self) -> bytes | None:
if self._initializer is None:
logger.debug("%r initializer is None", self)
return None
init_args = (
self._initializer.function,
self._initializer.args,
self._initializer.kwargs,
)
logger.debug("%r before dump initializer", self)
init_args_as_bytes = cloudpickle.dumps(init_args)
logger.debug(
"%r after dump initializer :: size: %d", self, len(init_args_as_bytes)
)
return init_args_as_bytes

def _create_process(
self, input_file: Path | anyio.Path, stacklevel: int = 2
self,
input_file: Path | anyio.Path,
init_file: Path | anyio.Path | None,
stacklevel: int = 2,
) -> subprocess.Popen[str]:
"""create new process"""
command = self._command(stacklevel=stacklevel + 1)
logger.debug("%r before create new process", self, stacklevel=stacklevel)
process = subprocess.Popen( # noqa: S603
command,
env={TIMEOUT_EXECUTOR_INPUT_FILE: input_file.as_posix()},
env={
TIMEOUT_EXECUTOR_INPUT_FILE: str(input_file),
TIMEOUT_EXECUTOR_INIT_FILE: "" if init_file is None else str(init_file),
},
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
Expand Down Expand Up @@ -121,6 +156,7 @@ def _init_process(
self,
input_file: Path | anyio.Path,
output_file: Path | anyio.Path,
init_file: Path | anyio.Path | None,
stacklevel: int = 2,
) -> AsyncResult[P, T]:
"""init process.
Expand All @@ -144,7 +180,7 @@ def _init_process(
self._create_executor_args, input_file, output_file
)
terminator = Terminator(executor_args_builder, self.callbacks)
process = self._create_process(input_file, stacklevel=stacklevel + 1)
process = self._create_process(input_file, init_file, stacklevel=stacklevel + 1)
result: AsyncResult[P, T] = AsyncResult(process, terminator.executor_args)
terminator.callback_args = CallbackArgs(process=process, result=result)
terminator.start()
Expand All @@ -153,28 +189,50 @@ def _init_process(

def apply(self, *args: P.args, **kwargs: P.kwargs) -> AsyncResult[P, T]:
"""run function with deadline"""
input_file, output_file = self._create_temp_files()
input_file, output_file, init_file = self._create_temp_files()
input_args_as_bytes = self._dump_args(output_file, *args, **kwargs)

logger.debug("%r before write input file", self)
with input_file.open("wb+") as file:
file.write(input_args_as_bytes)
logger.debug("%r after write input file", self)

return self._init_process(input_file, output_file)
init_args_as_bytes = self._dump_initializer()
if init_args_as_bytes is None:
init_file = None
else:
logger.debug("%r before write init file", self)
with init_file.open("wb+") as file:
file.write(init_args_as_bytes)
logger.debug("%r after write init file", self)

return self._init_process(input_file, output_file, init_file)

async def delay(self, *args: P.args, **kwargs: P.kwargs) -> AsyncResult[P, T]:
"""run function with deadline"""
input_file, output_file = self._create_temp_files()
input_file, output_file = anyio.Path(input_file), anyio.Path(output_file)
input_file, output_file, init_file = self._create_temp_files()
input_file, output_file, init_file = (
anyio.Path(input_file),
anyio.Path(output_file),
anyio.Path(init_file),
)
input_args_as_bytes = self._dump_args(output_file, *args, **kwargs)

logger.debug("%r before write input file", self)
async with await input_file.open("wb+") as file:
await file.write(input_args_as_bytes)
logger.debug("%r after write input file", self)

return self._init_process(input_file, output_file)
init_args_as_bytes = self._dump_initializer()
if init_args_as_bytes is None:
init_file = None
else:
logger.debug("%r before write init file", self)
async with await init_file.open("wb+") as file:
await file.write(init_args_as_bytes)
logger.debug("%r after write init file", self)

return self._init_process(input_file, output_file, init_file)

@override
def __repr__(self) -> str:
Expand Down Expand Up @@ -237,7 +295,10 @@ def apply_func(
executor = Executor(timeout_or_executor, func)
else:
executor = Executor(
timeout_or_executor.timeout, func, timeout_or_executor.callbacks
timeout_or_executor.timeout,
func,
timeout_or_executor.callbacks,
timeout_or_executor.initializer,
)
return executor.apply(*args, **kwargs)

Expand Down
12 changes: 11 additions & 1 deletion src/timeout_executor/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from typing_extensions import ParamSpec, Self, TypeVar, override

from timeout_executor.executor import apply_func, delay_func
from timeout_executor.types import Callback, ProcessCallback
from timeout_executor.types import Callback, InitializerArgs, ProcessCallback

if TYPE_CHECKING:
from collections.abc import Awaitable, Iterable
Expand All @@ -27,6 +27,7 @@ class TimeoutExecutor(Callback[Any, AnyT], Generic[AnyT]):
def __init__(self, timeout: float) -> None:
self._timeout = timeout
self._callbacks: deque[ProcessCallback[..., AnyT]] = deque()
self.initializer: InitializerArgs[..., Any] | None = None

@property
def timeout(self) -> float:
Expand Down Expand Up @@ -122,3 +123,12 @@ def remove_callback(self, callback: ProcessCallback[..., AnyT]) -> Self:
with suppress(ValueError):
self._callbacks.remove(callback)
return self

def set_initializer(
self, initializer: Callable[P, Any], *args: P.args, **kwargs: P.kwargs
) -> Self:
"""set initializer"""
self.initializer = InitializerArgs(
function=initializer, args=args, kwargs=kwargs
)
return self
11 changes: 10 additions & 1 deletion src/timeout_executor/subprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@
import cloudpickle
from anyio.lowlevel import checkpoint

from timeout_executor.const import TIMEOUT_EXECUTOR_INPUT_FILE
from timeout_executor.const import (
TIMEOUT_EXECUTOR_INIT_FILE,
TIMEOUT_EXECUTOR_INPUT_FILE,
)

if TYPE_CHECKING:
from typing_extensions import ParamSpec, TypeVar
Expand All @@ -24,6 +27,12 @@


def run_in_subprocess() -> None:
init_file = environ.get(TIMEOUT_EXECUTOR_INIT_FILE, "")
if init_file:
with Path(init_file).open("rb") as file_io:
init_func, init_args, init_kwargs = cloudpickle.load(file_io)
init_func(*init_args, **init_kwargs)

input_file = Path(environ.get(TIMEOUT_EXECUTOR_INPUT_FILE, ""))
with input_file.open("rb") as file_io:
func, args, kwargs, output_file = cloudpickle.load(file_io)
Expand Down
7 changes: 7 additions & 0 deletions src/timeout_executor/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,13 @@ class CallbackArgs(Generic[P, T]):
"""process state"""


@dataclass(**_DATACLASS_NON_FROZEN_KWARGS)
class InitializerArgs(Generic[P, T]):
function: Callable[P, T]
args: tuple[Any, ...]
kwargs: dict[str, Any]


class Callback(ABC, Generic[P, T]):
"""callback api interface"""

Expand Down

0 comments on commit 3b21ce8

Please sign in to comment.