Implement ToolLimitationNode in stateGraph

I was using ToolNode but I am open to use anything if that will be the solution

Hey @arthurdorigueto , I tested this code snippet, but it doesn’t work as expected.
When the LLM generates 6 tool calls, even if we remove the last 2 tool call messages, OpenAI still expects responses for all generated tool_call_ids.

`openai.OpenAIError: Error code: 400 - {ā€˜error’: {ā€˜message’: ā€œAn assistant message with ā€˜tool_calls’ must be followed by tool messages responding to each ā€˜tool_call_id’. The following tool_call_ids did not have response messages: call_dTukDpkNykgMR17pdBlmczHd, call_q9btjq3s6Tm7IgBcDpsSLbQ4, call_roeTD5s0HuImKIijnHP38AQz, call_E1VW3kH4cjeuz9GHzxdFZRs0, call_iuFONoEYz3ysirP0YPEizJRi, call_25oGrlVZZLrmvGmZTJar9JRwā€, ā€˜type’: ā€˜invalid_request_error’, ā€˜param’: ā€˜messages.[4].role’, ā€˜code’: None}}`

Any good idea on your side?

so for me I am going to add the skip parameter in the every tool call functions. and in the limitcheck node, I am gonna set the skip=True for the blocking tool calls.

What do you thiink?

cc: @pawel-twardziak

I am still looking at that, almost there. I have no idea where you wanna add that skip param, but be careful, try to be as idiomatic as possible and avoid dirty workarounds as they usually appear hard maintain.

@tool
    def file_tool(data: str, skip: bool, timeout: int | None = None) -> dict:
        if skip:
            return "skip this call due to limitation"
        print('tool calling start...')

this won’t be handled by the graph in an idiomatic way. Hang on, I’m finalizing - the key part is to reconstruct how the middleware is being handled within the create_react graph.

ok, I think I’m done (tool_limit_edge is the key part):

"""
Custom LangGraph StateGraph that wires in LangChain's ModelCallLimitMiddleware
around an OpenAI chat model, using python-dotenv for configuration.

This demonstrates how to *reuse* the built-in middleware in a hand-written
graph instead of only via `langchain.agents.create_agent`.

Key ideas:
- We define a TypedDict state compatible with what the middleware expects
  (message history + call counters).
- We call the middleware's `before_model` and `after_model` hooks from
  dedicated graph nodes.
- We use a conditional edge to respect `jump_to="end"` returned by the
  middleware when limits are exceeded.

Relevant docs:
- LangChain middleware (Python):
  https://docs.langchain.com/oss/python/langchain/middleware
- Built-in middleware (Model call limit):
  https://docs.langchain.com/oss/python/langchain/middleware/built-in#model-call-limit
- LangGraph Graph API:
  https://docs.langchain.com/oss/python/langgraph/graph-api
"""
import argparse
import os
from typing import Annotated, Any, Literal, List, cast

from dotenv import load_dotenv
from langchain.agents.middleware.model_call_limit import ModelCallLimitState
from langchain.agents.middleware.tool_call_limit import ToolCallLimitState
from langgraph.checkpoint.base import BaseCheckpointSaver
from langgraph.checkpoint.postgres import PostgresSaver
from langgraph.prebuilt.tool_node import ToolCallWithContext
from typing_extensions import NotRequired, TypedDict

from langchain_openai import ChatOpenAI
from langchain.agents.middleware import ModelCallLimitMiddleware, ToolCallLimitMiddleware
from langchain.agents.factory import _fetch_last_ai_and_tool_messages
from langchain_core.messages import AnyMessage, HumanMessage, SystemMessage, AIMessage, ToolMessage
from langchain_core.tools import tool

from langgraph.graph import END, START, StateGraph
from langgraph.graph.message import add_messages
from langgraph.prebuilt import ToolNode

from langgraph.types import Command, Send

#
# 1. Load environment and initialize OpenAI model
#

load_dotenv()

if not os.getenv("OPENAI_API_KEY"):
    raise RuntimeError(
        "OPENAI_API_KEY is not set. Please add it to your environment or a .env file."
    )

