hi @sasharosca
I use checkpoint instead of checkpoint_id only because checkpoint_id is deprecated.
def update_state(
self,
thread_id: str,
values: dict[str, Any] | Sequence[dict] | None,
*,
as_node: str | None = None,
checkpoint: Checkpoint | None = None,
checkpoint_id: str | None = None, # deprecated
headers: Mapping[str, str] | None = None,
params: QueryParamTypes | None = None,
) -> ThreadUpdateStateResponse:
And I found __copy__ in the source code
I am not sure whether it works since I couldn’t test it yet.
def bulk_update_state(
self,
config: RunnableConfig,
supersteps: Sequence[Sequence[StateUpdate]],
) -> RunnableConfig:
"""Apply updates to the graph state in bulk. Requires a checkpointer to be set.
Args:
config: The config to apply the updates to.
supersteps: A list of supersteps, each including a list of updates to apply sequentially to a graph state.
Each update is a tuple of the form `(values, as_node, task_id)` where `task_id` is optional.
Raises:
ValueError: If no checkpointer is set or no updates are provided.
InvalidUpdateError: If an invalid update is provided.
Returns:
RunnableConfig: The updated config.
"""
checkpointer: BaseCheckpointSaver | None = ensure_config(config)[CONF].get(
CONFIG_KEY_CHECKPOINTER, self.checkpointer
)
if not checkpointer:
raise ValueError("No checkpointer set")
if len(supersteps) == 0:
raise ValueError("No supersteps provided")
if any(len(u) == 0 for u in supersteps):
raise ValueError("No updates provided")
# delegate to subgraph
if (
checkpoint_ns := config[CONF].get(CONFIG_KEY_CHECKPOINT_NS, "")
) and CONFIG_KEY_CHECKPOINTER not in config[CONF]:
# remove task_ids from checkpoint_ns
recast = recast_checkpoint_ns(checkpoint_ns)
# find the subgraph with the matching name
for _, pregel in self.get_subgraphs(namespace=recast, recurse=True):
return pregel.bulk_update_state(
patch_configurable(config, {CONFIG_KEY_CHECKPOINTER: checkpointer}),
supersteps,
)
else:
raise ValueError(f"Subgraph {recast} not found")
def perform_superstep(
input_config: RunnableConfig, updates: Sequence[StateUpdate]
) -> RunnableConfig:
# ...
# copy checkpoint
if as_node == "__copy__":
if len(updates) > 1:
raise InvalidUpdateError(
"Cannot copy checkpoint with multiple updates"
)
if saved is None:
raise InvalidUpdateError("Cannot copy a non-existent checkpoint")
next_checkpoint = create_checkpoint(checkpoint, None, step)
# copy checkpoint
next_config = checkpointer.put(
saved.parent_config
or patch_configurable(
saved.config, {CONFIG_KEY_CHECKPOINT_ID: None}
),
next_checkpoint,
{
"source": "fork",
"step": step + 1,
"parents": saved.metadata.get("parents", {}),
},
{},
)
# we want to both clone a checkpoint and update state in one go.
# reuse the same task ID if possible.
if isinstance(values, list) and len(values) > 0:
# figure out the task IDs for the next update checkpoint
next_tasks = prepare_next_tasks(
next_checkpoint,
saved.pending_writes or [],
self.nodes,
channels,
managed,
next_config,
step + 2,
step + 4,
for_execution=True,
store=self.store,
checkpointer=checkpointer,
manager=None,
)
tasks_group_by = defaultdict(list)
user_group_by: dict[str, list[StateUpdate]] = defaultdict(list)
for task in next_tasks.values():
tasks_group_by[task.name].append(task.id)
for item in values:
if not isinstance(item, Sequence):
raise InvalidUpdateError(
f"Invalid update item: {item} when copying checkpoint"
)
values, as_node = item[:2]
user_group = user_group_by[as_node]
tasks_group = tasks_group_by[as_node]
target_idx = len(user_group)
task_id = (
tasks_group[target_idx]
if target_idx < len(tasks_group)
else None
)
user_group_by[as_node].append(
StateUpdate(values=values, as_node=as_node, task_id=task_id)
)
return perform_superstep(
patch_checkpoint_map(next_config, saved.metadata),
[item for lst in user_group_by.values() for item in lst],
)
return patch_checkpoint_map(next_config, saved.metadata)
# task ids can be provided in the StateUpdate, but if not,
# we use the task id generated by prepare_next_tasks
node_to_task_ids: dict[str, deque[str]] = defaultdict(deque)
if saved is not None and saved.pending_writes is not None:
# we call prepare_next_tasks to discover the task IDs that
# would have been generated, so we can reuse them and
# properly populate task.result in state history
next_tasks = prepare_next_tasks(
checkpoint,
saved.pending_writes,
self.nodes,
channels,
managed,
saved.config,
step + 1,
step + 3,
for_execution=True,
store=self.store,
checkpointer=checkpointer,
manager=None,
)
# collect task ids to reuse so we can properly attach task results
for t in next_tasks.values():
node_to_task_ids[t.name].append(t.id)
valid_updates: list[tuple[str, dict[str, Any] | None, str | None]] = []
if len(updates) == 1:
values, as_node, task_id = updates[0]
# find last node that updated the state, if not provided
if as_node is None and len(self.nodes) == 1:
as_node = tuple(self.nodes)[0]
elif as_node is None and not any(
v
for vv in checkpoint["versions_seen"].values()
for v in vv.values()
):
if (
isinstance(self.input_channels, str)
and self.input_channels in self.nodes
):
as_node = self.input_channels
elif as_node is None:
last_seen_by_node = sorted(
(v, n)
for n, seen in checkpoint["versions_seen"].items()
if n in self.nodes
for v in seen.values()
)
# if two nodes updated the state at the same time, it's ambiguous
if last_seen_by_node:
if len(last_seen_by_node) == 1:
as_node = last_seen_by_node[0][1]
elif last_seen_by_node[-1][0] != last_seen_by_node[-2][0]:
as_node = last_seen_by_node[-1][1]
if as_node is None:
raise InvalidUpdateError("Ambiguous update, specify as_node")
if as_node not in self.nodes:
raise InvalidUpdateError(f"Node {as_node} does not exist")
valid_updates.append((as_node, values, task_id))
else:
for values, as_node, task_id in updates:
if as_node is None:
raise InvalidUpdateError(
"as_node is required when applying multiple updates"
)
if as_node not in self.nodes:
raise InvalidUpdateError(f"Node {as_node} does not exist")
valid_updates.append((as_node, values, task_id))
run_tasks: list[PregelTaskWrites] = []
run_task_ids: list[str] = []
for as_node, values, provided_task_id in valid_updates:
# create task to run all writers of the chosen node
writers = self.nodes[as_node].flat_writers
if not writers:
raise InvalidUpdateError(f"Node {as_node} has no writers")
writes: deque[tuple[str, Any]] = deque()
task = PregelTaskWrites((), as_node, writes, [INTERRUPT])
# get the task ids that were prepared for this node
# if a task id was provided in the StateUpdate, we use it
# otherwise, we use the next available task id
prepared_task_ids = node_to_task_ids.get(as_node, deque())
task_id = provided_task_id or (
prepared_task_ids.popleft()
if prepared_task_ids
else str(uuid5(UUID(checkpoint["id"]), INTERRUPT))
)
run_tasks.append(task)
run_task_ids.append(task_id)
run = RunnableSequence(*writers) if len(writers) > 1 else writers[0]
# execute task
run.invoke(
values,
patch_config(
config,
run_name=self.name + "UpdateState",
configurable={
# deque.extend is thread-safe
CONFIG_KEY_SEND: writes.extend,
CONFIG_KEY_TASK_ID: task_id,
CONFIG_KEY_READ: partial(
local_read,
_scratchpad(
None,
[],
task_id,
"",
None,
step,
step + 2,
),
channels,
managed,
task,
),
},
),
)
# save task writes
for task_id, task in zip(run_task_ids, run_tasks):
# channel writes are saved to current checkpoint
channel_writes = [w for w in task.writes if w[0] != PUSH]
if saved and channel_writes:
checkpointer.put_writes(checkpoint_config, channel_writes, task_id)
# apply to checkpoint and save
apply_writes(
checkpoint,
channels,
run_tasks,
checkpointer.get_next_version,
self.trigger_to_nodes,
)
checkpoint = create_checkpoint(checkpoint, channels, step + 1)
next_config = checkpointer.put(
checkpoint_config,
checkpoint,
{
"source": "update",
"step": step + 1,
"parents": saved.metadata.get("parents", {}) if saved else {},
},
get_new_channel_versions(
checkpoint_previous_versions, checkpoint["channel_versions"]
),
)
for task_id, task in zip(run_task_ids, run_tasks):
# save push writes
if push_writes := [w for w in task.writes if w[0] == PUSH]:
checkpointer.put_writes(next_config, push_writes, task_id)
return patch_checkpoint_map(next_config, saved.metadata if saved else None)
current_config = patch_configurable(
config, {CONFIG_KEY_THREAD_ID: str(config[CONF][CONFIG_KEY_THREAD_ID])}
)
for superstep in supersteps:
current_config = perform_superstep(current_config, superstep)
return current_config
async def abulk_update_state(
self,
config: RunnableConfig,
supersteps: Sequence[Sequence[StateUpdate]],
) -> RunnableConfig:
"""Asynchronously apply updates to the graph state in bulk. Requires a checkpointer to be set.
Args:
config: The config to apply the updates to.
supersteps: A list of supersteps, each including a list of updates to apply sequentially to a graph state.
Each update is a tuple of the form `(values, as_node, task_id)` where `task_id` is optional.
Raises:
ValueError: If no checkpointer is set or no updates are provided.
InvalidUpdateError: If an invalid update is provided.
Returns:
RunnableConfig: The updated config.
"""
checkpointer: BaseCheckpointSaver | None = ensure_config(config)[CONF].get(
CONFIG_KEY_CHECKPOINTER, self.checkpointer
)
if not checkpointer:
raise ValueError("No checkpointer set")
if len(supersteps) == 0:
raise ValueError("No supersteps provided")
if any(len(u) == 0 for u in supersteps):
raise ValueError("No updates provided")
# delegate to subgraph
if (
checkpoint_ns := config[CONF].get(CONFIG_KEY_CHECKPOINT_NS, "")
) and CONFIG_KEY_CHECKPOINTER not in config[CONF]:
# remove task_ids from checkpoint_ns
recast = recast_checkpoint_ns(checkpoint_ns)
# find the subgraph with the matching name
async for _, pregel in self.aget_subgraphs(namespace=recast, recurse=True):
return await pregel.abulk_update_state(
patch_configurable(config, {CONFIG_KEY_CHECKPOINTER: checkpointer}),
supersteps,
)
else:
raise ValueError(f"Subgraph {recast} not found")
async def aperform_superstep(
input_config: RunnableConfig, updates: Sequence[StateUpdate]
) -> RunnableConfig:
# ...
# no values, copy checkpoint
if as_node == "__copy__":
if len(updates) > 1:
raise InvalidUpdateError(
"Cannot copy checkpoint with multiple updates"
)
if saved is None:
raise InvalidUpdateError("Cannot copy a non-existent checkpoint")
next_checkpoint = create_checkpoint(checkpoint, None, step)
# copy checkpoint
next_config = await checkpointer.aput(
saved.parent_config
or patch_configurable(
saved.config, {CONFIG_KEY_CHECKPOINT_ID: None}
),
next_checkpoint,
{
"source": "fork",
"step": step + 1,
"parents": saved.metadata.get("parents", {}),
},
{},
)
# we want to both clone a checkpoint and update state in one go.
# reuse the same task ID if possible.
if isinstance(values, list) and len(values) > 0:
# figure out the task IDs for the next update checkpoint
next_tasks = prepare_next_tasks(
next_checkpoint,
saved.pending_writes or [],
self.nodes,
channels,
managed,
next_config,
step + 2,
step + 4,
for_execution=True,
store=self.store,
checkpointer=checkpointer,
manager=None,
)
tasks_group_by = defaultdict(list)
user_group_by: dict[str, list[StateUpdate]] = defaultdict(list)
for task in next_tasks.values():
tasks_group_by[task.name].append(task.id)
for item in values:
if not isinstance(item, Sequence):
raise InvalidUpdateError(
f"Invalid update item: {item} when copying checkpoint"
)
values, as_node = item[:2]
user_group = user_group_by[as_node]
tasks_group = tasks_group_by[as_node]
target_idx = len(user_group)
task_id = (
tasks_group[target_idx]
if target_idx < len(tasks_group)
else None
)
user_group_by[as_node].append(
StateUpdate(values=values, as_node=as_node, task_id=task_id)
)
return await aperform_superstep(
patch_checkpoint_map(next_config, saved.metadata),
[item for lst in user_group_by.values() for item in lst],
)
return patch_checkpoint_map(
next_config, saved.metadata if saved else None
)
# task ids can be provided in the StateUpdate, but if not,
# we use the task id generated by prepare_next_tasks
node_to_task_ids: dict[str, deque[str]] = defaultdict(deque)
if saved is not None and saved.pending_writes is not None:
# we call prepare_next_tasks to discover the task IDs that
# would have been generated, so we can reuse them and
# properly populate task.result in state history
next_tasks = prepare_next_tasks(
checkpoint,
saved.pending_writes,
self.nodes,
channels,
managed,
saved.config,
step + 1,
step + 3,
for_execution=True,
store=self.store,
checkpointer=checkpointer,
manager=None,
)
# collect task ids to reuse so we can properly attach task results
for t in next_tasks.values():
node_to_task_ids[t.name].append(t.id)
valid_updates: list[tuple[str, dict[str, Any] | None, str | None]] = []
if len(updates) == 1:
values, as_node, task_id = updates[0]
# find last node that updated the state, if not provided
if as_node is None and len(self.nodes) == 1:
as_node = tuple(self.nodes)[0]
elif as_node is None and not saved:
if (
isinstance(self.input_channels, str)
and self.input_channels in self.nodes
):
as_node = self.input_channels
elif as_node is None:
last_seen_by_node = sorted(
(v, n)
for n, seen in checkpoint["versions_seen"].items()
if n in self.nodes
for v in seen.values()
)
# if two nodes updated the state at the same time, it's ambiguous
if last_seen_by_node:
if len(last_seen_by_node) == 1:
as_node = last_seen_by_node[0][1]
elif last_seen_by_node[-1][0] != last_seen_by_node[-2][0]:
as_node = last_seen_by_node[-1][1]
if as_node is None:
raise InvalidUpdateError("Ambiguous update, specify as_node")
if as_node not in self.nodes:
raise InvalidUpdateError(f"Node {as_node} does not exist")
valid_updates.append((as_node, values, task_id))
else:
for values, as_node, task_id in updates:
if as_node is None:
raise InvalidUpdateError(
"as_node is required when applying multiple updates"
)
if as_node not in self.nodes:
raise InvalidUpdateError(f"Node {as_node} does not exist")
valid_updates.append((as_node, values, task_id))
run_tasks: list[PregelTaskWrites] = []
run_task_ids: list[str] = []
for as_node, values, provided_task_id in valid_updates:
# create task to run all writers of the chosen node
writers = self.nodes[as_node].flat_writers
if not writers:
raise InvalidUpdateError(f"Node {as_node} has no writers")
writes: deque[tuple[str, Any]] = deque()
task = PregelTaskWrites((), as_node, writes, [INTERRUPT])
# get the task ids that were prepared for this node
# if a task id was provided in the StateUpdate, we use it
# otherwise, we use the next available task id
prepared_task_ids = node_to_task_ids.get(as_node, deque())
task_id = provided_task_id or (
prepared_task_ids.popleft()
if prepared_task_ids
else str(uuid5(UUID(checkpoint["id"]), INTERRUPT))
)
run_tasks.append(task)
run_task_ids.append(task_id)
run = RunnableSequence(*writers) if len(writers) > 1 else writers[0]
# execute task
await run.ainvoke(
values,
patch_config(
config,
run_name=self.name + "UpdateState",
configurable={
# deque.extend is thread-safe
CONFIG_KEY_SEND: writes.extend,
CONFIG_KEY_TASK_ID: task_id,
CONFIG_KEY_READ: partial(
local_read,
_scratchpad(
None,
[],
task_id,
"",
None,
step,
step + 2,
),
channels,
managed,
task,
),
},
),
)
# save task writes
for task_id, task in zip(run_task_ids, run_tasks):
# channel writes are saved to current checkpoint
channel_writes = [w for w in task.writes if w[0] != PUSH]
if saved and channel_writes:
await checkpointer.aput_writes(
checkpoint_config, channel_writes, task_id
)
# apply to checkpoint and save
apply_writes(
checkpoint,
channels,
run_tasks,
checkpointer.get_next_version,
self.trigger_to_nodes,
)
checkpoint = create_checkpoint(checkpoint, channels, step + 1)
# save checkpoint, after applying writes
next_config = await checkpointer.aput(
checkpoint_config,
checkpoint,
{
"source": "update",
"step": step + 1,
"parents": saved.metadata.get("parents", {}) if saved else {},
},
get_new_channel_versions(
checkpoint_previous_versions, checkpoint["channel_versions"]
),
)
for task_id, task in zip(run_task_ids, run_tasks):
# save push writes
if push_writes := [w for w in task.writes if w[0] == PUSH]:
await checkpointer.aput_writes(next_config, push_writes, task_id)
return patch_checkpoint_map(next_config, saved.metadata if saved else None)
current_config = patch_configurable(
config, {CONFIG_KEY_THREAD_ID: str(config[CONF][CONFIG_KEY_THREAD_ID])}
)
for superstep in supersteps:
current_config = await aperform_superstep(current_config, superstep)
return current_config