Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 13 additions & 8 deletions alphatrion/log/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,19 +28,22 @@ async def load_dataset(id: str | uuid.UUID, output_dir: str | None = None) -> li
return result


# TODO: we may need to add repo name to support sub-categorization of checkpoints,
# e.g., "ckpt/epoch".
async def load_checkpoint(
id: str | uuid.UUID,
version: str = "latest",
version_or_filename: str = "latest",
type: str = "experiment",
output_dir: str | None = None,
) -> list[str]:
"""
Load checkpoint from artifact registry, the path is expected to be in the format of
"org_id/team_id/exp_id/ckpt/".
Load checkpoint from artifact registry, the path is expected to be in the format of:
- OCI: "org_id/team_id/exp_id/ckpt:version"
- S3: "org_id/team_id/exp_id/ckpt/filename"

:param id: the id of the experiment.
:param version: the version of the checkpoint to load, default is "latest".
If version is "latest", it will load the latest version (for oci backend) or
:param version_or_filename: the version or filename of the checkpoint to load, default is "latest".
If version_or_filename is "latest", it will load the latest version (for oci backend) or
the file with the latest timestamp (for s3 backend).
:param type: the type of the checkpoint, can be "experiment" or "agent", default is "experiment".
:param output_dir: the directory to which the checkpoint will be loaded.
Expand All @@ -58,15 +61,17 @@ async def load_checkpoint(

# We only need to do this for s3 backend, because for oci backend,
# the version is the tag and "latest" tag will always point to the latest version.
if version == "latest" and artifact.storage_type == ARTIFACT_TYPE_S3:
if version_or_filename == "latest" and artifact.storage_type == ARTIFACT_TYPE_S3:
versions = artifact.list_versions(repo_name)
if versions is None or len(versions) == 0:
return []

version = versions[0] # Assuming versions are sorted by time, newest first
version_or_filename = versions[
0
] # Assuming versions are sorted by time, newest first

result = await asyncio.get_running_loop().run_in_executor(
None, artifact.pull, repo_name, version, output_dir
None, artifact.pull, repo_name, version_or_filename, output_dir
)

return result
10 changes: 6 additions & 4 deletions tests/integration/test_oci_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,7 @@ async def test_load_checkpoint_latest(artifact):
# Load latest checkpoint
output_dir = os.path.join(tmpdir, "download")
result = await alpha.load_checkpoint(
id=exp_id, version="latest", output_dir=output_dir
id=exp_id, version_or_filename="latest", output_dir=output_dir
)

# Verify checkpoint was downloaded
Expand Down Expand Up @@ -388,7 +388,7 @@ async def test_load_checkpoint_specific_version(artifact):
)

result = await alpha.load_checkpoint(
id=exp_id, version="v1", output_dir=output_dir
id=exp_id, version_or_filename="v1", output_dir=output_dir
)

# Validate output_dir was created
Expand Down Expand Up @@ -435,7 +435,9 @@ async def test_load_checkpoint_nonexistent(artifact):
# For OCI, trying to pull a non-existent tag should raise an error
# (unlike S3 which returns [] when no files exist)
with pytest.raises(RuntimeError, match="Failed to pull artifacts"):
await alpha.load_checkpoint(id=exp_id, version="latest", output_dir=tmpdir)
await alpha.load_checkpoint(
id=exp_id, version_or_filename="latest", output_dir=tmpdir
)


@pytest.mark.asyncio
Expand Down Expand Up @@ -464,7 +466,7 @@ async def test_load_checkpoint_multiple_files(artifact):
# Load checkpoint
output_dir = os.path.join(tmpdir, "download")
result = await alpha.load_checkpoint(
id=exp_id, version="v1", output_dir=output_dir
id=exp_id, version_or_filename="v1", output_dir=output_dir
)

# Verify all files were downloaded
Expand Down
265 changes: 265 additions & 0 deletions tests/integration/test_s3_backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,265 @@
"""Integration tests for S3 artifact backend with load_checkpoint.

