Implement ToolLimitationNode in stateGraph

I have one significant challenge to implement the ToolCallLimitMiddleware
I set the tool call limit to 5. When I ask the LLM to create 6 images, the assistant returns:

tool_calls = [call1, call2, call3, call4, call5, call6]


What I want is to execute only the first 5 tool calls, ignore the 6th, and then return control to the LLM.

However, with my current implementation using ToolCallLimitMiddleware, all 6 tool calls are executed sequentially, and the middleware only triggers after the 6th call, which defeats the purpose of the limit.

How can I modify the implementation so that only the first 5 tool calls are executed and the 6th is skipped?

cc: @pawel-twardziak

If an assistant message contains N tool_calls, the next messages must contain exactly N ToolMessages — one per tool_call_id.

This is causing the error for all approaches when I stop the last tool call.

1 Like

Thanks @Dev I’m on it (it’s a continuation of that post I am gonna use langchain builtin middleware in custom state graph - #7 by pawel-twardziak

def create_handle_max_tool_calls(n: int = 5):
    @after_model
    def handle_max_tool_calls(state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
        ai_message = state["messages"][-1]
        if not ai_message.type == "ai":
            return None
        if not ai_message.tool_calls:
            return None
        if len(ai_message.tool_calls) <= n:
            return None

        allowed_tool_calls = ai_message.tool_calls[:n]
        
        new_ai_message = ai_message.model_copy()
        new_ai_message.tool_calls = allowed_tool_calls
        new_messages = state["messages"][:-1] + [new_ai_message]

        return {"messages": [RemoveMessage(id=REMOVE_ALL_MESSAGES), *new_messages]}

    return handle_max_tool_calls

You can use the factory to create a middleware that limits the tool calls by editing it directly in the LLM’s response.

hi @Dev

if this is your full implementation (source):

limit = ToolCallLimitMiddleware(
    thread_limit=None,
    run_limit=DEFAULT_TOOL_LIMIT,
    exit_behavior=DEFAULT_EXIT_BEHAVIOR,
)
def limit_check_node(state: GraphState, runtime):
        """
        MUST be placed after tool execution.
        This is the only correct place where tool call count is updated.
        """
        updates = limit.after_model(state, runtime) or {}

        if updates.get("jump_to") == "end":
            return {"__end__": True}

        return updates

?

after_model doesn’t return jump_to. It only retuns this:

    def after_model(self, state: ModelCallLimitState, runtime: Runtime) -> dict[str, Any] | None:  # noqa: ARG002
        """Increment model call counts after a model call.

        Args:
            state: The current agent state.
            runtime: The langgraph runtime.

        Returns:
            State updates with incremented call counts.
        """
        return {
            "thread_model_call_count": state.get("thread_model_call_count", 0) + 1,
            "run_model_call_count": state.get("run_model_call_count", 0) + 1,
        }

you have to create another node that is before your llm call and call limit.before_model there/

See this (source):

from langchain.agents.middleware.model_call_limit import ModelCallLimitMiddleware, ModelCallLimitExceededError

limit = ModelCallLimitMiddleware(thread_limit=10, run_limit=5, exit_behavior="end")

def before_model_node(state, runtime):
    updates = limit.before_model(state, runtime) or {}
    # If updates includes {"jump_to": "end"}, add your conditional edge to END.
    return updates

def model_node(state):
    # ... call your model ...
    return {"messages": [...]}  # your AIMessage(s)

def after_model_node(state, runtime):
    return limit.after_model(state, runtime) or {}

Examine this script:

import argparse
import os
from typing import Annotated, Any, Literal, List

from dotenv import load_dotenv
from typing_extensions import NotRequired, TypedDict

from langchain_openai import ChatOpenAI
from langchain.agents.middleware import ModelCallLimitMiddleware, ToolCallLimitMiddleware
from langchain_core.messages import AnyMessage, HumanMessage, SystemMessage
from langchain_core.tools import tool

from langgraph.graph import END, START, StateGraph
from langgraph.graph.message import add_messages
from langgraph.prebuilt import ToolNode, tools_condition

from langgraph.types import Command

#
# 1. Load environment and initialize OpenAI model
#

load_dotenv()

if not os.getenv("OPENAI_API_KEY"):
    raise RuntimeError(
        "OPENAI_API_KEY is not set. Please add it to your environment or a .env file."
    )

# ChatOpenAI reads OPENAI_API_KEY from the environment by default.
llm = ChatOpenAI(model="gpt-4o-mini", temperature=0)

#
# 2. Define custom graph state compatible with the middleware
#


class GraphState(TypedDict):
    """State for our custom graph.

    This mirrors the shape expected by `ModelCallLimitMiddleware`:
    - `messages`: chat history (append-only via `add_messages` reducer)
    - `thread_model_call_count`: total calls across runs / thread
    - `run_model_call_count`: calls in the current run (normally untracked)
    - `jump_to`: optional flow-control hint (e.g. "end") set by middleware
    """

    messages: Annotated[list[AnyMessage], add_messages]
    thread_model_call_count: NotRequired[int]
    run_model_call_count: NotRequired[int]
    # Optional tool call limit middleware counters (dicts keyed by tool name or "__all__")
    thread_tool_call_count: NotRequired[dict[str, int]]
    run_tool_call_count: NotRequired[dict[str, int]]
    next_model_node: NotRequired[str]
    jump_to: NotRequired[str]


#
# 3. Tools for the four-tools demo branch
#


@tool
def tool_a() -> str:
    """Return a short string for tool A."""

    return "Result from tool A"


@tool
def tool_b() -> str:
    """Return a short string for tool B."""

    return "Result from tool B"


@tool
def tool_c() -> str:
    """Return a short string for tool C."""

    return "Result from tool C"


@tool
def tool_d() -> str:
    """Return a short string for tool D."""

    return "Result from tool D"


TOOLS_FOR_DEMO = [tool_a, tool_b, tool_c, tool_d]
llm_with_tools = llm.bind_tools(TOOLS_FOR_DEMO)


#
# 4. Instantiate the built-in ModelCallLimitMiddleware
#


call_limiter: ModelCallLimitMiddleware
def set_call_limiter(run_limit: int | None) -> None:
    global call_limiter
    call_limiter = ModelCallLimitMiddleware(
        thread_limit=10,  # Max calls per long-lived "thread"
        run_limit=run_limit,  # Max calls per single graph run
        exit_behavior="end",  # Inject message + request jump to end when exceeded
    )

# 4b. ToolCallLimitMiddleware (used in four_tools_demo branch)
tool_call_limiter = ToolCallLimitMiddleware[Any, None](run_limit=3, exit_behavior="end")

#
# 4. Wrap middleware hooks in graph nodes
#    Each node takes `state` and returns a partial state update dict.
#


def enforce_call_limit(state: GraphState) -> dict[str, Any]:
    """Run ModelCallLimitMiddleware.before_model on the current state.

    The middleware signature is `before_model(self, state, runtime)`. The current
    implementation does not use `runtime`, so we pass `None` here.

    When limits are exceeded and `exit_behavior="end"`, this returns:

        {"jump_to": "end", "messages": [AIMessage(...)]}

    Otherwise it returns either {} or None.
    """
    updates = call_limiter.before_model(state, runtime=None)  # type: ignore[arg-type]
    return updates or {}

def set_summarization_model(state: GraphState) -> Command[Literal["update_call_counts"]]:
    return Command[Literal['update_call_counts']](goto="update_call_counts", update={"next_model_node": "summarize"})

def call_model(state: GraphState) -> dict[str, Any]:
    """Call the OpenAI chat model with the accumulated messages."""
    response = llm.invoke(state["messages"])
    # Because `messages` has the `add_messages` reducer, returning a list of new
    # messages appends them to the existing list.
    return {"messages": [response]}


def call_model_with_four_tools(state: GraphState) -> dict[str, Any]:
    """Call the OpenAI chat model configured with 4 tools.

    The initial messages for this branch should strongly instruct the model to
    call *all four tools in a single response*. Even if it does so, that still
    counts as **one** model call for `ModelCallLimitMiddleware` – the middleware
    does not count tool calls individually. To enforce limits on tool calls
    themselves, you would instead use `ToolCallLimitMiddleware`.
    """
    response = llm_with_tools.invoke(state["messages"])
    return {"messages": [response]}


def update_call_counts(state: GraphState) -> dict[str, Any]:
    """Increment call counters via ModelCallLimitMiddleware.after_model."""
    updates = call_limiter.after_model(state, runtime=None)  # type: ignore[arg-type]
    return updates or {}


def enforce_tool_call_limit(state: GraphState) -> dict[str, Any]:
    """Run ToolCallLimitMiddleware.after_model on the current state.

    This middleware inspects the last AIMessage's tool calls and:
    - updates per-tool/thread/run call counters, and
    - can block excess tool calls or end the run (exit_behavior="end").
    """
    updates = tool_call_limiter.after_model(state, runtime=None)  # type: ignore[arg-type]
    return updates or {}


def summarize_after_tools(state: GraphState) -> dict[str, Any]:
    """Run a summarization LLM call after tools have executed.

    This uses the base chat model (without tools) to generate a concise
    final answer that takes any tool results into account.
    """
    response = llm.invoke(state["messages"])
    return {"messages": [response]}


#
# 5. Routing logic: respect middleware's jump_to="end"
#


def route_after_enforce(state: GraphState) -> str:
    """Determine the next node after `enforce_call_limit`.

    If the middleware set `jump_to="end"`, route directly to END.
    Otherwise, proceed to the model node.

    This pattern follows the conditional edge examples in the LangGraph
    Graph API docs:
    https://docs.langchain.com/oss/python/langgraph/graph-api#conditional-branching
    """
    if state.get("jump_to") == "end":
        return END
    return state.get("next_model_node", "model")


#
# 6. Build and compile the StateGraph
#


def build_graph(*, use_four_tools_demo: bool = False):
    """Build and compile the StateGraph with the call-limit middleware wired in.

    Args:
        use_four_tools_demo: If True, the graph's model node uses the
            tool-calling LLM (`llm_with_tools`) so the first model call can
            emit four tool calls in one shot.
    """
    builder = StateGraph(GraphState)

    if use_four_tools_demo:
        # Nodes
        builder.add_node("enforce_call_limit", enforce_call_limit)
        builder.add_node("model", call_model_with_four_tools)
        # Apply tool call limits before routing to tools
        builder.add_node("tool_call_limit", enforce_tool_call_limit)
        # ToolNode executes any tool_calls emitted by the model and appends ToolMessages.
        builder.add_node("tools", ToolNode(TOOLS_FOR_DEMO))
        # After tools complete, run a summarization LLM call.
        builder.add_node("summarize", summarize_after_tools)
        builder.add_node("set_summarization_model", set_summarization_model)
        builder.add_node("update_call_counts", update_call_counts)

        # Edges
        builder.add_edge(START, "enforce_call_limit")
        builder.add_conditional_edges("enforce_call_limit", route_after_enforce)

        # First, always run the tool call limit middleware after the model.
        builder.add_edge("model", "tool_call_limit")

        # Then, use tools_condition on the *post-middleware* state to decide
        # whether to execute tools or end the tool-calling phase.
        builder.add_conditional_edges(
            "tool_call_limit",
            tools_condition,
            {
                "tools": "tools",        # execute tools (four-tools case)
                "__end__": "update_call_counts",  # no (remaining) tool calls
            },
        )

        builder.add_edge("tools", "set_summarization_model")
        builder.add_edge("summarize", "update_call_counts")
        builder.add_edge("update_call_counts", "enforce_call_limit")
    else:
        # Nodes
        builder.add_node("enforce_call_limit", enforce_call_limit)
        builder.add_node("model", call_model)
        builder.add_node("update_call_counts", update_call_counts)

        # Edges
        builder.add_edge(START, "enforce_call_limit")
        builder.add_conditional_edges("enforce_call_limit", route_after_enforce)
        builder.add_edge("model",  "update_call_counts")
        builder.add_edge("update_call_counts", "enforce_call_limit")

    return builder.compile()


#
# 7. Example CLI entrypoint
#


def main(argv: list[str] | None = None) -> None:
    """Run demo(s) showcasing ModelCallLimitMiddleware in a custom graph.

    By default this runs a simple loop that will hit the `run_limit` after a
    few model calls. With `--four-tools-demo`, it instead configures the LLM
    with four tools and prompts it to emit four tool calls in a single
    response, demonstrating that tool calls do not affect the model-call limit.
    """
    parser = argparse.ArgumentParser(description="ModelCallLimitMiddleware custom graph demo")
    parser.add_argument(
        "--four-tools-demo",
        action="store_true",
        help="Use a graph where the model is configured with 4 tools and is prompted "
        "to call all of them in one shot.",
    )
    args = parser.parse_args(argv)

    graph = build_graph(use_four_tools_demo=args.four_tools_demo)

    if args.four_tools_demo:
        set_call_limiter(2)
        # Strongly instruct the model to call all four tools exactly once,
        # in a single response.
        initial_messages: list[AnyMessage] = [
            SystemMessage(
                content=(
                    "You are a tool-calling assistant. You MUST respond to the next user "
                    "message by calling ALL of the following tools exactly once, in a "
                    "single tool_calls response: tool_a, tool_b, tool_c, tool_d. "
                    "Do not answer directly; only return tool calls."
                )
            ),
            HumanMessage(content="Demonstrate all four tools now."),
        ]
    else:
        set_call_limiter(4)
        # Initial state: a single user message and counters at 0.
        initial_messages = [HumanMessage(content="Briefly explain what LangGraph is.")]

    initial_state: GraphState = {
        "messages": initial_messages,
        "thread_model_call_count": 0,
        "run_model_call_count": 0,
    }

    result = graph.invoke(initial_state)

    print("\n=== Final messages ===")
    for msg in result["messages"]:
        # `AnyMessage` comes from langgraph.graph.message but wraps LangChain messages.
        # They implement `.pretty_print()` in LangChain core.
        if hasattr(msg, "pretty_print"):
            msg.pretty_print()
        else:
            print(repr(msg))

    print(
        f"\nRun model call count: {result.get('run_model_call_count', 0)} "
        f"(limit={call_limiter.run_limit})"
    )
    print(
        f"Thread model call count: {result.get('thread_model_call_count', 0)} "
        f"(limit={call_limiter.thread_limit})"
    )


if __name__ == "__main__":
    main()

Test with this commands:

  • python main.py
  • python main.py --four-tools-demo

Hello @pawel-twardziak
Thanks for providing it, I tested this code, but it didn’t execute any toolcall.

After that, it returns this message: “Tool call limit reached: run limit exceeded (6/3 calls).”

But what I want to do is

  • tool call limit set is 3, and Ilm returns 6 tool calls
  • execute only 3 tool calls and block other 3 calls.
  • return the 3 tool call result and 3 error message to LLM
  • finally LLM returns the all results

I think this is possible to implement technically, not sure what’s wrong

And there is another problem.

After reached out to tool call limit, I can’t continue to chat on same thread. it redirects to end, so It doesn’t go to LLM call.

For both of using create_agent and using manual stateGraph.

Yes, that is feasible - want me to change the implementation?

Yes, if you still have time.

This is by design afaik. There is run_limit and thread_limit - if you reached thread_limit, you probably won’t be able to run any next tool calls, but you still should be able to run the conversation itself.
Could you elaborate a bit more on the issue?

I do, unless you mark it Solved

Yes, I will mark as Solved if that resolves the issue

I couldn’t get the response for the next question in the stateGraph based on your implementation.

stateGraph was connected to the checkpointer, I tried to ask next question with same threadID of checkpointer.

Have you set thread_limit?
What’s the message you are facing?

Ok, I added the checkpoint and see the issue. Will try to fix it tomorrow. Time to hit the hay.

Thanks for your help today, have a good rest of your day

Thanks for your response, actually it could be one approach for me, but a bit different.

This is what I want to do

  • when tool call limit set is 3, and Ilm returns 6 tool calls

  • execute only 3 tool calls and block other 3 calls.

  • return the 3 tool call result and 3 error message to LLM

  • finally LLM returns the all results

To make up for this scenario, you need a dedicated approach - built-in tool middleware doesn’t cover it.
I have a solution in my head, will drop it here soon

1 Like

Do you use ToolNode for you graph tool node or do you have your custom implementation for the node?