"""
Custom graph implementation with tool calling functionality + TOOL LIMITING
"""
import operator
import os
from typing import Annotated, Any, TypedDict
from langchain.agents.factory import _fetch_last_ai_and_tool_messages
from langchain.agents.middleware import ModelCallLimitMiddleware, ToolCallLimitMiddleware
from langchain.agents.middleware.types import PrivateStateAttr
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import AnyMessage, SystemMessage
from langchain_core.runnables import RunnableConfig
from langchain_core.tools import BaseTool
from langgraph.channels.untracked_value import UntrackedValue
from langgraph.checkpoint.base import BaseCheckpointSaver
from langgraph.graph import END, START, StateGraph
from langgraph.prebuilt import ToolNode
from langgraph.prebuilt.tool_node import ToolCallWithContext
from langgraph.types import Send
from pydantic import BaseModel
from typing_extensions import NotRequired
class GraphState(TypedDict):
messages: Annotated[list[AnyMessage], operator.add]
thread_model_call_count: NotRequired[Annotated[int, PrivateStateAttr]]
run_model_call_count: NotRequired[Annotated[int, UntrackedValue, PrivateStateAttr]]
thread_tool_call_count: NotRequired[Annotated[dict[str, int], PrivateStateAttr]]
run_tool_call_count: NotRequired[Annotated[dict[str, int], UntrackedValue, PrivateStateAttr]]
__end__: NotRequired[bool]
metadata: dict | None
TOOL_CALL_LIMIT_PER_RUN = int(os.getenv("TOOL_CALL_LIMIT", "5"))
TOOL_CALL_LIMIT_PER_THREAD = None
TOOL_CALL_EXIT_BEHAVIOR = os.getenv("TOOL_EXIT_BEHAVIOR", "continue")
MODEL_CALL_LIMIT_PER_RUN = 5
MODEL_CALL_LIMIT_PER_THREAD = None
call_limiter: ModelCallLimitMiddleware
def set_call_limiter(run_limit: int | None) -> None:
global call_limiter
call_limiter = ModelCallLimitMiddleware(
thread_limit=MODEL_CALL_LIMIT_PER_THREAD,
run_limit=run_limit,
exit_behavior="end", # inject jump_to="end"
)
tool_call_limiter = ToolCallLimitMiddleware[Any, None](
thread_limit=TOOL_CALL_LIMIT_PER_THREAD,
run_limit=TOOL_CALL_LIMIT_PER_RUN,
exit_behavior=TOOL_CALL_EXIT_BEHAVIOR
)
def create_custom_graph(
model: BaseChatModel,
tools: list[BaseTool],
pre_model_hook=None,
system_prompt: str | None = None,
checkpointer: BaseCheckpointSaver | None = None,
response_format: type[BaseModel] | None = None,
):
builder = StateGraph(GraphState)
set_call_limiter(MODEL_CALL_LIMIT_PER_RUN)
def initialize_state(state: GraphState) -> dict[str, Any]:
"""
Reset per-run transient fields so they don't leak across turns.
"""
return {
"metadata": None,
"__end__": False,
}
def enforce_call_limit(state: GraphState) -> dict[str, Any]:
print("------ Enforcing model call limit...")
updates = call_limiter.before_model(state, runtime=None) # type: ignore[arg-type]
return updates or {}
def call_model_with_tools(state: GraphState, config: RunnableConfig) -> dict[str, Any]:
print("------ Calling model with tools...")
if state.get("__end__", False):
return {"__end__": True}
if pre_model_hook:
try:
hook_output = pre_model_hook(state)
llm_messages = hook_output.get("llm_input_messages", state["messages"])
except Exception:
llm_messages = state["messages"]
else:
llm_messages = state["messages"]
if system_prompt:
final_messages = [SystemMessage(content=system_prompt)] + llm_messages
else:
final_messages = llm_messages
llm_with_tools = (
model
.bind_tools(tools, strict=True)
.bind(response_format=response_format)
)
response = llm_with_tools.invoke(final_messages, config=config)
return {"messages": [response], "metadata": state.get("metadata")}
def enforce_tool_call_limit(state: GraphState) -> dict[str, Any]:
print("------ Enforcing tool call limit...")
updates = tool_call_limiter.after_model(state, runtime=None) # type: ignore[arg-type]
return updates or {}
def update_call_counts(state: GraphState) -> dict[str, Any]:
print("------ Updating call counts...")
updates = call_limiter.after_model(state, runtime=None) # type: ignore[arg-type]
return updates or {}
def tool_limit_edge(state: GraphState) -> list[Send] | str:
print("------ Tool limit edge...")
# 1. Middleware forced termination
if state.get("jump_to") == "end":
state["__end__"] = True
return END
messages = state.get("messages", [])
last_ai, tool_messages = _fetch_last_ai_and_tool_messages(messages)
if last_ai is None:
state["__end__"] = True
return END
tool_message_ids = {
m.tool_call_id for m in tool_messages if m.tool_call_id is not None
}
# 2. No tool calls → done
if not last_ai.tool_calls:
state["__end__"] = True
return END
# 3. Execute only pending tool calls
pending_tool_calls = [
tc for tc in last_ai.tool_calls if tc["id"] not in tool_message_ids
]
if pending_tool_calls:
return [
Send(
"tools",
ToolCallWithContext(
__type="tool_call_with_context",
tool_call=tool_call,
state=state,
),
)
for tool_call in pending_tool_calls
]
# 4. All tools handled → done
state["__end__"] = True
return END
def route_after_enforce(state: GraphState) -> str:
print(f"------ Routing after enforce (jump_to={state.get('jump_to')})")
if state.get("__end__", False):
return END
if state.get("jump_to") == "end":
return END
return "model"
builder.add_node("initialize_state", initialize_state)
builder.add_node("enforce_call_limit", enforce_call_limit)
builder.add_node("model", call_model_with_tools)
builder.add_node("tool_call_limit", enforce_tool_call_limit)
builder.add_node("tools", ToolNode(tools))
builder.add_node("update_call_counts", update_call_counts)
builder.add_edge(START, "initialize_state")
builder.add_edge("initialize_state", "enforce_call_limit")
builder.add_conditional_edges("enforce_call_limit", route_after_enforce)
builder.add_edge("model", "tool_call_limit")
builder.add_conditional_edges(
"tool_call_limit",
tool_limit_edge,
["tools", END],
)
builder.add_edge("tools", "update_call_counts")
builder.add_edge("update_call_counts", "enforce_call_limit")
if checkpointer:
return builder.compile(checkpointer=checkpointer)
return builder.compile()