Hi @hnanhtuan
Maybe this?
References: BaseCheckpointSaver API, Persistence guide
DynamoDB data model
- Table:
langgraph_checkpoints
- PK:
thread_id (S)
- SK:
checkpoint_id (S) — use the checkpoint["id"] LangGraph provides (UUID/ULID). It’s monotonically increasing for ordering.
- Attributes:
checkpoint_ns (S) — namespace (empty string by default)
ts (S) — ISO timestamp from checkpoint["ts"]
checkpoint_b (B) — serialized checkpoint bytes (via serde.dumps)
channel_versions (M of S/Number) — optional denormalized
metadata (M) — CheckpointMetadata (small JSON)
parent_checkpoint_id (S, optional)
- Table:
langgraph_writes
- PK:
thread_id#checkpoint_id (S)
- SK:
task_id#seq (S) — append a sequence for stable ordering
- Attributes:
task_path (S)
channel (S)
value_b (B) — serialized write value
Note: You can also keep writes in a map in the main item, but a separate table scales better.
Minimal implementation (sync)
import os
import uuid
from typing import Any, AsyncIterator, Iterator, Mapping, Sequence, Tuple
import boto3
from boto3.dynamodb.conditions import Key
from langgraph.checkpoint.base import (
BaseCheckpointSaver,
Checkpoint,
CheckpointMetadata,
)
from langgraph.checkpoint.serde.jsonplus import JsonPlusSerializer
from langchain_core.runnables import RunnableConfig
class DynamoDBSaver(BaseCheckpointSaver[str]):
def __init__(
self,
checkpoints_table: str,
writes_table: str,
*,
serde=None,
dynamodb_resource=None,
) -> None:
super().__init__()
self.serde = serde or JsonPlusSerializer()
self._dynamo = dynamodb_resource or boto3.resource("dynamodb")
self._t_checkpoints = self._dynamo.Table(checkpoints_table)
self._t_writes = self._dynamo.Table(writes_table)
# ---- Required (sync) ----
def get_tuple(self, config: RunnableConfig):
cfg = config.get("configurable", {})
thread_id: str = cfg["thread_id"]
ckpt_ns: str = cfg.get("checkpoint_ns", "")
ckpt_id: str | None = cfg.get("checkpoint_id")
if ckpt_id:
# Exact checkpoint
item = self._t_checkpoints.get_item(
Key={"thread_id": self._pk(thread_id, ckpt_ns), "checkpoint_id": ckpt_id}
).get("Item")
else:
# Latest for the thread_id/ns
resp = self._t_checkpoints.query(
KeyConditionExpression=Key("thread_id").eq(self._pk(thread_id, ckpt_ns)),
ScanIndexForward=False, # newest first
Limit=1,
)
items = resp.get("Items", [])
item = items[0] if items else None
if not item:
return None
# Deserialize checkpoint payload
checkpoint: Checkpoint = self.serde.loads(item["checkpoint_b"].value)
# Build configs
out_config: RunnableConfig = {
"configurable": {
"thread_id": thread_id,
"checkpoint_ns": ckpt_ns,
"checkpoint_id": item["checkpoint_id"],
}
}
parent_cfg = None
if parent_id := item.get("parent_checkpoint_id"):
parent_cfg = {
"configurable": {
"thread_id": thread_id,
"checkpoint_ns": ckpt_ns,
"checkpoint_id": parent_id,
}
}
metadata: CheckpointMetadata = item.get("metadata", {})
# Pending writes for this checkpoint (optional to return here; LangGraph tolerates empty)
# You can also fetch them in list() if needed.
return (out_config, checkpoint, parent_cfg, metadata)
def list(
self,
config: RunnableConfig | None,
*,
filter: dict[str, Any] | None = None,
before: RunnableConfig | None = None,
limit: int | None = None,
) -> Iterator[Tuple[RunnableConfig, Checkpoint, RunnableConfig | None, CheckpointMetadata]]:
cfg = (config or {}).get("configurable", {})
thread_id: str | None = cfg.get("thread_id")
ckpt_ns: str = cfg.get("checkpoint_ns", "")
if not thread_id:
# List all threads: scan (costly). Prefer listing per-thread.
resp = self._t_checkpoints.scan(Limit=limit or 100)
items = resp.get("Items", [])
else:
key_expr = Key("thread_id").eq(self._pk(thread_id, ckpt_ns))
kwargs = dict(KeyConditionExpression=key_expr, ScanIndexForward=False)
if before:
bcfg = before.get("configurable", {})
if bcfg.get("checkpoint_id"):
# Use exclusive start to get strictly before
kwargs["ExclusiveStartKey"] = {
"thread_id": self._pk(thread_id, ckpt_ns),
"checkpoint_id": bcfg["checkpoint_id"],
}
if limit:
kwargs["Limit"] = limit
resp = self._t_checkpoints.query(**kwargs)
items = resp.get("Items", [])
for item in items:
if filter:
# Optional: simple metadata filter
meta = item.get("metadata", {})
if not all(meta.get(k) == v for k, v in filter.items()):
continue
checkpoint: Checkpoint = self.serde.loads(item["checkpoint_b"].value)
out_cfg: RunnableConfig = {
"configurable": {
"thread_id": self._unpk_thread(item["thread_id"])[0],
"checkpoint_ns": self._unpk_thread(item["thread_id"])[1],
"checkpoint_id": item["checkpoint_id"],
}
}
parent_cfg = None
if parent_id := item.get("parent_checkpoint_id"):
tid, ns = self._unpk_thread(item["thread_id"])
parent_cfg = {
"configurable": {
"thread_id": tid,
"checkpoint_ns": ns,
"checkpoint_id": parent_id,
}
}
metadata: CheckpointMetadata = item.get("metadata", {})
yield (out_cfg, checkpoint, parent_cfg, metadata)
def put(
self,
config: RunnableConfig,
checkpoint: Checkpoint,
metadata: CheckpointMetadata,
new_versions: Mapping[str, str | int | float],
) -> RunnableConfig:
cfg = config.get("configurable", {})
thread_id: str = cfg["thread_id"]
ckpt_ns: str = cfg.get("checkpoint_ns", "")
# Ensure checkpoint has id/ts (LangGraph usually sets these)
ckpt_id: str = checkpoint.get("id") or str(uuid.uuid4())
ts: str = checkpoint["ts"]
# Serialize checkpoint (the whole dict)
payload = self.serde.dumps(checkpoint)
parent_cfg = config.get("parent_config")
parent_id = None
if parent_cfg:
parent_id = parent_cfg["configurable"].get("checkpoint_id")
item = {
"thread_id": self._pk(thread_id, ckpt_ns),
"checkpoint_id": ckpt_id,
"ts": ts,
"checkpoint_ns": ckpt_ns,
"checkpoint_b": payload, # Binary attribute
"metadata": metadata,
"new_versions": {k: str(v) for k, v in new_versions.items()},
}
if parent_id:
item["parent_checkpoint_id"] = parent_id
self._t_checkpoints.put_item(Item=item)
# Return updated config with assigned checkpoint_id
return {
"configurable": {
"thread_id": thread_id,
"checkpoint_ns": ckpt_ns,
"checkpoint_id": ckpt_id,
}
}
def put_writes(
self,
config: RunnableConfig,
writes: Sequence[tuple[str, Any]],
task_id: str,
task_path: str = "",
) -> None:
cfg = config.get("configurable", {})
thread_id: str = cfg["thread_id"]
ckpt_ns: str = cfg.get("checkpoint_ns", "")
ckpt_id: str = cfg["checkpoint_id"]
pk = f"{self._pk(thread_id, ckpt_ns)}#{ckpt_id}"
with self._t_writes.batch_writer() as bw:
for seq, (channel, value) in enumerate(writes):
bw.put_item(
Item={
"thread_checkpoint": pk,
"task_key": f"{task_id}#{seq:06d}",
"task_path": task_path,
"channel": channel,
"value_b": self.serde.dumps(value),
}
)
def delete_thread(self, thread_id: str) -> None:
# Delete all checkpoints for all namespaces for this thread (scan by begins_with)
# If you only use default namespace "", you can query that PK directly.
# 1) Find all items
resp = self._t_checkpoints.scan(
FilterExpression=Key("thread_id").begins_with(f"{thread_id}#")
)
items = resp.get("Items", [])
# 2) Delete them
with self._t_checkpoints.batch_writer() as bw:
for it in items:
bw.delete_item(
Key={
"thread_id": it["thread_id"],
"checkpoint_id": it["checkpoint_id"],
}
)
# 3) Delete writes
# Find all writes partition keys for this thread
resp2 = self._t_writes.scan(
FilterExpression=Key("thread_checkpoint").begins_with(f"{thread_id}#")
)
witems = resp2.get("Items", [])
with self._t_writes.batch_writer() as bw2:
for it in witems:
bw2.delete_item(
Key={
"thread_checkpoint": it["thread_checkpoint"],
"task_key": it["task_key"],
}
)
# ---- Helpers ----
@staticmethod
def _pk(thread_id: str, ns: str) -> str:
return f"{thread_id}#{ns}"
@staticmethod
def _unpk_thread(pk: str) -> tuple[str, str]:
if "#" in pk:
tid, ns = pk.split("#", 1)
else:
tid, ns = pk, ""
return tid, ns
Notes:
- DynamoDB Binary attributes: with
boto3, use raw bytes. Above, self.serde.dumps() returns bytes; boto3 will store them as B automatically (you can also wrap with boto3.dynamodb.types.Binary explicitly if desired).
get and aget will work via the base class, which uses get_tuple internally.
get_next_version can stay default (monotonic integer). If you wish to store versions as strings, you can override to return strings.
Optional async variant
If you use .ainvoke()/.astream(), implement async counterparts with aioboto3. Method bodies mirror the sync ones.
import aioboto3
from typing import AsyncIterator
class AsyncDynamoDBSaver(DynamoDBSaver):
def __init__(self, checkpoints_table: str, writes_table: str, *, serde=None, session=None):
super().__init__(checkpoints_table, writes_table, serde=serde)
self._session = session or aioboto3.Session()
self._checkpoints_table = checkpoints_table
self._writes_table = writes_table
self._dynamo = None
self._t_checkpoints = None
self._t_writes = None
async def __aenter__(self):
self._dynamo = await self._session.resource("dynamodb").__aenter__()
self._t_checkpoints = await self._dynamo.Table(self._checkpoints_table)
self._t_writes = await self._dynamo.Table(self._writes_table)
return self
async def __aexit__(self, exc_type, exc, tb):
if self._dynamo:
await self._dynamo.__aexit__(exc_type, exc, tb)
async def aget_tuple(self, config):
# Same logic as get_tuple, using await'ed aioboto3 calls
...
async def alist(self, config, *, filter=None, before=None, limit=None) -> AsyncIterator:
...
async def aput(self, config, checkpoint, metadata, new_versions):
...
async def aput_writes(self, config, writes, task_id, task_path: str = ""):
...
async def adelete_thread(self, thread_id: str):
...
Using the saver
from langgraph.graph import StateGraph
from langgraph.checkpoint.serde.encrypted import EncryptedSerializer
# Optional: encrypt at rest (reads LANGGRAPH_AES_KEY)
serde = EncryptedSerializer.from_pycryptodome_aes()
checkpointer = DynamoDBSaver(
checkpoints_table="langgraph_checkpoints",
writes_table="langgraph_writes",
serde=serde, # or default JsonPlusSerializer()
)
builder = StateGraph(int)
builder.add_node("inc", lambda x: x + 1)
builder.set_entry_point("inc")
builder.set_finish_point("inc")
graph = builder.compile(checkpointer=checkpointer)
config = {"configurable": {"thread_id": "user-123"}}
result = graph.invoke(1, config) # persists checkpoints to DynamoDB
Table creation (one-time)
import boto3
d = boto3.client("dynamodb")
# Checkpoints
d.create_table(
TableName="langgraph_checkpoints",
AttributeDefinitions=[
{"AttributeName": "thread_id", "AttributeType": "S"},
{"AttributeName": "checkpoint_id", "AttributeType": "S"},
],
KeySchema=[
{"AttributeName": "thread_id", "KeyType": "HASH"},
{"AttributeName": "checkpoint_id", "KeyType": "RANGE"},
],
BillingMode="PAY_PER_REQUEST",
)
# Writes
d.create_table(
TableName="langgraph_writes",
AttributeDefinitions=[
{"AttributeName": "thread_checkpoint", "AttributeType": "S"},
{"AttributeName": "task_key", "AttributeType": "S"},
],
KeySchema=[
{"AttributeName": "thread_checkpoint", "KeyType": "HASH"},
{"AttributeName": "task_key", "KeyType": "RANGE"},
],
BillingMode="PAY_PER_REQUEST",
)