How to implement custom BaseCheckpointSaver?

Hi everyone,

I’m working on saving checkpoints in AWS DynamoDB. I’ve gone through the documentation for BaseCheckpointSaver (https://langchain-ai.github.io/langgraph/reference/checkpoints/), but it’s still not clear to me.

Does anyone have an example or a more detailed explanation they could share? Any guidance would be greatly appreciated.

Thanks a lot!

Hi @hnanhtuan

Maybe this?

References: BaseCheckpointSaver API, Persistence guide


DynamoDB data model

  • Table: langgraph_checkpoints
    • PK: thread_id (S)
    • SK: checkpoint_id (S) — use the checkpoint["id"] LangGraph provides (UUID/ULID). It’s monotonically increasing for ordering.
    • Attributes:
      • checkpoint_ns (S) — namespace (empty string by default)
      • ts (S) — ISO timestamp from checkpoint["ts"]
      • checkpoint_b (B) — serialized checkpoint bytes (via serde.dumps)
      • channel_versions (M of S/Number) — optional denormalized
      • metadata (M) — CheckpointMetadata (small JSON)
      • parent_checkpoint_id (S, optional)
  • Table: langgraph_writes
    • PK: thread_id#checkpoint_id (S)
    • SK: task_id#seq (S) — append a sequence for stable ordering
    • Attributes:
      • task_path (S)
      • channel (S)
      • value_b (B) — serialized write value

Note: You can also keep writes in a map in the main item, but a separate table scales better.


Minimal implementation (sync)

import os
import uuid
from typing import Any, AsyncIterator, Iterator, Mapping, Sequence, Tuple

import boto3
from boto3.dynamodb.conditions import Key

from langgraph.checkpoint.base import (
    BaseCheckpointSaver,
    Checkpoint,
    CheckpointMetadata,
)
from langgraph.checkpoint.serde.jsonplus import JsonPlusSerializer
from langchain_core.runnables import RunnableConfig


class DynamoDBSaver(BaseCheckpointSaver[str]):
    def __init__(
        self,
        checkpoints_table: str,
        writes_table: str,
        *,
        serde=None,
        dynamodb_resource=None,
    ) -> None:
        super().__init__()
        self.serde = serde or JsonPlusSerializer()
        self._dynamo = dynamodb_resource or boto3.resource("dynamodb")
        self._t_checkpoints = self._dynamo.Table(checkpoints_table)
        self._t_writes = self._dynamo.Table(writes_table)

    # ---- Required (sync) ----
    def get_tuple(self, config: RunnableConfig):
        cfg = config.get("configurable", {})
        thread_id: str = cfg["thread_id"]
        ckpt_ns: str = cfg.get("checkpoint_ns", "")
        ckpt_id: str | None = cfg.get("checkpoint_id")

        if ckpt_id:
            # Exact checkpoint
            item = self._t_checkpoints.get_item(
                Key={"thread_id": self._pk(thread_id, ckpt_ns), "checkpoint_id": ckpt_id}
            ).get("Item")
        else:
            # Latest for the thread_id/ns
            resp = self._t_checkpoints.query(
                KeyConditionExpression=Key("thread_id").eq(self._pk(thread_id, ckpt_ns)),
                ScanIndexForward=False,  # newest first
                Limit=1,
            )
            items = resp.get("Items", [])
            item = items[0] if items else None

        if not item:
            return None

        # Deserialize checkpoint payload
        checkpoint: Checkpoint = self.serde.loads(item["checkpoint_b"].value)

        # Build configs
        out_config: RunnableConfig = {
            "configurable": {
                "thread_id": thread_id,
                "checkpoint_ns": ckpt_ns,
                "checkpoint_id": item["checkpoint_id"],
            }
        }
        parent_cfg = None
        if parent_id := item.get("parent_checkpoint_id"):
            parent_cfg = {
                "configurable": {
                    "thread_id": thread_id,
                    "checkpoint_ns": ckpt_ns,
                    "checkpoint_id": parent_id,
                }
            }

        metadata: CheckpointMetadata = item.get("metadata", {})

        # Pending writes for this checkpoint (optional to return here; LangGraph tolerates empty)
        # You can also fetch them in list() if needed.
        return (out_config, checkpoint, parent_cfg, metadata)

    def list(
        self,
        config: RunnableConfig | None,
        *,
        filter: dict[str, Any] | None = None,
        before: RunnableConfig | None = None,
        limit: int | None = None,
    ) -> Iterator[Tuple[RunnableConfig, Checkpoint, RunnableConfig | None, CheckpointMetadata]]:
        cfg = (config or {}).get("configurable", {})
        thread_id: str | None = cfg.get("thread_id")
        ckpt_ns: str = cfg.get("checkpoint_ns", "")

        if not thread_id:
            # List all threads: scan (costly). Prefer listing per-thread.
            resp = self._t_checkpoints.scan(Limit=limit or 100)
            items = resp.get("Items", [])
        else:
            key_expr = Key("thread_id").eq(self._pk(thread_id, ckpt_ns))
            kwargs = dict(KeyConditionExpression=key_expr, ScanIndexForward=False)
            if before:
                bcfg = before.get("configurable", {})
                if bcfg.get("checkpoint_id"):
                    # Use exclusive start to get strictly before
                    kwargs["ExclusiveStartKey"] = {
                        "thread_id": self._pk(thread_id, ckpt_ns),
                        "checkpoint_id": bcfg["checkpoint_id"],
                    }
            if limit:
                kwargs["Limit"] = limit
            resp = self._t_checkpoints.query(**kwargs)
            items = resp.get("Items", [])

        for item in items:
            if filter:
                # Optional: simple metadata filter
                meta = item.get("metadata", {})
                if not all(meta.get(k) == v for k, v in filter.items()):
                    continue

            checkpoint: Checkpoint = self.serde.loads(item["checkpoint_b"].value)
            out_cfg: RunnableConfig = {
                "configurable": {
                    "thread_id": self._unpk_thread(item["thread_id"])[0],
                    "checkpoint_ns": self._unpk_thread(item["thread_id"])[1],
                    "checkpoint_id": item["checkpoint_id"],
                }
            }
            parent_cfg = None
            if parent_id := item.get("parent_checkpoint_id"):
                tid, ns = self._unpk_thread(item["thread_id"])
                parent_cfg = {
                    "configurable": {
                        "thread_id": tid,
                        "checkpoint_ns": ns,
                        "checkpoint_id": parent_id,
                    }
                }
            metadata: CheckpointMetadata = item.get("metadata", {})
            yield (out_cfg, checkpoint, parent_cfg, metadata)

    def put(
        self,
        config: RunnableConfig,
        checkpoint: Checkpoint,
        metadata: CheckpointMetadata,
        new_versions: Mapping[str, str | int | float],
    ) -> RunnableConfig:
        cfg = config.get("configurable", {})
        thread_id: str = cfg["thread_id"]
        ckpt_ns: str = cfg.get("checkpoint_ns", "")

        # Ensure checkpoint has id/ts (LangGraph usually sets these)
        ckpt_id: str = checkpoint.get("id") or str(uuid.uuid4())
        ts: str = checkpoint["ts"]

        # Serialize checkpoint (the whole dict)
        payload = self.serde.dumps(checkpoint)

        parent_cfg = config.get("parent_config")
        parent_id = None
        if parent_cfg:
            parent_id = parent_cfg["configurable"].get("checkpoint_id")

        item = {
            "thread_id": self._pk(thread_id, ckpt_ns),
            "checkpoint_id": ckpt_id,
            "ts": ts,
            "checkpoint_ns": ckpt_ns,
            "checkpoint_b": payload,  # Binary attribute
            "metadata": metadata,
            "new_versions": {k: str(v) for k, v in new_versions.items()},
        }
        if parent_id:
            item["parent_checkpoint_id"] = parent_id

        self._t_checkpoints.put_item(Item=item)

        # Return updated config with assigned checkpoint_id
        return {
            "configurable": {
                "thread_id": thread_id,
                "checkpoint_ns": ckpt_ns,
                "checkpoint_id": ckpt_id,
            }
        }

    def put_writes(
        self,
        config: RunnableConfig,
        writes: Sequence[tuple[str, Any]],
        task_id: str,
        task_path: str = "",
    ) -> None:
        cfg = config.get("configurable", {})
        thread_id: str = cfg["thread_id"]
        ckpt_ns: str = cfg.get("checkpoint_ns", "")
        ckpt_id: str = cfg["checkpoint_id"]

        pk = f"{self._pk(thread_id, ckpt_ns)}#{ckpt_id}"

        with self._t_writes.batch_writer() as bw:
            for seq, (channel, value) in enumerate(writes):
                bw.put_item(
                    Item={
                        "thread_checkpoint": pk,
                        "task_key": f"{task_id}#{seq:06d}",
                        "task_path": task_path,
                        "channel": channel,
                        "value_b": self.serde.dumps(value),
                    }
                )

    def delete_thread(self, thread_id: str) -> None:
        # Delete all checkpoints for all namespaces for this thread (scan by begins_with)
        # If you only use default namespace "", you can query that PK directly.
        # 1) Find all items
        resp = self._t_checkpoints.scan(
            FilterExpression=Key("thread_id").begins_with(f"{thread_id}#")
        )
        items = resp.get("Items", [])
        # 2) Delete them
        with self._t_checkpoints.batch_writer() as bw:
            for it in items:
                bw.delete_item(
                    Key={
                        "thread_id": it["thread_id"],
                        "checkpoint_id": it["checkpoint_id"],
                    }
                )
        # 3) Delete writes
        # Find all writes partition keys for this thread
        resp2 = self._t_writes.scan(
            FilterExpression=Key("thread_checkpoint").begins_with(f"{thread_id}#")
        )
        witems = resp2.get("Items", [])
        with self._t_writes.batch_writer() as bw2:
            for it in witems:
                bw2.delete_item(
                    Key={
                        "thread_checkpoint": it["thread_checkpoint"],
                        "task_key": it["task_key"],
                    }
                )

    # ---- Helpers ----
    @staticmethod
    def _pk(thread_id: str, ns: str) -> str:
        return f"{thread_id}#{ns}"

    @staticmethod
    def _unpk_thread(pk: str) -> tuple[str, str]:
        if "#" in pk:
            tid, ns = pk.split("#", 1)
        else:
            tid, ns = pk, ""
        return tid, ns

Notes:

  • DynamoDB Binary attributes: with boto3, use raw bytes. Above, self.serde.dumps() returns bytes; boto3 will store them as B automatically (you can also wrap with boto3.dynamodb.types.Binary explicitly if desired).
  • get and aget will work via the base class, which uses get_tuple internally.
  • get_next_version can stay default (monotonic integer). If you wish to store versions as strings, you can override to return strings.

Optional async variant

If you use .ainvoke()/.astream(), implement async counterparts with aioboto3. Method bodies mirror the sync ones.

import aioboto3
from typing import AsyncIterator

class AsyncDynamoDBSaver(DynamoDBSaver):
    def __init__(self, checkpoints_table: str, writes_table: str, *, serde=None, session=None):
        super().__init__(checkpoints_table, writes_table, serde=serde)
        self._session = session or aioboto3.Session()
        self._checkpoints_table = checkpoints_table
        self._writes_table = writes_table
        self._dynamo = None
        self._t_checkpoints = None
        self._t_writes = None

    async def __aenter__(self):
        self._dynamo = await self._session.resource("dynamodb").__aenter__()
        self._t_checkpoints = await self._dynamo.Table(self._checkpoints_table)
        self._t_writes = await self._dynamo.Table(self._writes_table)
        return self

    async def __aexit__(self, exc_type, exc, tb):
        if self._dynamo:
            await self._dynamo.__aexit__(exc_type, exc, tb)

    async def aget_tuple(self, config):
        # Same logic as get_tuple, using await'ed aioboto3 calls
        ...

    async def alist(self, config, *, filter=None, before=None, limit=None) -> AsyncIterator:
        ...

    async def aput(self, config, checkpoint, metadata, new_versions):
        ...

    async def aput_writes(self, config, writes, task_id, task_path: str = ""):
        ...

    async def adelete_thread(self, thread_id: str):
        ...

Using the saver

from langgraph.graph import StateGraph
from langgraph.checkpoint.serde.encrypted import EncryptedSerializer

# Optional: encrypt at rest (reads LANGGRAPH_AES_KEY)
serde = EncryptedSerializer.from_pycryptodome_aes()

checkpointer = DynamoDBSaver(
    checkpoints_table="langgraph_checkpoints",
    writes_table="langgraph_writes",
    serde=serde,  # or default JsonPlusSerializer()
)

builder = StateGraph(int)
builder.add_node("inc", lambda x: x + 1)
builder.set_entry_point("inc")
builder.set_finish_point("inc")

graph = builder.compile(checkpointer=checkpointer)

config = {"configurable": {"thread_id": "user-123"}}
result = graph.invoke(1, config)  # persists checkpoints to DynamoDB

Table creation (one-time)

import boto3

d = boto3.client("dynamodb")

# Checkpoints
d.create_table(
    TableName="langgraph_checkpoints",
    AttributeDefinitions=[
        {"AttributeName": "thread_id", "AttributeType": "S"},
        {"AttributeName": "checkpoint_id", "AttributeType": "S"},
    ],
    KeySchema=[
        {"AttributeName": "thread_id", "KeyType": "HASH"},
        {"AttributeName": "checkpoint_id", "KeyType": "RANGE"},
    ],
    BillingMode="PAY_PER_REQUEST",
)

# Writes
d.create_table(
    TableName="langgraph_writes",
    AttributeDefinitions=[
        {"AttributeName": "thread_checkpoint", "AttributeType": "S"},
        {"AttributeName": "task_key", "AttributeType": "S"},
    ],
    KeySchema=[
        {"AttributeName": "thread_checkpoint", "KeyType": "HASH"},
        {"AttributeName": "task_key", "KeyType": "RANGE"},
    ],
    BillingMode="PAY_PER_REQUEST",
)

Thank you very much.