AsyncPostgresSaver and JSON serializable error

I’ve implemented the AsyncPostgresSaver as my checkpointer and connected it to my database. However I keep getting the errors: "Object of type HumanMessage is not JSON serializable" and "Object of type AIMessage is not JSON serializable" when I invoke the graph with graph.astream_events. I fix the HumanMessage error with

 messages = [HumanMessage(content=content)]
messages = dumps(messages, ensure_ascii=False)

But I still get the Object of type AIMessage is not JSON serializable error.

AsyncPostgresSaver implementation

Class Checkpointer:
    """Class to handle memory checkpointing and environment variable management."""
    
    @staticmethod
    def is_langgraph_studio_session() -> bool:
        """Check if the application is running in LangGraph API environment.
        
        Returns:
            bool: True if running in LangGraph API environment
        """
        return os.getenv("IS_LANGGRAPH_STUDIO_SESSION", "false").lower() == "true"
    
    @staticmethod
    # async def get_graph_with_memory_saver(graph: StateGraph):
    def get_graph_with_memory_saver(graph: StateGraph):
        """Compile the graph with appropriate memory saver based on the environment.
        
        Args:
            graph (BaseGraph): The graph to compile
            
        Returns:
            Compiled graph with appropriate memory saver configuration
        """
        # Don't use custom checkpointer when running in LangGraph Studio
        if Checkpointer.is_langgraph_studio_session():
            return graph.compile()  
        else: 
            db = os.getenv("DATABASE_URL")
            if not db:
                raise RuntimeError("DATABASE_URL missing")  
                saver_cm = AsyncPostgresSaver.from_conn_string(
                 db, serde= JsonPlusSerializer(pickle_fallback=True)
             )
        
            # unwrap context manager to actual saver with Asynchronous Context Managers
            saver = asyncio.get_event_loop().run_until_complete(saver_cm.__aenter__())
            return graph.compile(checkpointer=saver)

I’ve tried this implementation as well:
async with AsyncPostgresSaver.from_conn_string(DB_URI) as checkpointer:

Here is the table in my database that I’ve made myself so I don’t call .setup()


CREATE TABLE checkpoint_migrations (
    v INTEGER NOT NULL PRIMARY KEY
);

CREATE TABLE checkpoints (
    thread_id             TEXT    NOT NULL,
    checkpoint_ns         TEXT    NOT NULL DEFAULT '',
    checkpoint_id         TEXT    NOT NULL,
    parent_checkpoint_id  TEXT,
    type                  TEXT,
    checkpoint            JSONB   NOT NULL,
    metadata              JSONB   NOT NULL DEFAULT '{}',
    PRIMARY KEY (thread_id, checkpoint_ns, checkpoint_id)
);

CREATE TABLE checkpoint_blobs (
    thread_id     TEXT    NOT NULL,
    checkpoint_ns TEXT    NOT NULL DEFAULT '',
    channel       TEXT    NOT NULL,
    version       TEXT    NOT NULL,
    type          TEXT    NOT NULL,
    blob          BYTEA,
    PRIMARY KEY (thread_id, checkpoint_ns, channel, version)
);

CREATE TABLE checkpoint_writes (
    thread_id     TEXT    NOT NULL,
    checkpoint_ns TEXT    NOT NULL DEFAULT '',
    checkpoint_id TEXT    NOT NULL,
    task_id       TEXT    NOT NULL,
    task_path     TEXT    NOT NULL,
    idx           INTEGER NOT NULL,
    channel       TEXT    NOT NULL,
    type          TEXT,
    blob          BYTEA   NOT NULL,
    PRIMARY KEY (thread_id, checkpoint_ns, checkpoint_id, task_id, idx)
);

I’m not sure why this error is happening, it is not happening with MemorySaver() as my checkpointer and I think AsyncPostgresSaver is having trouble saving into my database

We have the same issue, for the time being our workaround is to:

return {
   # **state,
   "messages": [message.model_dump()],
   "current_step": "final_node"
}