Note: These tests use moto to mock AWS S3 for integration testing.

Run tests with:
pytest tests/integration/test_s3_backend.py -v
"""

import os
import tempfile
import uuid

import pytest

import alphatrion as alpha
import alphatrion.storage.runtime as storage_runtime_module


@pytest.fixture(autouse=True)
def s3_env_vars():
"""Set up S3 environment variables for testing."""
original_env = {}

# Reset storage runtime to ensure it picks up new env vars
storage_runtime_module.__STORAGE_RUNTIME__ = None

# Environment variables to set for S3
env_vars = {
"ALPHATRION_ARTIFACT_STORAGE_TYPE": "s3",
"ALPHATRION_ARTIFACT_S3_BUCKET": "test-bucket",
"ALPHATRION_ARTIFACT_S3_REGION": "us-east-1",
"ALPHATRION_ENABLE_ARTIFACT_STORAGE": "true",
}

# Save original values and set S3 variables
for key, value in env_vars.items():
original_env[key] = os.environ.get(key)
os.environ[key] = value

yield

# Restore original environment
for key, value in original_env.items():
if value is None:
os.environ.pop(key, None)
else:
os.environ[key] = value

storage_runtime_module.__STORAGE_RUNTIME__ = None # Reset again to clear any cached runtime


@pytest.fixture
def mock_s3():
"""Create a mock AWS context for S3 testing."""
try:
from moto import mock_aws
except ImportError:
pytest.skip("moto is required for S3 backend tests")

with mock_aws():
import boto3

# Create the bucket
s3 = boto3.client("s3", region_name="us-east-1")
s3.create_bucket(Bucket="test-bucket")

# Enable versioning on the bucket for native versioning support
s3.put_bucket_versioning(
Bucket="test-bucket", VersioningConfiguration={"Status": "Enabled"}
)

yield s3


@pytest.fixture
def artifact(mock_s3):
"""Create an artifact instance with S3 backend.

