Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

V1 release candidate #161

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions examples/functions/parallel_steps.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ async def _step_1b() -> int:
await asyncio.sleep(2)
return 2

return await step.parallel(
return await ctx.group.parallel(
(
lambda: step.run("1a", _step_1a),
lambda: step.run("1b", _step_1b),
Expand All @@ -48,7 +48,7 @@ def _step_1b() -> int:
time.sleep(2)
return 2

return step.parallel(
return ctx.group.parallel_sync(
(
lambda: step.run("1a", _step_1a),
lambda: step.run("1b", _step_1b),
Expand Down
3 changes: 2 additions & 1 deletion inngest/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
TriggerCron,
TriggerEvent,
)
from ._internal.step_lib import Step, StepMemos, StepSync
from ._internal.step_lib import Group, Step, StepMemos, StepSync
from ._internal.types import JSON

__all__ = [
Expand All @@ -33,6 +33,7 @@
"Debounce",
"Event",
"Function",
"Group",
"Inngest",
"JSON",
"Middleware",
Expand Down
2 changes: 2 additions & 0 deletions inngest/_internal/comm_lib/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ async def post(
attempt=request.ctx.attempt,
event=request.event,
events=events,
group=step_lib.Group(),
logger=self._client.logger,
run_id=request.ctx.run_id,
),
Expand Down Expand Up @@ -259,6 +260,7 @@ def post_sync(
attempt=request.ctx.attempt,
event=request.event,
events=events,
group=step_lib.Group(),
logger=self._client.logger,
run_id=request.ctx.run_id,
),
Expand Down
2 changes: 0 additions & 2 deletions inngest/_internal/execution_lib/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ class BaseExecution(typing.Protocol):
async def report_step(
self,
step_info: step_lib.StepInfo,
inside_parallel: bool,
) -> ReportedStep:
...

Expand All @@ -44,7 +43,6 @@ class BaseExecutionSync(typing.Protocol):
def report_step(
self,
step_info: step_lib.StepInfo,
inside_parallel: bool,
) -> ReportedStepSync:
...

Expand Down
1 change: 0 additions & 1 deletion inngest/_internal/execution_lib/experimental.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ def __init__(
async def report_step(
self,
step_info: step_lib.StepInfo,
inside_parallel: bool,
) -> execution_lib.ReportedStep:
step_signal = asyncio.Future[execution_lib.ReportedStep]()
self._pending_steps[step_info.id] = execution_lib.ReportedStep(
Expand Down
1 change: 1 addition & 0 deletions inngest/_internal/execution_lib/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ class Context:
attempt: int
event: server_lib.Event
events: list[server_lib.Event]
group: step_lib.Group
logger: types.Logger
run_id: str

Expand Down
6 changes: 2 additions & 4 deletions inngest/_internal/execution_lib/v0.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ def _handle_skip(
async def report_step(
self,
step_info: step_lib.StepInfo,
inside_parallel: bool,
) -> ReportedStep:
step_signal = asyncio.Future[ReportedStep]()

Expand Down Expand Up @@ -87,7 +86,7 @@ async def report_step(
self._handle_skip(step_info)

is_targeting_enabled = self._target_hashed_id is not None
if inside_parallel and not is_targeting_enabled:
if step_lib.in_parallel.get() and not is_targeting_enabled:
if step_info.op == server_lib.Opcode.STEP_RUN:
step_info.op = server_lib.Opcode.PLANNED

Expand Down Expand Up @@ -244,7 +243,6 @@ def _handle_skip(
def report_step(
self,
step_info: step_lib.StepInfo,
inside_parallel: bool,
) -> ReportedStepSync:
step = ReportedStepSync(step_info)

Expand All @@ -269,7 +267,7 @@ def report_step(
self._handle_skip(step_info)

is_targeting_enabled = self._target_hashed_id is not None
if inside_parallel and not is_targeting_enabled:
if step_lib.in_parallel.get() and not is_targeting_enabled:
if step_info.op == server_lib.Opcode.STEP_RUN:
step_info.op = server_lib.Opcode.PLANNED

Expand Down
3 changes: 3 additions & 0 deletions inngest/_internal/step_lib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@
StepMemos,
StepResponse,
)
from .group import Group, in_parallel
from .step_async import Step
from .step_sync import StepSync

__all__ = [
"Group",
"ParsedStepID",
"ResponseInterrupt",
"SkipInterrupt",
Expand All @@ -20,4 +22,5 @@
"StepMemos",
"StepResponse",
"StepSync",
"in_parallel",
]
1 change: 0 additions & 1 deletion inngest/_internal/step_lib/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,6 @@ def __init__(
target_hashed_id: typing.Optional[str],
) -> None:
self._client = client
self._inside_parallel = False
self._memos = memos
self._middleware = middleware
self._step_id_counter = step_id_counter
Expand Down
75 changes: 75 additions & 0 deletions inngest/_internal/step_lib/group.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import contextvars
import typing

from inngest._internal import types

from .base import ResponseInterrupt, SkipInterrupt, StepResponse

# Create a context variable to track if we're in a parallel group
in_parallel = contextvars.ContextVar("in_parallel", default=False)


class Group:
async def parallel(
self,
callables: tuple[typing.Callable[[], typing.Awaitable[types.T]], ...],
) -> tuple[types.T, ...]:
"""
Run multiple steps in parallel.

Args:
----
callables: An arbitrary number of step callbacks to run. These are callables that contain the step (e.g. `lambda: step.run("my_step", my_step_fn)`.
"""

token = in_parallel.set(True)
try:
outputs = tuple[types.T]()
responses: list[StepResponse] = []
for cb in callables:
try:
output = await cb()
outputs = (*outputs, output)
except ResponseInterrupt as interrupt:
responses = [*responses, *interrupt.responses]
except SkipInterrupt:
pass

if len(responses) > 0:
raise ResponseInterrupt(responses)

return outputs
finally:
in_parallel.reset(token)

def parallel_sync(
self,
callables: tuple[typing.Callable[[], types.T], ...],
) -> tuple[types.T, ...]:
"""
Run multiple steps in parallel.

Args:
----
callables: An arbitrary number of step callbacks to run. These are callables that contain the step (e.g. `lambda: step.run("my_step", my_step_fn)`.
"""

token = in_parallel.set(True)
try:
outputs = tuple[types.T]()
responses: list[StepResponse] = []
for cb in callables:
try:
output = cb()
outputs = (*outputs, output)
except ResponseInterrupt as interrupt:
responses = [*responses, *interrupt.responses]
except SkipInterrupt:
pass

if len(responses) > 0:
raise ResponseInterrupt(responses)

return outputs
finally:
in_parallel.reset(token)
51 changes: 4 additions & 47 deletions inngest/_internal/step_lib/step_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,10 +132,7 @@ async def invoke_by_id(
opts=opts,
)

async with await self._execution.report_step(
step_info,
self._inside_parallel,
) as step:
async with await self._execution.report_step(step_info) as step:
if step.skip:
raise base.SkipInterrupt(parsed_step_id.user_facing)
if step.error is not None:
Expand All @@ -145,37 +142,6 @@ async def invoke_by_id(

raise Exception("unreachable")

async def parallel(
self,
callables: tuple[typing.Callable[[], typing.Awaitable[types.T]], ...],
) -> tuple[types.T, ...]:
"""
Run multiple steps in parallel.

Args:
----
callables: An arbitrary number of step callbacks to run. These are callables that contain the step (e.g. `lambda: step.run("my_step", my_step_fn)`.
"""

self._inside_parallel = True

outputs = tuple[types.T]()
responses: list[base.StepResponse] = []
for cb in callables:
try:
output = await cb()
outputs = (*outputs, output)
except base.ResponseInterrupt as interrupt:
responses = [*responses, *interrupt.responses]
except base.SkipInterrupt:
pass

if len(responses) > 0:
raise base.ResponseInterrupt(responses)

self._inside_parallel = False
return outputs

@typing.overload
async def run(
self,
Expand Down Expand Up @@ -232,10 +198,7 @@ async def run(
op=server_lib.Opcode.STEP_RUN,
)

async with await self._execution.report_step(
step_info,
self._inside_parallel,
) as step:
async with await self._execution.report_step(step_info) as step:
if step.skip:
raise base.SkipInterrupt(parsed_step_id.user_facing)
if step.error is not None:
Expand Down Expand Up @@ -364,10 +327,7 @@ async def sleep_until(
op=server_lib.Opcode.SLEEP,
)

async with await self._execution.report_step(
step_info,
self._inside_parallel,
) as step:
async with await self._execution.report_step(step_info) as step:
if step.skip:
raise base.SkipInterrupt(parsed_step_id.user_facing)
if step.error is not None:
Expand Down Expand Up @@ -417,10 +377,7 @@ async def wait_for_event(
opts=opts,
)

async with await self._execution.report_step(
step_info,
self._inside_parallel,
) as step:
async with await self._execution.report_step(step_info) as step:
if step.skip:
raise base.SkipInterrupt(parsed_step_id.user_facing)
if step.error is not None:
Expand Down
51 changes: 4 additions & 47 deletions inngest/_internal/step_lib/step_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,10 +147,7 @@ def invoke_by_id(
opts=opts,
)

with self._execution.report_step(
step_info,
self._inside_parallel,
) as step:
with self._execution.report_step(step_info) as step:
if step.skip:
raise base.SkipInterrupt(parsed_step_id.user_facing)
if step.error is not None:
Expand All @@ -160,37 +157,6 @@ def invoke_by_id(

raise Exception("unreachable")

def parallel(
self,
callables: tuple[typing.Callable[[], types.T], ...],
) -> tuple[types.T, ...]:
"""
Run multiple steps in parallel.

Args:
----
callables: An arbitrary number of step callbacks to run. These are callables that contain the step (e.g. `lambda: step.run("my_step", my_step_fn)`.
"""

self._inside_parallel = True

outputs = tuple[types.T]()
responses: list[base.StepResponse] = []
for cb in callables:
try:
output = cb()
outputs = (*outputs, output)
except base.ResponseInterrupt as interrupt:
responses = [*responses, *interrupt.responses]
except base.SkipInterrupt:
pass

if len(responses) > 0:
raise base.ResponseInterrupt(responses)

self._inside_parallel = False
return outputs

def run(
self,
step_id: str,
Expand Down Expand Up @@ -219,10 +185,7 @@ def run(
op=server_lib.Opcode.STEP_RUN,
)

with self._execution.report_step(
step_info,
self._inside_parallel,
) as step:
with self._execution.report_step(step_info) as step:
if step.skip:
raise base.SkipInterrupt(parsed_step_id.user_facing)
if step.error is not None:
Expand Down Expand Up @@ -347,10 +310,7 @@ def sleep_until(
op=server_lib.Opcode.SLEEP,
)

with self._execution.report_step(
step_info,
self._inside_parallel,
) as step:
with self._execution.report_step(step_info) as step:
if step.skip:
raise base.SkipInterrupt(parsed_step_id.user_facing)
if step.error is not None:
Expand Down Expand Up @@ -418,10 +378,7 @@ def wait_for_event(
opts=opts,
)

with self._execution.report_step(
step_info,
self._inside_parallel,
) as step:
with self._execution.report_step(step_info) as step:
if step.skip:
raise base.SkipInterrupt(parsed_step_id.user_facing)
if step.error is not None:
Expand Down
Loading
Loading