# ChatOpenAI reads OPENAI_API_KEY from the environment by default.
# See: https://reference.langchain.com/python/integrations/langchain_openai/ChatOpenAI
llm = ChatOpenAI(model="gpt-4o-mini", temperature=0)

#
# 2. Define custom graph state compatible with the middleware
#


class GraphState(TypedDict, ModelCallLimitState, ToolCallLimitState):
    """State for our custom graph.

    This mirrors the shape expected by `ModelCallLimitMiddleware`:
    - `messages`: chat history (append-only via `add_messages` reducer)
    - `thread_model_call_count`: total calls across runs / thread
    - `run_model_call_count`: calls in the current run (normally untracked)
    - `jump_to`: optional flow-control hint (e.g. "end") set by middleware
    """

    messages: Annotated[list[AnyMessage], add_messages]
    # thread_model_call_count: NotRequired[int]
    # run_model_call_count: NotRequired[int]

    # Optional tool call limit middleware counters (dicts keyed by tool name or "__all__")
    # thread_tool_call_count: NotRequired[Annotated[dict[str, int], PrivateStateAttr]]
    # run_tool_call_count: NotRequired[Annotated[dict[str, int], UntrackedValue, PrivateStateAttr]]

    next_model_node: NotRequired[str]
    # jump_to: NotRequired[str]


#
# 3. Tools for the four-tools demo branch
#


@tool
def tool_a() -> str:
    """Return a short string for tool A."""

    return "Result from tool A"


@tool
def tool_b() -> str:
    """Return a short string for tool B."""

    return "Result from tool B"


@tool
def tool_c() -> str:
    """Return a short string for tool C."""

    return "Result from tool C"


@tool
def tool_d() -> str:
    """Return a short string for tool D."""

    return "Result from tool D"


TOOLS_FOR_DEMO = [tool_a, tool_b, tool_c, tool_d]
llm_with_tools = llm.bind_tools(TOOLS_FOR_DEMO)

# Max number of tool calls we will allow to be executed
TOOL_CALL_LIMIT_PER_RUN = 2
TOOL_CALL_LIMIT_PER_THREAD = None
# Max number of model calls we will allow to be executed
MODEL_CALL_LIMIT_PER_RUN = 2
MODEL_CALL_LIMIT_PER_THREAD = None
# Thread ID increment
THREAD_ID_INCREMENT = 18

#
# 4. Instantiate the built-in ModelCallLimitMiddleware
#


call_limiter: ModelCallLimitMiddleware
def set_call_limiter(run_limit: int | None) -> None:
    global call_limiter
    call_limiter = ModelCallLimitMiddleware(
        thread_limit=MODEL_CALL_LIMIT_PER_THREAD,  # Max calls per long-lived "thread"
        run_limit=run_limit,  # Max calls per single graph run
        exit_behavior="end",  # Inject message + request jump to end when exceeded
    )

# 4b. ToolCallLimitMiddleware (used in four_tools_demo branch).
# We set `exit_behavior="continue"` so that blocked tool calls are turned into
# error ToolMessages, but we still allow the agent to continue and execute the
# remaining allowed tool calls. This mirrors how `create_agent` wires this
# middleware together with `ToolNode` and `_make_model_to_tools_edge`.
tool_call_limiter = ToolCallLimitMiddleware[Any, None](
    thread_limit=TOOL_CALL_LIMIT_PER_THREAD,
    run_limit=TOOL_CALL_LIMIT_PER_RUN,
    exit_behavior="continue",
)

#
# 4. Wrap middleware hooks in graph nodes
#    Each node takes `state` and returns a partial state update dict.
#


def enforce_call_limit(state: GraphState) -> dict[str, Any]:
    """Run ModelCallLimitMiddleware.before_model on the current state.

    The middleware signature is `before_model(self, state, runtime)`. The current
    implementation does not use `runtime`, so we pass `None` here.

    When limits are exceeded and `exit_behavior="end"`, this returns:

        {"jump_to": "end", "messages": [AIMessage(...)]}

    Otherwise it returns either {} or None.
    """
    print("------ Enforcing call limit...")
    updates = call_limiter.before_model(state, runtime=None)  # type: ignore[arg-type]
    return updates or {}

