Examine this script:
import argparse
import os
from typing import Annotated, Any, Literal, List
from dotenv import load_dotenv
from typing_extensions import NotRequired, TypedDict
from langchain_openai import ChatOpenAI
from langchain.agents.middleware import ModelCallLimitMiddleware, ToolCallLimitMiddleware
from langchain_core.messages import AnyMessage, HumanMessage, SystemMessage
from langchain_core.tools import tool
from langgraph.graph import END, START, StateGraph
from langgraph.graph.message import add_messages
from langgraph.prebuilt import ToolNode, tools_condition
from langgraph.types import Command
#
# 1. Load environment and initialize OpenAI model
#
load_dotenv()
if not os.getenv("OPENAI_API_KEY"):
raise RuntimeError(
"OPENAI_API_KEY is not set. Please add it to your environment or a .env file."
)
# ChatOpenAI reads OPENAI_API_KEY from the environment by default.
llm = ChatOpenAI(model="gpt-4o-mini", temperature=0)
#
# 2. Define custom graph state compatible with the middleware
#
class GraphState(TypedDict):
"""State for our custom graph.
This mirrors the shape expected by `ModelCallLimitMiddleware`:
- `messages`: chat history (append-only via `add_messages` reducer)
- `thread_model_call_count`: total calls across runs / thread
- `run_model_call_count`: calls in the current run (normally untracked)
- `jump_to`: optional flow-control hint (e.g. "end") set by middleware
"""
messages: Annotated[list[AnyMessage], add_messages]
thread_model_call_count: NotRequired[int]
run_model_call_count: NotRequired[int]
# Optional tool call limit middleware counters (dicts keyed by tool name or "__all__")
thread_tool_call_count: NotRequired[dict[str, int]]
run_tool_call_count: NotRequired[dict[str, int]]
next_model_node: NotRequired[str]
jump_to: NotRequired[str]
#
# 3. Tools for the four-tools demo branch
#
@tool
def tool_a() -> str:
"""Return a short string for tool A."""
return "Result from tool A"
@tool
def tool_b() -> str:
"""Return a short string for tool B."""
return "Result from tool B"
@tool
def tool_c() -> str:
"""Return a short string for tool C."""
return "Result from tool C"
@tool
def tool_d() -> str:
"""Return a short string for tool D."""
return "Result from tool D"
TOOLS_FOR_DEMO = [tool_a, tool_b, tool_c, tool_d]
llm_with_tools = llm.bind_tools(TOOLS_FOR_DEMO)
#
# 4. Instantiate the built-in ModelCallLimitMiddleware
#
call_limiter: ModelCallLimitMiddleware
def set_call_limiter(run_limit: int | None) -> None:
global call_limiter
call_limiter = ModelCallLimitMiddleware(
thread_limit=10, # Max calls per long-lived "thread"
run_limit=run_limit, # Max calls per single graph run
exit_behavior="end", # Inject message + request jump to end when exceeded
)
# 4b. ToolCallLimitMiddleware (used in four_tools_demo branch)
tool_call_limiter = ToolCallLimitMiddleware[Any, None](run_limit=3, exit_behavior="end")
#
# 4. Wrap middleware hooks in graph nodes
# Each node takes `state` and returns a partial state update dict.
#
def enforce_call_limit(state: GraphState) -> dict[str, Any]:
"""Run ModelCallLimitMiddleware.before_model on the current state.
The middleware signature is `before_model(self, state, runtime)`. The current
implementation does not use `runtime`, so we pass `None` here.
When limits are exceeded and `exit_behavior="end"`, this returns:
{"jump_to": "end", "messages": [AIMessage(...)]}
Otherwise it returns either {} or None.
"""
updates = call_limiter.before_model(state, runtime=None) # type: ignore[arg-type]
return updates or {}
def set_summarization_model(state: GraphState) -> Command[Literal["update_call_counts"]]:
return Command[Literal['update_call_counts']](goto="update_call_counts", update={"next_model_node": "summarize"})
def call_model(state: GraphState) -> dict[str, Any]:
"""Call the OpenAI chat model with the accumulated messages."""
response = llm.invoke(state["messages"])
# Because `messages` has the `add_messages` reducer, returning a list of new
# messages appends them to the existing list.
return {"messages": [response]}
def call_model_with_four_tools(state: GraphState) -> dict[str, Any]:
"""Call the OpenAI chat model configured with 4 tools.
The initial messages for this branch should strongly instruct the model to
call *all four tools in a single response*. Even if it does so, that still
counts as **one** model call for `ModelCallLimitMiddleware` – the middleware
does not count tool calls individually. To enforce limits on tool calls
themselves, you would instead use `ToolCallLimitMiddleware`.
"""
response = llm_with_tools.invoke(state["messages"])
return {"messages": [response]}
def update_call_counts(state: GraphState) -> dict[str, Any]:
"""Increment call counters via ModelCallLimitMiddleware.after_model."""
updates = call_limiter.after_model(state, runtime=None) # type: ignore[arg-type]
return updates or {}
def enforce_tool_call_limit(state: GraphState) -> dict[str, Any]:
"""Run ToolCallLimitMiddleware.after_model on the current state.
This middleware inspects the last AIMessage's tool calls and:
- updates per-tool/thread/run call counters, and
- can block excess tool calls or end the run (exit_behavior="end").
"""
updates = tool_call_limiter.after_model(state, runtime=None) # type: ignore[arg-type]
return updates or {}
def summarize_after_tools(state: GraphState) -> dict[str, Any]:
"""Run a summarization LLM call after tools have executed.
This uses the base chat model (without tools) to generate a concise
final answer that takes any tool results into account.
"""
response = llm.invoke(state["messages"])
return {"messages": [response]}
#
# 5. Routing logic: respect middleware's jump_to="end"
#
def route_after_enforce(state: GraphState) -> str:
"""Determine the next node after `enforce_call_limit`.
If the middleware set `jump_to="end"`, route directly to END.
Otherwise, proceed to the model node.
This pattern follows the conditional edge examples in the LangGraph
Graph API docs:
https://docs.langchain.com/oss/python/langgraph/graph-api#conditional-branching
"""
if state.get("jump_to") == "end":
return END
return state.get("next_model_node", "model")
#
# 6. Build and compile the StateGraph
#
def build_graph(*, use_four_tools_demo: bool = False):
"""Build and compile the StateGraph with the call-limit middleware wired in.
Args:
use_four_tools_demo: If True, the graph's model node uses the
tool-calling LLM (`llm_with_tools`) so the first model call can
emit four tool calls in one shot.
"""
builder = StateGraph(GraphState)
if use_four_tools_demo:
# Nodes
builder.add_node("enforce_call_limit", enforce_call_limit)
builder.add_node("model", call_model_with_four_tools)
# Apply tool call limits before routing to tools
builder.add_node("tool_call_limit", enforce_tool_call_limit)
# ToolNode executes any tool_calls emitted by the model and appends ToolMessages.
builder.add_node("tools", ToolNode(TOOLS_FOR_DEMO))
# After tools complete, run a summarization LLM call.
builder.add_node("summarize", summarize_after_tools)
builder.add_node("set_summarization_model", set_summarization_model)
builder.add_node("update_call_counts", update_call_counts)
# Edges
builder.add_edge(START, "enforce_call_limit")
builder.add_conditional_edges("enforce_call_limit", route_after_enforce)
# First, always run the tool call limit middleware after the model.
builder.add_edge("model", "tool_call_limit")
# Then, use tools_condition on the *post-middleware* state to decide
# whether to execute tools or end the tool-calling phase.
builder.add_conditional_edges(
"tool_call_limit",
tools_condition,
{
"tools": "tools", # execute tools (four-tools case)
"__end__": "update_call_counts", # no (remaining) tool calls
},
)
builder.add_edge("tools", "set_summarization_model")
builder.add_edge("summarize", "update_call_counts")
builder.add_edge("update_call_counts", "enforce_call_limit")
else:
# Nodes
builder.add_node("enforce_call_limit", enforce_call_limit)
builder.add_node("model", call_model)
builder.add_node("update_call_counts", update_call_counts)
# Edges
builder.add_edge(START, "enforce_call_limit")
builder.add_conditional_edges("enforce_call_limit", route_after_enforce)
builder.add_edge("model", "update_call_counts")
builder.add_edge("update_call_counts", "enforce_call_limit")
return builder.compile()
#
# 7. Example CLI entrypoint
#
def main(argv: list[str] | None = None) -> None:
"""Run demo(s) showcasing ModelCallLimitMiddleware in a custom graph.
By default this runs a simple loop that will hit the `run_limit` after a
few model calls. With `--four-tools-demo`, it instead configures the LLM
with four tools and prompts it to emit four tool calls in a single
response, demonstrating that tool calls do not affect the model-call limit.
"""
parser = argparse.ArgumentParser(description="ModelCallLimitMiddleware custom graph demo")
parser.add_argument(
"--four-tools-demo",
action="store_true",
help="Use a graph where the model is configured with 4 tools and is prompted "
"to call all of them in one shot.",
)
args = parser.parse_args(argv)
graph = build_graph(use_four_tools_demo=args.four_tools_demo)
if args.four_tools_demo:
set_call_limiter(2)
# Strongly instruct the model to call all four tools exactly once,
# in a single response.
initial_messages: list[AnyMessage] = [
SystemMessage(
content=(
"You are a tool-calling assistant. You MUST respond to the next user "
"message by calling ALL of the following tools exactly once, in a "
"single tool_calls response: tool_a, tool_b, tool_c, tool_d. "
"Do not answer directly; only return tool calls."
)
),
HumanMessage(content="Demonstrate all four tools now."),
]
else:
set_call_limiter(4)
# Initial state: a single user message and counters at 0.
initial_messages = [HumanMessage(content="Briefly explain what LangGraph is.")]
initial_state: GraphState = {
"messages": initial_messages,
"thread_model_call_count": 0,
"run_model_call_count": 0,
}
result = graph.invoke(initial_state)
print("\n=== Final messages ===")
for msg in result["messages"]:
# `AnyMessage` comes from langgraph.graph.message but wraps LangChain messages.
# They implement `.pretty_print()` in LangChain core.
if hasattr(msg, "pretty_print"):
msg.pretty_print()
else:
print(repr(msg))
print(
f"\nRun model call count: {result.get('run_model_call_count', 0)} "
f"(limit={call_limiter.run_limit})"
)
print(
f"Thread model call count: {result.get('thread_model_call_count', 0)} "
f"(limit={call_limiter.thread_limit})"
)
if __name__ == "__main__":
main()
Test with this commands:
python main.py
python main.py --four-tools-demo