diff --git a/alphatrion/log/load.py b/alphatrion/log/load.py index 5a1f270..7456905 100644 --- a/alphatrion/log/load.py +++ b/alphatrion/log/load.py @@ -28,19 +28,22 @@ async def load_dataset(id: str | uuid.UUID, output_dir: str | None = None) -> li return result +# TODO: we may need to add repo name to support sub-categorization of checkpoints, +# e.g., "ckpt/epoch". async def load_checkpoint( id: str | uuid.UUID, - version: str = "latest", + version_or_filename: str = "latest", type: str = "experiment", output_dir: str | None = None, ) -> list[str]: """ - Load checkpoint from artifact registry, the path is expected to be in the format of - "org_id/team_id/exp_id/ckpt/". + Load checkpoint from artifact registry, the path is expected to be in the format of: + - OCI: "org_id/team_id/exp_id/ckpt:version" + - S3: "org_id/team_id/exp_id/ckpt/filename" :param id: the id of the experiment. - :param version: the version of the checkpoint to load, default is "latest". - If version is "latest", it will load the latest version (for oci backend) or + :param version_or_filename: the version or filename of the checkpoint to load, default is "latest". + If version_or_filename is "latest", it will load the latest version (for oci backend) or the file with the latest timestamp (for s3 backend). :param type: the type of the checkpoint, can be "experiment" or "agent", default is "experiment". :param output_dir: the directory to which the checkpoint will be loaded. @@ -58,15 +61,17 @@ async def load_checkpoint( # We only need to do this for s3 backend, because for oci backend, # the version is the tag and "latest" tag will always point to the latest version. - if version == "latest" and artifact.storage_type == ARTIFACT_TYPE_S3: + if version_or_filename == "latest" and artifact.storage_type == ARTIFACT_TYPE_S3: versions = artifact.list_versions(repo_name) if versions is None or len(versions) == 0: return [] - version = versions[0] # Assuming versions are sorted by time, newest first + version_or_filename = versions[ + 0 + ] # Assuming versions are sorted by time, newest first result = await asyncio.get_running_loop().run_in_executor( - None, artifact.pull, repo_name, version, output_dir + None, artifact.pull, repo_name, version_or_filename, output_dir ) return result diff --git a/tests/integration/test_oci_backend.py b/tests/integration/test_oci_backend.py index 5350c03..5ed9059 100644 --- a/tests/integration/test_oci_backend.py +++ b/tests/integration/test_oci_backend.py @@ -342,7 +342,7 @@ async def test_load_checkpoint_latest(artifact): # Load latest checkpoint output_dir = os.path.join(tmpdir, "download") result = await alpha.load_checkpoint( - id=exp_id, version="latest", output_dir=output_dir + id=exp_id, version_or_filename="latest", output_dir=output_dir ) # Verify checkpoint was downloaded @@ -388,7 +388,7 @@ async def test_load_checkpoint_specific_version(artifact): ) result = await alpha.load_checkpoint( - id=exp_id, version="v1", output_dir=output_dir + id=exp_id, version_or_filename="v1", output_dir=output_dir ) # Validate output_dir was created @@ -435,7 +435,9 @@ async def test_load_checkpoint_nonexistent(artifact): # For OCI, trying to pull a non-existent tag should raise an error # (unlike S3 which returns [] when no files exist) with pytest.raises(RuntimeError, match="Failed to pull artifacts"): - await alpha.load_checkpoint(id=exp_id, version="latest", output_dir=tmpdir) + await alpha.load_checkpoint( + id=exp_id, version_or_filename="latest", output_dir=tmpdir + ) @pytest.mark.asyncio @@ -464,7 +466,7 @@ async def test_load_checkpoint_multiple_files(artifact): # Load checkpoint output_dir = os.path.join(tmpdir, "download") result = await alpha.load_checkpoint( - id=exp_id, version="v1", output_dir=output_dir + id=exp_id, version_or_filename="v1", output_dir=output_dir ) # Verify all files were downloaded diff --git a/tests/integration/test_s3_backend.py b/tests/integration/test_s3_backend.py new file mode 100644 index 0000000..7346f00 --- /dev/null +++ b/tests/integration/test_s3_backend.py @@ -0,0 +1,265 @@ +"""Integration tests for S3 artifact backend with load_checkpoint. + +Note: These tests use moto to mock AWS S3 for integration testing. + +Run tests with: + pytest tests/integration/test_s3_backend.py -v +""" + +import os +import tempfile +import uuid + +import pytest + +import alphatrion as alpha +import alphatrion.storage.runtime as storage_runtime_module + + +@pytest.fixture(autouse=True) +def s3_env_vars(): + """Set up S3 environment variables for testing.""" + original_env = {} + + # Reset storage runtime to ensure it picks up new env vars + storage_runtime_module.__STORAGE_RUNTIME__ = None + + # Environment variables to set for S3 + env_vars = { + "ALPHATRION_ARTIFACT_STORAGE_TYPE": "s3", + "ALPHATRION_ARTIFACT_S3_BUCKET": "test-bucket", + "ALPHATRION_ARTIFACT_S3_REGION": "us-east-1", + "ALPHATRION_ENABLE_ARTIFACT_STORAGE": "true", + } + + # Save original values and set S3 variables + for key, value in env_vars.items(): + original_env[key] = os.environ.get(key) + os.environ[key] = value + + yield + + # Restore original environment + for key, value in original_env.items(): + if value is None: + os.environ.pop(key, None) + else: + os.environ[key] = value + + storage_runtime_module.__STORAGE_RUNTIME__ = None # Reset again to clear any cached runtime + + +@pytest.fixture +def mock_s3(): + """Create a mock AWS context for S3 testing.""" + try: + from moto import mock_aws + except ImportError: + pytest.skip("moto is required for S3 backend tests") + + with mock_aws(): + import boto3 + + # Create the bucket + s3 = boto3.client("s3", region_name="us-east-1") + s3.create_bucket(Bucket="test-bucket") + + # Enable versioning on the bucket for native versioning support + s3.put_bucket_versioning( + Bucket="test-bucket", VersioningConfiguration={"Status": "Enabled"} + ) + + yield s3 + + +@pytest.fixture +def artifact(mock_s3): + """Create an artifact instance with S3 backend. + + This fixture depends on mock_s3 to ensure the mock context is active. + """ + from alphatrion.artifact.artifact import Artifact + + return Artifact() + + +@pytest.mark.asyncio +async def test_load_checkpoint_by_filename(artifact, mock_s3): + """Test load_checkpoint by filename for S3 backend.""" + org_id = uuid.uuid4() + team_id = uuid.uuid4() + user_id = uuid.uuid4() + exp_id = uuid.uuid4() + + alpha.init(org_id=org_id, team_id=team_id, user_id=user_id) + + with tempfile.TemporaryDirectory() as tmpdir: + # Push multiple checkpoints as separate files (S3 flat structure) + for i in range(3): + test_file = os.path.join(tmpdir, f"checkpoint_{i}.pt") + with open(test_file, "w") as f: + f.write(f"model weights version {i}") + + artifact.push( + repo_name=f"{org_id}/{team_id}/{exp_id}/ckpt", + paths=test_file, + ) + + # Load specific checkpoint by filename + output_dir = os.path.join(tmpdir, "download") + result = await alpha.load_checkpoint( + id=exp_id, version_or_filename="checkpoint_1.pt", output_dir=output_dir + ) + + # Verify checkpoint was downloaded + assert result is not None + assert len(result) == 1 + assert os.path.exists(result[0]) + assert os.path.basename(result[0]) == "checkpoint_1.pt" + + # Verify content + with open(result[0]) as f: + content = f.read() + assert content == "model weights version 1" + + +@pytest.mark.asyncio +async def test_load_checkpoint_filename_with_dot(artifact, mock_s3): + """Test load_checkpoint with filename containing dots (e.g., checkpoint.v1.0.pt).""" + org_id = uuid.uuid4() + team_id = uuid.uuid4() + user_id = uuid.uuid4() + exp_id = uuid.uuid4() + + alpha.init(org_id=org_id, team_id=team_id, user_id=user_id) + + with tempfile.TemporaryDirectory() as tmpdir: + # Push checkpoint with filename containing dots + test_file = os.path.join(tmpdir, "checkpoint.v1.0.pt") + with open(test_file, "w") as f: + f.write("model weights version 1.0") + + artifact.push( + repo_name=f"{org_id}/{team_id}/{exp_id}/ckpt", + paths=test_file, + ) + + # Load checkpoint by filename with dots + output_dir = os.path.join(tmpdir, "download") + result = await alpha.load_checkpoint( + id=exp_id, version_or_filename="checkpoint.v1.0.pt", output_dir=output_dir + ) + + # Verify checkpoint was downloaded + assert result is not None + assert len(result) == 1 + assert os.path.exists(result[0]) + assert os.path.basename(result[0]) == "checkpoint.v1.0.pt" + + # Verify content + with open(result[0]) as f: + content = f.read() + assert content == "model weights version 1.0" + + +@pytest.mark.asyncio +async def test_load_checkpoint_nonexistent(artifact, mock_s3): + """Test load_checkpoint with nonexistent checkpoint returns empty list for S3.""" + org_id = uuid.uuid4() + team_id = uuid.uuid4() + user_id = uuid.uuid4() + exp_id = uuid.uuid4() + + alpha.init(org_id=org_id, team_id=team_id, user_id=user_id) + + with tempfile.TemporaryDirectory() as tmpdir: + # For S3, trying to pull a non-existent version/folder should return empty list + result = await alpha.load_checkpoint( + id=exp_id, version_or_filename="nonexistent", output_dir=tmpdir + ) + assert result == [] + + +@pytest.mark.asyncio +async def test_load_checkpoint_folder(artifact, mock_s3): + """Test load_checkpoint loading all files from a folder prefix for S3 backend.""" + org_id = uuid.uuid4() + team_id = uuid.uuid4() + user_id = uuid.uuid4() + exp_id = uuid.uuid4() + + alpha.init(org_id=org_id, team_id=team_id, user_id=user_id) + + with tempfile.TemporaryDirectory() as tmpdir: + # Push checkpoint with multiple files in a folder + files = [] + for i in range(3): + file_path = os.path.join(tmpdir, f"layer_{i}.pt") + with open(file_path, "w") as f: + f.write(f"layer {i} weights") + files.append(file_path) + + # Push to a folder path (e.g., "epoch_10") + artifact.push( + repo_name=f"{org_id}/{team_id}/{exp_id}/ckpt", + paths=files, + version="epoch_10", + ) + + # Load all files from the folder + output_dir = os.path.join(tmpdir, "download") + result = await alpha.load_checkpoint( + id=exp_id, version_or_filename="epoch_10", output_dir=output_dir + ) + + # Verify all files were downloaded + assert result is not None + assert len(result) == 3 + + for i in range(3): + filename = f"layer_{i}.pt" + assert any(filename in r for r in result) + + file_path = os.path.join(output_dir, filename) + assert os.path.exists(file_path) + + with open(file_path) as f: + assert f.read() == f"layer {i} weights" + + +@pytest.mark.asyncio +async def test_load_checkpoint_single_file(artifact, mock_s3): + """Test load_checkpoint pulling a single file directly (flat structure).""" + org_id = uuid.uuid4() + team_id = uuid.uuid4() + user_id = uuid.uuid4() + exp_id = uuid.uuid4() + + alpha.init(org_id=org_id, team_id=team_id, user_id=user_id) + + with tempfile.TemporaryDirectory() as tmpdir: + # Push checkpoint without version (flat structure) + test_file = os.path.join(tmpdir, "checkpoint.pt") + with open(test_file, "w") as f: + f.write("model weights") + + artifact.push( + repo_name=f"{org_id}/{team_id}/{exp_id}/ckpt", + paths=test_file, + ) + + # Load checkpoint by filename + output_dir = os.path.join(tmpdir, "download") + result = await alpha.load_checkpoint( + id=exp_id, version_or_filename="checkpoint.pt", output_dir=output_dir + ) + + # Verify file was downloaded + assert result is not None + assert len(result) == 1 + assert os.path.exists(result[0]) + assert os.path.basename(result[0]) == "checkpoint.pt" + + # Verify content + with open(result[0]) as f: + assert f.read() == "model weights"