How to use Langchain v1.x middleware in langgraph?

In Langchain v1.x, we have nice middlewares to pass into create_agent method to manage chat histories. However, if the agent is build using Langgraph, how can we integrate it with middlewares please?

I appreciate the help!

I mean those prefefined middlewares like “from langchain.agents.middleware import SummarizationMiddleware“, etc.

hi @DataNoob0723

If you already have a custom LangGraph graph

  • Option A (compose): Wrap your existing graph as a tool and call it from an agent created via create_agent to gain middleware support.
  • Option B (manual integration): Instantiate the middleware and add nodes for its hooks (before_agent, before_model, after_model, after_agent) into your StateGraph, chaining edges the same way create_agent does. This is exactly what create_agent implements internally.
  • For chat history: LangGraph recommends persistence via a checkpointer (e.g., InMemorySaver, SqliteSaver, PostgresSaver) to get durable, multi-turn memory. Summarization/trim can be done with the middleware or with message transformers if you’re building your own nodes.
Option A - compose your existing graph as a tool
import os

from dotenv import load_dotenv
from langgraph.graph import StateGraph, START, END
from langgraph.checkpoint.memory import InMemorySaver
from typing import TypedDict

load_dotenv(verbose=True)

DEBUG_SUMMARY = os.getenv("DEBUG_SUMMARY", "y").lower() in {"1", "true", "yes", "y"}
SUMMARY_MAX_TOKENS = int(os.getenv("SUMMARY_MAX_TOKENS", "50"))
SUMMARY_MESSAGES_TO_KEEP = int(os.getenv("SUMMARY_MESSAGES_TO_KEEP", "4"))
FORCE_SUMMARY_ASYNC = os.getenv("FORCE_SUMMARY_ASYNC", "1").lower() in {"1", "true", "yes", "y"}

# Assume you already have a custom LangGraph graph
class MessagesState(TypedDict):
    messages: list

def my_pipeline(state: MessagesState) -> MessagesState:
    # ... your custom pipeline logic ...
    return {"messages": state["messages"]}

builder = StateGraph(MessagesState)
builder.add_node("pipeline", my_pipeline)
builder.add_edge(START, "pipeline")
builder.add_edge("pipeline", END)
custom_graph = builder.compile(checkpointer=InMemorySaver())

# Wrap it as a tool to use from an agent with middleware
from langchain_core.tools import tool

@tool
def run_custom_graph(prompt: str) -> str:
    """Run the custom LangGraph pipeline with the given user prompt and
    return the content of the last message if present, otherwise 'ok'."""
    # Route the user input through your graph; return last AI message (if any)
    result = custom_graph.invoke(
        {"messages": [{"role": "user", "content": prompt}]},
        {"configurable": {"thread_id": "t1"}},
    )
    msgs = result.get("messages", [])
    return getattr(msgs[-1], "content", "ok") if msgs else "ok"

from langchain.agents import create_agent
from langchain.agents.middleware import SummarizationMiddleware
from langchain.agents.middleware.types import AgentMiddleware
from typing import Any, Dict

