I am updating some code to Langchain 1.0 and trying out the new agents and middleware and I noticed that the model is selecting multiple tools in its response and they are then getting executed concurrently.
I would like to execute them sequentially since the output of the first call is the input to the second.
Is that possible?
I’m working in Python, using create_agent with a couple of @tool annotated functions.
This tends to display the behavior, but you might have to run it a few times to get it to elect to make >1 tool call at once.
@tool
def compute_answer(request:str, runtime:ToolRuntime) -> str:
"""Given a data in a data frame, use natural language to compute some value or value(s)."""
print("CALLING COMPUTE ANSWER")
return "42"
@tool
def fetch_data(request:str, runtime:ToolRuntime) -> str:
"""Fetch data from the database using a structured natural language query."""
print("CALLING FETCH DATA")
return "42"
@wrap_model_call
def sequential_tool_call_model(
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> ModelResponse:
request.model = request.model.bind_tools(
request.tools,
parallel_tool_calls=False
)
return handler(request)
if __name__ == "__main__":
question = "predict the total sales for products in the Sports category in 2021 using 2015-2020 as a baseline."
llm = ChatFireworks(model="accounts/fireworks/models/llama4-maverick-instruct-basic")
system_prompt = """Respond to the human's question as helpfully and accurately as possible.
Make as many tool calls at the same time as you can. Speed is key."""
agent = create_agent(model=llm, tools=[fetch_data, compute_answer],
middleware=[sequential_tool_call_model],
system_prompt=system_prompt)
config = ensure_config()
config["configurable"]["thread_id"] = f"{threading.current_thread().ident}"
input_state = {"messages": [HumanMessage(content=question, id=str(uuid.uuid4()))]}
idx = 0
for st in agent.stream(input=input_state,
config=config,
stream_mode="values"):
for msg in st["messages"][idx:]:
if not msg.content:
print(f"{msg.type}: {msg.tool_calls}")
else:
print(f"{msg.type}: {msg.content}")
idx = len(st["messages"])
print("--------------------------------")
I have figured out this approach (not tested yet), try it please:
Provider-agnostic sequentialization (should work even when the model ignores the flag)
Queue extra tool calls after the model step, and replay them one by one without re-calling the model.
from typing import Any, Callable
from typing_extensions import TypedDict, NotRequired
from langchain_core.messages import AIMessage
from langchain.agents.middleware.types import after_model, wrap_model_call, AgentState, ModelRequest, ModelResponse
class SeqState(AgentState):
queued_tool_calls: NotRequired[list[dict[str, Any]]]
@after_model(state_schema=SeqState, name="QueueAndTruncate")
def queue_and_truncate_tool_calls(state: SeqState, runtime) -> dict[str, Any] | None:
if not state["messages"]:
return None
last = state["messages"][-1]
if isinstance(last, AIMessage) and getattr(last, "tool_calls", None) and len(last.tool_calls) > 1:
rest = last.tool_calls[1:]
last.tool_calls = [last.tool_calls[0]]
queued = list(state.get("queued_tool_calls") or [])
queued.extend(rest)
return {"queued_tool_calls": queued}
return None
@wrap_model_call(state_schema=SeqState, name="ReplayQueuedIfAny")
def replay_queued_tool_call(request: ModelRequest, handler: Callable[[ModelRequest], ModelResponse]) -> ModelResponse:
queued = list(request.state.get("queued_tool_calls") or [])
if queued:
next_call = queued.pop(0)
request.state["queued_tool_calls"] = queued # persist updated queue
synthetic = AIMessage(content="", tool_calls=[next_call])
return ModelResponse(result=[synthetic]) # append synthetic AIMessage; tools run next
return handler(request)
Add both middlewares when creating the agent, and adjust the system prompt to “Call at most one tool per turn; if multiple steps are needed, call one, wait for results, then continue.”
This should guarantee that if the model emits N tool calls in a single turn, they’ll be executed one-by-one across subsequent iterations, without reliance on provider support for parallel_tool_calls.
Your code as is incurred GraphRecursionLimiterrors. What I believe is happening is the state cannot be modified from the wrap style middleware and the queue was never decremented correctly. What I observed as the second (and only cached) tool call was executed continuously until the recursion limit was reached.
I was able to combine your middlewares into a single working version:
@after_model(state_schema=SeqState, name="QueueAndTruncate")
def queue_and_truncate_tool_calls(state: SeqState, runtime) -> dict[str, Any] | None:
if not state["messages"]:
return None
last = state["messages"][-1]
queued = list(state.get("queued_tool_calls") or [])
if isinstance(last, AIMessage) and getattr(last, "tool_calls", None) and len(last.tool_calls) > 1:
rest = last.tool_calls[1:]
last.tool_calls = [last.tool_calls[0]]
queued.extend(rest)
return {"queued_tool_calls": queued}
elif queued:
next_call = queued.pop(0)
synthetic = AIMessage(content="", tool_calls=[next_call])
return {"messages": [synthetic], "queued_tool_calls": queued}
return None
While this seems to work, it does not seem like what was intended by the authors. Is there a more canonical way to do this?