Hello!
I’m trying to persist HITL payloads in the AIMessage additional_kwargs. To do so, I send() a channel update with the updated message to include it in a checkpoint save.
It works when I set durability="exit" but I then lose the progressive persistence. With either “async” (default) or “sync”, I get unreliable results, but most of the time, the last message in the checkpoint won’t contain the extra payload when I resume the conversation. From what I can read, “async” will only save when the next node runs. If a user reloads the conversation before approving/rejecting the HITL prompt, the state won’t be saved, right?
Is there a way to combine the benefits of “exit” and “async”? Or is my approach to retain pending interrupts fundamentally wrong?
Note: I’ve had a look at these related threads, but couldn’t find exactly what I’m looking for:
- Steaming interrupt, if unresolved on reload bad state
- How to update graph state while preserving interrupts?
I’m attaching a minimal, self-contained test case. Replace durability="exit" with something else to make it fail.
from __future__ import annotations
from typing import Annotated, TypedDict
import pytest
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.runnables import RunnableConfig
from langgraph._internal._constants import CONFIG_KEY_SEND
from langgraph.checkpoint.memory import InMemorySaver
from langgraph.config import get_config
from langgraph.constants import START
from langgraph.graph import StateGraph
from langgraph.graph.message import AnyMessage, add_messages
from langgraph.types import Command, interrupt
class State(TypedDict):
messages: Annotated[list[AnyMessage], add_messages]
@pytest.mark.anyio
async def test_human_hitl_interrupt_checkpoint_behavior():
"""Minimal repro: persist AIMessage additional_kwargs before interrupt and after resume."""
checkpointer = InMemorySaver()
node_name = "pause_node"
persisted_payload = "payload"
def node(state: State) -> None:
# Persist metadata onto the latest AI message before pausing.
latest_ai_message = next(
(msg for msg in reversed(state["messages"]) if isinstance(msg, AIMessage)),
None,
)
assert latest_ai_message is not None
latest_ai_message.additional_kwargs["hitl_payload"] = persisted_payload
configurable = get_config().get("configurable", {})
send = configurable.get(CONFIG_KEY_SEND)
assert callable(send)
send([("messages", [latest_ai_message])])
interrupt("pause")
return None
graph = (
StateGraph(State)
.add_node(node_name, node)
.add_edge(START, node_name)
.compile(checkpointer=checkpointer)
)
config = RunnableConfig(
configurable={"thread_id": "test-thread-interrupt-repro", "checkpoint_ns": ""}
)
result = graph.invoke(
{
"messages": [
HumanMessage(content="Do something"),
AIMessage(content="Request approval", id="assistant-1"),
]
},
config,
durability="exit",
)
assert "__interrupt__" in result
latest_before_resume = await anext(checkpointer.alist(config=config, limit=1), None)
assert latest_before_resume is not None
channel_values_before_resume = latest_before_resume.checkpoint.get(
"channel_values", {}
)
messages_before_resume = channel_values_before_resume.get("messages", [])
persisted_ai_before_resume = next(
(msg for msg in reversed(messages_before_resume) if isinstance(msg, AIMessage)),
None,
)
assert persisted_ai_before_resume is not None
assert persisted_ai_before_resume.id == "assistant-1"
assert (
persisted_ai_before_resume.additional_kwargs.get("hitl_payload")
== persisted_payload
)
graph.invoke(Command(resume="approved"), config, durability="exit")
latest_after_resume = await anext(checkpointer.alist(config=config, limit=1), None)
assert latest_after_resume is not None
channel_values_after_resume = latest_after_resume.checkpoint.get(
"channel_values", {}
)
messages_after_resume = channel_values_after_resume.get("messages", [])
persisted_ai_after_resume = next(
(msg for msg in reversed(messages_after_resume) if isinstance(msg, AIMessage)),
None,
)
assert persisted_ai_after_resume is not None
assert persisted_ai_after_resume.id == "assistant-1"
assert (
persisted_ai_after_resume.additional_kwargs.get("hitl_payload")
== persisted_payload
)