def set_summarization_model(state: GraphState) -> Command[Literal["update_call_counts"]]:
    print("------ Setting summarization model and going to `update_call_counts`...")
    return Command[Literal['update_call_counts']](goto="update_call_counts", update={"next_model_node": "summarize"})

def call_model(state: GraphState) -> dict[str, Any]:
    """Call the OpenAI chat model with the accumulated messages."""
    print("------ Calling model...")
    response = llm.invoke(state["messages"])
    # Because `messages` has the `add_messages` reducer, returning a list of new
    # messages appends them to the existing list.
    return {"messages": [response]}


def call_model_with_four_tools(state: GraphState) -> dict[str, Any]:
    """Call the OpenAI chat model configured with 4 tools.

    The initial messages for this branch should strongly instruct the model to
    call *all four tools in a single response*. Even if it does so, that still
    counts as **one** model call for `ModelCallLimitMiddleware` – the middleware
    does not count tool calls individually. To enforce limits on tool calls
    themselves, you would instead use `ToolCallLimitMiddleware`.
    """
    print("------ Calling model with four tools...")
    response = llm_with_tools.invoke(state["messages"])
    return {"messages": [response]}


def update_call_counts(state: GraphState) -> dict[str, Any]:
    """Increment call counters via ModelCallLimitMiddleware.after_model."""
    print("------ Updating call counts...")
    updates = call_limiter.after_model(state, runtime=None)  # type: ignore[arg-type]
    return updates or {}


def enforce_tool_call_limit(state: GraphState) -> dict[str, Any]:
    """Run ToolCallLimitMiddleware.after_model on the current state.

    This middleware inspects the last AIMessage's tool calls and:
    - updates per-tool/thread/run call counters, and
    - can block excess tool calls or end the run (exit_behavior="end").
    """
    print("------ Enforcing tool call limit...")
    updates = tool_call_limiter.after_model(state, runtime=None)  # type: ignore[arg-type]
    return updates or {}

""" this is imported from langchain.agents.factory """
# def _fetch_last_ai_and_tool_messages(
#     messages: list[AnyMessage],
# ) -> tuple[AIMessage | None, list[ToolMessage]]:
#     """Return the last AIMessage and all following ToolMessages.
#
#     This mirrors the helper used in `langchain.agents.factory._make_model_to_tools_edge`.
#     """
#     print("------ Fetching last ai and tool messages...")
#     last_ai_index: int | None = None
#     last_ai_message: AIMessage | None = None
#
#     for i in range(len(messages) - 1, -1, -1):
#         msg = messages[i]
#         if isinstance(msg, AIMessage):
#             last_ai_index = i
#             last_ai_message = cast(AIMessage, msg)
#             break
#
#     if last_ai_index is None or last_ai_message is None:
#         return None, []
#
#     tool_messages = [m for m in messages[last_ai_index + 1 :] if isinstance(m, ToolMessage)]
#     return last_ai_message, tool_messages


def tool_limit_edge(state: GraphState) -> list[Send] | str:
    """Decide whether to route from `tool_call_limit` to `tools` or `update_call_counts`.

    Patterned after `_make_model_to_tools_edge`:

    1. If `jump_to` is set (e.g. by another middleware) – honor it.
    2. Take the last `AIMessage` and all subsequent `ToolMessage`s.
    3. If the model did not emit any `tool_calls` → finish (nothing to send to `ToolNode`).
    4. Compute *pending* tool calls as those for which there is **no** corresponding
       `ToolMessage`. In our setup `ToolCallLimitMiddleware` generates `ToolMessage`s
       with status `"error"` for blocked calls, so those should not go to `ToolNode`.
    5. If there are pending calls → route to the `tools` node, otherwise finish.
    """
    print("------ Tool limit edge...")
    # 1. Honor explicitly set `jump_to` (e.g. by other middleware).
    jump_to = state.get("jump_to")
    if jump_to == "end":
        print("------ Ending tool limit edge - update_call_counts #1 ...")
        return "update_call_counts"
    if jump_to == "tools":
        print("------ Ending tool limit edge - tools ...")
        return "tools"

    messages = state.get("messages", [])
    last_ai, tool_messages = _fetch_last_ai_and_tool_messages(messages)
    if last_ai is None:
        print("------ Ending tool limit edge - update_call_counts #2 ...")
        return "update_call_counts"

    tool_message_ids = [m.tool_call_id for m in tool_messages if m.tool_call_id is not None]

    # 2. If the model did not call any tools – finish.
    if not last_ai.tool_calls:
        print("------ Ending tool limit edge - update_call_counts #3 ...")
        return "update_call_counts"

    # 3. Identify calls that do not yet have a corresponding ToolMessage.
    pending_tool_calls = [
        tc for tc in last_ai.tool_calls if tc.get("id") not in tool_message_ids
    ]

    # 4. If there is anything left to execute – go to `tools`.
    if pending_tool_calls:
        print("------ Ending tool limit edge - tool_call_with_context ...")
        return [
            Send(
                "tools",
                ToolCallWithContext(
                    __type="tool_call_with_context",
                    tool_call=tool_call,
                    state=state,
                ),
            )
            for tool_call in pending_tool_calls
        ]

    # 5. All calls already have a ToolMessage (e.g. they were blocked by middleware),
    # so there is nothing more to execute – go to the end node for this iteration.
    print("------ Ending tool limit edge - update_call_counts #4 ...")
    return "update_call_counts"