It also persists further down the graph in future nodes. So we need to dump all our messages written to state after the first time we do it. For Example in the graph:

 async def create_graph(checkpointer: AsyncPostgresSaver):
    # Initialize the graph
    workflow = StateGraph(State)

    # Add our nodes
    workflow.add_node("setup_node", setup_node)
    workflow.add_node("tool_node", ToolNode([greet]))
    workflow.add_node("final_node", final_node)

    # Add edges
    workflow.add_edge(START, "setup_node")
    workflow.add_edge("setup_node", "tool_node")
    # workflow.add_edge("tool_node", END)
    workflow.add_edge("tool_node", "final_node")
    workflow.add_edge("final_node", END)

    # Create necessary tables
    # await checkpointer.setup()

    # Compile the graph with the checkpointer
    return workflow.compile(checkpointer=checkpointer)


async def setup_node(state: State):
    """First node that processes the input"""
    print("--Running setup node--")
    lawyer_model = ChatGoogleGenerativeAI(
        model="gemini-1.5-pro",
        temperature=0,
        api_key=SecretStr(os.getenv("GOOGLE_API_KEY") or ""),
    )

    bounded_lawyer_model = lawyer_model.bind_tools([greet])

    response = await bounded_lawyer_model.ainvoke(
        "Use the tool to send a message to a user whose name you come up with."
    )

    return {
        "messages": [response.model_dump()],
        "current_step": "setup_node"
    }

async def final_node(state: State):
    """Node that shows the final state"""
    print("--Running final node--")
    print(state)

    message = HumanMessage(content="Completed")

    return {
        # **state,
        "messages": [message.model_dump()],
        "current_step": "final_node"
    }

@tool("greet", parse_docstring=True)
def greet(
    name: str,
    tool_call_id: Annotated[str, InjectedToolCallId],
    state: Annotated[dict, InjectedState],
):
    """
    This is a test tool that says hello world.

    Args:
        name: The name to say hello to.
    """
    greeting = f"Hello {name}!"
    print("--Running greet tool--")
    print(greeting)

    update = {
        **state,
        "messages": [
            ToolMessage(
                content=greeting,
                tool_call_id=tool_call_id,
                name="greet",
            ).model_dump(),
        ],
    }

    return Command(
        update=update
    )


The setup_node will produce an AIMessage in the messages which will result in the in the issue once the graph finishes(This code includes the work-around remove the model_dumps to get the error).

# Example usage
async def main():
    db_url = os.getenv("POSTGRES_DATABASE_URL") or ""
    async with AsyncPostgresSaver.from_conn_string(
        db_url, serde=JsonPlusSerializer(pickle_fallback=True)
    ) as checkpointer:
        graph = await create_graph(checkpointer)

        # Initial state
        state: State = {
            "messages": [],
            "current_step": START
        }

        # Configuration for the thread
        config: RunnableConfig = {
            "configurable": {
                "thread_id": "5"
            }
        }

        # Run the graph
        # await graph.ainvoke(state, config)
        async for output in graph.astream(state, config):
            # print(output)
            # Print the current node being processed
            for key, value in output.items():
                print(f"\nProcessing Node: {key}")
                if isinstance(value, dict):
                    if 'messages' in value:
                        for msg in value['messages']:
                            # Print content if exists
                            if msg['content']:
                                print(f"Message: {msg['content']}")

                            # Print tool calls if they exist
                            if hasattr(msg, 'tool_calls') and msg['tool_calls']:
                                for tool_call in msg['tool_calls']:
                                    print(f"Tool Call: {tool_call['name']}({tool_call['args']})")

                            # Print function calls from additional_kwargs if they exist
                            if hasattr(msg, 'additional_kwargs') and 'function_call' in msg['additional_kwargs']:
                                func_call = msg['additional_kwargs']['function_call']
                                print(f"Function Call: {func_call['name']}({func_call['arguments']})")

                    if 'current_step' in value:
                        print(f"Current Step: {value['current_step']}")
                print("---")

# Run the example
await main()

The error is thrown at the ainvoke/astream only when using a AsyncPostgresSaver. MemorySaver works fine. What is very annoying about this bug is how it persists down the graph. If you were to uncomment just the model_dumps in the final node you would start getting the AIMessage is not serializable even though it was already made serializable earlier when adding it to the messages forcing you to model_dump all messages after the original setup_node.