def _estimate_tokens(messages):
    text = []
    for msg in messages or []:
        if isinstance(msg, dict):
            content = msg.get("content")
        else:
            content = getattr(msg, "content", None)
        if content:
            text.append(str(content))
    return max(1, len(" ".join(text)) // 4)

def _debug_roles_and_summary(messages):
    roles = []
    system_snippets = []
    for msg in messages or []:
        if isinstance(msg, dict):
            role = msg.get("role") or msg.get("type")
            content = msg.get("content")
        else:
            role = getattr(msg, "type", getattr(msg, "role", None))
            content = getattr(msg, "content", None)
        roles.append(role)
        if role == "system" and content:
            system_snippets.append(str(content)[:120].replace("\n", " "))
    return roles, system_snippets

def _print_state_snapshot(label: str, messages):
    roles, _ = _debug_roles_and_summary(messages)
    print(f"[SUMMARY-STATE] {label}: messages={len(messages)} est_tokens={_estimate_tokens(messages)} roles={roles}")
    for idx, msg in enumerate(messages):
        if isinstance(msg, dict):
            role = msg.get("role") or msg.get("type")
            content = msg.get("content")
        else:
            role = getattr(msg, "type", getattr(msg, "role", None))
            content = getattr(msg, "content", None)
        snippet = (str(content) if content is not None else "")[:160].replace("\n", " ")
        print(f"[SUMMARY-STATE]   [{idx}] {role}: {snippet}")


class DebugSummarizationMiddleware(SummarizationMiddleware):
    def before_model(self, state: Dict[str, Any], runtime: Any):
        if DEBUG_SUMMARY:
            msgs = state.get("messages", [])
            print(
                f"[SUMMARY-DEBUG] before_model: messages={len(msgs)} est_tokens={_estimate_tokens(msgs)} "
                f"threshold={SUMMARY_MAX_TOKENS} keep={SUMMARY_MESSAGES_TO_KEEP} has_summary={bool(state.get('summary'))}"
            )
            print(
                f"[SUMMARY-DEBUG] mw={self.__class__.__name__} base_sync_impl={SummarizationMiddleware.before_model is not AgentMiddleware.before_model} "
                f"base_async_impl={SummarizationMiddleware.abefore_model is not AgentMiddleware.abefore_model}"
            )
        result = super().before_model(state, runtime)
        if DEBUG_SUMMARY:
            res_msgs = result.get("messages") if isinstance(result, dict) else None
            if res_msgs is not None:
                print(f"[SUMMARY-DEBUG] before_model result: messages_update={len(res_msgs)} est_tokens={_estimate_tokens(res_msgs)}")
                roles, sys_snips = _debug_roles_and_summary(res_msgs)
                print(f"[SUMMARY-DEBUG] before_model result roles: {roles}")
                if sys_snips:
                    print(f"[SUMMARY-DEBUG] before_model system snippet: {sys_snips[0]}")
                _print_state_snapshot("BEFORE summarization", state.get("messages", []))
                _print_state_snapshot("AFTER summarization", res_msgs)
            print(f"[SUMMARY-DEBUG] before_model keys: {list(result.keys()) if isinstance(result, dict) else type(result)}")
        return result

    async def abefore_model(self, state: Dict[str, Any], runtime: Any):
        if DEBUG_SUMMARY:
            msgs = state.get("messages", [])
            print(
                f"[SUMMARY-DEBUG] abefore_model: messages={len(msgs)} est_tokens={_estimate_tokens(msgs)} "
                f"threshold={SUMMARY_MAX_TOKENS} keep={SUMMARY_MESSAGES_TO_KEEP} has_summary={bool(state.get('summary'))}"
            )
            print(
                f"[SUMMARY-DEBUG] mw={self.__class__.__name__} base_sync_impl={SummarizationMiddleware.before_model is not AgentMiddleware.before_model} "
                f"base_async_impl={SummarizationMiddleware.abefore_model is not AgentMiddleware.abefore_model}"
            )
        result = await super().abefore_model(state, runtime)
        if DEBUG_SUMMARY:
            res_msgs = result.get("messages") if isinstance(result, dict) else None
            if res_msgs is not None:
                print(f"[SUMMARY-DEBUG] abefore_model result: messages_update={len(res_msgs)} est_tokens={_estimate_tokens(res_msgs)}")
                roles, sys_snips = _debug_roles_and_summary(res_msgs)
                print(f"[SUMMARY-DEBUG] abefore_model result roles: {roles}")
                if sys_snips:
                    print(f"[SUMMARY-DEBUG] abefore_model system snippet: {sys_snips[0]}")
                _print_state_snapshot("BEFORE summarization", state.get("messages", []))
                _print_state_snapshot("AFTER summarization", res_msgs)
            print(f"[SUMMARY-DEBUG] abefore_model keys: {list(result.keys()) if isinstance(result, dict) else type(result)}")
        return result


agent = create_agent(
    model="openai:gpt-4o-mini",
    tools=[run_custom_graph],            # <-- your graph is now a tool
    middleware=[
        DebugSummarizationMiddleware(
            model="openai:gpt-4o-mini",
            max_tokens_before_summary=SUMMARY_MAX_TOKENS,
            messages_to_keep=SUMMARY_MESSAGES_TO_KEEP,
        )
    ],
    checkpointer=InMemorySaver(),
)

if DEBUG_SUMMARY:
    base_has_sync = SummarizationMiddleware.before_model is not AgentMiddleware.before_model
    base_has_async = SummarizationMiddleware.abefore_model is not AgentMiddleware.abefore_model
    print(f"[SUMMARY-DEBUG] hook availability (base SummarizationMiddleware): has_sync={base_has_sync} has_async={base_has_async}")

def run_cli_conversation():
    """
    Simple CLI conversation loop that sends user inputs through the agent and
    prints the latest assistant response. Type 'exit' or 'quit' to leave.
    """
    config = {"configurable": {"thread_id": "cli"}}
    print("CLI started. Type 'exit' or 'quit' to end.")
    while True:
        try:
            user_input = input("You: ").strip()
        except (EOFError, KeyboardInterrupt):
            print()
            break

        if user_input.lower() in {"exit", "quit"}:
            break

        final_state = None
        for state in agent.stream({"messages": [{"role": "user", "content": user_input}]}, config=config, stream_mode="values"):
            final_state = state

        if not final_state:
            print("AI: [no response]")
            continue

        ai_text = None
        messages = final_state.get("messages", [])

        # Try to find the last assistant/ai message, supporting dict or message objects
        for msg in reversed(messages):
            if isinstance(msg, dict):
                role = msg.get("role") or msg.get("type")
                content = msg.get("content")
            else:
                role = getattr(msg, "type", getattr(msg, "role", None))
                content = getattr(msg, "content", None)
            if role in ("assistant", "ai"):
                ai_text = content
                break

        if DEBUG_SUMMARY and final_state is not None:
            msgs = final_state.get("messages", [])
            has_summary = bool(final_state.get("summary"))
            print(f"[SUMMARY-DEBUG] after step: messages={len(msgs)} est_tokens={_estimate_tokens(msgs)} has_summary={has_summary}")
            roles, sys_snips = _debug_roles_and_summary(msgs)
            print(f"[SUMMARY-DEBUG] after step roles: {roles}")
            if sys_snips:
                print(f"[SUMMARY-DEBUG] after step system snippet: {sys_snips[0]}")

        if ai_text is not None:
            print(f"AI: {ai_text}")
        else:
            print("AI: [no response produced by agent]")


if __name__ == "__main__":
    run_cli_conversation()
Option B - manually insert middleware hooks into your StateGraph

When you own the LangGraph graph, you can add middleware hook nodes (e.g., before_model) directly before your model node. This mirrors what create_agent does internally.

import os
import asyncio
import inspect
from dotenv import load_dotenv
from langgraph._internal._runnable import RunnableCallable
from langgraph.checkpoint.memory import InMemorySaver
from langgraph.graph import StateGraph, START, END
from langchain.agents.middleware import SummarizationMiddleware
from langchain.agents.middleware.types import AgentMiddleware, AgentState
from langchain_openai import ChatOpenAI
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage

load_dotenv(verbose=True)

# Chat model used by the model node
chat_llm = ChatOpenAI(model="gpt-4o-mini", temperature=0)

# Debug controls
DEBUG_SUMMARY = os.getenv("DEBUG_SUMMARY", "y").lower() in {"1", "true", "yes", "y"}
SUMMARY_MAX_TOKENS = int(os.getenv("SUMMARY_MAX_TOKENS", "50"))
SUMMARY_MESSAGES_TO_KEEP = int(os.getenv("SUMMARY_MESSAGES_TO_KEEP", "4"))
FORCE_SUMMARY_ASYNC = os.getenv("FORCE_SUMMARY_ASYNC", "1").lower() in {"1", "true", "yes", "y"}

def _coerce_to_lc_messages(messages):
    """Convert tuple messages (role, content) to LangChain message objects."""
    converted = []
    for msg in messages or []:
        if isinstance(msg, tuple) and len(msg) >= 2:
            role, content = msg[0], msg[1]
            role_lower = (role or "").lower()
            if role_lower in ("user", "human"):
                converted.append(HumanMessage(content=content))
            elif role_lower in ("assistant", "ai"):
                converted.append(AIMessage(content=content))
            elif role_lower == "system":
                converted.append(SystemMessage(content=content))
            else:
                converted.append(HumanMessage(content=str(content)))
        else:
            converted.append(msg)
    return converted

def _estimate_tokens(messages):
    """Cheap token estimate based on characters; avoids extra deps."""
    text = []
    for msg in messages or []:
        if isinstance(msg, tuple) and len(msg) >= 2:
            text.append(str(msg[1]))
        else:
            content = getattr(msg, "content", None)
            if content:
                text.append(str(content))
    chars = len(" ".join(text))
    return max(1, chars // 4)

def _debug_roles_and_summary(messages):
    roles = []
    system_snippets = []
    for msg in messages or []:
        role = None
        content = None
        if isinstance(msg, tuple) and len(msg) >= 2:
            role = msg[0]
            content = msg[1]
        else:
            role = getattr(msg, "type", getattr(msg, "role", None))
            content = getattr(msg, "content", None)
        roles.append(role)
        if role == "system" and content:
            text = str(content)
            system_snippets.append(text[:120].replace("\n", " "))
    return roles, system_snippets

def _print_state_snapshot(label, messages):
    roles, sys_snips = _debug_roles_and_summary(messages)
    print(
        f"[SUMMARY-STATE] {label}: messages={len(messages)} est_tokens={_estimate_tokens(messages)} roles={roles}"
    )
    # Print each message role with a short content snippet for visibility
    for idx, msg in enumerate(messages):
        role = None
        content = None
        if isinstance(msg, tuple) and len(msg) >= 2:
            role = msg[0]
            content = msg[1]
        else:
            role = getattr(msg, "type", getattr(msg, "role", None))
            content = getattr(msg, "content", None)
        snippet = (str(content) if content is not None else "")[:160].replace("\n", " ")
        print(f"[SUMMARY-STATE]   [{idx}] {role}: {snippet}")


# Your model node: call the chat model and append an AIMessage
def model_node(state: AgentState):
    messages = state.get("messages", [])
    input_messages = _coerce_to_lc_messages(messages)
    ai_message = chat_llm.invoke(input_messages)
    return {"messages": messages + [ai_message]}

mw = SummarizationMiddleware(model="openai:gpt-4o-mini", max_tokens_before_summary=SUMMARY_MAX_TOKENS, messages_to_keep=SUMMARY_MESSAGES_TO_KEEP)

def _before_model_debug_sync(state: AgentState, runtime):
    if DEBUG_SUMMARY:
        msgs = state.get("messages", [])
        print(
            f"[SUMMARY-DEBUG] before_model: messages={len(msgs)} est_tokens={_estimate_tokens(msgs)} "
            f"threshold={SUMMARY_MAX_TOKENS} keep={SUMMARY_MESSAGES_TO_KEEP} has_summary={bool(state.get('summary'))}"
        )
        try:
            print(
                f"[SUMMARY-DEBUG] mw={mw.__class__.__name__} sync_impl={mw.__class__.before_model is not AgentMiddleware.before_model} "
                f"async_impl={mw.__class__.abefore_model is not AgentMiddleware.abefore_model} "
                f"attrs={{max_tokens_before_summary={getattr(mw, 'max_tokens_before_summary', None)}, messages_to_keep={getattr(mw, 'messages_to_keep', None)}, model={getattr(mw, 'model', None)}}}"
            )
        except Exception as e:
            print(f"[SUMMARY-DEBUG] mw introspection error: {e}")

    result = mw.before_model(state, runtime)
    if result is None and FORCE_SUMMARY_ASYNC and (mw.__class__.abefore_model is not AgentMiddleware.abefore_model):
        if DEBUG_SUMMARY:
            print("[SUMMARY-DEBUG] before_model returned None; attempting async abefore_model from sync context…")
        try:
            async def _run():
                return await mw.abefore_model(state, runtime)

            result = asyncio.run(_run())
        except RuntimeError as e:
            if DEBUG_SUMMARY:
                print(f"[SUMMARY-DEBUG] asyncio.run failed (likely event loop running): {e}")
        except Exception as e:
            if DEBUG_SUMMARY:
                print(f"[SUMMARY-DEBUG] abefore_model call errored: {e}")
    if DEBUG_SUMMARY:
        res_msgs = result.get("messages") if isinstance(result, dict) else None
        if res_msgs is not None:
            print(
                f"[SUMMARY-DEBUG] before_model result: messages_update={len(res_msgs)} est_tokens={_estimate_tokens(res_msgs)}"
            )
            roles, sys_snips = _debug_roles_and_summary(res_msgs)
            print(f"[SUMMARY-DEBUG] before_model result roles: {roles}")
            if sys_snips:
                print(f"[SUMMARY-DEBUG] before_model system snippet: {sys_snips[0]}")
            # Print full state snapshot after summarization (messages patch applied)
            before_msgs = state.get("messages", [])
            _print_state_snapshot("BEFORE summarization", before_msgs)
            _print_state_snapshot("AFTER summarization", res_msgs)
        print(f"[SUMMARY-DEBUG] before_model keys: {list(result.keys()) if isinstance(result, dict) else type(result)}")
    return result

async def _before_model_debug_async(state: AgentState, runtime):
    if DEBUG_SUMMARY:
        msgs = state.get("messages", [])
        print(
            f"[SUMMARY-DEBUG] abefore_model: messages={len(msgs)} est_tokens={_estimate_tokens(msgs)} "
            f"threshold={SUMMARY_MAX_TOKENS} keep={SUMMARY_MESSAGES_TO_KEEP} has_summary={bool(state.get('summary'))}"
        )
    result = await mw.abefore_model(state, runtime)
    if DEBUG_SUMMARY:
        res_msgs = result.get("messages") if isinstance(result, dict) else None
        if res_msgs is not None:
            print(
                f"[SUMMARY-DEBUG] abefore_model result: messages_update={len(res_msgs)} est_tokens={_estimate_tokens(res_msgs)}"
            )
            roles, sys_snips = _debug_roles_and_summary(res_msgs)
            print(f"[SUMMARY-DEBUG] abefore_model result roles: {roles}")
            if sys_snips:
                print(f"[SUMMARY-DEBUG] abefore_model system snippet: {sys_snips[0]}")
            # Print full state snapshot after summarization (messages patch applied)
            before_msgs = state.get("messages", [])
            _print_state_snapshot("BEFORE summarization", before_msgs)
            _print_state_snapshot("AFTER summarization", res_msgs)
        print(f"[SUMMARY-DEBUG] abefore_model keys: {list(result.keys()) if isinstance(result, dict) else type(result)}")
    return result

# Wrap middleware hook as a runnable so it receives (state, runtime) with debug
# Only provide sync/async implementations if the middleware actually overrides them.
_has_sync = mw.__class__.before_model is not AgentMiddleware.before_model
_has_async = mw.__class__.abefore_model is not AgentMiddleware.abefore_model

before_model_node = RunnableCallable(
    _before_model_debug_sync if _has_sync else None,
    _before_model_debug_async if _has_async else None,
    trace=False,
)

if DEBUG_SUMMARY:
    print(f"[SUMMARY-DEBUG] hook availability: has_sync={_has_sync} has_async={_has_async}")

graph = StateGraph(AgentState)
graph.add_node("Summarization.before_model", before_model_node, input_schema=AgentState)
graph.add_node("model", model_node)

# Run middleware before your model, then go to END
graph.add_edge(START, "Summarization.before_model")
graph.add_edge("Summarization.before_model", "model")
graph.add_edge("model", END)

compiled = graph.compile(checkpointer=InMemorySaver())


def run_cli_conversation():
    """
    Simple CLI conversation loop that sends user inputs through the compiled graph
    and prints the latest assistant response. Type 'exit' or 'quit' to leave.
    """
    config = {"configurable": {"thread_id": "cli"}}
    print("CLI started. Type 'exit' or 'quit' to end.")
    while True:
        try:
            user_input = input("You: ").strip()
        except (EOFError, KeyboardInterrupt):
            print()
            break

        if user_input.lower() in {"exit", "quit"}:
            break

        final_state = None
        for state in compiled.stream({"messages": [HumanMessage(content=user_input)]}, config=config, stream_mode="values"):
            final_state = state

        if not final_state:
            print("AI: [no response]")
            continue

        ai_text = None
        messages = final_state.get("messages", [])

        # Try to find the last assistant/ai message, supporting tuple or message objects
        for msg in reversed(messages):
            if isinstance(msg, tuple) and len(msg) >= 2:
                role, content = msg[0], msg[1]
                if role in ("assistant", "ai"):
                    ai_text = content
                    break
            else:
                role = getattr(msg, "type", getattr(msg, "role", None))
                content = getattr(msg, "content", None)
                if role in ("ai", "assistant"):
                    ai_text = content
                    break

        if DEBUG_SUMMARY:
            msgs = final_state.get("messages", [])
            has_summary = bool(final_state.get("summary"))
            print(
                f"[SUMMARY-DEBUG] after step: messages={len(msgs)} est_tokens={_estimate_tokens(msgs)} has_summary={has_summary}"
            )
            roles, sys_snips = _debug_roles_and_summary(msgs)
            print(f"[SUMMARY-DEBUG] after step roles: {roles}")
            if sys_snips:
                print(f"[SUMMARY-DEBUG] after step system snippet: {sys_snips[0]}")

        if ai_text is not None:
            print(f"AI: {ai_text}")
        else:
            print("AI: [no response produced by model_node]")


if __name__ == "__main__":
    run_cli_conversation()

Hi, Pawel, thank you so much for the detailed explaination! Greatly appreciate it!

1 Like