def summarize_after_tools(state: GraphState) -> dict[str, Any]:
    """Run a summarization LLM call after tools have executed.

    This uses the base chat model (without tools) to generate a concise
    final answer that takes any tool results into account.
    """
    print("------ Summarizing after tools ...")
    response = llm.invoke(state["messages"])
    return {"messages": [response]}


#
# 5. Routing logic: respect middleware's jump_to="end"
#


def route_after_enforce(state: GraphState) -> str:
    """Determine the next node after `enforce_call_limit`.

    If the middleware set `jump_to="end"`, route directly to END.
    Otherwise, proceed to the model node.

    This pattern follows the conditional edge examples in the LangGraph
    Graph API docs:
    https://docs.langchain.com/oss/python/langgraph/graph-api#conditional-branching
    """
    print(f"------ Routing after enforce ({state.get("jump_to")}, {state.get("next_model_node", "model")}) ...")
    if state.get("jump_to") == "end":
        return END
    return state.get("next_model_node", "model")


#
# 6. Build and compile the StateGraph
#


def build_graph(*, use_four_tools_demo: bool = False, checkpointer: BaseCheckpointSaver):
    """Build and compile the StateGraph with the call-limit middleware wired in.

    Args:
        use_four_tools_demo: If True, the graph's model node uses the
            tool-calling LLM (`llm_with_tools`) so the first model call can
            emit four tool calls in one shot.
    """
    builder = StateGraph(GraphState)

    if use_four_tools_demo:
        # Nodes
        builder.add_node("enforce_call_limit", enforce_call_limit)
        builder.add_node("model", call_model_with_four_tools)
        # Apply tool call limits before routing to tools
        builder.add_node("tool_call_limit", enforce_tool_call_limit)
        # ToolNode executes any tool_calls emitted by the model and appends ToolMessages.
        builder.add_node("tools", ToolNode(TOOLS_FOR_DEMO))
        # After tools complete, run a summarization LLM call.
        builder.add_node("summarize", summarize_after_tools)
        builder.add_node("set_summarization_model", set_summarization_model)
        builder.add_node("update_call_counts", update_call_counts)

        # Edges
        builder.add_edge(START, "enforce_call_limit")
        builder.add_conditional_edges("enforce_call_limit", route_after_enforce)

        # First, always run the ToolCallLimitMiddleware after the model.
        builder.add_edge("model", "tool_call_limit")

        # Then, based on the state after the middleware, decide whether to go to `tools`
        # or to finish the current iteration directly.
        builder.add_conditional_edges(
            "tool_call_limit",
            tool_limit_edge,
            ["tools", "update_call_counts"],
        )

        builder.add_edge("tools", "set_summarization_model")
        builder.add_edge("summarize", "update_call_counts")
        builder.add_edge("update_call_counts", "enforce_call_limit")
    else:
        # Nodes
        builder.add_node("enforce_call_limit", enforce_call_limit)
        builder.add_node("model", call_model)
        builder.add_node("update_call_counts", update_call_counts)

        # Edges
        builder.add_edge(START, "enforce_call_limit")
        builder.add_conditional_edges("enforce_call_limit", route_after_enforce)
        builder.add_edge("model",  "update_call_counts")
        builder.add_edge("update_call_counts", "enforce_call_limit")

    # Attach the in-memory checkpointer so state can be persisted
    return builder.compile(checkpointer=checkpointer)


