I was using ToolNode but I am open to use anything if that will be the solution
Hey @arthurdorigueto , I tested this code snippet, but it doesnāt work as expected.
When the LLM generates 6 tool calls, even if we remove the last 2 tool call messages, OpenAI still expects responses for all generated tool_call_ids.
`openai.OpenAIError: Error code: 400 - {āerrorā: {āmessageā: āAn assistant message with ātool_callsā must be followed by tool messages responding to each ātool_call_idā. The following tool_call_ids did not have response messages: call_dTukDpkNykgMR17pdBlmczHd, call_q9btjq3s6Tm7IgBcDpsSLbQ4, call_roeTD5s0HuImKIijnHP38AQz, call_E1VW3kH4cjeuz9GHzxdFZRs0, call_iuFONoEYz3ysirP0YPEizJRi, call_25oGrlVZZLrmvGmZTJar9JRwā, ātypeā: āinvalid_request_errorā, āparamā: āmessages.[4].roleā, ācodeā: None}}`
Any good idea on your side?
so for me I am going to add the skip parameter in the every tool call functions. and in the limitcheck node, I am gonna set the skip=True for the blocking tool calls.
What do you thiink?
cc: @pawel-twardziak
I am still looking at that, almost there. I have no idea where you wanna add that skip param, but be careful, try to be as idiomatic as possible and avoid dirty workarounds as they usually appear hard maintain.
@tool
def file_tool(data: str, skip: bool, timeout: int | None = None) -> dict:
if skip:
return "skip this call due to limitation"
print('tool calling start...')
this wonāt be handled by the graph in an idiomatic way. Hang on, Iām finalizing - the key part is to reconstruct how the middleware is being handled within the create_react graph.
ok, I think Iām done (tool_limit_edge is the key part):
"""
Custom LangGraph StateGraph that wires in LangChain's ModelCallLimitMiddleware
around an OpenAI chat model, using python-dotenv for configuration.
This demonstrates how to *reuse* the built-in middleware in a hand-written
graph instead of only via `langchain.agents.create_agent`.
Key ideas:
- We define a TypedDict state compatible with what the middleware expects
(message history + call counters).
- We call the middleware's `before_model` and `after_model` hooks from
dedicated graph nodes.
- We use a conditional edge to respect `jump_to="end"` returned by the
middleware when limits are exceeded.
Relevant docs:
- LangChain middleware (Python):
https://docs.langchain.com/oss/python/langchain/middleware
- Built-in middleware (Model call limit):
https://docs.langchain.com/oss/python/langchain/middleware/built-in#model-call-limit
- LangGraph Graph API:
https://docs.langchain.com/oss/python/langgraph/graph-api
"""
import argparse
import os
from typing import Annotated, Any, Literal, List, cast
from dotenv import load_dotenv
from langchain.agents.middleware.model_call_limit import ModelCallLimitState
from langchain.agents.middleware.tool_call_limit import ToolCallLimitState
from langgraph.checkpoint.base import BaseCheckpointSaver
from langgraph.checkpoint.postgres import PostgresSaver
from langgraph.prebuilt.tool_node import ToolCallWithContext
from typing_extensions import NotRequired, TypedDict
from langchain_openai import ChatOpenAI
from langchain.agents.middleware import ModelCallLimitMiddleware, ToolCallLimitMiddleware
from langchain.agents.factory import _fetch_last_ai_and_tool_messages
from langchain_core.messages import AnyMessage, HumanMessage, SystemMessage, AIMessage, ToolMessage
from langchain_core.tools import tool
from langgraph.graph import END, START, StateGraph
from langgraph.graph.message import add_messages
from langgraph.prebuilt import ToolNode
from langgraph.types import Command, Send
#
# 1. Load environment and initialize OpenAI model
#
load_dotenv()
if not os.getenv("OPENAI_API_KEY"):
raise RuntimeError(
"OPENAI_API_KEY is not set. Please add it to your environment or a .env file."
)
# ChatOpenAI reads OPENAI_API_KEY from the environment by default.
# See: https://reference.langchain.com/python/integrations/langchain_openai/ChatOpenAI
llm = ChatOpenAI(model="gpt-4o-mini", temperature=0)
#
# 2. Define custom graph state compatible with the middleware
#
class GraphState(TypedDict, ModelCallLimitState, ToolCallLimitState):
"""State for our custom graph.
This mirrors the shape expected by `ModelCallLimitMiddleware`:
- `messages`: chat history (append-only via `add_messages` reducer)
- `thread_model_call_count`: total calls across runs / thread
- `run_model_call_count`: calls in the current run (normally untracked)
- `jump_to`: optional flow-control hint (e.g. "end") set by middleware
"""
messages: Annotated[list[AnyMessage], add_messages]
# thread_model_call_count: NotRequired[int]
# run_model_call_count: NotRequired[int]
# Optional tool call limit middleware counters (dicts keyed by tool name or "__all__")
# thread_tool_call_count: NotRequired[Annotated[dict[str, int], PrivateStateAttr]]
# run_tool_call_count: NotRequired[Annotated[dict[str, int], UntrackedValue, PrivateStateAttr]]
next_model_node: NotRequired[str]
# jump_to: NotRequired[str]
#
# 3. Tools for the four-tools demo branch
#
@tool
def tool_a() -> str:
"""Return a short string for tool A."""
return "Result from tool A"
@tool
def tool_b() -> str:
"""Return a short string for tool B."""
return "Result from tool B"
@tool
def tool_c() -> str:
"""Return a short string for tool C."""
return "Result from tool C"
@tool
def tool_d() -> str:
"""Return a short string for tool D."""
return "Result from tool D"
TOOLS_FOR_DEMO = [tool_a, tool_b, tool_c, tool_d]
llm_with_tools = llm.bind_tools(TOOLS_FOR_DEMO)
# Max number of tool calls we will allow to be executed
TOOL_CALL_LIMIT_PER_RUN = 2
TOOL_CALL_LIMIT_PER_THREAD = None
# Max number of model calls we will allow to be executed
MODEL_CALL_LIMIT_PER_RUN = 2
MODEL_CALL_LIMIT_PER_THREAD = None
# Thread ID increment
THREAD_ID_INCREMENT = 18
#
# 4. Instantiate the built-in ModelCallLimitMiddleware
#
call_limiter: ModelCallLimitMiddleware
def set_call_limiter(run_limit: int | None) -> None:
global call_limiter
call_limiter = ModelCallLimitMiddleware(
thread_limit=MODEL_CALL_LIMIT_PER_THREAD, # Max calls per long-lived "thread"
run_limit=run_limit, # Max calls per single graph run
exit_behavior="end", # Inject message + request jump to end when exceeded
)
# 4b. ToolCallLimitMiddleware (used in four_tools_demo branch).
# We set `exit_behavior="continue"` so that blocked tool calls are turned into
# error ToolMessages, but we still allow the agent to continue and execute the
# remaining allowed tool calls. This mirrors how `create_agent` wires this
# middleware together with `ToolNode` and `_make_model_to_tools_edge`.
tool_call_limiter = ToolCallLimitMiddleware[Any, None](
thread_limit=TOOL_CALL_LIMIT_PER_THREAD,
run_limit=TOOL_CALL_LIMIT_PER_RUN,
exit_behavior="continue",
)
#
# 4. Wrap middleware hooks in graph nodes
# Each node takes `state` and returns a partial state update dict.
#
def enforce_call_limit(state: GraphState) -> dict[str, Any]:
"""Run ModelCallLimitMiddleware.before_model on the current state.
The middleware signature is `before_model(self, state, runtime)`. The current
implementation does not use `runtime`, so we pass `None` here.
When limits are exceeded and `exit_behavior="end"`, this returns:
{"jump_to": "end", "messages": [AIMessage(...)]}
Otherwise it returns either {} or None.
"""
print("------ Enforcing call limit...")
updates = call_limiter.before_model(state, runtime=None) # type: ignore[arg-type]
return updates or {}
def set_summarization_model(state: GraphState) -> Command[Literal["update_call_counts"]]:
print("------ Setting summarization model and going to `update_call_counts`...")
return Command[Literal['update_call_counts']](goto="update_call_counts", update={"next_model_node": "summarize"})
def call_model(state: GraphState) -> dict[str, Any]:
"""Call the OpenAI chat model with the accumulated messages."""
print("------ Calling model...")
response = llm.invoke(state["messages"])
# Because `messages` has the `add_messages` reducer, returning a list of new
# messages appends them to the existing list.
return {"messages": [response]}
def call_model_with_four_tools(state: GraphState) -> dict[str, Any]:
"""Call the OpenAI chat model configured with 4 tools.
The initial messages for this branch should strongly instruct the model to
call *all four tools in a single response*. Even if it does so, that still
counts as **one** model call for `ModelCallLimitMiddleware` ā the middleware
does not count tool calls individually. To enforce limits on tool calls
themselves, you would instead use `ToolCallLimitMiddleware`.
"""
print("------ Calling model with four tools...")
response = llm_with_tools.invoke(state["messages"])
return {"messages": [response]}
def update_call_counts(state: GraphState) -> dict[str, Any]:
"""Increment call counters via ModelCallLimitMiddleware.after_model."""
print("------ Updating call counts...")
updates = call_limiter.after_model(state, runtime=None) # type: ignore[arg-type]
return updates or {}
def enforce_tool_call_limit(state: GraphState) -> dict[str, Any]:
"""Run ToolCallLimitMiddleware.after_model on the current state.
This middleware inspects the last AIMessage's tool calls and:
- updates per-tool/thread/run call counters, and
- can block excess tool calls or end the run (exit_behavior="end").
"""
print("------ Enforcing tool call limit...")
updates = tool_call_limiter.after_model(state, runtime=None) # type: ignore[arg-type]
return updates or {}
""" this is imported from langchain.agents.factory """
# def _fetch_last_ai_and_tool_messages(
# messages: list[AnyMessage],
# ) -> tuple[AIMessage | None, list[ToolMessage]]:
# """Return the last AIMessage and all following ToolMessages.
#
# This mirrors the helper used in `langchain.agents.factory._make_model_to_tools_edge`.
# """
# print("------ Fetching last ai and tool messages...")
# last_ai_index: int | None = None
# last_ai_message: AIMessage | None = None
#
# for i in range(len(messages) - 1, -1, -1):
# msg = messages[i]
# if isinstance(msg, AIMessage):
# last_ai_index = i
# last_ai_message = cast(AIMessage, msg)
# break
#
# if last_ai_index is None or last_ai_message is None:
# return None, []
#
# tool_messages = [m for m in messages[last_ai_index + 1 :] if isinstance(m, ToolMessage)]
# return last_ai_message, tool_messages
def tool_limit_edge(state: GraphState) -> list[Send] | str:
"""Decide whether to route from `tool_call_limit` to `tools` or `update_call_counts`.
Patterned after `_make_model_to_tools_edge`:
1. If `jump_to` is set (e.g. by another middleware) ā honor it.
2. Take the last `AIMessage` and all subsequent `ToolMessage`s.
3. If the model did not emit any `tool_calls` ā finish (nothing to send to `ToolNode`).
4. Compute *pending* tool calls as those for which there is **no** corresponding
`ToolMessage`. In our setup `ToolCallLimitMiddleware` generates `ToolMessage`s
with status `"error"` for blocked calls, so those should not go to `ToolNode`.
5. If there are pending calls ā route to the `tools` node, otherwise finish.
"""
print("------ Tool limit edge...")
# 1. Honor explicitly set `jump_to` (e.g. by other middleware).
jump_to = state.get("jump_to")
if jump_to == "end":
print("------ Ending tool limit edge - update_call_counts #1 ...")
return "update_call_counts"
if jump_to == "tools":
print("------ Ending tool limit edge - tools ...")
return "tools"
messages = state.get("messages", [])
last_ai, tool_messages = _fetch_last_ai_and_tool_messages(messages)
if last_ai is None:
print("------ Ending tool limit edge - update_call_counts #2 ...")
return "update_call_counts"
tool_message_ids = [m.tool_call_id for m in tool_messages if m.tool_call_id is not None]
# 2. If the model did not call any tools ā finish.
if not last_ai.tool_calls:
print("------ Ending tool limit edge - update_call_counts #3 ...")
return "update_call_counts"
# 3. Identify calls that do not yet have a corresponding ToolMessage.
pending_tool_calls = [
tc for tc in last_ai.tool_calls if tc.get("id") not in tool_message_ids
]
# 4. If there is anything left to execute ā go to `tools`.
if pending_tool_calls:
print("------ Ending tool limit edge - tool_call_with_context ...")
return [
Send(
"tools",
ToolCallWithContext(
__type="tool_call_with_context",
tool_call=tool_call,
state=state,
),
)
for tool_call in pending_tool_calls
]
# 5. All calls already have a ToolMessage (e.g. they were blocked by middleware),
# so there is nothing more to execute ā go to the end node for this iteration.
print("------ Ending tool limit edge - update_call_counts #4 ...")
return "update_call_counts"
def summarize_after_tools(state: GraphState) -> dict[str, Any]:
"""Run a summarization LLM call after tools have executed.
This uses the base chat model (without tools) to generate a concise
final answer that takes any tool results into account.
"""
print("------ Summarizing after tools ...")
response = llm.invoke(state["messages"])
return {"messages": [response]}
#
# 5. Routing logic: respect middleware's jump_to="end"
#
def route_after_enforce(state: GraphState) -> str:
"""Determine the next node after `enforce_call_limit`.
If the middleware set `jump_to="end"`, route directly to END.
Otherwise, proceed to the model node.
This pattern follows the conditional edge examples in the LangGraph
Graph API docs:
https://docs.langchain.com/oss/python/langgraph/graph-api#conditional-branching
"""
print(f"------ Routing after enforce ({state.get("jump_to")}, {state.get("next_model_node", "model")}) ...")
if state.get("jump_to") == "end":
return END
return state.get("next_model_node", "model")
#
# 6. Build and compile the StateGraph
#
def build_graph(*, use_four_tools_demo: bool = False, checkpointer: BaseCheckpointSaver):
"""Build and compile the StateGraph with the call-limit middleware wired in.
Args:
use_four_tools_demo: If True, the graph's model node uses the
tool-calling LLM (`llm_with_tools`) so the first model call can
emit four tool calls in one shot.
"""
builder = StateGraph(GraphState)
if use_four_tools_demo:
# Nodes
builder.add_node("enforce_call_limit", enforce_call_limit)
builder.add_node("model", call_model_with_four_tools)
# Apply tool call limits before routing to tools
builder.add_node("tool_call_limit", enforce_tool_call_limit)
# ToolNode executes any tool_calls emitted by the model and appends ToolMessages.
builder.add_node("tools", ToolNode(TOOLS_FOR_DEMO))
# After tools complete, run a summarization LLM call.
builder.add_node("summarize", summarize_after_tools)
builder.add_node("set_summarization_model", set_summarization_model)
builder.add_node("update_call_counts", update_call_counts)
# Edges
builder.add_edge(START, "enforce_call_limit")
builder.add_conditional_edges("enforce_call_limit", route_after_enforce)
# First, always run the ToolCallLimitMiddleware after the model.
builder.add_edge("model", "tool_call_limit")
# Then, based on the state after the middleware, decide whether to go to `tools`
# or to finish the current iteration directly.
builder.add_conditional_edges(
"tool_call_limit",
tool_limit_edge,
["tools", "update_call_counts"],
)
builder.add_edge("tools", "set_summarization_model")
builder.add_edge("summarize", "update_call_counts")
builder.add_edge("update_call_counts", "enforce_call_limit")
else:
# Nodes
builder.add_node("enforce_call_limit", enforce_call_limit)
builder.add_node("model", call_model)
builder.add_node("update_call_counts", update_call_counts)
# Edges
builder.add_edge(START, "enforce_call_limit")
builder.add_conditional_edges("enforce_call_limit", route_after_enforce)
builder.add_edge("model", "update_call_counts")
builder.add_edge("update_call_counts", "enforce_call_limit")
# Attach the in-memory checkpointer so state can be persisted
return builder.compile(checkpointer=checkpointer)
#
# 7. Example CLI entrypoint
#
def main(argv: list[str] | None = None) -> None:
with PostgresSaver.from_conn_string(os.getenv("POSTGRES_URI")) as memory:
memory.setup()
"""Run demo(s) showcasing ModelCallLimitMiddleware in a custom graph.
By default this runs a simple loop that will hit the `run_limit` after a
few model calls. With `--four-tools-demo`, it instead configures the LLM
with four tools and prompts it to emit four tool calls in a single
response, demonstrating that tool calls do not affect the model-call limit.
"""
parser = argparse.ArgumentParser(description="ModelCallLimitMiddleware custom graph demo")
parser.add_argument(
"--four-tools-demo",
action="store_true",
help="Use a graph where the model is configured with 4 tools and is prompted "
"to call all of them in one shot.",
)
args = parser.parse_args(argv)
config = {"configurable": {"thread_id": f"thread_id_{THREAD_ID_INCREMENT}_{"four_tools_demo" if args.four_tools_demo else "regular_demo"}", "user_id": "twapaw"}}
graph = build_graph(use_four_tools_demo=args.four_tools_demo, checkpointer=memory)
state = graph.get_state(config=config)
print(f"------ Initial state ------ {state}")
if args.four_tools_demo:
set_call_limiter(MODEL_CALL_LIMIT_PER_RUN)
# Strongly instruct the model to call all four tools exactly once,
# in a single response.
initial_messages: list[AnyMessage] = [
SystemMessage(
content=(
"You are a tool-calling assistant. You MUST respond to the next user "
"message by calling ALL of the following tools exactly once, in a "
"single tool_calls response: tool_a, tool_b, tool_c, tool_d. "
"Do not answer directly; only return tool calls."
)
),
HumanMessage(content="Demonstrate all four tools now."),
]
else:
set_call_limiter(MODEL_CALL_LIMIT_PER_RUN * 2)
# Initial state: a single user message and counters at 0.
initial_messages = [HumanMessage(content="Briefly explain what LangGraph is.")]
initial_state: GraphState = {
"messages": initial_messages,
"thread_model_call_count": 0,
"run_model_call_count": 0,
}
result = graph.invoke(initial_state, config=config)
print("\n=== Final messages ===")
for msg in result["messages"]:
# `AnyMessage` comes from langgraph.graph.message but wraps LangChain messages.
# They implement `.pretty_print()` in LangChain core.
if hasattr(msg, "pretty_print") and False:
msg.pretty_print()
else:
print(f"\n{repr(msg)}\n")
print(
f"\nRun model call count: {result.get('run_model_call_count', 0)} "
f"(limit={call_limiter.run_limit})"
)
print(
f"Thread model call count: {result.get('thread_model_call_count', 0)} "
f"(limit={call_limiter.thread_limit})"
)
print(
f"Run tool call count: {result.get('run_tool_call_count', 0)} "
f"(limit={tool_call_limiter.run_limit})"
)
print(
f"Thread tool call count: {result.get('thread_tool_call_count', 0)} "
f"(limit={tool_call_limiter.thread_limit})"
)
if __name__ == "__main__":
main()
Thanks, I will try to test several cases and come back to you.
Enjoy ![]()
My solution replicates what is happening under the hood inside the react_agent. It might need some minor fixes, but in general the key things are there.
small change to the script above: istead of my custom _fetch_last_ai_and_tool_messages, itās now imported from langchain.agents.factory
Hello @pawel-twardziak
i embedded your code into my codebase, and it perfectly resolved what I have in mind.
Thank you very much
My pleasure ![]()