The Problem:
I’m performing an action using the tool and interrupt. Once the tool executes certain process, the graph should interrupt. Because of the execution of processes, that produce different result on every execution, I can’t place the interrupt inside the tool, so I’m using the flag to redirect to interrupt node in which I do preprocessing and postprocessing for the interrupt.
When resuming the graph, it should be starting from the interrupt_node, performs the postprocessing(updating the tool message) and provide that new tool message to the agent which should produce an AI message. But in my case, it produces 2 AI messages.
Understanding I got from Debugging:
After receiving a signal from the tool_node, I proceed to the interrupt_node. I use the interrupt keyword to pause the graph. When it resumes, it first goes to the agent_node. At the agent_node, it reaches the line where the llm chain is invoked. Instead of executing this line, it continues from the interrupt_node. While at the interrupt_node, I update the tool message (since I didn’t use the interrupt within the tool_node, as I needed to perform some operations first, which conflicted with the interrupt requirements). After updating the tool message, I update the state to reflect this specific tool message. Finally, the flow returns to the agent_node as initially outlined in the graph’s setup.
Now, it resumes its execution from the llm chain invocation line, thus generates a message.
After that the agent node is being executed 2nd time from the first line of the agent_node. Thus generating the second message.
If anyone can explain this behaviour, then it would be quite helpful.
Approaches I have used:
I’ve tried to use Command(goto=“agent_node”, update=…) and returning the updated state and defining the edge during the initialization as well. Nothing worked. I’ve tried to ask it to ‘Chat LangChain’, but it gets confused while helping me and returning me the same result.
Graph Initialization
builder = StateGraph(states.GraphState, states.ContextSchema)
builder.add_node("agent_node", nodes.agent_node)
builder.add_node("tool_node", ToolNode(list (tools.tools_map.values())))
builder.add_node("interrupt_node", nodes.interrupt_node)
builder.add_edge(START, "agent_node")
builder.add_conditional_edges("tool_node", nodes.conditional_edge_for_tool_node_to_interrupt_node, {
"interrupt_node": "interrupt_node",
"agent_node": "agent_node"
})
# Explicit edge from interrupt_node to agent_node to prevent double routing
builder.add_edge("interrupt_node", "agent_node")
graph = builder.compile(checkpointer=checkpointer, store=memory_store)
- Above snippet is how I create my graph.
Remaining Nodes:
async def agent_node(state: GraphState, config, runtime: Runtime, store: BaseStore):
student_profile = ... # user profile fetched from the store
grade_subject_id = state.get("grade_subject_id")
grade_subject_name = "Subject not selected yet."
if grade_subject_id:
grade_subject_name = state.get("grade_subject_response").get("grade_subjects").get(grade_subject_id).get("grade_subject_name")
input_dict = {
"messages": state["messages"],
"current_time": datetime.now(ZoneInfo("Asia/Kolkata")).strftime("%Y-%m-%d %H:%M:%S"),
"student_profile": student_profile,
"available_skills": SKILLS_PROMPT,
"selected_subject": grade_subject_name
}
res = await tuvaed_runnable.ainvoke(input_dict)
logger.info(res.pretty_repr())
if res.tool_calls:
writer = get_stream_writer()
writer({"is_tool_call": True, "tool_calls": res.tool_calls})
return Command(
goto="tool_node",
update={
"messages": [res],
}
)
return Command(
goto=END,
update={
"messages": [res],
}
)
def process_interrupt_response(interrupt_response: InterruptResponse):
# Using this, I'm preparing the feedback for the tool message and updated values for the state keys.
interrupt_reason_res = interrupt_response.reason.value
feedback_dict = interrupt_response.feedback or {}
state_update_dict = {
"is_flow_interrupted": False,
"interrupt_reason": None,
"interrupt_tool_call_id": None,
}
tool_call_feedback = ""
if interrupt_reason_res == InterruptionReason.ASSESSMENT_CREATED.value:
assessment_feedback = "some feedback string"
state_update_dict["assessment_meta"] = None
tool_call_feedback = assessment_feedback
else:
raise ValueError(f"Invalid reason passed. Received: `{interrupt_reason_res}`")
print("--------------------------------")
print(f"tool_call_feedback: {tool_call_feedback}, state_update_dict: {state_update_dict}")
print("--------------------------------")
return tool_call_feedback, state_update_dict
def prepare_interrupt_payload(state: GraphState):
# Using this function, I'm doing the preprocessing for the interrupt response which will be consumed by the FE
final = {
"interrupt_tool_call_id": state.get("interrupt_tool_call_id"),
"is_flow_interrupted": state.get("is_flow_interrupted"),
"interrupt_reason": state.get("interrupt_reason"),
}
if not all(final.values()):
raise ValueError("Missing required fields in final payload, values: {final}")
interrupt_reason = final["interrupt_reason"]
if interrupt_reason == InterruptionReason.ASSESSMENT_CREATED.value:
assessment_meta = state.get("assessment_meta")
if not assessment_meta:
raise ValueError("Assessment meta not found")
final["data_payload"] = assessment_meta
return final
async def interrupt_node(state: GraphState, config, runtime: Runtime, store: BaseStore):
"""
Interrupts the flow and waits for the feedback from the user.
When resuming, this node processes the resume response and continues to agent_node.
"""
final = prepare_interrupt_payload(state)
logger.info(f"***Final payload before interrupting***: {final}")
interrupt_response = interrupt(final)
logger.info(f"***Interrupt response***: {interrupt_response}")
interrupt_response = InterruptResponse(**interrupt_response)
tool_call_feedback, state_update_dict = process_interrupt_response(interrupt_response)
messages = state["messages"]
existing_tool_message = None
i = len(messages) - 1
while i >= 0:
if isinstance(messages[i], ToolMessage) and messages[i].tool_call_id == interrupt_response.tool_call_id:
existing_tool_message = messages[i]
break
i -= 1
if not existing_tool_message:
logger.error(f"Interrupt tool call id {interrupt_response.tool_call_id} not found in messages")
return state_update_dict
existing_tool_message.content += tool_call_feedback
return {
"messages": [existing_tool_message],
**state_update_dict
}
async def conditional_edge_for_tool_node_to_interrupt_node(state: GraphState, config, runtime: Runtime, store: BaseStore):
"""
Conditional edge for tool node to interrupt node.
This should only be evaluated when coming FROM tool_node, not when coming from interrupt_node.
"""
is_flow_interrupted = state.get("is_flow_interrupted")
logger.info(f"Conditional edge evaluation: is_flow_interrupted={is_flow_interrupted}")
if is_flow_interrupted:
logger.info("Interrupting the flow - routing to interrupt_node")
return "interrupt_node"
logger.info("Flow not interrupted - routing to agent_node")
return "agent_node"
It would be helpful if anyone can explain the problem and possible solution.