#
# 7. Example CLI entrypoint
#


def main(argv: list[str] | None = None) -> None:
    with PostgresSaver.from_conn_string(os.getenv("POSTGRES_URI")) as memory:
        memory.setup()
        """Run demo(s) showcasing ModelCallLimitMiddleware in a custom graph.
    
        By default this runs a simple loop that will hit the `run_limit` after a
        few model calls. With `--four-tools-demo`, it instead configures the LLM
        with four tools and prompts it to emit four tool calls in a single
        response, demonstrating that tool calls do not affect the model-call limit.
        """
        parser = argparse.ArgumentParser(description="ModelCallLimitMiddleware custom graph demo")
        parser.add_argument(
            "--four-tools-demo",
            action="store_true",
            help="Use a graph where the model is configured with 4 tools and is prompted "
            "to call all of them in one shot.",
        )
        args = parser.parse_args(argv)

        config = {"configurable": {"thread_id": f"thread_id_{THREAD_ID_INCREMENT}_{"four_tools_demo" if args.four_tools_demo else "regular_demo"}", "user_id": "twapaw"}}

        graph = build_graph(use_four_tools_demo=args.four_tools_demo, checkpointer=memory)

        state = graph.get_state(config=config)

        print(f"------ Initial state ------ {state}")

        if args.four_tools_demo:
            set_call_limiter(MODEL_CALL_LIMIT_PER_RUN)
            # Strongly instruct the model to call all four tools exactly once,
            # in a single response.
            initial_messages: list[AnyMessage] = [
                SystemMessage(
                    content=(
                        "You are a tool-calling assistant. You MUST respond to the next user "
                        "message by calling ALL of the following tools exactly once, in a "
                        "single tool_calls response: tool_a, tool_b, tool_c, tool_d. "
                        "Do not answer directly; only return tool calls."
                    )
                ),
                HumanMessage(content="Demonstrate all four tools now."),
            ]
        else:
            set_call_limiter(MODEL_CALL_LIMIT_PER_RUN * 2)
            # Initial state: a single user message and counters at 0.
            initial_messages = [HumanMessage(content="Briefly explain what LangGraph is.")]

        initial_state: GraphState = {
            "messages": initial_messages,
            "thread_model_call_count": 0,
            "run_model_call_count": 0,
        }

        result = graph.invoke(initial_state, config=config)

        print("\n=== Final messages ===")
        for msg in result["messages"]:
            # `AnyMessage` comes from langgraph.graph.message but wraps LangChain messages.
            # They implement `.pretty_print()` in LangChain core.
            if hasattr(msg, "pretty_print") and False:
                msg.pretty_print()
            else:
                print(f"\n{repr(msg)}\n")

        print(
            f"\nRun model call count: {result.get('run_model_call_count', 0)} "
            f"(limit={call_limiter.run_limit})"
        )
        print(
            f"Thread model call count: {result.get('thread_model_call_count', 0)} "
            f"(limit={call_limiter.thread_limit})"
        )

        print(
            f"Run tool call count: {result.get('run_tool_call_count', 0)} "
            f"(limit={tool_call_limiter.run_limit})"
        )

        print(
            f"Thread tool call count: {result.get('thread_tool_call_count', 0)} "
            f"(limit={tool_call_limiter.thread_limit})"
        )


if __name__ == "__main__":
    main()

1 Like

Thanks, I will try to test several cases and come back to you.

Enjoy :slight_smile:

My solution replicates what is happening under the hood inside the react_agent. It might need some minor fixes, but in general the key things are there.

1 Like

small change to the script above: istead of my custom _fetch_last_ai_and_tool_messages, it’s now imported from langchain.agents.factory

1 Like

Hello @pawel-twardziak
i embedded your code into my codebase, and it perfectly resolved what I have in mind.

Thank you very much

1 Like

My pleasure :slight_smile: