Problems encountered when constructing a chain with LangChain

When I was building a langchain chain, I encountered a problem. First, please take a look at my code below:

async def main_chain():
    agent = await agent_get()
    chain =(RunnableLambda(retriever_get)
            | RunnableLambda(temp1)
            | prompt
            | agent
            )
    return chain

My first step is to obtain knowledge retrieved by RAG, but this causes another problem when I use a stream to call the chain.
 I cannot pass stream_mode because the first step is not an agent. Is it common to build it this way? If it is, how should I choose the streaming output mode?

async def  main():
    chain = await main_chain()
    async for chunk in chain.astream({"input":"测试"}):
        print(chunk)

Or do I have a new idea? Separate the agent part, package the remaining chain into a tool, and let the agent call it. This way, I only need to call the agent to use a stream, so can I then pass stream_mode?

Perhaps this approach can avoid wasting time querying data every time? This way, it might be possible to dynamically choose whether to query or not?

hi @Huimin-station

I think this is a common pain point when combining retrieval pre-processing with an agent imho :slight_smile:

stream_mode is a langgraph-specific parameter available on CompiledStateGraph objects (what create_agent() returns). When you pipe components together with |, you create a LangChain Core RunnableSequence - a different type of runnable that doesn’t support stream_mode.

This is a RunnableSequence, NOT a LangGraph agent

chain = RunnableLambda(retriever_get) | RunnableLambda(temp1) | prompt | agent
type(chain)  # => RunnableSequence

So chain.astream(input, stream_mode="messages") won’t work because RunnableSequence.astream() doesn’t accept stream_mode.

Try one of these ways:

  1. use astream_events on the chain
async def main():
    chain = await main_chain()
    async for event in chain.astream_events({"input": "测试"}, version="v2"):
        if event["event"] == "on_chat_model_stream":
            token = event["data"]["chunk"].content
            if token:
                print(token, end="", flush=True)

You can also filter by specific components using include_names, include_tags, or include_types:

async for event in chain.astream_events(
    {"input": "测试"},
    version="v2",
    include_types=["chat_model"], 
):
    if event["event"] == "on_chat_model_stream":
        print(event["data"]["chunk"].content, end="", flush=True)
  1. @before_model middleware in create_agent

Maybe the best approach is to move your retrieval logic into the agent itself using create_agent with a @before_model middleware. This keeps everything inside a single agent, giving you full access to stream_mode:

from langchain.agents import create_agent
from langchain.agents.middleware import before_model
from langgraph.runtime import Runtime

@before_model
def retrieval_middleware(state, runtime: Runtime):
    """Retrieve context before each LLM call."""
    messages = state["messages"]
    query = messages[-1].content

    docs = retriever_get(query)
    context = temp1(docs)  # Your processing step

    return {
        "llm_input_messages": [
            {"role": "system", "content": f"Use this context to answer:\n\n{context}"},
            *messages,
        ]
    }

agent = create_agent(
    model="your-model",
    tools=[...],
    middleware=[retrieval_middleware],
)

async for chunk in agent.astream(
    {"messages": [{"role": "user", "content": "测试"}]},
    stream_mode="messages",
):
    message_chunk, metadata = chunk
    if message_chunk.content:
        print(message_chunk.content, end="", flush=True)
  1. make retrieval a tool
from langchain.agents import create_agent
from langgraph.config import get_stream_writer

def search_knowledge_base(query: str) -> str:
    """Search the knowledge base for relevant information."""
    writer = get_stream_writer()
    writer(f"Searching for: {query}") 

    docs = retriever_get(query)
    result = temp1(docs)

    writer(f"Found {len(docs)} relevant documents")
    return result

agent = create_agent(
    model="your-model",
    tools=[search_knowledge_base],
)

for stream_mode, chunk in agent.stream(
    {"messages": [{"role": "user", "content": "测试"}]},
    stream_mode=["messages", "custom"],
):
    if stream_mode == "messages":
        msg, metadata = chunk
        if msg.content:
            print(msg.content, end="", flush=True)
    elif stream_mode == "custom":
        print(f"\n[Status] {chunk}")
  1. custom langGraph StateGraph (maximum control)
from typing import Annotated, TypedDict
from langgraph.graph import StateGraph, START, END
from langchain_core.messages import AnyMessage
from langgraph.graph.message import add_messages

class State(TypedDict):
    messages: Annotated[list[AnyMessage], add_messages]
    context: str

def retrieve_node(state: State) -> dict:
    query = state["messages"][-1].content
    docs = retriever_get(query)
    context = temp1(docs)
    return {"context": context}

def agent_node(state: State) -> dict:
    system_msg = {"role": "system", "content": f"Context:\n{state['context']}"}
    response = model.invoke([system_msg] + state["messages"])
    return {"messages": [response]}

graph = (
    StateGraph(State)
    .add_node("retrieve", retrieve_node)
    .add_node("agent", agent_node)
    .add_edge(START, "retrieve")
    .add_edge("retrieve", "agent")
    .add_edge("agent", END)
    .compile()
)

async for chunk in graph.astream(
    {"messages": [{"role": "user", "content": "测试"}]},
    stream_mode="messages",
    version="v2",
):
    if chunk["type"] == "messages":
        msg, metadata = chunk["data"]
        if msg.content:
            print(msg.content, end="", flush=True)