Restoring to checkpoint doesn't resume from node of checkpoint

hi @sasharosca

I use checkpoint instead of checkpoint_id only because checkpoint_id is deprecated.

def update_state(
        self,
        thread_id: str,
        values: dict[str, Any] | Sequence[dict] | None,
        *,
        as_node: str | None = None,
        checkpoint: Checkpoint | None = None,
        checkpoint_id: str | None = None,  # deprecated
        headers: Mapping[str, str] | None = None,
        params: QueryParamTypes | None = None,
    ) -> ThreadUpdateStateResponse:

And I found __copy__ in the source code :slight_smile: I am not sure whether it works since I couldn’t test it yet.


    def bulk_update_state(
        self,
        config: RunnableConfig,
        supersteps: Sequence[Sequence[StateUpdate]],
    ) -> RunnableConfig:
        """Apply updates to the graph state in bulk. Requires a checkpointer to be set.

        Args:
            config: The config to apply the updates to.
            supersteps: A list of supersteps, each including a list of updates to apply sequentially to a graph state.
                        Each update is a tuple of the form `(values, as_node, task_id)` where `task_id` is optional.

        Raises:
            ValueError: If no checkpointer is set or no updates are provided.
            InvalidUpdateError: If an invalid update is provided.

        Returns:
            RunnableConfig: The updated config.
        """

        checkpointer: BaseCheckpointSaver | None = ensure_config(config)[CONF].get(
            CONFIG_KEY_CHECKPOINTER, self.checkpointer
        )
        if not checkpointer:
            raise ValueError("No checkpointer set")

        if len(supersteps) == 0:
            raise ValueError("No supersteps provided")

        if any(len(u) == 0 for u in supersteps):
            raise ValueError("No updates provided")

        # delegate to subgraph
        if (
            checkpoint_ns := config[CONF].get(CONFIG_KEY_CHECKPOINT_NS, "")
        ) and CONFIG_KEY_CHECKPOINTER not in config[CONF]:
            # remove task_ids from checkpoint_ns
            recast = recast_checkpoint_ns(checkpoint_ns)
            # find the subgraph with the matching name
            for _, pregel in self.get_subgraphs(namespace=recast, recurse=True):
                return pregel.bulk_update_state(
                    patch_configurable(config, {CONFIG_KEY_CHECKPOINTER: checkpointer}),
                    supersteps,
                )
            else:
                raise ValueError(f"Subgraph {recast} not found")

        def perform_superstep(
            input_config: RunnableConfig, updates: Sequence[StateUpdate]
        ) -> RunnableConfig:
            
            # ...

            # copy checkpoint
            if as_node == "__copy__":
                if len(updates) > 1:
                    raise InvalidUpdateError(
                        "Cannot copy checkpoint with multiple updates"
                    )

                if saved is None:
                    raise InvalidUpdateError("Cannot copy a non-existent checkpoint")

                next_checkpoint = create_checkpoint(checkpoint, None, step)

                # copy checkpoint
                next_config = checkpointer.put(
                    saved.parent_config
                    or patch_configurable(
                        saved.config, {CONFIG_KEY_CHECKPOINT_ID: None}
                    ),
                    next_checkpoint,
                    {
                        "source": "fork",
                        "step": step + 1,
                        "parents": saved.metadata.get("parents", {}),
                    },
                    {},
                )

                # we want to both clone a checkpoint and update state in one go.
                # reuse the same task ID if possible.
                if isinstance(values, list) and len(values) > 0:
                    # figure out the task IDs for the next update checkpoint
                    next_tasks = prepare_next_tasks(
                        next_checkpoint,
                        saved.pending_writes or [],
                        self.nodes,
                        channels,
                        managed,
                        next_config,
                        step + 2,
                        step + 4,
                        for_execution=True,
                        store=self.store,
                        checkpointer=checkpointer,
                        manager=None,
                    )

                    tasks_group_by = defaultdict(list)
                    user_group_by: dict[str, list[StateUpdate]] = defaultdict(list)

                    for task in next_tasks.values():
                        tasks_group_by[task.name].append(task.id)

                    for item in values:
                        if not isinstance(item, Sequence):
                            raise InvalidUpdateError(
                                f"Invalid update item: {item} when copying checkpoint"
                            )

                        values, as_node = item[:2]

                        user_group = user_group_by[as_node]
                        tasks_group = tasks_group_by[as_node]

                        target_idx = len(user_group)
                        task_id = (
                            tasks_group[target_idx]
                            if target_idx < len(tasks_group)
                            else None
                        )

                        user_group_by[as_node].append(
                            StateUpdate(values=values, as_node=as_node, task_id=task_id)
                        )

                    return perform_superstep(
                        patch_checkpoint_map(next_config, saved.metadata),
                        [item for lst in user_group_by.values() for item in lst],
                    )

                return patch_checkpoint_map(next_config, saved.metadata)

            # task ids can be provided in the StateUpdate, but if not,
            # we use the task id generated by prepare_next_tasks
            node_to_task_ids: dict[str, deque[str]] = defaultdict(deque)
            if saved is not None and saved.pending_writes is not None:
                # we call prepare_next_tasks to discover the task IDs that
                # would have been generated, so we can reuse them and
                # properly populate task.result in state history
                next_tasks = prepare_next_tasks(
                    checkpoint,
                    saved.pending_writes,
                    self.nodes,
                    channels,
                    managed,
                    saved.config,
                    step + 1,
                    step + 3,
                    for_execution=True,
                    store=self.store,
                    checkpointer=checkpointer,
                    manager=None,
                )
                # collect task ids to reuse so we can properly attach task results
                for t in next_tasks.values():
                    node_to_task_ids[t.name].append(t.id)

            valid_updates: list[tuple[str, dict[str, Any] | None, str | None]] = []
            if len(updates) == 1:
                values, as_node, task_id = updates[0]
                # find last node that updated the state, if not provided
                if as_node is None and len(self.nodes) == 1:
                    as_node = tuple(self.nodes)[0]
                elif as_node is None and not any(
                    v
                    for vv in checkpoint["versions_seen"].values()
                    for v in vv.values()
                ):
                    if (
                        isinstance(self.input_channels, str)
                        and self.input_channels in self.nodes
                    ):
                        as_node = self.input_channels
                elif as_node is None:
                    last_seen_by_node = sorted(
                        (v, n)
                        for n, seen in checkpoint["versions_seen"].items()
                        if n in self.nodes
                        for v in seen.values()
                    )
                    # if two nodes updated the state at the same time, it's ambiguous
                    if last_seen_by_node:
                        if len(last_seen_by_node) == 1:
                            as_node = last_seen_by_node[0][1]
                        elif last_seen_by_node[-1][0] != last_seen_by_node[-2][0]:
                            as_node = last_seen_by_node[-1][1]
                if as_node is None:
                    raise InvalidUpdateError("Ambiguous update, specify as_node")
                if as_node not in self.nodes:
                    raise InvalidUpdateError(f"Node {as_node} does not exist")
                valid_updates.append((as_node, values, task_id))
            else:
                for values, as_node, task_id in updates:
                    if as_node is None:
                        raise InvalidUpdateError(
                            "as_node is required when applying multiple updates"
                        )
                    if as_node not in self.nodes:
                        raise InvalidUpdateError(f"Node {as_node} does not exist")

                    valid_updates.append((as_node, values, task_id))

            run_tasks: list[PregelTaskWrites] = []
            run_task_ids: list[str] = []

            for as_node, values, provided_task_id in valid_updates:
                # create task to run all writers of the chosen node
                writers = self.nodes[as_node].flat_writers
                if not writers:
                    raise InvalidUpdateError(f"Node {as_node} has no writers")
                writes: deque[tuple[str, Any]] = deque()
                task = PregelTaskWrites((), as_node, writes, [INTERRUPT])
                # get the task ids that were prepared for this node
                # if a task id was provided in the StateUpdate, we use it
                # otherwise, we use the next available task id
                prepared_task_ids = node_to_task_ids.get(as_node, deque())
                task_id = provided_task_id or (
                    prepared_task_ids.popleft()
                    if prepared_task_ids
                    else str(uuid5(UUID(checkpoint["id"]), INTERRUPT))
                )
                run_tasks.append(task)
                run_task_ids.append(task_id)
                run = RunnableSequence(*writers) if len(writers) > 1 else writers[0]
                # execute task
                run.invoke(
                    values,
                    patch_config(
                        config,
                        run_name=self.name + "UpdateState",
                        configurable={
                            # deque.extend is thread-safe
                            CONFIG_KEY_SEND: writes.extend,
                            CONFIG_KEY_TASK_ID: task_id,
                            CONFIG_KEY_READ: partial(
                                local_read,
                                _scratchpad(
                                    None,
                                    [],
                                    task_id,
                                    "",
                                    None,
                                    step,
                                    step + 2,
                                ),
                                channels,
                                managed,
                                task,
                            ),
                        },
                    ),
                )
            # save task writes
            for task_id, task in zip(run_task_ids, run_tasks):
                # channel writes are saved to current checkpoint
                channel_writes = [w for w in task.writes if w[0] != PUSH]
                if saved and channel_writes:
                    checkpointer.put_writes(checkpoint_config, channel_writes, task_id)
            # apply to checkpoint and save
            apply_writes(
                checkpoint,
                channels,
                run_tasks,
                checkpointer.get_next_version,
                self.trigger_to_nodes,
            )
            checkpoint = create_checkpoint(checkpoint, channels, step + 1)
            next_config = checkpointer.put(
                checkpoint_config,
                checkpoint,
                {
                    "source": "update",
                    "step": step + 1,
                    "parents": saved.metadata.get("parents", {}) if saved else {},
                },
                get_new_channel_versions(
                    checkpoint_previous_versions, checkpoint["channel_versions"]
                ),
            )
            for task_id, task in zip(run_task_ids, run_tasks):
                # save push writes
                if push_writes := [w for w in task.writes if w[0] == PUSH]:
                    checkpointer.put_writes(next_config, push_writes, task_id)

            return patch_checkpoint_map(next_config, saved.metadata if saved else None)

        current_config = patch_configurable(
            config, {CONFIG_KEY_THREAD_ID: str(config[CONF][CONFIG_KEY_THREAD_ID])}
        )
        for superstep in supersteps:
            current_config = perform_superstep(current_config, superstep)
        return current_config

    async def abulk_update_state(
        self,
        config: RunnableConfig,
        supersteps: Sequence[Sequence[StateUpdate]],
    ) -> RunnableConfig:
        """Asynchronously apply updates to the graph state in bulk. Requires a checkpointer to be set.

        Args:
            config: The config to apply the updates to.
            supersteps: A list of supersteps, each including a list of updates to apply sequentially to a graph state.
                        Each update is a tuple of the form `(values, as_node, task_id)` where `task_id` is optional.

        Raises:
            ValueError: If no checkpointer is set or no updates are provided.
            InvalidUpdateError: If an invalid update is provided.

        Returns:
            RunnableConfig: The updated config.
        """

        checkpointer: BaseCheckpointSaver | None = ensure_config(config)[CONF].get(
            CONFIG_KEY_CHECKPOINTER, self.checkpointer
        )
        if not checkpointer:
            raise ValueError("No checkpointer set")

        if len(supersteps) == 0:
            raise ValueError("No supersteps provided")

        if any(len(u) == 0 for u in supersteps):
            raise ValueError("No updates provided")

        # delegate to subgraph
        if (
            checkpoint_ns := config[CONF].get(CONFIG_KEY_CHECKPOINT_NS, "")
        ) and CONFIG_KEY_CHECKPOINTER not in config[CONF]:
            # remove task_ids from checkpoint_ns
            recast = recast_checkpoint_ns(checkpoint_ns)
            # find the subgraph with the matching name
            async for _, pregel in self.aget_subgraphs(namespace=recast, recurse=True):
                return await pregel.abulk_update_state(
                    patch_configurable(config, {CONFIG_KEY_CHECKPOINTER: checkpointer}),
                    supersteps,
                )
            else:
                raise ValueError(f"Subgraph {recast} not found")

        async def aperform_superstep(
            input_config: RunnableConfig, updates: Sequence[StateUpdate]
        ) -> RunnableConfig:

            # ...

            # no values, copy checkpoint
            if as_node == "__copy__":
                if len(updates) > 1:
                    raise InvalidUpdateError(
                        "Cannot copy checkpoint with multiple updates"
                    )

                if saved is None:
                    raise InvalidUpdateError("Cannot copy a non-existent checkpoint")

                next_checkpoint = create_checkpoint(checkpoint, None, step)

                # copy checkpoint
                next_config = await checkpointer.aput(
                    saved.parent_config
                    or patch_configurable(
                        saved.config, {CONFIG_KEY_CHECKPOINT_ID: None}
                    ),
                    next_checkpoint,
                    {
                        "source": "fork",
                        "step": step + 1,
                        "parents": saved.metadata.get("parents", {}),
                    },
                    {},
                )

                # we want to both clone a checkpoint and update state in one go.
                # reuse the same task ID if possible.
                if isinstance(values, list) and len(values) > 0:
                    # figure out the task IDs for the next update checkpoint
                    next_tasks = prepare_next_tasks(
                        next_checkpoint,
                        saved.pending_writes or [],
                        self.nodes,
                        channels,
                        managed,
                        next_config,
                        step + 2,
                        step + 4,
                        for_execution=True,
                        store=self.store,
                        checkpointer=checkpointer,
                        manager=None,
                    )

                    tasks_group_by = defaultdict(list)
                    user_group_by: dict[str, list[StateUpdate]] = defaultdict(list)

                    for task in next_tasks.values():
                        tasks_group_by[task.name].append(task.id)

                    for item in values:
                        if not isinstance(item, Sequence):
                            raise InvalidUpdateError(
                                f"Invalid update item: {item} when copying checkpoint"
                            )

                        values, as_node = item[:2]
                        user_group = user_group_by[as_node]
                        tasks_group = tasks_group_by[as_node]

                        target_idx = len(user_group)
                        task_id = (
                            tasks_group[target_idx]
                            if target_idx < len(tasks_group)
                            else None
                        )

                        user_group_by[as_node].append(
                            StateUpdate(values=values, as_node=as_node, task_id=task_id)
                        )

                    return await aperform_superstep(
                        patch_checkpoint_map(next_config, saved.metadata),
                        [item for lst in user_group_by.values() for item in lst],
                    )

                return patch_checkpoint_map(
                    next_config, saved.metadata if saved else None
                )

            # task ids can be provided in the StateUpdate, but if not,
            # we use the task id generated by prepare_next_tasks
            node_to_task_ids: dict[str, deque[str]] = defaultdict(deque)
            if saved is not None and saved.pending_writes is not None:
                # we call prepare_next_tasks to discover the task IDs that
                # would have been generated, so we can reuse them and
                # properly populate task.result in state history
                next_tasks = prepare_next_tasks(
                    checkpoint,
                    saved.pending_writes,
                    self.nodes,
                    channels,
                    managed,
                    saved.config,
                    step + 1,
                    step + 3,
                    for_execution=True,
                    store=self.store,
                    checkpointer=checkpointer,
                    manager=None,
                )
                # collect task ids to reuse so we can properly attach task results
                for t in next_tasks.values():
                    node_to_task_ids[t.name].append(t.id)

            valid_updates: list[tuple[str, dict[str, Any] | None, str | None]] = []
            if len(updates) == 1:
                values, as_node, task_id = updates[0]
                # find last node that updated the state, if not provided
                if as_node is None and len(self.nodes) == 1:
                    as_node = tuple(self.nodes)[0]
                elif as_node is None and not saved:
                    if (
                        isinstance(self.input_channels, str)
                        and self.input_channels in self.nodes
                    ):
                        as_node = self.input_channels
                elif as_node is None:
                    last_seen_by_node = sorted(
                        (v, n)
                        for n, seen in checkpoint["versions_seen"].items()
                        if n in self.nodes
                        for v in seen.values()
                    )
                    # if two nodes updated the state at the same time, it's ambiguous
                    if last_seen_by_node:
                        if len(last_seen_by_node) == 1:
                            as_node = last_seen_by_node[0][1]
                        elif last_seen_by_node[-1][0] != last_seen_by_node[-2][0]:
                            as_node = last_seen_by_node[-1][1]
                if as_node is None:
                    raise InvalidUpdateError("Ambiguous update, specify as_node")
                if as_node not in self.nodes:
                    raise InvalidUpdateError(f"Node {as_node} does not exist")
                valid_updates.append((as_node, values, task_id))
            else:
                for values, as_node, task_id in updates:
                    if as_node is None:
                        raise InvalidUpdateError(
                            "as_node is required when applying multiple updates"
                        )
                    if as_node not in self.nodes:
                        raise InvalidUpdateError(f"Node {as_node} does not exist")

                    valid_updates.append((as_node, values, task_id))

            run_tasks: list[PregelTaskWrites] = []
            run_task_ids: list[str] = []

            for as_node, values, provided_task_id in valid_updates:
                # create task to run all writers of the chosen node
                writers = self.nodes[as_node].flat_writers
                if not writers:
                    raise InvalidUpdateError(f"Node {as_node} has no writers")
                writes: deque[tuple[str, Any]] = deque()
                task = PregelTaskWrites((), as_node, writes, [INTERRUPT])
                # get the task ids that were prepared for this node
                # if a task id was provided in the StateUpdate, we use it
                # otherwise, we use the next available task id
                prepared_task_ids = node_to_task_ids.get(as_node, deque())
                task_id = provided_task_id or (
                    prepared_task_ids.popleft()
                    if prepared_task_ids
                    else str(uuid5(UUID(checkpoint["id"]), INTERRUPT))
                )
                run_tasks.append(task)
                run_task_ids.append(task_id)
                run = RunnableSequence(*writers) if len(writers) > 1 else writers[0]
                # execute task
                await run.ainvoke(
                    values,
                    patch_config(
                        config,
                        run_name=self.name + "UpdateState",
                        configurable={
                            # deque.extend is thread-safe
                            CONFIG_KEY_SEND: writes.extend,
                            CONFIG_KEY_TASK_ID: task_id,
                            CONFIG_KEY_READ: partial(
                                local_read,
                                _scratchpad(
                                    None,
                                    [],
                                    task_id,
                                    "",
                                    None,
                                    step,
                                    step + 2,
                                ),
                                channels,
                                managed,
                                task,
                            ),
                        },
                    ),
                )
            # save task writes
            for task_id, task in zip(run_task_ids, run_tasks):
                # channel writes are saved to current checkpoint
                channel_writes = [w for w in task.writes if w[0] != PUSH]
                if saved and channel_writes:
                    await checkpointer.aput_writes(
                        checkpoint_config, channel_writes, task_id
                    )
            # apply to checkpoint and save
            apply_writes(
                checkpoint,
                channels,
                run_tasks,
                checkpointer.get_next_version,
                self.trigger_to_nodes,
            )
            checkpoint = create_checkpoint(checkpoint, channels, step + 1)
            # save checkpoint, after applying writes
            next_config = await checkpointer.aput(
                checkpoint_config,
                checkpoint,
                {
                    "source": "update",
                    "step": step + 1,
                    "parents": saved.metadata.get("parents", {}) if saved else {},
                },
                get_new_channel_versions(
                    checkpoint_previous_versions, checkpoint["channel_versions"]
                ),
            )
            for task_id, task in zip(run_task_ids, run_tasks):
                # save push writes
                if push_writes := [w for w in task.writes if w[0] == PUSH]:
                    await checkpointer.aput_writes(next_config, push_writes, task_id)
            return patch_checkpoint_map(next_config, saved.metadata if saved else None)

        current_config = patch_configurable(
            config, {CONFIG_KEY_THREAD_ID: str(config[CONF][CONFIG_KEY_THREAD_ID])}
        )
        for superstep in supersteps:
            current_config = await aperform_superstep(current_config, superstep)
        return current_config