Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 48 additions & 22 deletions src/spikeinterface/extractors/phykilosortextractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,12 @@
ComputeTemplates,
create_sorting_analyzer,
SortingAnalyzer,
aggregate_channels,
)
from spikeinterface.core.core_tools import define_function_from_class

from spikeinterface.postprocessing import ComputeSpikeAmplitudes, ComputeSpikeLocations
from probeinterface import read_prb, Probe
from probeinterface import read_prb, Probe, ProbeGroup


class BasePhyKilosortSortingExtractor(BaseSorting):
Expand Down Expand Up @@ -314,7 +315,9 @@ def __init__(self, folder_path: Path | str, keep_good_only: bool = False, remove
read_kilosort = define_function_from_class(source_class=KiloSortSortingExtractor, name="read_kilosort")


def read_kilosort_as_analyzer(folder_path, unwhiten=True, gain_to_uV=None, offset_to_uV=None) -> SortingAnalyzer:
def read_kilosort_as_analyzer(
folder_path, recording=None, unwhiten=True, gain_to_uV=None, offset_to_uV=None
) -> SortingAnalyzer:
"""
Load Kilosort output into a SortingAnalyzer. Output from Kilosort version 4.1 and
above are supported. The function may work on older versions of Kilosort output,
Expand All @@ -324,6 +327,8 @@ def read_kilosort_as_analyzer(folder_path, unwhiten=True, gain_to_uV=None, offse
----------
folder_path : str or Path
Path to the output Phy folder (containing the params.py).
recording : BaseRecording
A spikeinterface Recording object which will be attached to the analyzer
unwhiten : bool, default: True
Unwhiten the templates computed by kilosort.
gain_to_uV : float | None, default: None
Expand Down Expand Up @@ -359,25 +364,49 @@ def read_kilosort_as_analyzer(folder_path, unwhiten=True, gain_to_uV=None, offse

if (phy_path / "probe.prb").is_file():
probegroup = read_prb(phy_path / "probe.prb")
if len(probegroup.probes) > 0:
warnings.warn("Found more than one probe. Selecting the first probe in ProbeGroup.")
probe = probegroup.probes[0]
elif (phy_path / "channel_positions.npy").is_file():
probe = Probe(si_units="um")
channel_positions = np.load(phy_path / "channel_positions.npy")
probe.set_contacts(channel_positions)
probe.set_device_channel_indices(range(probe.get_contact_count()))
channel_map = np.load(phy_path / "channel_map.npy")
probe.set_device_channel_indices(channel_map)

probegroup = ProbeGroup()
probegroup.add_probe(probe)
else:
AssertionError(f"Cannot read probe layout from folder {phy_path}.")

# to make the initial analyzer, we'll use a fake recording and set it to None later
recording, _ = generate_ground_truth_recording(
probe=probe,
sampling_frequency=sampling_frequency,
durations=[duration],
num_units=1,
seed=1205,
)
if recording is not None:
# Re-wire recording to match the output from the kilosort probe
user_gave_recording = True
all_contact_positions = np.vstack([probe.contact_positions for probe in probegroup.probes])

new_device_channel_indices = []
for recording_channel_location in recording.get_channel_locations():
for channel_index, probe_contact_position in enumerate(all_contact_positions):
if np.all(recording_channel_location == probe_contact_position):
new_device_channel_indices.append(channel_index)
break

if len(new_device_channel_indices) != len(all_contact_positions):
raise ValueError("The channel locations in your `recording` and the probe channel locations do not match.")

recording.get_probe().set_device_channel_indices(new_device_channel_indices)

else:
user_gave_recording = False
# to make the initial analyzer, we'll use a fake recording and set it to None later
recordings = []
for probe in probegroup.probes:
one_recording, _ = generate_ground_truth_recording(
probe=probe,
sampling_frequency=sampling_frequency,
durations=[duration],
num_units=1,
seed=1205,
)
recordings.append(one_recording)
recording = aggregate_channels(recordings)

sparsity = _make_sparsity_from_templates(sorting, recording, phy_path)

Expand All @@ -397,7 +426,9 @@ def read_kilosort_as_analyzer(folder_path, unwhiten=True, gain_to_uV=None, offse
)
_make_locations(sorting_analyzer, phy_path)

sorting_analyzer._recording = None
if not user_gave_recording:
sorting_analyzer._recording = None

return sorting_analyzer


Expand All @@ -413,14 +444,9 @@ def _make_locations(sorting_analyzer, kilosort_output_path):
else:
return

# Check that the spike locations vector is the same size as the spike vector
# When recording is given, need to trim spike locations to match spikes in sorting
num_spikes = len(sorting_analyzer.sorting.to_spike_vector())
num_spike_locs = len(locs_np)
if num_spikes != num_spike_locs:
warnings.warn(
"The number of spikes does not match the number of spike locations in `spike_positions.npy`. Skipping spike locations."
)
return
locs_np = locs_np[:num_spikes]

num_dims = len(locs_np[0])
column_names = ["x", "y", "z"][:num_dims]
Expand Down
Loading