Skip to content

Commit c82695e

Browse files
committed
Merge branch 'add_options_mixed' of https://github.com/timonmerk/CEBRA into add_options_mixed
2 parents 9a23197 + 30a70e5 commit c82695e

1 file changed

Lines changed: 93 additions & 0 deletions

File tree

docs/source/usage.rst

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1316,3 +1316,96 @@ Below is the documentation on the available arguments.
13161316
--train-ratio 0.8 Ratio of train dataset. The remaining will be used for valid and test split.
13171317
--valid-ratio 0.1 Ratio of validation set after the train data split. The remaining will be test split
13181318
--share-model
1319+
1320+
Model initialization using the Torch API
1321+
----------------------------------------
1322+
1323+
The scikit-learn API provides parametrization to many common use cases.
1324+
The Torch API however allows for more flexibility and customization, for e.g.
1325+
sampling, criterions, and data loaders.
1326+
1327+
In this minimal example we show how to initialize a CEBRA model using the Torch API.
1328+
Here the :py:class:`cebra.data.single_session.DiscreteDataLoader`
1329+
gets initilized which also allows the `prior` to be directly parametrized.
1330+
1331+
👉 For an example notebook using the Torch API check out the :doc:`demo_notebooks/Demo_Allen`.
1332+
1333+
1334+
.. testcode::
1335+
1336+
import numpy as np
1337+
import cebra.datasets
1338+
from cebra import plot_embedding
1339+
import torch
1340+
1341+
if torch.cuda.is_available():
1342+
device = "cuda"
1343+
else:
1344+
device = "cpu"
1345+
1346+
neural_data = cebra.load_data(file="neural_data.npz", key="neural")
1347+
1348+
discrete_label = cebra.load_data(
1349+
file="auxiliary_behavior_data.h5", key="auxiliary_variables", columns=["discrete"],
1350+
)
1351+
1352+
# 1. Define Cebra Dataset
1353+
InputData = cebra.data.TensorDataset(
1354+
torch.from_numpy(neural_data).type(torch.FloatTensor),
1355+
discrete=torch.from_numpy(np.array(discrete_label[:, 0])).type(torch.LongTensor),
1356+
).to(device)
1357+
1358+
# 2. Define Cebra Model
1359+
neural_model = cebra.models.init(
1360+
name="offset10-model",
1361+
num_neurons=InputData.input_dimension,
1362+
num_units=32,
1363+
num_output=2,
1364+
).to(device)
1365+
1366+
InputData.configure_for(neural_model)
1367+
1368+
# 3. Define Loss Function Criterion and Optimizer
1369+
Crit = cebra.models.criterions.LearnableCosineInfoNCE(
1370+
temperature=0.001,
1371+
min_temperature=0.0001
1372+
).to(device)
1373+
1374+
Opt = torch.optim.Adam(
1375+
list(neural_model.parameters()) + list(Crit.parameters()),
1376+
lr=0.001,
1377+
weight_decay=0,
1378+
)
1379+
1380+
# 4. Initialize Cebra Model
1381+
solver = cebra.solver.init(
1382+
name="single-session",
1383+
model=neural_model,
1384+
criterion=Crit,
1385+
optimizer=Opt,
1386+
tqdm_on=True,
1387+
).to(device)
1388+
1389+
# 5. Define Data Loader
1390+
loader = cebra.data.single_session.DiscreteDataLoader(
1391+
dataset=InputData, num_steps=10, batch_size=200, prior="uniform"
1392+
)
1393+
1394+
# 6. Fit Model
1395+
solver.fit(loader=loader)
1396+
1397+
# 7. Transform Embedding
1398+
TrainBatches = np.lib.stride_tricks.sliding_window_view(
1399+
neural_data, neural_model.get_offset().__len__(), axis=0
1400+
)
1401+
1402+
X_train_emb = solver.transform(
1403+
torch.from_numpy(TrainBatches[:]).type(torch.FloatTensor).to(device)
1404+
).to(device)
1405+
1406+
# 8. Plot Embedding
1407+
plot_embedding(
1408+
X_train_emb,
1409+
discrete_label[neural_model.get_offset().__len__() - 1 :, 0],
1410+
markersize=10,
1411+
)

0 commit comments

Comments
 (0)