@@ -1316,3 +1316,96 @@ Below is the documentation on the available arguments.
13161316 --train-ratio 0.8 Ratio of train dataset. The remaining will be used for valid and test split.
13171317 --valid-ratio 0.1 Ratio of validation set after the train data split. The remaining will be test split
13181318 --share-model
1319+
1320+ Model initialization using the Torch API
1321+ ----------------------------------------
1322+
1323+ The scikit-learn API provides parametrization to many common use cases.
1324+ The Torch API however allows for more flexibility and customization, for e.g.
1325+ sampling, criterions, and data loaders.
1326+
1327+ In this minimal example we show how to initialize a CEBRA model using the Torch API.
1328+ Here the :py:class: `cebra.data.single_session.DiscreteDataLoader `
1329+ gets initilized which also allows the `prior ` to be directly parametrized.
1330+
1331+ 👉 For an example notebook using the Torch API check out the :doc: `demo_notebooks/Demo_Allen `.
1332+
1333+
1334+ .. testcode ::
1335+
1336+ import numpy as np
1337+ import cebra.datasets
1338+ from cebra import plot_embedding
1339+ import torch
1340+
1341+ if torch.cuda.is_available():
1342+ device = "cuda"
1343+ else:
1344+ device = "cpu"
1345+
1346+ neural_data = cebra.load_data(file="neural_data.npz", key="neural")
1347+
1348+ discrete_label = cebra.load_data(
1349+ file="auxiliary_behavior_data.h5", key="auxiliary_variables", columns=["discrete"],
1350+ )
1351+
1352+ # 1. Define Cebra Dataset
1353+ InputData = cebra.data.TensorDataset(
1354+ torch.from_numpy(neural_data).type(torch.FloatTensor),
1355+ discrete=torch.from_numpy(np.array(discrete_label[:, 0])).type(torch.LongTensor),
1356+ ).to(device)
1357+
1358+ # 2. Define Cebra Model
1359+ neural_model = cebra.models.init(
1360+ name="offset10-model",
1361+ num_neurons=InputData.input_dimension,
1362+ num_units=32,
1363+ num_output=2,
1364+ ).to(device)
1365+
1366+ InputData.configure_for(neural_model)
1367+
1368+ # 3. Define Loss Function Criterion and Optimizer
1369+ Crit = cebra.models.criterions.LearnableCosineInfoNCE(
1370+ temperature=0.001,
1371+ min_temperature=0.0001
1372+ ).to(device)
1373+
1374+ Opt = torch.optim.Adam(
1375+ list(neural_model.parameters()) + list(Crit.parameters()),
1376+ lr=0.001,
1377+ weight_decay=0,
1378+ )
1379+
1380+ # 4. Initialize Cebra Model
1381+ solver = cebra.solver.init(
1382+ name="single-session",
1383+ model=neural_model,
1384+ criterion=Crit,
1385+ optimizer=Opt,
1386+ tqdm_on=True,
1387+ ).to(device)
1388+
1389+ # 5. Define Data Loader
1390+ loader = cebra.data.single_session.DiscreteDataLoader(
1391+ dataset=InputData, num_steps=10, batch_size=200, prior="uniform"
1392+ )
1393+
1394+ # 6. Fit Model
1395+ solver.fit(loader=loader)
1396+
1397+ # 7. Transform Embedding
1398+ TrainBatches = np.lib.stride_tricks.sliding_window_view(
1399+ neural_data, neural_model.get_offset().__len__(), axis=0
1400+ )
1401+
1402+ X_train_emb = solver.transform(
1403+ torch.from_numpy(TrainBatches[:]).type(torch.FloatTensor).to(device)
1404+ ).to(device)
1405+
1406+ # 8. Plot Embedding
1407+ plot_embedding(
1408+ X_train_emb,
1409+ discrete_label[neural_model.get_offset().__len__() - 1 :, 0],
1410+ markersize=10,
1411+ )
0 commit comments