This fixture depends on mock_s3 to ensure the mock context is active.
"""
from alphatrion.artifact.artifact import Artifact

return Artifact()


@pytest.mark.asyncio
async def test_load_checkpoint_by_filename(artifact, mock_s3):
"""Test load_checkpoint by filename for S3 backend."""
org_id = uuid.uuid4()
team_id = uuid.uuid4()
user_id = uuid.uuid4()
exp_id = uuid.uuid4()

alpha.init(org_id=org_id, team_id=team_id, user_id=user_id)

with tempfile.TemporaryDirectory() as tmpdir:
# Push multiple checkpoints as separate files (S3 flat structure)
for i in range(3):
test_file = os.path.join(tmpdir, f"checkpoint_{i}.pt")
with open(test_file, "w") as f:
f.write(f"model weights version {i}")

artifact.push(
repo_name=f"{org_id}/{team_id}/{exp_id}/ckpt",
paths=test_file,
)

# Load specific checkpoint by filename
output_dir = os.path.join(tmpdir, "download")
result = await alpha.load_checkpoint(
id=exp_id, version_or_filename="checkpoint_1.pt", output_dir=output_dir
)

# Verify checkpoint was downloaded
assert result is not None
assert len(result) == 1
assert os.path.exists(result[0])
assert os.path.basename(result[0]) == "checkpoint_1.pt"

# Verify content
with open(result[0]) as f:
content = f.read()
assert content == "model weights version 1"


@pytest.mark.asyncio
async def test_load_checkpoint_filename_with_dot(artifact, mock_s3):
"""Test load_checkpoint with filename containing dots (e.g., checkpoint.v1.0.pt)."""
org_id = uuid.uuid4()
team_id = uuid.uuid4()
user_id = uuid.uuid4()
exp_id = uuid.uuid4()

alpha.init(org_id=org_id, team_id=team_id, user_id=user_id)

with tempfile.TemporaryDirectory() as tmpdir:
# Push checkpoint with filename containing dots
test_file = os.path.join(tmpdir, "checkpoint.v1.0.pt")
with open(test_file, "w") as f:
f.write("model weights version 1.0")

artifact.push(
repo_name=f"{org_id}/{team_id}/{exp_id}/ckpt",
paths=test_file,
)

# Load checkpoint by filename with dots
output_dir = os.path.join(tmpdir, "download")
result = await alpha.load_checkpoint(
id=exp_id, version_or_filename="checkpoint.v1.0.pt", output_dir=output_dir
)

# Verify checkpoint was downloaded
assert result is not None
assert len(result) == 1
assert os.path.exists(result[0])
assert os.path.basename(result[0]) == "checkpoint.v1.0.pt"

# Verify content
with open(result[0]) as f:
content = f.read()
assert content == "model weights version 1.0"


@pytest.mark.asyncio
async def test_load_checkpoint_nonexistent(artifact, mock_s3):
"""Test load_checkpoint with nonexistent checkpoint returns empty list for S3."""
org_id = uuid.uuid4()
team_id = uuid.uuid4()
user_id = uuid.uuid4()
exp_id = uuid.uuid4()

alpha.init(org_id=org_id, team_id=team_id, user_id=user_id)

with tempfile.TemporaryDirectory() as tmpdir:
# For S3, trying to pull a non-existent version/folder should return empty list
result = await alpha.load_checkpoint(
id=exp_id, version_or_filename="nonexistent", output_dir=tmpdir
)
assert result == []


@pytest.mark.asyncio
async def test_load_checkpoint_folder(artifact, mock_s3):
"""Test load_checkpoint loading all files from a folder prefix for S3 backend."""
org_id = uuid.uuid4()
team_id = uuid.uuid4()
user_id = uuid.uuid4()
exp_id = uuid.uuid4()

alpha.init(org_id=org_id, team_id=team_id, user_id=user_id)

with tempfile.TemporaryDirectory() as tmpdir:
# Push checkpoint with multiple files in a folder
files = []
for i in range(3):
file_path = os.path.join(tmpdir, f"layer_{i}.pt")
with open(file_path, "w") as f:
f.write(f"layer {i} weights")
files.append(file_path)

# Push to a folder path (e.g., "epoch_10")
artifact.push(
repo_name=f"{org_id}/{team_id}/{exp_id}/ckpt",
paths=files,
version="epoch_10",
)

# Load all files from the folder
output_dir = os.path.join(tmpdir, "download")
result = await alpha.load_checkpoint(
id=exp_id, version_or_filename="epoch_10", output_dir=output_dir
)

# Verify all files were downloaded
assert result is not None
assert len(result) == 3

for i in range(3):
filename = f"layer_{i}.pt"
assert any(filename in r for r in result)

file_path = os.path.join(output_dir, filename)
assert os.path.exists(file_path)

with open(file_path) as f:
assert f.read() == f"layer {i} weights"


@pytest.mark.asyncio
async def test_load_checkpoint_single_file(artifact, mock_s3):
"""Test load_checkpoint pulling a single file directly (flat structure)."""
org_id = uuid.uuid4()
team_id = uuid.uuid4()
user_id = uuid.uuid4()
exp_id = uuid.uuid4()

alpha.init(org_id=org_id, team_id=team_id, user_id=user_id)

with tempfile.TemporaryDirectory() as tmpdir:
# Push checkpoint without version (flat structure)
test_file = os.path.join(tmpdir, "checkpoint.pt")
with open(test_file, "w") as f:
f.write("model weights")

artifact.push(
repo_name=f"{org_id}/{team_id}/{exp_id}/ckpt",
paths=test_file,
)

# Load checkpoint by filename
output_dir = os.path.join(tmpdir, "download")
result = await alpha.load_checkpoint(
id=exp_id, version_or_filename="checkpoint.pt", output_dir=output_dir
)

# Verify file was downloaded
assert result is not None
assert len(result) == 1
assert os.path.exists(result[0])
assert os.path.basename(result[0]) == "checkpoint.pt"

# Verify content
with open(result[0]) as f:
assert f.read() == "model weights"
Loading