hi @Dev
what do you mean by my custom stateGraph?
Afaik LangChain’s built‑in agent middleware (e.g., ModelCallLimitMiddleware) is wired into the agent graph that create_agent builds. In a hand-built StateGraph, there is no generic graph.add_middleware(...) and the built-ins assume the standard agent loop nodes (model, tools, end) and jump semantics.
I reckon you can instantiate ModelCallLimitMiddleware and call its hooks around your model node:
- Call
before_model(state, runtime)(orabefore_model) before the model; if it returns{"jump_to": "end", "messages": [...]}, route to your END node; ifexit_behavior="error", handleModelCallLimitExceededError - Call
after_model(state, runtime)(oraafter_model) after the model to increment counters - Ensure your state carries thread_model_call_count and run_model_call_count; if you use a checkpointer, treat run count as non-persistent (untracked) yourself
Maybe sth like this:
from langchain.agents.middleware.model_call_limit import ModelCallLimitMiddleware, ModelCallLimitExceededError
limit = ModelCallLimitMiddleware(thread_limit=10, run_limit=5, exit_behavior="end")
def before_model_node(state, runtime):
updates = limit.before_model(state, runtime) or {}
# If updates includes {"jump_to": "end"}, add your conditional edge to END.
return updates
def model_node(state):
# ... call your model ...
return {"messages": [...]} # your AIMessage(s)
def after_model_node(state, runtime):
return limit.after_model(state, runtime) or {}
Key caveats:
-
The middleware’s
hook_config(can_jump_to=["end"])metadata isn’t wired automatically in a plainStateGraph; you must add conditional edges based on the returnedjump_to -
The built-in uses
UntrackedValuefor the run counter; replicate that by not checkpointingrun_model_call_count(or resetting it per run)