diff --git a/src/vercel/workflow/__init__.py b/src/vercel/workflow/__init__.py index 329729d..e487277 100644 --- a/src/vercel/workflow/__init__.py +++ b/src/vercel/workflow/__init__.py @@ -1,4 +1,4 @@ -from .core import sleep, step, workflow +from .core import HookEvent, HookMixin, sleep, step, workflow from .runtime import Run, start -__all__ = ["step", "workflow", "sleep", "start", "Run"] +__all__ = ["step", "workflow", "sleep", "start", "Run", "HookMixin", "HookEvent"] diff --git a/src/vercel/workflow/core.py b/src/vercel/workflow/core.py index 0438756..7b5a44f 100644 --- a/src/vercel/workflow/core.py +++ b/src/vercel/workflow/core.py @@ -1,11 +1,26 @@ +from __future__ import annotations + +import dataclasses import datetime -from collections.abc import Callable, Coroutine -from typing import Any, Generic, ParamSpec, TypeVar +import json +import sys +from collections.abc import AsyncIterator, Callable, Coroutine, Generator +from typing import TYPE_CHECKING, Any, Generic, ParamSpec, TypeVar + +if sys.version_info >= (3, 11): + from typing import Self +else: + from typing_extensions import Self + +import pydantic + +if TYPE_CHECKING: + from . import world as w P = ParamSpec("P") T = TypeVar("T") -_workflows: dict[str, "Workflow[Any, Any]"] = {} -_steps: dict[str, "Step[Any, Any]"] = {} +_workflows: dict[str, Workflow[Any, Any]] = {} +_steps: dict[str, Step[Any, Any]] = {} class Workflow(Generic[P, T]): @@ -65,3 +80,83 @@ async def sleep(param: int | float | datetime.datetime | str) -> None: raise RuntimeError("cannot call sleep outside workflow") from None await ctx.run_wait(param) + + +class HookEvent(Generic[T]): + def __init__(self, *, correlation_id: str, token: str) -> None: + self._correlation_id = correlation_id + self._token = token + self._disposed = False + + def __await__(self) -> Generator[Any, None, T | None]: + async def next_or_none() -> T | None: + try: + return await self.__anext__() + except StopAsyncIteration: + return None + + return next_or_none().__await__() + + def __aiter__(self) -> AsyncIterator[T]: + return self + + async def __anext__(self) -> T: + from . import runtime + + try: + ctx = runtime.WorkflowOrchestratorContext.current() + except LookupError: + raise RuntimeError("cannot iterate HookEvent outside workflow") from None + + return await ctx.run_hook(correlation_id=self._correlation_id) + + def dispose(self) -> None: + if self._disposed: + return + + from . import runtime + + try: + ctx = runtime.WorkflowOrchestratorContext.current() + except LookupError: + raise RuntimeError("cannot call dispose() outside workflow") from None + + self._disposed = True + ctx.dispose_hook(correlation_id=self._correlation_id) + + +class HookMixin: + @classmethod + def wait(cls, *, token: str | None = None) -> HookEvent[Self]: + from . import runtime + + try: + ctx = runtime.WorkflowOrchestratorContext.current() + except LookupError: + raise RuntimeError("cannot call wait() outside workflow") from None + else: + return ctx.create_hook(token, cls) + + async def resume(self, token_or_hook: str | w.Hook, **kwargs) -> w.Hook: + from . import runtime + + try: + runtime.WorkflowOrchestratorContext.current() + except LookupError: + pass + else: + raise RuntimeError("cannot call resume() inside workflow") + + if isinstance(self, pydantic.BaseModel): + json_str = self.model_dump_json(**kwargs) + elif dataclasses.is_dataclass(self): + obj = dataclasses.asdict(self, dict_factory=kwargs.pop("dict_factory", dict)) + json_str = json.dumps(obj, **kwargs) + else: + raise TypeError("resume only supports pydantic models or dataclasses") + + return await runtime.resume_hook(token_or_hook, json_str) + + +# must not import in workflow +# celery/temporal/CLI style, how to import subapps diff --git a/src/vercel/workflow/nanoid.py b/src/vercel/workflow/nanoid.py new file mode 100644 index 0000000..e8eaa4d --- /dev/null +++ b/src/vercel/workflow/nanoid.py @@ -0,0 +1,228 @@ +""" +Minimal Nano ID implementation in Python. + +Based on the JavaScript Nano ID library: https://github.com/ai/nanoid +This implementation provides URL-friendly unique string IDs with customizable +alphabet and size, with support for custom PRNG functions. + +Reference: https://github.com/ai/nanoid/blob/main/index.js +""" + +import os +from collections.abc import Callable +from math import ceil + +# This alphabet uses `A-Za-z0-9_-` symbols. +# The order of characters is optimized for better gzip and brotli compression. +# Same as JS: https://github.com/ai/nanoid/blob/main/url-alphabet/index.js +# References to the same file (works both for gzip and brotli): +# `'use`, `andom`, and `rict'` +# References to the brotli default dictionary: +# `-26T`, `1983`, `40px`, `75px`, `bush`, `jack`, `mind`, `very`, and `wolf` +URL_ALPHABET = "useandom-26T198340PX75pxJACKVERYMINDBUSHWOLF_GQZbfghjklqvwyzrict" +DEFAULT_SIZE = 21 + + +def _detect_prng() -> Callable[[int], bytes]: + """ + Create a default random bytes generator using os.urandom() for + cryptographically secure randomness. + + Matches JavaScript's crypto.getRandomValues() behavior. + """ + return os.urandom + + +def _custom_alphabet_generator( + alphabet: str, + default_size: int, + get_random: Callable[[int], bytes], +) -> Callable[..., str]: + """ + Core nanoid generator function (matches JS nanoid/index.js customRandom). + + Args: + alphabet: Characters to use for ID generation. + default_size: Default length for generated IDs. + get_random: Function to generate random bytes. + + Returns: + A function that generates IDs with the custom alphabet. + + Reference: https://github.com/ai/nanoid/blob/main/index.js customRandom() + """ + alphabet_len = len(alphabet) + + if alphabet_len == 0 or alphabet_len > 256: + raise ValueError("Alphabet must contain between 1 and 256 symbols") + + # First, a bitmask is necessary to generate the ID. The bitmask makes bytes + # values closer to the alphabet size. The bitmask calculates the closest + # `2^31 - 1` number, which exceeds the alphabet size. + # For example, the bitmask for the alphabet size 30 is 31 (00011111). + # Matches JS: mask = (2 << (31 - Math.clz32((alphabet.length - 1) | 1))) - 1 + mask = (2 << (31 - _clz32((alphabet_len - 1) | 1))) - 1 + + # Though, the bitmask solution is not perfect since the bytes exceeding + # the alphabet size are refused. Therefore, to reliably generate the ID, + # the random bytes redundancy has to be satisfied. + + # Note: every hardware random generator call is performance expensive, + # because the system call for entropy collection takes a lot of time. + # So, to avoid additional system calls, extra bytes are requested in advance. + + # Next, a step determines how many random bytes to generate. + # The number of random bytes gets decided upon the ID size, mask, + # alphabet size, and magic number 1.6 (using 1.6 peaks at performance + # according to benchmarks). + # Matches JS: step = Math.ceil((1.6 * mask * defaultSize) / alphabet.length) + step = ceil((1.6 * mask * default_size) / alphabet_len) + + def generate_id(size: int = default_size) -> str: + """Generate a nano ID of the specified size.""" + if size <= 0: + raise ValueError("Size must be positive") + + id_str = "" + while True: + random_bytes = get_random(step) + + # A compact alternative for `for (let i = 0; i < step; i++)`. + # Matches JS nanoid implementation + i = step + while i > 0: + i -= 1 + # Adding `|| ''` refuses a random byte that exceeds the alphabet size. + # Matches JS: id += alphabet[bytes[i] & mask] || '' + byte_index = random_bytes[i] & mask + if byte_index < alphabet_len: + id_str += alphabet[byte_index] + if len(id_str) >= size: + return id_str + + return generate_id + + +def _clz32(n: int) -> int: + """ + Count leading zeros in 32-bit integer. + Matches JavaScript's Math.clz32(). + + Examples: + _clz32(1) == 31 + _clz32(2) == 30 + _clz32(3) == 30 + _clz32(4) == 29 + """ + if n == 0: + return 32 + # Convert to 32-bit unsigned integer + n = n & 0xFFFFFFFF + if n == 0: + return 32 + # Check each bit from MSB + for i in range(31, -1, -1): + if n & (1 << i): + return 31 - i + return 32 + + +def custom_alphabet( + alphabet: str, + size: int = DEFAULT_SIZE, +) -> Callable[..., str]: + """ + Create a custom ID generator with a specific alphabet. + + This factory function returns a generator that uses the specified alphabet. + Matches JS nanoid customAlphabet() function. + + Args: + alphabet: Characters to use for ID generation. + size: Default length for generated IDs. + + Returns: + A function that generates IDs with the custom alphabet. + + Examples: + # Create a hex ID generator + hex_id = custom_alphabet('0123456789abcdef', 16) + id1 = hex_id() # '4f3a2b1c9d8e7f6a' + id2 = hex_id(8) # '9d8e7f6a' + + # Create a custom alphabet generator + safe_id = custom_alphabet('ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789') + id3 = safe_id() # 'K3PQ9XY7ZR2J8Y4H3L6M9' + """ + return _custom_alphabet_generator(alphabet, size, _detect_prng()) + + +def custom_random( + alphabet: str, + size: int, + prng: Callable[[], float], +) -> Callable[..., str]: + """ + Create a custom ID generator with a specific alphabet and PRNG. + + This matches the style of ulid.monotonic_factory() and allows for + deterministic ID generation (useful for testing or workflow replay). + + Args: + alphabet: Characters to use for ID generation. + size: Default length for generated IDs. + prng: Pseudo-random number generator function that returns float in [0, 1). + Example: random.Random(seed).random + + Returns: + A function that generates IDs with the custom alphabet and PRNG. + + Examples: + # Create a deterministic generator for testing + import random + prng = random.Random(42).random + test_id = custom_random('0123456789', 10, prng) + id1 = test_id() # Always generates the same ID with the same seed + + # Usage in workflow context (similar to ulid) + prng = random.Random(workflow_seed).random + nanoid_gen = custom_random(URL_ALPHABET, 21, prng) + token = nanoid_gen() + """ + + def get_random_bytes(n: int) -> bytes: + """Convert PRNG floats to random bytes.""" + # Match JavaScript: use floor(prng() * 256) to get byte values + return bytes(int(prng() * 256) for _ in range(n)) + + return _custom_alphabet_generator(alphabet, size, get_random_bytes) + + +def generate( + alphabet: str = URL_ALPHABET, + size: int = DEFAULT_SIZE, +) -> str: + """ + Generate a Nano ID string using default cryptographically secure randomness. + + This is the main function that matches JS nanoid() behavior. + + Args: + alphabet: Characters to use for ID generation. Default is URL-safe alphabet. + size: Length of the generated ID. Default is 21. + + Returns: + A random string ID of specified size using the specified alphabet. + + Examples: + # Basic usage with defaults + id1 = generate() # 'V1StGXR8_Z5jdHi6B-myT' + + # Custom size + id2 = generate(size=10) # 'IRFa-VaY2b' + + # Custom alphabet (numbers only) + id3 = generate(alphabet='0123456789', size=6) # '482014' + """ + generator = _custom_alphabet_generator(alphabet, size, _detect_prng()) + return generator(size) diff --git a/src/vercel/workflow/runtime.py b/src/vercel/workflow/runtime.py index 849a850..cd70a62 100644 --- a/src/vercel/workflow/runtime.py +++ b/src/vercel/workflow/runtime.py @@ -5,34 +5,61 @@ import json import random import re +import sys import traceback -from datetime import UTC, datetime, timedelta -from typing import Any, Generic, Literal, ParamSpec, Self, TypeVar +from collections import deque +from datetime import datetime, timedelta, timezone +from typing import Any, Generic, Literal, ParamSpec, TypeVar + +if sys.version_info >= (3, 11): + from typing import Self +else: + from typing_extensions import Self import anyio +import pydantic -from . import core, ulid, world as w +from . import core, nanoid, ulid, world as w P = ParamSpec("P") T = TypeVar("T") SUSPENDED_MESSAGE = "" -@dataclasses.dataclass -class Suspension(Generic[T]): +@dataclasses.dataclass(kw_only=True) +class BaseSuspension: correlation_id: str + has_created_event: bool = False + + +@dataclasses.dataclass(kw_only=True) +class Suspension(BaseSuspension, Generic[T]): step: core.Step[Any, T] input: bytes future: asyncio.Future[T] = dataclasses.field(default_factory=asyncio.Future) - has_created_event: bool = False -@dataclasses.dataclass -class Wait: - correlation_id: str +@dataclasses.dataclass(kw_only=True) +class Wait(BaseSuspension): resume_at: datetime future: asyncio.Future[None] = dataclasses.field(default_factory=asyncio.Future) - has_created_event: bool = False + + +@dataclasses.dataclass(kw_only=True) +class Hook(BaseSuspension, Generic[T]): + token: str + disposed: bool = False + futures: deque[asyncio.Future[T]] = dataclasses.field(default_factory=deque) + hook_cls: type[T] + + def set_result(self, raw_data: Any) -> None: + if dataclasses.is_dataclass(self.hook_cls): + res = self.hook_cls(**raw_data) + elif issubclass(self.hook_cls, pydantic.BaseModel): + res = self.hook_cls.model_validate(raw_data) + else: + raise RuntimeError(f"Invalid hook type for {self.hook_cls}") + self.futures.popleft().set_result(res) class WorkflowOrchestratorContext: @@ -43,8 +70,10 @@ def __init__(self, events: list[w.Event], *, seed: str, started_at: int): self.replay_index = 0 prng = random.Random(seed) self.generate_ulid = functools.partial(ulid.monotonic_factory(prng.random), started_at) + self.generate_nanoid = nanoid.custom_random(nanoid.URL_ALPHABET, 21, prng.random) self._fut: asyncio.Future[Any] | None = None - self.suspensions: dict[str, Wait | Suspension[Any]] = {} + self.suspensions: dict[str, BaseSuspension] = {} + self.hooks: dict[str, Hook] = {} self.resume_handle: asyncio.Handle | None = None @classmethod @@ -83,6 +112,32 @@ async def run_wait(self, param: int | float | datetime | str) -> None: self.resume_handle = asyncio.get_running_loop().call_soon(self.resume) await wait.future + def create_hook(self, token: str | None, hook_cls: type[T]) -> core.HookEvent[T]: + hook = Hook( + correlation_id=f"hook_{self.generate_ulid()}", + token=token or self.generate_nanoid(), + hook_cls=hook_cls, + ) + self.hooks[hook.correlation_id] = hook + return core.HookEvent(correlation_id=hook.correlation_id, token=hook.token) + + async def run_hook(self, *, correlation_id: str) -> T: + hook = self.hooks[correlation_id] + if hook.disposed: + raise StopAsyncIteration + self.suspensions[hook.correlation_id] = hook + fut = asyncio.Future[T]() + hook.futures.append(fut) + if self.resume_handle is None: + self.resume_handle = asyncio.get_running_loop().call_soon(self.resume) + return await fut + + def dispose_hook(self, *, correlation_id: str) -> None: + hook = self.hooks[correlation_id] + hook.disposed = True + while hook.futures: + hook.futures.popleft().set_exception(StopAsyncIteration) + def resume(self) -> None: self.resume_handle = None @@ -94,11 +149,12 @@ def resume(self) -> None: self.replay_index += 1 match event: - case w.StepCreatedEvent() | w.WaitCreatedEvent(): + case w.StepCreatedEvent() | w.HookCreatedEvent() | w.WaitCreatedEvent(): self.suspensions[event.correlation_id].has_created_event = True case w.StepCompletedEvent(event_data=w.StepCompletedEventData(result=data)): sus = self.suspensions.pop(event.correlation_id) + assert isinstance(sus, Suspension) if data[0].startswith(b"json"): result = json.loads(data[0][len(b"json") :].decode()) else: @@ -111,12 +167,44 @@ def resume(self) -> None: case w.WaitCompletedEvent(): wait = self.suspensions.pop(event.correlation_id) + assert isinstance(wait, Wait) wait.future.set_result(None) case w.StepFailedEvent(event_data=w.StepFailedEventData(error=e)): sus = self.suspensions.pop(event.correlation_id) + assert isinstance(sus, Suspension) sus.future.set_exception(RuntimeError(e)) + case w.HookConflictEvent(event_data=w.HookConflictEventData(token=token)): + hook = self.suspensions.pop(event.correlation_id, None) + if hook is not None: + assert isinstance(hook, Hook) + while hook.futures: + hook.futures.popleft().set_exception( + RuntimeError( + f'Hook token "{token}" is already in use by another workflow' + ) + ) + + case w.HookReceivedEvent(event_data=w.HookReceivedEventData(payload=data)): + hook = self.suspensions[event.correlation_id] + assert isinstance(hook, Hook) + if data[0].startswith(b"json"): + result = json.loads(data[0][len(b"json") :].decode()) + else: + self._fut.cancel( + f"Unsupported step result encoding for " + f"correlation ID {event.correlation_id}" + ) + return + hook.set_result(result) + if not hook.futures: + self.suspensions.pop(event.correlation_id) + + case w.HookDisposedEvent(): + self.suspensions.pop(event.correlation_id, None) + self.hooks.pop(event.correlation_id) + if self.suspensions: self._fut.cancel(SUSPENDED_MESSAGE) @@ -152,7 +240,7 @@ async def workflow_handler( events = await get_all_workflow_run_events(run_id) # Check for any elapsed waits and create wait_completed events - now = datetime.now(UTC) + now = datetime.now(timezone.utc) # Pre-compute completed correlation IDs for O(n) lookup instead of O(n²) completed_wait_ids = {e.correlation_id for e in events if e.event_type == "wait_completed"} @@ -214,14 +302,25 @@ async def workflow_handler( workflowRunId=run_id, workflowStartedAt=workflow_started_at, stepId=sus.correlation_id, - requestedAt=datetime.now(UTC), + requestedAt=datetime.now(timezone.utc), ), ) elif isinstance(sus, Wait): wait_data = w.WaitCreatedEventData(resumeAt=sus.resume_at) tg.start_soon(world.events_create, run_id, wait_data.into_event(sus.correlation_id)) + elif isinstance(sus, Hook): + hook_data = w.HookCreatedEventData(token=sus.token) + tg.start_soon(world.events_create, run_id, hook_data.into_event(sus.correlation_id)) + + for hook in context.hooks.values(): + if hook.disposed: + tg.start_soon( + world.events_create, + run_id, + w.HookDisposedEvent(correlationId=hook.correlation_id), + ) - now = datetime.now(UTC) + now = datetime.now(timezone.utc) min_timeout_seconds = -1.0 for sus in context.suspensions.values(): if isinstance(sus, Wait): @@ -248,7 +347,7 @@ async def step_handler( step = core.get_step(step_run.step_name) # Check if retry_after timestamp hasn't been reached yet - now = datetime.now(UTC) + now = datetime.now(timezone.utc) if step_run.retry_after and step_run.retry_after > now: timeout_seconds = max(1, int((step_run.retry_after - now).total_seconds())) return timeout_seconds @@ -275,7 +374,10 @@ async def step_handler( # Re-invoke the workflow to handle the failed step await world.queue( f"__wkf_workflow_{req.workflow_name}", - w.WorkflowInvokePayload(runId=req.workflow_run_id, requestedAt=datetime.now(UTC)), + w.WorkflowInvokePayload( + runId=req.workflow_run_id, + requestedAt=datetime.now(timezone.utc), + ), ) return None @@ -383,7 +485,7 @@ async def step_handler( # Re-invoke the workflow to continue execution await world.queue( f"__wkf_workflow_{req.workflow_name}", - w.WorkflowInvokePayload(runId=req.workflow_run_id, requestedAt=datetime.now(UTC)), + w.WorkflowInvokePayload(runId=req.workflow_run_id, requestedAt=datetime.now(timezone.utc)), ) return None @@ -454,12 +556,12 @@ def parse_duration_to_date(param: int | float | datetime | str) -> datetime: ms = sum(items) if ms < 0: raise RuntimeError(f"Duration parameter must be non-negative: {param}") - return datetime.now(UTC) + timedelta(milliseconds=ms) + return datetime.now(timezone.utc) + timedelta(milliseconds=ms) elif isinstance(param, (int, float)): if param < 0: raise RuntimeError(f"Duration parameter must be non-negative: {param}") - return datetime.now(UTC) + timedelta(milliseconds=param) + return datetime.now(timezone.utc) + timedelta(milliseconds=param) elif isinstance(param, datetime): if param.tzinfo is None: @@ -524,3 +626,20 @@ async def start(wf: core.Workflow[P, T], *args: P.args, **kwargs: P.kwargs) -> R ) return Run(run_id) + + +async def resume_hook(token_or_hook: str | w.Hook, payload_json: str) -> w.Hook: + world = w.get_world() + if isinstance(token_or_hook, str): + hook = await world.hooks_get_by_token(token_or_hook) + else: + hook = token_or_hook + run = await world.runs_get(hook.run_id) + payload = b"json" + payload_json.encode() + data = w.HookReceivedEventData(payload=[payload]) + await world.events_create(hook.run_id, data.into_event(hook.hook_id)) + await world.queue( + f"__wkf_workflow_{run.workflow_name}", + w.WorkflowInvokePayload(runId=hook.run_id), + ) + return hook diff --git a/src/vercel/workflow/world.py b/src/vercel/workflow/world.py index dfebcd9..ae522ed 100644 --- a/src/vercel/workflow/world.py +++ b/src/vercel/workflow/world.py @@ -2,6 +2,7 @@ import dataclasses import json import os +import sys from datetime import datetime from typing import ( Annotated, @@ -9,12 +10,16 @@ Generic, Literal, Protocol, - Self, TypeAlias, TypeVar, overload, ) +if sys.version_info >= (3, 11): + from typing import Self +else: + from typing_extensions import Self + import pydantic T = TypeVar("T") @@ -123,34 +128,6 @@ class BaseWorkflowRun(BaseModel): created_at: datetime = pydantic.Field(alias="createdAt") updated_at: datetime = pydantic.Field(alias="updatedAt") - # @pydantic.model_validator(mode="wrap") - # @classmethod - # def discriminate( - # cls, - # data: Any, - # handler: pydantic.ModelWrapValidatorHandler[Self], - # info: pydantic.ValidationInfo, - # ) -> Self: - # if isinstance(info.context, _ContextWrapper): - # return handler(data) - # - # config = info.config or {} - # args = {"context": _ContextWrapper(info.context)} - # if "strict" in config: - # args["strict"] = config["strict"] - # if "extra_fields_behavior" in config: - # args["extra"] = config["extra_fields_behavior"] - # if "validate_by_alias" in config: - # args["by_alias"] = config["validate_by_alias"] - # if "validate_by_name" in config: - # args["by_name"] = config["validate_by_name"] - # if info.mode == "python": - # if "from_attributes" in config: - # args["from_attributes"] = config["from_attributes"] - # return WorkflowRunAdaptor.validate_python(data, **args) - # else: - # return WorkflowRunAdaptor.validate_json(data, **args) - class NonFinalWorkflowRun(BaseWorkflowRun): status: Literal["pending", "running"] @@ -451,6 +428,87 @@ class StepFailedEvent(BaseEvent): event_data: StepFailedEventData = pydantic.Field(alias="eventData") +class Hook(BaseModel): + run_id: str = pydantic.Field(alias="runId") + hook_id: str = pydantic.Field(alias="hookId") + token: str + owner_id: str = pydantic.Field(alias="ownerId") + project_id: str = pydantic.Field(alias="projectId") + environment: str + metadata: list[bytes] | None = None + created_at: datetime = pydantic.Field(alias="createdAt") + spec_version: int | None = pydantic.Field(default=None, alias="specVersion") + is_webhook: bool | None = pydantic.Field(default=None, alias="isWebhook") + + +class HookCreatedEventData(BaseModel): + token: str + metadata: list[bytes] | None = pydantic.Field(default=None, exclude_if=lambda e: e is None) + + def into_event(self, correlation_id: str) -> "HookCreatedEvent": + return HookCreatedEvent(correlationId=correlation_id, eventData=self) + + +class HookCreatedEvent(BaseEvent): + """ + Event created when a hook is first invoked. The World implementation + atomically creates both the event and the hook entity. + """ + + event_type: Literal["hook_created"] = pydantic.Field( + default="hook_created", + alias="eventType", + ) + correlation_id: str = pydantic.Field(alias="correlationId") + event_data: HookCreatedEventData = pydantic.Field(alias="eventData") + + +class HookReceivedEventData(BaseModel): + payload: list[bytes] + + def into_event(self, correlation_id: str) -> "HookReceivedEvent": + return HookReceivedEvent(correlationId=correlation_id, eventData=self) + + +class HookReceivedEvent(BaseEvent): + event_type: Literal["hook_received"] = pydantic.Field( + default="hook_received", + alias="eventType", + ) + correlation_id: str = pydantic.Field(alias="correlationId") + event_data: HookReceivedEventData = pydantic.Field(alias="eventData") + + +class HookDisposedEvent(BaseEvent): + event_type: Literal["hook_disposed"] = pydantic.Field( + default="hook_disposed", + alias="eventType", + ) + correlation_id: str = pydantic.Field(alias="correlationId") + + +class HookConflictEventData(BaseModel): + token: str + + +class HookConflictEvent(BaseEvent): + """ + Event created by World implementations when a hook_created request + conflicts with an existing hook token. This event is NOT user-creatable - + it is only returned by the World when a token conflict is detected. + + When the hook consumer sees this event, it should reject any awaited + promises with a HookTokenConflictError. + """ + + event_type: Literal["hook_conflict"] = pydantic.Field( + default="hook_conflict", + alias="eventType", + ) + correlation_id: str = pydantic.Field(alias="correlationId") + event_data: HookConflictEventData = pydantic.Field(alias="eventData") + + class WaitCreatedEventData(BaseModel): resume_at: datetime = pydantic.Field(alias="resumeAt") @@ -484,6 +542,9 @@ class WaitCompletedEvent(BaseEvent): | StepRetryingEvent | StepCompletedEvent | StepFailedEvent + | HookCreatedEvent + | HookReceivedEvent + | HookDisposedEvent | WaitCreatedEvent | WaitCompletedEvent ) @@ -498,6 +559,10 @@ class WaitCompletedEvent(BaseEvent): | StepRetryingEvent | StepCompletedEvent | StepFailedEvent + | HookCreatedEvent + | HookReceivedEvent + | HookDisposedEvent + | HookConflictEvent | WaitCreatedEvent | WaitCompletedEvent ), @@ -616,6 +681,9 @@ async def runs_get(self, run_id: str) -> WorkflowRun: ... @abc.abstractmethod async def steps_get(self, run_id: str, step_id: str) -> WorkflowStep: ... + @abc.abstractmethod + async def hooks_get_by_token(self, token: str) -> Hook: ... + @overload async def events_create(self, run_id: None, data: RunCreatedEvent) -> EventResult: """ diff --git a/src/vercel/workflow/worlds/local.py b/src/vercel/workflow/worlds/local.py index a264814..392ca7d 100644 --- a/src/vercel/workflow/worlds/local.py +++ b/src/vercel/workflow/worlds/local.py @@ -1,8 +1,9 @@ +import hashlib import json import os import pathlib import traceback -from datetime import UTC, datetime +from datetime import datetime, timezone from typing import Any, TypeVar import cbor2 @@ -53,11 +54,34 @@ def write_json(path: pathlib.Path, data: w.BaseModel | dict, *, overwrite: bool cbor2.dump(data, f) +def write_exclusive(path: pathlib.Path, data: str) -> bool: + path.parent.mkdir(parents=True, exist_ok=True) + try: + with path.open("x") as f: + f.write(data) + except FileExistsError: + return False + else: + return True + + class LocalWorld(w.World): def __init__(self) -> None: self.monotonic_ulid = monotonic_factory() self.data_dir = pathlib.Path(os.getenv("WORKFLOW_LOCAL_DATA_DIR", ".workflow-data")) + def delete_all_hooks_for_run(self, run_id: str) -> None: + hooks_dir = self.data_dir / "hooks" + for hook_path in hooks_dir.iterdir(): + if hook_path.suffix != ".json": + continue + hook = read_json(hook_path, w.Hook) + if hook is not None and hook.run_id == run_id: + hashed_token = hashlib.sha256(hook.token.encode()).hexdigest() + constraint_path = hooks_dir / "tokens" / f"${hashed_token}.json" + constraint_path.unlink(missing_ok=True) + hook_path.unlink(missing_ok=True) + async def get_deployment_id(self) -> str: return "" @@ -191,9 +215,20 @@ async def steps_get(self, run_id: str, step_id: str) -> w.WorkflowStep: raise RuntimeError(f"Step {step_id} not found in run {run_id}") return step + async def hooks_get_by_token(self, token: str) -> w.Hook: + hooks_dir = self.data_dir / "hooks" + if hooks_dir.exists(): + for hook_path in hooks_dir.iterdir(): + if hook_path.suffix != ".json": + continue + hook = read_json(hook_path, w.Hook) + if hook is not None and hook.token == token: + return hook + raise RuntimeError(f"Hook with token {token!r} not found") + async def events_create(self, run_id: str | None, data: w.Event) -> w.EventResult: event_id = f"evnt_{self.monotonic_ulid(None)}" - now = datetime.now(UTC) + now = datetime.now(timezone.utc) if data.event_type == "run_created" and not run_id: effective_run_id = f"wrun_{self.monotonic_ulid(None)}" @@ -257,6 +292,13 @@ async def events_create(self, run_id: str | None, data: w.Event) -> w.EventResul f'"{current_run.status}"' ) + hook_events_requiring_existance = ["hook_disposed", "hook_received"] + if data.event_type in hook_events_requiring_existance and data.correlation_id: + hook_path = self.data_dir / "hooks" / f"{data.correlation_id}.json" + existing_hook = read_json(hook_path, w.Hook) + if existing_hook is None: + raise RuntimeError(f"Hook {data.correlation_id!r} not found") + event = w.EventAdaptor.validate_python( data.model_dump() | { @@ -322,6 +364,7 @@ async def events_create(self, run_id: str | None, data: w.Event) -> w.EventResul ) run_path = self.data_dir / "runs" / f"{effective_run_id}.json" write_json(run_path, run, overwrite=True) + self.delete_all_hooks_for_run(effective_run_id) elif data.event_type == "run_failed" and hasattr(data, "event_data"): failed_data = data.event_data @@ -361,6 +404,7 @@ async def events_create(self, run_id: str | None, data: w.Event) -> w.EventResul ) run_path = self.data_dir / "runs" / f"{effective_run_id}.json" write_json(run_path, run, overwrite=True) + self.delete_all_hooks_for_run(effective_run_id) elif data.event_type == "run_cancelled": if current_run: @@ -380,6 +424,7 @@ async def events_create(self, run_id: str | None, data: w.Event) -> w.EventResul ) run_path = self.data_dir / "runs" / f"{effective_run_id}.json" write_json(run_path, run, overwrite=True) + self.delete_all_hooks_for_run(effective_run_id) elif data.event_type == "step_created" and hasattr(data, "event_data"): step_data = data.event_data @@ -472,6 +517,69 @@ async def events_create(self, run_id: str | None, data: w.Event) -> w.EventResul ) write_json(step_path, step, overwrite=True) + elif data.event_type == "hook_created" and hasattr(data, "event_data"): + hook_data = data.event_data + hashed_token = hashlib.sha256(hook_data.token.encode()).hexdigest() + constraint_path = self.data_dir / "hooks" / "tokens" / f"{hashed_token}.json" + token_claimed = write_exclusive( + constraint_path, + json.dumps( + { + "token": hook_data.token, + "hookId": data.correlation_id, + "runId": effective_run_id, + } + ), + ) + if not token_claimed: + conflict_event = w.HookConflictEvent( + correlationId=data.correlation_id, + eventData=w.HookConflictEventData(token=hook_data.token), + server_props=w.ServerProps( + runId=effective_run_id, + eventId=event_id, + createdAt=now, + ), + ) + assert conflict_event.server_props is not None + composite_key = f"{effective_run_id}-{event_id}" + event_path = self.data_dir / "events" / f"{composite_key}.json" + write_json( + event_path, + conflict_event.model_dump() | conflict_event.server_props.model_dump(), + ) + return w.EventResult( + event=conflict_event, + run=run, + step=step, + hook=None, + ) + hook = w.Hook( + runId=effective_run_id, + hookId=data.correlation_id, + token=hook_data.token, + metadata=hook_data.metadata, + ownerId="local-owner", + projectId="local-project", + environment="local", + createdAt=now, + specVersion=2, + isWebhook=False, + ) + hook_path = self.data_dir / "hooks" / f"{data.correlation_id}.json" + write_json(hook_path, hook) + + elif data.event_type == "hook_disposed": + hook_path = self.data_dir / "hooks" / f"{data.correlation_id}.json" + existing_hook = read_json(hook_path, w.Hook) + if existing_hook is not None: + hashed_token = hashlib.sha256(existing_hook.token.encode()).hexdigest() + disposed_constraint_path = ( + self.data_dir / "hooks" / "tokens" / f"{hashed_token}.json" + ) + disposed_constraint_path.unlink(missing_ok=True) + hook_path.unlink(missing_ok=True) + composite_key = f"{effective_run_id}-{event_id}" event_path = self.data_dir / "events" / f"{composite_key}.json" if event.server_props: diff --git a/src/vercel/workflow/worlds/vercel.py b/src/vercel/workflow/worlds/vercel.py index 282c7fa..851d899 100644 --- a/src/vercel/workflow/worlds/vercel.py +++ b/src/vercel/workflow/worlds/vercel.py @@ -105,7 +105,7 @@ async def _cbor_request( headers["Accept"] = "application/cbor" # NOTE: Add a unique header to bypass RSC request memoization. # See: https://github.com/vercel/workflow/issues/618 - headers["X-Request-Time"] = datetime.datetime.now(datetime.UTC).isoformat() + "Z" + headers["X-Request-Time"] = datetime.datetime.now(datetime.timezone.utc).isoformat() + "Z" # Encode body as CBOR if data is provided body: bytes | None = None @@ -270,6 +270,13 @@ async def steps_get(self, run_id: str, step_id: str) -> w.WorkflowStep: schema=w.WorkflowStepAdaptor, ) + async def hooks_get_by_token(self, token: str) -> w.Hook: + return await self._cbor_request( + "GET", + f"/v2/hooks/by-token?token={token}", + schema=w.Hook, + ) + async def events_create(self, run_id: str | None, data: w.Event) -> w.EventResult: run_id_path = "null" if run_id is None else run_id remote_ref_behavior = ( diff --git a/tests/test_nanoid.py b/tests/test_nanoid.py new file mode 100644 index 0000000..5293c8c --- /dev/null +++ b/tests/test_nanoid.py @@ -0,0 +1,174 @@ +"""Test nanoid implementation.""" + +from vercel.workflow import nanoid + + +def test_generate_default(): + """Test basic ID generation with defaults.""" + id1 = nanoid.generate() + assert len(id1) == 21 + assert all(c in nanoid.URL_ALPHABET for c in id1) + + # Generate multiple IDs and check they're unique + ids = {nanoid.generate() for _ in range(100)} + assert len(ids) == 100, "Generated IDs should be unique" + + +def test_generate_custom_size(): + """Test ID generation with custom size.""" + id1 = nanoid.generate(size=10) + assert len(id1) == 10 + + id2 = nanoid.generate(size=5) + assert len(id2) == 5 + + id3 = nanoid.generate(size=50) + assert len(id3) == 50 + + +def test_generate_custom_alphabet(): + """Test ID generation with custom alphabet.""" + # Numbers only + id1 = nanoid.generate(alphabet="0123456789", size=10) + assert len(id1) == 10 + assert all(c in "0123456789" for c in id1) + + # Uppercase only + id2 = nanoid.generate(alphabet="ABCDEFGHIJKLMNOPQRSTUVWXYZ", size=15) + assert len(id2) == 15 + assert all(c in "ABCDEFGHIJKLMNOPQRSTUVWXYZ" for c in id2) + + # Binary + id3 = nanoid.generate(alphabet="01", size=20) + assert len(id3) == 20 + assert all(c in "01" for c in id3) + + +def test_custom_alphabet_factory(): + """Test custom alphabet factory function.""" + hex_id = nanoid.custom_alphabet("0123456789abcdef", 16) + + id1 = hex_id() + assert len(id1) == 16 + assert all(c in "0123456789abcdef" for c in id1) + + id2 = hex_id(8) + assert len(id2) == 8 + assert all(c in "0123456789abcdef" for c in id2) + + +def test_custom_random(): + """Test custom random function with PRNG (similar to ulid style).""" + import random + + # Create a deterministic PRNG for testing + prng = random.Random(42).random + + generator = nanoid.custom_random(nanoid.URL_ALPHABET, 21, prng) + id1 = generator() + + # Reset PRNG with same seed + prng = random.Random(42).random + generator2 = nanoid.custom_random(nanoid.URL_ALPHABET, 21, prng) + id2 = generator2() + + # Both should be the same because we're using deterministic random + assert id1 == id2 + assert len(id1) == 21 + + +def test_custom_random_with_custom_size(): + """Test custom random with different sizes.""" + import random + + prng = random.Random(42).random + generator = nanoid.custom_random("0123456789", 10, prng) + + id1 = generator() + assert len(id1) == 10 + assert all(c in "0123456789" for c in id1) + + id2 = generator(5) + assert len(id2) == 5 + assert all(c in "0123456789" for c in id2) + + +def test_collision_resistance(): + """Test that IDs have low collision probability.""" + # Generate a large number of IDs and check for uniqueness + ids = {nanoid.generate() for _ in range(10000)} + assert len(ids) == 10000, "Should have no collisions in 10,000 IDs" + + +def test_alphabet_validation(): + """Test alphabet validation.""" + import pytest + + # Empty alphabet should raise error + with pytest.raises(ValueError, match="Alphabet must contain between 1 and 256 symbols"): + generator = nanoid.custom_alphabet("") + generator() + + # Alphabet too large should raise error + with pytest.raises(ValueError, match="Alphabet must contain between 1 and 256 symbols"): + generator = nanoid.custom_alphabet("a" * 257) + generator() + + +def test_size_validation(): + """Test size validation.""" + import pytest + + generator = nanoid.custom_alphabet("0123456789") + + # Zero size should raise error + with pytest.raises(ValueError, match="Size must be positive"): + generator(0) + + # Negative size should raise error + with pytest.raises(ValueError, match="Size must be positive"): + generator(-1) + + +def test_single_character_alphabet(): + """Test with single character alphabet.""" + generator = nanoid.custom_alphabet("A", 10) + id1 = generator() + assert id1 == "AAAAAAAAAA" + + +def test_distribution(): + """Test that character distribution is reasonably uniform.""" + # Generate many short IDs and check distribution + alphabet = "0123456789" + generator = nanoid.custom_alphabet(alphabet, 10) + counts = dict.fromkeys(alphabet, 0) + + for _ in range(1000): + id_str = generator() + for c in id_str: + counts[c] += 1 + + # Each character should appear roughly 1000 times (10% of 10000) + # Allow for some variance (between 800 and 1200) + for c, count in counts.items(): + assert 800 <= count <= 1200, f"Character {c} appeared {count} times (expected ~1000)" + + +def test_deterministic_with_seed(): + """Test that same seed produces same sequence of IDs.""" + import random + + # Create two generators with same seed + prng1 = random.Random(12345).random + gen1 = nanoid.custom_random(nanoid.URL_ALPHABET, 21, prng1) + + prng2 = random.Random(12345).random + gen2 = nanoid.custom_random(nanoid.URL_ALPHABET, 21, prng2) + + # Generate multiple IDs from each + ids1 = [gen1() for _ in range(10)] + ids2 = [gen2() for _ in range(10)] + + # All IDs should match + assert ids1 == ids2