How to use human-in-the-loop with stream_mode="messages"

@zjt148680

I made this (both messages and updates are streamable):

import os

from dotenv import load_dotenv

from typing import Annotated

from langchain.chat_models import init_chat_model
from langchain_core.messages import HumanMessage
from langchain_core.tools import tool
from langgraph.checkpoint.memory import InMemorySaver
from langgraph.checkpoint.serde.types import INTERRUPT
from langgraph.constants import START, END
from langgraph.graph import StateGraph, add_messages
from langgraph.prebuilt import ToolNode
from langgraph.config import get_stream_writer
from langgraph.types import interrupt, Command
from pydantic import BaseModel

load_dotenv()

# --------------------tool-------------------------
@tool
def human_assistance(query: str) -> str:
    """Request assistance from a human."""
    # Emit a preview of the interrupt prompt to the client via custom stream
    try:
        writer = get_stream_writer()
        # You can emit structured data; here we send a simple dict payload
        writer({"type": "interrupt_preview", "prompt": query})
    except Exception:
        # get_stream_writer is only available within a LangGraph execution context
        # If not available, silently skip emitting custom data
        pass

    human_response = interrupt(query) 
    return human_response


tools = [human_assistance]

# -----------------------llm----------------------
llm = init_chat_model(
    "claude-3-7-sonnet-latest",
    model_provider="anthropic",
    api_key=os.getenv("ANTHROPIC_API_KEY", ""),
)

llm_with_tools = llm.bind_tools(tools)

# ---------------state-----------------------
class State(BaseModel):
    messages: Annotated[list, add_messages]


# ----------------node------------------
def chat_bot(state: State):
    print("Chat history:")
    for msg in state.messages:
        role = getattr(msg, "role", None) or getattr(msg, "type", None) or "unknown"
        content = getattr(msg, "content", None) or str(msg)
        print(f"  [{role}] {content}")
    ai_msg = llm_with_tools.invoke(state.messages)
    # IMPORTANT: return new message(s); do NOT mutate state (append returns None)
    return {"messages": [ai_msg]}


tool_node = ToolNode(tools)


def route_tools(state: State) -> str:
    messages = state.messages if state.messages else []
    if messages:
        ai_message = messages[-1]
    else:
        raise ValueError(f"No messages found in input state to tool_edge: {state}")
    if hasattr(ai_message, "tool_calls") and len(ai_message.tool_calls) > 0:
        return "tool_node"
    return END

graph = (
    StateGraph(State)
    .add_node("chat_bot", chat_bot)
    .add_node("tool_node", tool_node)

    .add_edge(START, "chat_bot")
    .add_conditional_edges(
        "chat_bot",
        route_tools,
        {"tool_node": "tool_node", END: END},
    )
    .add_edge("tool_node", "chat_bot")

    .compile(name="New Graph", checkpointer=InMemorySaver())
)

def stream_graph(user_input: str, thread_id: str):
    config = {"configurable": {"thread_id": thread_id}}

    # First invocation uses the user's message. Resumes use Command(...)
    pending_input = {"messages": [{"role": "user", "content": user_input}]}

    while True:
        interrupted = False

        for mode, chunk in graph.stream(
            pending_input,
            config=config,
            stream_mode=["messages", "updates", "custom"],  # tokens + interrupts + custom tool events
        ):
            # print(f"\n--- Stream mode: {mode} ---")

            if mode == "messages":
                token, meta = chunk
                # Only print real token content (function-call chunks can be None)
                if token.content and meta.get("langgraph_node") == "chat_bot":
                    content = token.content
                    if isinstance(content, str):
                        text = content
                    elif isinstance(content, list):
                        # Extract text blocks for providers that return content blocks
                        parts = []
                        for block in content:
                            if isinstance(block, dict) and block.get("type") == "text":
                                parts.append(block.get("text", ""))
                        text = "".join(parts)
                    else:
                        text = str(content)
                    if text:
                        print(text, end="", flush=True)

            elif mode == "custom":
                # Print any custom tool emissions (e.g., interrupt previews)
                try:
                    print(chunk.prompt, flush=True)
                except Exception:
                    pass

            elif mode == "updates":
                event = chunk
                # Interrupts are surfaced via updates/debug
                if isinstance(event, dict) and INTERRUPT in event:
                    print()  # newline before prompting
                    # prompt = event[INTERRUPT][0].value
                    # human = input(f"interrupt: {prompt}: ")
                    human = input(f"interrupt: ")
                    # Resume the same run using Command(resume=...)
                    pending_input = Command(resume=human)
                    interrupted = True
                    break  # exit inner for-loop and resume

        if not interrupted:
            print()  # finish the line of streamed tokens
            return


def main():
    while True:
        try:
            thread_id = input("thread_id: ")
            user_input = input("user: ")
            if user_input.lower() in ["quit", "exit", "q"]:
                print("Goodbye!")
                break

            stream_graph(user_input, thread_id)
        except Exception as e:
            print(e)


if "__main__" == __name__:
    main()