From 635c48ace8d31f8ab8ced05955a7a2fd80b27b2f Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Mon, 18 Dec 2023 13:06:46 +0000 Subject: [PATCH 1/6] Adding sort by group. --- spikewrap/examples/example_full_pipeline.py | 11 +++-- spikewrap/pipeline/full_pipeline.py | 5 ++ spikewrap/pipeline/sort.py | 51 +++++++++++++++++---- tests/test_integration/base.py | 4 ++ tests/test_integration/test_slurm.py | 3 ++ 5 files changed, 62 insertions(+), 12 deletions(-) diff --git a/spikewrap/examples/example_full_pipeline.py b/spikewrap/examples/example_full_pipeline.py index 41d8194..41b050c 100644 --- a/spikewrap/examples/example_full_pipeline.py +++ b/spikewrap/examples/example_full_pipeline.py @@ -4,15 +4,19 @@ from spikewrap.pipeline.full_pipeline import run_full_pipeline base_path = Path( - r"C:\fMRIData\git-repo\spikewrap\tests\data\small_toy_data" # r"/ceph/neuroinformatics/neuroinformatics/scratch/jziminski/ephys/test_data/steve_multi_run/1119617/time-long_origdata" + r"C:\fMRIData\git-repo\spikewrap\tests\data\small_toy_data", ) sub_name = "sub-001_type-test" -sessions_and_runs = {"all": ["all"]} + +sessions_and_runs = { + "all": ["all"], +} + # sub_name = "1119617" # sessions_and_runs = { -# "ses-001": ["1119617_LSE1_shank12_g0"], +# "ses-001": ["1119617_LSE1_shank12_g0"], # } config_name = "test_default" @@ -28,6 +32,7 @@ "spikeinterface", config_name, sorter, + sort_by_group=True, save_preprocessing_chunk_size=30000, existing_preprocessed_data="overwrite", existing_sorting_output="overwrite", diff --git a/spikewrap/pipeline/full_pipeline.py b/spikewrap/pipeline/full_pipeline.py index f82c379..7a66555 100644 --- a/spikewrap/pipeline/full_pipeline.py +++ b/spikewrap/pipeline/full_pipeline.py @@ -26,6 +26,7 @@ def run_full_pipeline( config_name: str = "default", sorter: str = "kilosort2_5", preprocess_by_group: bool = False, + sort_by_group: bool = False, concat_sessions_for_sorting: bool = False, concat_runs_for_sorting: bool = False, existing_preprocessed_data: HandleExisting = "fail_if_exists", @@ -53,6 +54,7 @@ def run_full_pipeline( "config_name": config_name, "sorter": sorter, "preprocess_by_group": preprocess_by_group, + "sort_by_group": sort_by_group, "concat_sessions_for_sorting": concat_sessions_for_sorting, "concat_runs_for_sorting": concat_runs_for_sorting, "existing_preprocessed_data": existing_preprocessed_data, @@ -72,6 +74,7 @@ def run_full_pipeline( config_name, sorter, preprocess_by_group, + sort_by_group, concat_sessions_for_sorting, concat_runs_for_sorting, existing_preprocessed_data, @@ -91,6 +94,7 @@ def _run_full_pipeline( config_name: str = "default", sorter: str = "kilosort2_5", preprocess_by_group: bool = False, + sort_by_group: bool = False, concat_sessions_for_sorting: bool = False, concat_runs_for_sorting: bool = False, existing_preprocessed_data: HandleExisting = "fail_if_exists", @@ -231,6 +235,7 @@ def _run_full_pipeline( sub_name, sessions_and_runs, sorter, + sort_by_group, concat_sessions_for_sorting, concat_runs_for_sorting, sorter_options, diff --git a/spikewrap/pipeline/sort.py b/spikewrap/pipeline/sort.py index 5680002..9b4cc1a 100644 --- a/spikewrap/pipeline/sort.py +++ b/spikewrap/pipeline/sort.py @@ -7,6 +7,7 @@ if TYPE_CHECKING: from spikewrap.utils.custom_types import HandleExisting +import numpy as np import spikeinterface.sorters as ss from spikewrap.data_classes.sorting import ( @@ -27,6 +28,7 @@ def run_sorting( sub_name: str, sessions_and_runs: Dict[str, List[str]], sorter: str, + sort_per_group: bool = False, concatenate_sessions: bool = False, concatenate_runs: bool = False, sorter_options: Optional[Dict] = None, @@ -47,6 +49,8 @@ def run_sorting( "sub_name": sub_name, "sessions_and_runs": sessions_and_runs, "concatenate_sessions": concatenate_sessions, + "sorter": sorter, + "sort_per_group": sort_per_group, "concatenate_runs": concatenate_runs, "sorter_options": sorter_options, "existing_sorting_output": existing_sorting_output, @@ -59,6 +63,7 @@ def run_sorting( sub_name, sessions_and_runs, sorter, + sort_per_group, concatenate_sessions, concatenate_runs, sorter_options, @@ -72,6 +77,7 @@ def _run_sorting( sub_name: str, sessions_and_runs: Dict[str, List[str]], sorter: str, + sort_per_group: bool, concatenate_sessions: bool = False, concatenate_runs: bool = False, sorter_options: Optional[Dict] = None, @@ -179,6 +185,7 @@ def _run_sorting( sorting_data, singularity_image, docker_image, + sort_per_group, existing_sorting_output=existing_sorting_output, **sorter_options_dict, ) @@ -223,6 +230,7 @@ def run_sorting_on_all_runs( sorting_data: SortingData, singularity_image: Union[Literal[True], None, str], docker_image: Optional[Literal[True]], + sort_per_group: bool, existing_sorting_output: HandleExisting, **sorter_options_dict, ) -> None: @@ -256,6 +264,7 @@ def run_sorting_on_all_runs( for ses_name, run_name in sorting_data.get_sorting_sessions_and_runs(): sorting_output_path = sorting_data.get_sorting_path(ses_name, run_name) + preprocessed_recording = sorting_data.get_preprocessed_recordings( ses_name, run_name ) @@ -284,15 +293,39 @@ def run_sorting_on_all_runs( quick_safety_check(existing_sorting_output, sorting_output_path) - ss.run_sorter( - sorting_data.sorter, - preprocessed_recording, - output_folder=sorting_output_path, - singularity_image=singularity_image, - docker_image=docker_image, - remove_existing_folder=True, - **sorter_options_dict, - ) + if sort_per_group: + if np.unqiue(preprocessed_recording.get_property("group")).size == 1: + raise RuntimeError( + "`sort_per_group` is `True` but the recording" + "only has one channel group. Set `sort_per_group`" + "to `False` for this recording." + ) + + split_recordings = preprocessed_recording.split_by("group") + + for group_idx, recording in split_recordings.items(): + output_folder = sorting_output_path / f"group_{group_idx}" + + ss.run_sorter( + sorting_data.sorter, + preprocessed_recording, + output_folder=output_folder, + singularity_image=singularity_image, + docker_image=docker_image, + remove_existing_folder=True, + **sorter_options_dict, + ) + + else: + ss.run_sorter( + sorting_data.sorter, + preprocessed_recording, + output_folder=sorting_output_path, + singularity_image=singularity_image, + docker_image=docker_image, + remove_existing_folder=True, + **sorter_options_dict, + ) sorting_data.save_sorting_info(ses_name, run_name) diff --git a/tests/test_integration/base.py b/tests/test_integration/base.py index 7a4fc2b..ce118a2 100644 --- a/tests/test_integration/base.py +++ b/tests/test_integration/base.py @@ -113,6 +113,8 @@ def run_full_pipeline( data_format, config_name="test_default", sorter="kilosort2_5", + preprocess_by_group=False, + sort_by_group=False, concatenate_sessions=False, concatenate_runs=False, existing_preprocessed_data="fail_if_exists", @@ -129,6 +131,8 @@ def run_full_pipeline( data_format=data_format, config_name=config_name, sorter=sorter, + preprocess_by_group=preprocess_by_group, + sort_by_group=sort_by_group, concat_sessions_for_sorting=concatenate_sessions, concat_runs_for_sorting=concatenate_runs, existing_preprocessed_data=existing_preprocessed_data, diff --git a/tests/test_integration/test_slurm.py b/tests/test_integration/test_slurm.py index 9a4522c..da6d78a 100644 --- a/tests/test_integration/test_slurm.py +++ b/tests/test_integration/test_slurm.py @@ -28,6 +28,9 @@ class TestSLURM(BaseTest): # TODO: cannot test the actual output. # can test recording at least. + # TODO: this is just a smoke test. Need to test against actual sorting + # to ensure matches as expected. Missed case where sorter was not passed + # and default was used! @pytest.mark.skipif(CAN_SLURM is False, reason="CAN_SLURM is false") @pytest.mark.parametrize( "concatenation", [(False, False), (False, True), (True, True)] From 6890270bb61277845120cf11a6b6ab7688779988 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Mon, 18 Dec 2023 15:44:07 +0000 Subject: [PATCH 2/6] Working version of split by group for sorting. --- spikewrap/data_classes/postprocessing.py | 5 +- spikewrap/data_classes/sorting.py | 46 ++++++++---- spikewrap/examples/example_full_pipeline.py | 4 - spikewrap/pipeline/full_pipeline.py | 71 +++++++++++------- spikewrap/pipeline/sort.py | 83 ++++++++++----------- spikewrap/utils/utils.py | 5 ++ 6 files changed, 120 insertions(+), 94 deletions(-) diff --git a/spikewrap/data_classes/postprocessing.py b/spikewrap/data_classes/postprocessing.py index 566dbdc..d762f14 100644 --- a/spikewrap/data_classes/postprocessing.py +++ b/spikewrap/data_classes/postprocessing.py @@ -42,6 +42,7 @@ class PostprocessingData: def __init__(self, sorting_path: Union[str, Path]) -> None: self.sorting_path = Path(sorting_path) + self.sorter_output_path = self.sorting_path / "sorter_output" self.sorting_info_path = self.sorting_path / utils.canonical_names( "sorting_yaml" @@ -151,9 +152,7 @@ def get_sorting_extractor_object(self) -> si.SortingExtractor: return sorting_without_excess_spikes def get_postprocessing_path(self) -> Path: - return self.sorting_data.get_postprocessing_path( - self.sorted_ses_name, self.sorted_run_name - ) + return utils.make_postprocessing_path(self.sorting_path) def get_quality_metrics_path(self) -> Path: return self.get_postprocessing_path() / "quality_metrics.csv" diff --git a/spikewrap/data_classes/sorting.py b/spikewrap/data_classes/sorting.py index 6e4dd2f..860df51 100644 --- a/spikewrap/data_classes/sorting.py +++ b/spikewrap/data_classes/sorting.py @@ -186,19 +186,31 @@ def check_ses_or_run_folders_in_datetime_order( # Paths # ---------------------------------------------------------------------------------- - def get_sorting_path(self, ses_name: str, run_name: Optional[str]) -> Path: - return self._get_base_sorting_path(ses_name, run_name) / "sorting" - - def get_sorter_output_path(self, ses_name: str, run_name: Optional[str]) -> Path: - return self.get_sorting_path(ses_name, run_name) / "sorter_output" + def get_sorting_path( + self, ses_name: str, run_name: Optional[str], group_idx: Optional[int] = None + ) -> Path: + if group_idx is None: + format_group_name = "" + else: + format_group_name = f"group-{group_idx}" - def _get_sorting_info_path(self, ses_name: str, run_name: Optional[str]) -> Path: - return self.get_sorting_path(ses_name, run_name) / utils.canonical_names( - "sorting_yaml" + return ( + self.get_base_sorting_path(ses_name, run_name) + / format_group_name + / "sorting" ) - def get_postprocessing_path(self, ses_name: str, run_name: Optional[str]) -> Path: - return self._get_base_sorting_path(ses_name, run_name) / "postprocessing" + def get_sorter_output_path( + self, ses_name: str, run_name: Optional[str], group_idx: Optional[int] = None + ) -> Path: + return self.get_sorting_path(ses_name, run_name, group_idx) / "sorter_output" + + def _get_sorting_info_path( + self, ses_name: str, run_name: Optional[str], group_idx: Optional[int] = None + ) -> Path: + return self.get_sorting_path( + ses_name, run_name, group_idx + ) / utils.canonical_names("sorting_yaml") def _validate_derivatives_inputs(self): self._validate_inputs( @@ -247,7 +259,9 @@ def _make_run_name_from_multiple_run_names(self, run_names: List[str]) -> str: # Sorting info # ---------------------------------------------------------------------------------- - def save_sorting_info(self, ses_name: str, run_name: str) -> None: + def save_sorting_info( + self, ses_name: str, run_name: str, group_idx: Optional[int] = None + ) -> None: """ Save a sorting_info.yaml file containing a dictionary holding important information on the sorting. This is for provenance. @@ -289,7 +303,7 @@ def save_sorting_info(self, ses_name: str, run_name: str) -> None: sorting_info["datetime_created"] = utils.get_formatted_datetime() utils.dump_dict_to_yaml( - self._get_sorting_info_path(ses_name, run_name), sorting_info + self._get_sorting_info_path(ses_name, run_name, group_idx), sorting_info ) @property @@ -325,7 +339,7 @@ def get_preprocessed_recordings( raise NotImplementedError @abstractmethod - def _get_base_sorting_path(self, ses_name: str, run_name: Optional[str]) -> Path: + def get_base_sorting_path(self, ses_name: str, run_name: Optional[str]) -> Path: raise NotImplementedError @@ -364,7 +378,7 @@ def get_preprocessed_recordings( self.assert_names(ses_name, run_name) return self[self.concat_ses_name()] - def _get_base_sorting_path(self, ses_name: str, run_name: Optional[str]) -> Path: + def get_base_sorting_path(self, ses_name: str, run_name: Optional[str]) -> Path: """""" self.assert_names(ses_name, run_name) @@ -447,7 +461,7 @@ def get_preprocessed_recordings( return self[ses_name][run_name] - def _get_base_sorting_path(self, ses_name: str, run_name: Optional[str]) -> Path: + def get_base_sorting_path(self, ses_name: str, run_name: Optional[str]) -> Path: assert run_name == self.concat_run_name(ses_name) assert run_name is not None @@ -501,7 +515,7 @@ def get_preprocessed_recordings( assert run_name is not None return self[ses_name][run_name] - def _get_base_sorting_path(self, ses_name: str, run_name: Optional[str]) -> Path: + def get_base_sorting_path(self, ses_name: str, run_name: Optional[str]) -> Path: assert run_name is not None # TODO: centralise paths!!# TODO: centralise paths!!# TODO: centralise paths!! return ( diff --git a/spikewrap/examples/example_full_pipeline.py b/spikewrap/examples/example_full_pipeline.py index 41b050c..93dd67a 100644 --- a/spikewrap/examples/example_full_pipeline.py +++ b/spikewrap/examples/example_full_pipeline.py @@ -39,10 +39,6 @@ overwrite_postprocessing=True, concat_sessions_for_sorting=False, # TODO: validate this at the start, in `run_full_pipeline` concat_runs_for_sorting=False, - # existing_preprocessed_data="skip_if_exists", # this is kind of confusing... - # existing_sorting_output="overwrite", - # overwrite_postprocessing=True, - # slurm_batch=False, ) print(f"TOOK {time.time() - t}") diff --git a/spikewrap/pipeline/full_pipeline.py b/spikewrap/pipeline/full_pipeline.py index 7a66555..41cc821 100644 --- a/spikewrap/pipeline/full_pipeline.py +++ b/spikewrap/pipeline/full_pipeline.py @@ -245,20 +245,26 @@ def _run_full_pipeline( # Run Postprocessing for ses_name, run_name in sorting_data.get_sorting_sessions_and_runs(): - sorting_path = sorting_data.get_sorting_path(ses_name, run_name) - - postprocess_data = run_postprocess( - sorting_path, - overwrite_postprocessing=overwrite_postprocessing, - existing_waveform_data="fail_if_exists", - waveform_options=waveform_options, - ) + for sorting_path in _get_sorting_paths( + sorting_data, ses_name, run_name, sort_by_group + ): + postprocess_data = run_postprocess( + sorting_path, + overwrite_postprocessing=overwrite_postprocessing, + existing_waveform_data="fail_if_exists", + waveform_options=waveform_options, + ) # Delete intermediate files for ses_name, run_name in sorting_data.get_sorting_sessions_and_runs(): - handle_delete_intermediate_files( - ses_name, run_name, sorting_data, delete_intermediate_files - ) + for sorting_path in _get_sorting_paths( + sorting_data, ses_name, run_name, sort_by_group + ): + postprocessing_path = utils.make_postprocessing_path(sorting_path) + + handle_delete_intermediate_files( + sorting_path, postprocessing_path, delete_intermediate_files + ) logs.stop_logging() return ( @@ -267,15 +273,37 @@ def _run_full_pipeline( ) +def _get_sorting_paths( + sorting_data: SortingData, ses_name: str, run_name: str, sort_by_group: bool +) -> List[Path]: + """ """ + if sort_by_group: + all_group_paths = sorting_data.get_base_sorting_path(ses_name, run_name).glob( + "group-*" + ) + group_indexes = [ + int(group.name.split("group-")[1]) + for group in all_group_paths + if group.is_dir() + ] # TODO: kind of hacky + all_sorting_paths = [ + sorting_data.get_sorting_path(ses_name, run_name, idx) + for idx in group_indexes + ] + else: + all_sorting_paths = [sorting_data.get_sorting_path(ses_name, run_name)] + + return all_sorting_paths + + # -------------------------------------------------------------------------------------- # Remove Intermediate Files # -------------------------------------------------------------------------------------- def handle_delete_intermediate_files( - ses_name: str, - run_name: Optional[str], - sorting_data: SortingData, + sorting_path: Path, + postprocessing_path: Path, delete_intermediate_files: DeleteIntermediate, ): """ @@ -284,22 +312,13 @@ def handle_delete_intermediate_files( for Kilosort). See `run_full_pipeline` for inputs """ if "recording.dat" in delete_intermediate_files: - if ( - recording_file := sorting_data.get_sorter_output_path(ses_name, run_name) - / "recording.dat" - ).is_file(): + if (recording_file := sorting_path / "recording.dat").is_file(): recording_file.unlink() if "temp_wh.dat" in delete_intermediate_files: - if ( - recording_file := sorting_data.get_sorter_output_path(ses_name, run_name) - / "temp_wh.dat" - ).is_file(): + if (recording_file := sorting_path / "temp_wh.dat").is_file(): recording_file.unlink() if "waveforms" in delete_intermediate_files: - if ( - waveforms_path := sorting_data.get_postprocessing_path(ses_name, run_name) - / "waveforms" - ).is_dir(): + if (waveforms_path := postprocessing_path / "waveforms").is_dir(): shutil.rmtree(waveforms_path) diff --git a/spikewrap/pipeline/sort.py b/spikewrap/pipeline/sort.py index 9b4cc1a..128a0a6 100644 --- a/spikewrap/pipeline/sort.py +++ b/spikewrap/pipeline/sort.py @@ -7,7 +7,6 @@ if TYPE_CHECKING: from spikewrap.utils.custom_types import HandleExisting -import numpy as np import spikeinterface.sorters as ss from spikewrap.data_classes.sorting import ( @@ -263,63 +262,56 @@ def run_sorting_on_all_runs( utils.message_user(f"Starting {sorting_data.sorter} sorting...") for ses_name, run_name in sorting_data.get_sorting_sessions_and_runs(): - sorting_output_path = sorting_data.get_sorting_path(ses_name, run_name) + utils.message_user(f"Sorting session: {ses_name} \n" f"run: {run_name}...") - preprocessed_recording = sorting_data.get_preprocessed_recordings( + orig_preprocessed_recording = sorting_data.get_preprocessed_recordings( ses_name, run_name ) - utils.message_user( - f"Sorting session: {ses_name} \n" - f"run: {ses_name}..." - # TODO: I think can just use run_name now? - ) + if sort_per_group: + split_preprocessing = orig_preprocessed_recording.split_by("group") - if sorting_output_path.is_dir(): - if existing_sorting_output == "fail_if_exists": + if len(split_preprocessing.keys()) == 1: raise RuntimeError( - f"Sorting output already exists at {sorting_output_path} and" - f"`existing_sorting_output` is set to 'fail_if_exists'." - ) - - elif existing_sorting_output == "skip_if_exists": - utils.message_user( - f"Sorting output already exists at {sorting_output_path}. Nothing " - f"will be done. The existing sorting will be used for " - f"postprocessing " - f"if running with `run_full_pipeline`" + "`sort_per_group` is `True` but the recording only has " + "one channel group. Set `sort_per_group`to `False` " + "for this recording." ) - continue - - quick_safety_check(existing_sorting_output, sorting_output_path) - if sort_per_group: - if np.unqiue(preprocessed_recording.get_property("group")).size == 1: - raise RuntimeError( - "`sort_per_group` is `True` but the recording" - "only has one channel group. Set `sort_per_group`" - "to `False` for this recording." - ) + group_indexes = list(split_preprocessing.keys()) + all_preprocessed_recordings = list(split_preprocessing.values()) + else: + group_indexes = [None] + all_preprocessed_recordings = [orig_preprocessed_recording] + + for group_idx, prepro_recording in zip( + group_indexes, all_preprocessed_recordings + ): + sorting_output_path = sorting_data.get_sorting_path( + ses_name, run_name, group_idx + ) - split_recordings = preprocessed_recording.split_by("group") + if sorting_output_path.is_dir(): + if existing_sorting_output == "fail_if_exists": + raise RuntimeError( + f"Sorting output already exists at {sorting_output_path} and" + f"`existing_sorting_output` is set to 'fail_if_exists'." + ) - for group_idx, recording in split_recordings.items(): - output_folder = sorting_output_path / f"group_{group_idx}" + elif existing_sorting_output == "skip_if_exists": + utils.message_user( + f"Sorting output already exists at {sorting_output_path}. Nothing " + f"will be done. The existing sorting will be used for " + f"postprocessing " + f"if running with `run_full_pipeline`" + ) + continue - ss.run_sorter( - sorting_data.sorter, - preprocessed_recording, - output_folder=output_folder, - singularity_image=singularity_image, - docker_image=docker_image, - remove_existing_folder=True, - **sorter_options_dict, - ) + quick_safety_check(existing_sorting_output, sorting_output_path) - else: ss.run_sorter( sorting_data.sorter, - preprocessed_recording, + prepro_recording, output_folder=sorting_output_path, singularity_image=singularity_image, docker_image=docker_image, @@ -327,7 +319,8 @@ def run_sorting_on_all_runs( **sorter_options_dict, ) - sorting_data.save_sorting_info(ses_name, run_name) + # TODO: how does this interact with concat sessions and recordings? + sorting_data.save_sorting_info(ses_name, run_name, group_idx) def quick_safety_check( diff --git a/spikewrap/utils/utils.py b/spikewrap/utils/utils.py index f45c2e7..663cf15 100644 --- a/spikewrap/utils/utils.py +++ b/spikewrap/utils/utils.py @@ -166,6 +166,11 @@ def cast_pp_steps_values( # -------------------------------------------------------------------------------------- +# TODO: move +def make_postprocessing_path(sorting_path: Path): + return sorting_path.parent / "postprocessing" + + def get_keys_first_char( data: Union[PreprocessingData, SortingData], as_int: bool = False ) -> Union[List[str], List[int]]: From a5cd6ee534db621f0cebd3bf5e5f384f00efb5d0 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Mon, 18 Dec 2023 18:49:23 +0000 Subject: [PATCH 3/6] Add `sort_by_group` to relevant tests and add new tests explicitly checking sorting." --- .../small_toy_data/in_container_params.json | 1 + .../in_container_recording.json | 743 +++++++++++++++ .../in_container_sorter_script.py | 46 + .../sorter_output/firings.npz | Bin 0 -> 1574 bytes .../spikeinterface_log.json | 8 + .../spikeinterface_params.json | 25 + .../spikeinterface_recording.json | 869 ++++++++++++++++++ tests/test_integration/base.py | 15 +- tests/test_integration/test_full_pipeline.py | 383 ++++++-- 9 files changed, 2012 insertions(+), 78 deletions(-) create mode 100644 tests/data/small_toy_data/in_container_params.json create mode 100644 tests/data/small_toy_data/in_container_recording.json create mode 100644 tests/data/small_toy_data/in_container_sorter_script.py create mode 100644 tests/data/small_toy_data/mountainsort5_output/sorter_output/firings.npz create mode 100644 tests/data/small_toy_data/mountainsort5_output/spikeinterface_log.json create mode 100644 tests/data/small_toy_data/mountainsort5_output/spikeinterface_params.json create mode 100644 tests/data/small_toy_data/mountainsort5_output/spikeinterface_recording.json diff --git a/tests/data/small_toy_data/in_container_params.json b/tests/data/small_toy_data/in_container_params.json new file mode 100644 index 0000000..0967ef4 --- /dev/null +++ b/tests/data/small_toy_data/in_container_params.json @@ -0,0 +1 @@ +{} diff --git a/tests/data/small_toy_data/in_container_recording.json b/tests/data/small_toy_data/in_container_recording.json new file mode 100644 index 0000000..926e962 --- /dev/null +++ b/tests/data/small_toy_data/in_container_recording.json @@ -0,0 +1,743 @@ +{ + "class": "spikeinterface.core.channelslice.ChannelSliceRecording", + "module": "spikeinterface", + "kwargs": { + "parent_recording": { + "class": "spikeinterface.preprocessing.common_reference.CommonReferenceRecording", + "module": "spikeinterface", + "kwargs": { + "recording": { + "class": "spikeinterface.preprocessing.filter.BandpassFilterRecording", + "module": "spikeinterface", + "kwargs": { + "recording": { + "class": "spikeinterface.preprocessing.phase_shift.PhaseShiftRecording", + "module": "spikeinterface", + "kwargs": { + "recording": { + "class": "spikeinterface.preprocessing.astype.AstypeRecording", + "module": "spikeinterface", + "kwargs": { + "recording": { + "class": "spikeinterface.core.binaryfolder.BinaryFolderRecording", + "module": "spikeinterface", + "kwargs": { + "folder_path": "/fMRIData/git-repo/spikewrap/tests/data/small_toy_data/rawdata/sub-001_type-test/ses-001/ephys/ses-001_run-001" + }, + "version": "0.100.0.dev0", + "annotations": { + "is_filtered": true + }, + "properties": { + "group": [ + 0, + 0, + 0, + 0, + 1, + 1, + 1, + 1, + 2, + 2, + 2, + 2, + 3, + 3, + 3, + 3 + ], + "location": [ + [ + 0.0, + 0.0 + ], + [ + 0.0, + 40.0 + ], + [ + 0.0, + 80.0 + ], + [ + 0.0, + 120.0 + ], + [ + 0.0, + 160.0 + ], + [ + 0.0, + 200.0 + ], + [ + 0.0, + 240.0 + ], + [ + 0.0, + 280.0 + ], + [ + 0.0, + 320.0 + ], + [ + 0.0, + 360.0 + ], + [ + 0.0, + 400.0 + ], + [ + 0.0, + 440.0 + ], + [ + 0.0, + 480.0 + ], + [ + 0.0, + 520.0 + ], + [ + 0.0, + 560.0 + ], + [ + 0.0, + 600.0 + ] + ], + "gain_to_uV": [ + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0 + ], + "offset_to_uV": [ + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0 + ] + }, + "relative_paths": false + }, + "dtype": "*WS8 zkb6))5p>7Y3iZ4ygXQ_2rA4 z1Telp80?>d%&624G66G|_G(jv?dc_%2H2Br5x>03~mkp-fz5lXLt(qgO-@l8+~HGB(07$eUB z6@aiPD7}EgHzzSEH3v0#K?*1C%MmGaBzaX0t4NOmv*hSAh=w_g%31l;N Z1JfGT6cXUg$_7%%3WQET$DRWxJ^&XQ4qyNP literal 0 HcmV?d00001 diff --git a/tests/data/small_toy_data/mountainsort5_output/spikeinterface_log.json b/tests/data/small_toy_data/mountainsort5_output/spikeinterface_log.json new file mode 100644 index 0000000..42f4d75 --- /dev/null +++ b/tests/data/small_toy_data/mountainsort5_output/spikeinterface_log.json @@ -0,0 +1,8 @@ +{ + "sorter_name": "mountainsort5", + "sorter_version": "0.3.0", + "datetime": "2023-12-20T14:07:55.752677", + "runtime_trace": [], + "error": false, + "run_time": 0.2969591000000946 +} diff --git a/tests/data/small_toy_data/mountainsort5_output/spikeinterface_params.json b/tests/data/small_toy_data/mountainsort5_output/spikeinterface_params.json new file mode 100644 index 0000000..93387ee --- /dev/null +++ b/tests/data/small_toy_data/mountainsort5_output/spikeinterface_params.json @@ -0,0 +1,25 @@ +{ + "sorter_name": "mountainsort5", + "sorter_params": { + "scheme": "2", + "detect_threshold": 5.5, + "detect_sign": -1, + "detect_time_radius_msec": 0.5, + "snippet_T1": 20, + "snippet_T2": 20, + "npca_per_channel": 3, + "npca_per_subdivision": 10, + "snippet_mask_radius": 250, + "scheme1_detect_channel_radius": 150, + "scheme2_phase1_detect_channel_radius": 200, + "scheme2_detect_channel_radius": 50, + "scheme2_max_num_snippets_per_training_batch": 200, + "scheme2_training_duration_sec": 300, + "scheme2_training_recording_sampling_mode": "uniform", + "scheme3_block_duration_sec": 1800, + "freq_min": 300, + "freq_max": 6000, + "filter": false, + "whiten": false + } +} diff --git a/tests/data/small_toy_data/mountainsort5_output/spikeinterface_recording.json b/tests/data/small_toy_data/mountainsort5_output/spikeinterface_recording.json new file mode 100644 index 0000000..47e654b --- /dev/null +++ b/tests/data/small_toy_data/mountainsort5_output/spikeinterface_recording.json @@ -0,0 +1,869 @@ +{ + "class": "spikeinterface.core.channelslice.ChannelSliceRecording", + "module": "spikeinterface", + "kwargs": { + "parent_recording": { + "class": "spikeinterface.preprocessing.common_reference.CommonReferenceRecording", + "module": "spikeinterface", + "kwargs": { + "recording": { + "class": "spikeinterface.preprocessing.filter.BandpassFilterRecording", + "module": "spikeinterface", + "kwargs": { + "recording": { + "class": "spikeinterface.preprocessing.phase_shift.PhaseShiftRecording", + "module": "spikeinterface", + "kwargs": { + "recording": { + "class": "spikeinterface.preprocessing.astype.AstypeRecording", + "module": "spikeinterface", + "kwargs": { + "recording": { + "class": "spikeinterface.core.binaryfolder.BinaryFolderRecording", + "module": "spikeinterface", + "kwargs": { + "folder_path": "C:\\fMRIData\\git-repo\\spikewrap\\tests\\data\\small_toy_data\\rawdata\\sub-001_type-test\\ses-003\\ephys\\ses-003_run-002" + }, + "version": "0.100.0.dev0", + "annotations": { + "is_filtered": true, + "probe_0_planar_contour": [ + [ + -20.0, + 620.0 + ], + [ + -20.0, + -20.0 + ], + [ + 20.0, + -20.0 + ], + [ + 20.0, + 620.0 + ] + ], + "probes_info": [ + {} + ] + }, + "properties": { + "group": [ + 0, + 0, + 0, + 0, + 1, + 1, + 1, + 1, + 2, + 2, + 2, + 2, + 3, + 3, + 3, + 3 + ], + "location": [ + [ + 0.0, + 0.0 + ], + [ + 0.0, + 40.0 + ], + [ + 0.0, + 80.0 + ], + [ + 0.0, + 120.0 + ], + [ + 0.0, + 160.0 + ], + [ + 0.0, + 200.0 + ], + [ + 0.0, + 240.0 + ], + [ + 0.0, + 280.0 + ], + [ + 0.0, + 320.0 + ], + [ + 0.0, + 360.0 + ], + [ + 0.0, + 400.0 + ], + [ + 0.0, + 440.0 + ], + [ + 0.0, + 480.0 + ], + [ + 0.0, + 520.0 + ], + [ + 0.0, + 560.0 + ], + [ + 0.0, + 600.0 + ] + ], + "gain_to_uV": [ + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0 + ], + "offset_to_uV": [ + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0 + ] + }, + "relative_paths": false + }, + "dtype": " 1 + + for sorting_output_path in sorted_groups: + assert (sorting_output_path / "sorting").is_dir() + assert (sorting_output_path / "postprocessing").is_dir() + ses_path = sub_path / ses_name / "ephys" concat_all_run_names = "".join( diff --git a/tests/test_integration/test_full_pipeline.py b/tests/test_integration/test_full_pipeline.py index 35c0048..373ef0f 100644 --- a/tests/test_integration/test_full_pipeline.py +++ b/tests/test_integration/test_full_pipeline.py @@ -4,7 +4,7 @@ import pytest import spikeinterface as si import spikeinterface.extractors as se -from spikeinterface import concatenate_recordings +from spikeinterface import concatenate_recordings, sorters from spikeinterface.preprocessing import ( astype, bandpass_filter, @@ -12,7 +12,9 @@ phase_shift, ) -from spikewrap.data_classes.postprocessing import load_saved_sorting_output +from spikewrap.data_classes.postprocessing import ( + load_saved_sorting_output, +) from spikewrap.pipeline import full_pipeline, preprocess from spikewrap.pipeline.load_data import load_data from spikewrap.utils import checks, utils @@ -125,7 +127,8 @@ def test_no_concatenation_all_sorters_single_run(self, test_info, sorter): self.check_no_concat_results(test_info, loaded_data, sorting_data, sorter) @pytest.mark.parametrize("test_info", [DEFAULT_FORMAT], indirect=True) - def test_no_concatenation_single_run(self, test_info): + @pytest.mark.parametrize("sort_by_group", [True, False]) + def test_no_concatenation_single_run(self, test_info, sort_by_group): """ Run the full pipeline for a single session and run, and check preprocessing, sorting and waveforms. @@ -135,18 +138,22 @@ def test_no_concatenation_single_run(self, test_info): loaded_data, sorting_data = self.run_full_pipeline( *test_info, data_format=DEFAULT_FORMAT, + sort_by_group=sort_by_group, sorter=DEFAULT_SORTER, concatenate_sessions=False, concatenate_runs=False, ) - self.check_correct_folders_exist(test_info, False, False, DEFAULT_SORTER) + self.check_correct_folders_exist( + test_info, False, False, DEFAULT_SORTER, sort_by_group=sort_by_group + ) self.check_no_concat_results( - test_info, loaded_data, sorting_data, DEFAULT_SORTER + test_info, loaded_data, sorting_data, DEFAULT_SORTER, sort_by_group ) @pytest.mark.parametrize("test_info", [DEFAULT_FORMAT], indirect=True) - def test_no_concatenation_multiple_runs(self, test_info): + @pytest.mark.parametrize("sort_by_group", [True, False]) + def test_no_concatenation_multiple_runs(self, test_info, sort_by_group): """ For DEFAULT_SORTER, check `full_pipeline` across multiple sessions and runs without concatenation. @@ -156,16 +163,20 @@ def test_no_concatenation_multiple_runs(self, test_info): data_format=DEFAULT_FORMAT, concatenate_sessions=False, concatenate_runs=False, + sort_by_group=sort_by_group, sorter=DEFAULT_SORTER, ) - self.check_correct_folders_exist(test_info, False, False, DEFAULT_SORTER) - - self.check_correct_folders_exist(test_info, False, False, DEFAULT_SORTER) - self.check_no_concat_results(test_info, loaded_data, sorting_data) + self.check_correct_folders_exist( + test_info, False, False, DEFAULT_SORTER, sort_by_group=sort_by_group + ) + self.check_no_concat_results( + test_info, loaded_data, sorting_data, DEFAULT_SORTER, sort_by_group + ) @pytest.mark.parametrize("test_info", [DEFAULT_FORMAT], indirect=True) - def test_concatenate_runs_but_not_sessions(self, test_info): + @pytest.mark.parametrize("sort_by_group", [True, False]) + def test_concatenate_runs_but_not_sessions(self, test_info, sort_by_group): """ For DEFAULT_SORTER, check `full_pipeline` across multiple sessions concatenating runs, but not sessions. This results in a single @@ -177,16 +188,24 @@ def test_concatenate_runs_but_not_sessions(self, test_info): data_format=DEFAULT_FORMAT, concatenate_sessions=False, concatenate_runs=True, + sort_by_group=sort_by_group, sorter=DEFAULT_SORTER, ) - self.check_correct_folders_exist(test_info, False, True, DEFAULT_SORTER) + self.check_correct_folders_exist( + test_info, + False, + True, + DEFAULT_SORTER, + sort_by_group=sort_by_group, + ) self.check_concatenate_runs_but_not_sessions( - test_info, loaded_data, sorting_data + test_info, loaded_data, sorting_data, sort_by_group ) @pytest.mark.parametrize("test_info", [DEFAULT_FORMAT], indirect=True) - def test_concatenate_sessions_and_runs(self, test_info): + @pytest.mark.parametrize("sort_by_group", [True, False]) + def test_concatenate_sessions_and_runs(self, test_info, sort_by_group): """ For DEFAULT_SORTER, check `full_pipeline` across multiple sessions concatenating runs and sessions. This will lead to a single @@ -198,10 +217,15 @@ def test_concatenate_sessions_and_runs(self, test_info): concatenate_sessions=True, concatenate_runs=True, sorter=DEFAULT_SORTER, + sort_by_group=sort_by_group, ) - self.check_correct_folders_exist(test_info, True, True, DEFAULT_SORTER) - self.check_concatenate_sessions_and_runs(test_info, loaded_data, sorting_data) + self.check_correct_folders_exist( + test_info, True, True, DEFAULT_SORTER, sort_by_group + ) + self.check_concatenate_sessions_and_runs( + test_info, loaded_data, sorting_data, sort_by_group=sort_by_group + ) @pytest.mark.parametrize("test_info", [DEFAULT_FORMAT], indirect=True) def test_ses_concat_no_run_concat(self, test_info): @@ -225,7 +249,8 @@ def test_ses_concat_no_run_concat(self, test_info): ) @pytest.mark.parametrize("test_info", [DEFAULT_FORMAT], indirect=True) - def test_existing_output_settings(self, test_info): + @pytest.mark.parametrize("sort_by_group", [True, False]) + def test_existing_output_settings(self, test_info, sort_by_group): """ In spikewrap existing preprocessed and sorting output data is handled with options `fail_if_exists`, `skip_if_exists` or @@ -245,6 +270,7 @@ def test_existing_output_settings(self, test_info): self.run_full_pipeline( *test_info, data_format=DEFAULT_FORMAT, + sort_by_group=sort_by_group, existing_preprocessed_data="fail_if_exists", existing_sorting_output="fail_if_exists", overwrite_postprocessing=False, @@ -252,11 +278,14 @@ def test_existing_output_settings(self, test_info): ) # Test outputs are overwritten if `overwrite` set. - file_paths = self.write_an_empty_file_in_outputs(test_info, ses_name, run_name) + file_paths = self.write_an_empty_file_in_outputs( + test_info, ses_name, run_name, sort_by_group + ) self.run_full_pipeline( *test_info, data_format=DEFAULT_FORMAT, + sort_by_group=sort_by_group, existing_preprocessed_data="overwrite", existing_sorting_output="overwrite", overwrite_postprocessing=True, @@ -266,13 +295,16 @@ def test_existing_output_settings(self, test_info): for path_ in file_paths: assert not path_.is_file() - file_paths = self.write_an_empty_file_in_outputs(test_info, ses_name, run_name) + file_paths = self.write_an_empty_file_in_outputs( + test_info, ses_name, run_name, sort_by_group + ) # Test outputs are not overwritten if `skip_if_exists`. # Postprocessing is always deleted self.run_full_pipeline( *test_info, data_format=DEFAULT_FORMAT, + sort_by_group=sort_by_group, existing_preprocessed_data="skip_if_exists", existing_sorting_output="skip_if_exists", overwrite_postprocessing=True, @@ -287,6 +319,7 @@ def test_existing_output_settings(self, test_info): self.run_full_pipeline( *test_info, data_format=DEFAULT_FORMAT, + sort_by_group=sort_by_group, existing_preprocessed_data="fail_if_exists", existing_sorting_output="skip_if_exists", overwrite_postprocessing=True, @@ -307,6 +340,7 @@ def test_existing_output_settings(self, test_info): self.run_full_pipeline( *test_info, data_format=DEFAULT_FORMAT, + sort_by_group=sort_by_group, existing_preprocessed_data="skip_if_exists", existing_sorting_output="fail_if_exists", overwrite_postprocessing=True, @@ -320,6 +354,7 @@ def test_existing_output_settings(self, test_info): self.run_full_pipeline( *test_info, data_format=DEFAULT_FORMAT, + sort_by_group=sort_by_group, existing_preprocessed_data="skip_if_exists", existing_sorting_output="skip_if_exists", overwrite_postprocessing=False, @@ -354,7 +389,12 @@ def test_smoke_supply_chunk_size(self, test_info, capsys, specify_chunk_size): # ---------------------------------------------------------------------------------- def check_no_concat_results( - self, test_info, loaded_data, sorting_data, sorter=DEFAULT_SORTER + self, + test_info, + loaded_data, + sorting_data, + sorter=DEFAULT_SORTER, + sort_by_group=False, ): """ After `full_pipeline` is run, check the preprocessing, sorting and postprocessing @@ -410,20 +450,23 @@ def check_no_concat_results( ) paths = self.get_output_paths( - test_info, ses_name, run_name, sorter=sorter + test_info, ses_name, run_name, sort_by_group, sorter=sorter ) - self.check_waveforms( - paths["sorter_output"], - paths["postprocessing"], - recs_to_test=[ - sorting_data[ses_name][run_name], - ], - sorter=sorter, - ) + for sorter_output_path, postprocessing_path in zip( + paths["sorter_output"], paths["postprocessing"] + ): + self.check_waveforms( + sorter_output_path, + postprocessing_path, + recs_to_test=[ + sorting_data[ses_name][run_name], + ], + sorter=sorter, + ) def check_concatenate_runs_but_not_sessions( - self, test_info, loaded_data, sorting_data + self, test_info, loaded_data, sorting_data, sort_by_group ): """ Similar to `check_no_concat_results()`, however now test with @@ -483,27 +526,35 @@ def check_concatenate_runs_but_not_sessions( # Load the recording.dat and check it matches the expected data. # Finally, check the waveforms match the preprocessed data. paths = self.get_output_paths( - test_info, ses_name, concat_run_name, concatenate_runs=True + test_info, + ses_name, + concat_run_name, + concatenate_runs=True, + sort_by_group=sort_by_group, ) + for sorter_output_path, postprocessing_path, recording_dat_path in zip( + paths["sorter_output"], paths["postprocessing"], paths["recording_dat"] + ): + if "kilosort" in sorting_data.sorter: + saved_recording = si.read_binary( + recording_dat_path, + sampling_frequency=sorting_data_pp_run.get_sampling_frequency(), + dtype=data_type, + num_channels=sorting_data_pp_run.get_num_channels(), + ) + self.check_recordings_are_the_same( + saved_recording, test_concat_runs, n_split=2 + ) - if "kilosort" in sorting_data.sorter: - saved_recording = si.read_binary( - paths["recording_dat"], - sampling_frequency=sorting_data_pp_run.get_sampling_frequency(), - dtype=data_type, - num_channels=sorting_data_pp_run.get_num_channels(), - ) - self.check_recordings_are_the_same( - saved_recording, test_concat_runs, n_split=2 + self.check_waveforms( + sorter_output_path, + postprocessing_path, + recs_to_test=[sorting_data[ses_name][concat_run_name]], ) - self.check_waveforms( - paths["sorter_output"], - paths["postprocessing"], - recs_to_test=[sorting_data[ses_name][concat_run_name]], - ) - - def check_concatenate_sessions_and_runs(self, test_info, loaded_data, sorting_data): + def check_concatenate_sessions_and_runs( + self, test_info, loaded_data, sorting_data, sort_by_group + ): """ Similar to `check_no_concat_results()` and `check_concatenate_runs_but_not_sessions()`, but now we are checking when `concatenate_sessions=True` and `concatenate_runs=`True`. @@ -549,35 +600,38 @@ def check_concatenate_sessions_and_runs(self, test_info, loaded_data, sorting_da # dtype is converted to original dtype on file writing. test_concat_all = astype(test_concat_all, data_type) + self.check_recordings_are_the_same( + sorted_data_concat_all, test_concat_all, n_split=6 + ) + paths = self.get_output_paths( test_info, + sort_by_group=sort_by_group, ses_name=concat_ses_name, run_name=None, concatenate_sessions=True, concatenate_runs=True, ) + for sorter_output_path, postprocessing_path, recording_dat_path in zip( + paths["sorter_output"], paths["postprocessing"], paths["recording_dat"] + ): + if "kilosort" in sorting_data.sorter: + saved_recording = si.read_binary( + recording_dat_path, + sampling_frequency=sorted_data_concat_all.get_sampling_frequency(), + dtype=data_type, + num_channels=sorted_data_concat_all.get_num_channels(), + ) + self.check_recordings_are_the_same( + saved_recording, test_concat_all, n_split=6 + ) - self.check_recordings_are_the_same( - sorted_data_concat_all, test_concat_all, n_split=6 - ) - - if "kilosort" in sorting_data.sorter: - saved_recording = si.read_binary( - paths["recording_dat"], - sampling_frequency=sorted_data_concat_all.get_sampling_frequency(), - dtype=data_type, - num_channels=sorted_data_concat_all.get_num_channels(), - ) - self.check_recordings_are_the_same( - saved_recording, test_concat_all, n_split=6 + self.check_waveforms( + sorter_output_path, + postprocessing_path, + recs_to_test=[sorting_data[concat_ses_name]], ) - self.check_waveforms( - paths["sorter_output"], - paths["postprocessing"], - recs_to_test=[sorting_data[concat_ses_name]], - ) - def check_recordings_are_the_same(self, rec_1, rec_2, n_split=1): """ Check that two SI recording objects are exactly the same. When the @@ -678,23 +732,26 @@ def check_waveforms( assert np.array_equal(data, first_unit_waveforms[0]) def write_an_empty_file_in_outputs( - self, test_info, ses_name, run_name, sorter=DEFAULT_SORTER + self, test_info, ses_name, run_name, sort_by_group, sorter=DEFAULT_SORTER ): """ Write a file called `test_file.txt` with contents `test_file` in the preprocessed, sorting and postprocessing output path for this session / run. """ - paths = self.get_output_paths(test_info, ses_name, run_name, sorter=sorter) + paths = self.get_output_paths( + test_info, ses_name, run_name, sorter=sorter, sort_by_group=sort_by_group + ) + + paths_to_write = [paths["preprocessing"] / "test_file.txt"] - paths_to_write = [] - for output in ["preprocessing", "sorting_path", "postprocessing"]: - paths_to_write.append(paths[output] / "test_file.txt") + for output in ["sorting_path", "postprocessing"]: + for group_path in paths[output]: + paths_to_write.append(group_path / "test_file.txt") for path_ in paths_to_write: - with open(path_, "w") as file: + with open(path_.as_posix(), "w") as file: file.write("test file.") - return paths_to_write def get_output_paths( @@ -702,6 +759,7 @@ def get_output_paths( test_info, ses_name, run_name, + sort_by_group=False, sorter=DEFAULT_SORTER, concatenate_sessions=False, concatenate_runs=False, @@ -746,14 +804,185 @@ def get_output_paths( paths = { "preprocessing": run_path / "preprocessing", - "sorting_path": run_path / sorter / "sorting", - "postprocessing": run_path / sorter / "postprocessing", + "postprocessing": [], + "sorting_path": [], + "sorter_output": [], + "recording_dat": [], } - paths["sorter_output"] = paths["sorting_path"] / "sorter_output" - paths["recording_dat"] = paths["sorter_output"] / "recording.dat" + + if sort_by_group: + all_groups = sorted((run_path / sorter).glob("group-*")) + assert any(all_groups), "Groups output not found." + + for group in all_groups: + sorting_path = group / "sorting" + paths["sorting_path"].append(sorting_path) + paths["postprocessing"].append(group / "postprocessing") + paths["sorter_output"].append(sorting_path / "sorter_output") + paths["recording_dat"].append( + sorting_path / "sorter_output" / "recording.dat" + ) # TODO: this is only for kilosort! + else: + sorting_path = run_path / sorter / "sorting" + paths["sorting_path"] = [sorting_path] + paths["postprocessing"] = [run_path / sorter / "postprocessing"] + paths["sorter_output"] = [sorting_path / "sorter_output"] + paths["recording_dat"] = [sorting_path / "sorter_output" / "recording.dat"] return paths + @pytest.mark.parametrize("test_info", [DEFAULT_FORMAT], indirect=True) + def test_sort_by_group_concat_sessions(self, test_info): + preprocess_data, sorting_data = self.run_full_pipeline( + *test_info, + data_format=DEFAULT_FORMAT, + concatenate_runs=True, + concatenate_sessions=True, + sort_by_group=True, + existing_preprocessed_data="overwrite", + existing_sorting_output="overwrite", + overwrite_postprocessing=True, + sorter=DEFAULT_SORTER, + ) + + concat_ses_name = list(sorting_data.keys())[0] + + prepo_recordings = [ + val["3-raw-phase_shift-bandpass_filter-common_reference"] + for ses_name in preprocess_data.keys() + for val in preprocess_data[ses_name].values() + ] + + test_preprocessed = concatenate_recordings(prepo_recordings) + + sorting_output_paths = self.get_output_paths( + test_info, + concat_ses_name, + run_name=None, + sorter=DEFAULT_SORTER, + concatenate_sessions=True, + concatenate_runs=True, + sort_by_group=True, + )["sorting_path"] + + self.check_sorting_is_correct(test_preprocessed, sorting_output_paths) + + @pytest.mark.parametrize("test_info", [DEFAULT_FORMAT], indirect=True) + def test_sort_by_group_concat_runs_not_sessions(self, test_info): + preprocess_data, sorting_data = self.run_full_pipeline( + *test_info, + data_format=DEFAULT_FORMAT, + concatenate_runs=True, + concatenate_sessions=False, + sort_by_group=True, + existing_preprocessed_data="overwrite", + existing_sorting_output="overwrite", + overwrite_postprocessing=True, + sorter=DEFAULT_SORTER, + ) + + base_path, sub_name, sessions_and_runs = test_info + + for ses_name in sessions_and_runs.keys(): + concat_run_name = list(sorting_data[ses_name].keys())[0] + + prepo_recordings = [ + val["3-raw-phase_shift-bandpass_filter-common_reference"] + for val in preprocess_data[ses_name].values() + ] + test_preprocessed = concatenate_recordings(prepo_recordings) + + sorting_output_paths = self.get_output_paths( + test_info, + ses_name, + run_name=concat_run_name, + sorter=DEFAULT_SORTER, + concatenate_sessions=False, + concatenate_runs=True, + sort_by_group=True, + )["sorting_path"] + + self.check_sorting_is_correct(test_preprocessed, sorting_output_paths) + + @pytest.mark.parametrize("test_info", [DEFAULT_FORMAT], indirect=True) + def test_sort_by_group_no_concat(self, test_info): + self.run_full_pipeline( + *test_info, + data_format=DEFAULT_FORMAT, + concatenate_runs=False, + concatenate_sessions=False, + sort_by_group=True, + existing_preprocessed_data="overwrite", + existing_sorting_output="overwrite", + overwrite_postprocessing=True, + sorter=DEFAULT_SORTER, + ) + + base_path, sub_name, sessions_and_runs = test_info + + for ses_name in sessions_and_runs.keys(): + for run_name in sessions_and_runs[ses_name]: + sorting_output_paths = self.get_output_paths( + test_info, + ses_name, + run_name=run_name, + sorter=DEFAULT_SORTER, + sort_by_group=True, + )["sorting_path"] + + _, test_preprocessed = self.get_test_rawdata_and_preprocessed_data( + base_path, sub_name, ses_name, run_name + ) + self.check_sorting_is_correct(test_preprocessed, sorting_output_paths) + + def check_sorting_is_correct(self, test_preprocessed, sorting_output_paths): + """""" + split_recording = test_preprocessed.split_by("group") + + if "kilosort" in DEFAULT_SORTER: + singularity_image = True if platform.system() == "Linux" else False + docker_image = not singularity_image + else: + singularity_image = docker_image = False + + sortings = {} + for group, sub_recording in split_recording.items(): + sorting = sorters.run_sorter( + sorter_name=DEFAULT_SORTER, + recording=sub_recording, + output_folder=None, + docker_image=docker_image, + singularity_image=singularity_image, + remove_existing_folder=True, + **{ + "scheme": "2", + "filter": False, + "whiten": False, + "verbose": True, + }, + ) + + sortings[group] = sorting + + assert len(sorting_output_paths) > 1, "Groups output not found." + + for idx, path_ in enumerate(sorting_output_paths): + group_sorting = load_saved_sorting_output( + path_ / "sorter_output", DEFAULT_SORTER + ) + + assert np.array_equal( + group_sorting.get_unit_ids(), sortings[idx].get_unit_ids() + ) + + for unit in group_sorting.get_unit_ids(): + assert np.allclose( + group_sorting.get_unit_spike_train(unit), + sortings[idx].get_unit_spike_train(unit), + rtol=0, + atol=1e-10, + ), f"{idx}, {group_sorting}, {sortings}" + # ---------------------------------------------------------------------------------- # Getters # ---------------------------------------------------------------------------------- From 7acf67be7d93b6293c81c1574dfdd74ec5f5a844 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Wed, 20 Dec 2023 15:07:18 +0000 Subject: [PATCH 4/6] Rename instances of sort_per_group to sort_by_group. --- spikewrap/pipeline/sort.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/spikewrap/pipeline/sort.py b/spikewrap/pipeline/sort.py index 128a0a6..5d5b62e 100644 --- a/spikewrap/pipeline/sort.py +++ b/spikewrap/pipeline/sort.py @@ -27,7 +27,7 @@ def run_sorting( sub_name: str, sessions_and_runs: Dict[str, List[str]], sorter: str, - sort_per_group: bool = False, + sort_by_group: bool = False, concatenate_sessions: bool = False, concatenate_runs: bool = False, sorter_options: Optional[Dict] = None, @@ -49,7 +49,7 @@ def run_sorting( "sessions_and_runs": sessions_and_runs, "concatenate_sessions": concatenate_sessions, "sorter": sorter, - "sort_per_group": sort_per_group, + "sort_by_group": sort_by_group, "concatenate_runs": concatenate_runs, "sorter_options": sorter_options, "existing_sorting_output": existing_sorting_output, @@ -62,7 +62,7 @@ def run_sorting( sub_name, sessions_and_runs, sorter, - sort_per_group, + sort_by_group, concatenate_sessions, concatenate_runs, sorter_options, @@ -76,7 +76,7 @@ def _run_sorting( sub_name: str, sessions_and_runs: Dict[str, List[str]], sorter: str, - sort_per_group: bool, + sort_by_group: bool, concatenate_sessions: bool = False, concatenate_runs: bool = False, sorter_options: Optional[Dict] = None, @@ -184,7 +184,7 @@ def _run_sorting( sorting_data, singularity_image, docker_image, - sort_per_group, + sort_by_group, existing_sorting_output=existing_sorting_output, **sorter_options_dict, ) @@ -229,7 +229,7 @@ def run_sorting_on_all_runs( sorting_data: SortingData, singularity_image: Union[Literal[True], None, str], docker_image: Optional[Literal[True]], - sort_per_group: bool, + sort_by_group: bool, existing_sorting_output: HandleExisting, **sorter_options_dict, ) -> None: @@ -268,13 +268,13 @@ def run_sorting_on_all_runs( ses_name, run_name ) - if sort_per_group: + if sort_by_group: split_preprocessing = orig_preprocessed_recording.split_by("group") if len(split_preprocessing.keys()) == 1: raise RuntimeError( - "`sort_per_group` is `True` but the recording only has " - "one channel group. Set `sort_per_group`to `False` " + "`sort_by_group` is `True` but the recording only has " + "one channel group. Set `sort_by_group`to `False` " "for this recording." ) From 6771293e60b4d07cf863643bad62e3f4b7b860d4 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Wed, 20 Dec 2023 16:37:57 +0000 Subject: [PATCH 5/6] First refactor of preprocessing. --- spikewrap/__init__.py | 3 +- spikewrap/examples/example_preprocess.py | 43 +- spikewrap/pipeline/full_pipeline.py | 3 +- spikewrap/pipeline/preprocess.py | 385 ++++---- .../sorter_output/firings.npz | Bin 1574 -> 0 bytes .../spikeinterface_log.json | 8 - .../spikeinterface_params.json | 25 - .../spikeinterface_recording.json | 869 ------------------ 8 files changed, 213 insertions(+), 1123 deletions(-) delete mode 100644 tests/data/small_toy_data/mountainsort5_output/sorter_output/firings.npz delete mode 100644 tests/data/small_toy_data/mountainsort5_output/spikeinterface_log.json delete mode 100644 tests/data/small_toy_data/mountainsort5_output/spikeinterface_params.json delete mode 100644 tests/data/small_toy_data/mountainsort5_output/spikeinterface_recording.json diff --git a/spikewrap/__init__.py b/spikewrap/__init__.py index 781354f..90c560d 100644 --- a/spikewrap/__init__.py +++ b/spikewrap/__init__.py @@ -7,7 +7,8 @@ pass from .pipeline.full_pipeline import run_full_pipeline -from .pipeline.preprocess import _preprocess_and_save_all_runs + +# from .pipeline.preprocess import _preprocess_and_save_all_runs from .pipeline.sort import run_sorting from .pipeline.postprocess import run_postprocess diff --git a/spikewrap/examples/example_preprocess.py b/spikewrap/examples/example_preprocess.py index 7bf88eb..4b54741 100644 --- a/spikewrap/examples/example_preprocess.py +++ b/spikewrap/examples/example_preprocess.py @@ -1,34 +1,45 @@ from pathlib import Path from spikewrap.pipeline.load_data import load_data -from spikewrap.pipeline.preprocess import run_preprocessing +from spikewrap.pipeline.preprocess import PreprocessPipeline base_path = Path( - r"/ceph/neuroinformatics/neuroinformatics/scratch/jziminski/ephys/test_data/steve_multi_run/1119617/time-short-multises" + r"C:\fMRIData\git-repo\spikewrap\tests\data\small_toy_data" + # r"/ceph/neuroinformatics/neuroinformatics/scratch/jziminski/ephys/test_data/steve_multi_run/1119617/time-short-multises" # r"C:\data\ephys\test_data\steve_multi_run\1119617\time-miniscule-multises" ) -sub_name = "sub-1119617" +sub_name = "sub-001_type-test" +# sub_name = "sub-1119617" + sessions_and_runs = { - "ses-001": [ - "run-001_1119617_LSE1_shank12_g0", - "run-002_made_up_g0", - ], - "ses-002": [ - "run-001_1119617_pretest1_shank12_g0", - ], - "ses-003": [ - "run-002_1119617_pretest1_shank12_g0", - ], + "ses-001": ["all"], + "ses-002": ["all"], } -loaded_data = load_data(base_path, sub_name, sessions_and_runs, data_format="spikeglx") +if False: + sessions_and_runs = { + "ses-001": [ + "run-001_1119617_LSE1_shank12_g0", + "run-002_made_up_g0", + ], + "ses-002": [ + "run-001_1119617_pretest1_shank12_g0", + ], + "ses-003": [ + "run-002_1119617_pretest1_shank12_g0", + ], + } + +loaded_data = load_data( + base_path, sub_name, sessions_and_runs, data_format="spikeinterface" +) -run_preprocessing( +preprocess_pipeline = PreprocessPipeline( loaded_data, pp_steps="default", handle_existing_data="overwrite", preprocess_by_group=True, log=True, - slurm_batch=False, ) +preprocess_pipeline.run(slurm_batch=False) diff --git a/spikewrap/pipeline/full_pipeline.py b/spikewrap/pipeline/full_pipeline.py index 41cc821..d633d73 100644 --- a/spikewrap/pipeline/full_pipeline.py +++ b/spikewrap/pipeline/full_pipeline.py @@ -13,7 +13,8 @@ from spikewrap.configs.configs import get_configs from spikewrap.pipeline.load_data import load_data from spikewrap.pipeline.postprocess import run_postprocess -from spikewrap.pipeline.preprocess import run_preprocessing + +# from spikewrap.pipeline.preprocess import run_preprocessing from spikewrap.pipeline.sort import run_sorting from spikewrap.utils import logging_sw, slurm, utils, validate diff --git a/spikewrap/pipeline/preprocess.py b/spikewrap/pipeline/preprocess.py index 5821a0f..310baf5 100644 --- a/spikewrap/pipeline/preprocess.py +++ b/spikewrap/pipeline/preprocess.py @@ -1,4 +1,5 @@ import json +from types import MappingProxyType from typing import Dict, List, Optional, Tuple, Union import numpy as np @@ -15,81 +16,198 @@ # -------------------------------------------------------------------------------------- -def run_preprocessing( - preprocess_data: PreprocessingData, - pp_steps: Union[Dict, str], - handle_existing_data: HandleExisting, - preprocess_by_group: bool, - chunk_size: Optional[int] = None, - slurm_batch: Union[bool, Dict] = False, - log: bool = True, -): - """ - Main entry function to run preprocessing and write to file. Preprocessed - lazy spikeinterface recordings will be added to all sessions / runs in - `preprocess_data` and written to file. - - Parameters - ---------- - - preprocess_data : PreprocessingData - A preprocessing data object that has as attributes the - paths to rawdata. The pp_steps attribute is set on - this class during execution of this function. - - pp_steps: The name of valid preprocessing .yaml file (without the yaml extension). - stored in spikewrap/configs. - - existing_preprocessed_data : custom_types.HandleExisting - Determines how existing preprocessed data (e.g. from a prior pipeline run) - is handled. - "overwrite" : Will overwrite any existing preprocessed data output. - This will delete the 'preprocessed' folder. Therefore, - never save derivative work there. - "skip_if_exists" : will search for existing data and skip preprocesing - if it exists (sorting will run on existing - preprocessed data). - Otherwise, will preprocess and save the current run. - "fail_if_exists" : If existing preprocessed data is found, an error - will be raised. - - slurm_batch : Union[bool, Dict] - see `run_full_pipeline()` for details. - """ - # TOOD: refactor and handle argument groups separately. - # Avoid duplication with logging. - passed_arguments = locals() - validate.check_function_arguments(passed_arguments) - - if isinstance(pp_steps, Dict): - pp_steps_dict = pp_steps - else: - # TODO: do some check the name is valid - pp_steps_dict, _, _ = configs.get_configs(pp_steps) # TODO: call 'config_name' +class PreprocessPipeline: + """ """ + + def __init__( + self, + preprocess_data: PreprocessingData, + pp_steps: Union[Dict, str], + handle_existing_data: HandleExisting, + preprocess_by_group: bool, + chunk_size: Optional[int] = None, + # slurm_batch: Union[bool, Dict] = False, + log: bool = True, + ): + if isinstance(pp_steps, Dict): + pp_steps_dict = pp_steps + else: + pp_steps_dict, _, _ = configs.get_configs(pp_steps) + pp_steps_dict = MappingProxyType(pp_steps_dict) - if slurm_batch: - slurm.run_in_slurm( - slurm_batch, - _preprocess_and_save_all_runs, + self.passed_arguments = MappingProxyType( { "preprocess_data": preprocess_data, - "pp_steps": pp_steps_dict, + "pp_steps_dict": pp_steps_dict, + "handle_existing_data": handle_existing_data, "preprocess_by_group": preprocess_by_group, "chunk_size": chunk_size, - "handle_existing_data": handle_existing_data, + # "slurm_batch": slurm_batch, "log": log, - }, - ), - else: - _preprocess_and_save_all_runs( + } + ) + validate.check_function_arguments(self.passed_arguments) + + # TODO: do some check the name is valid + def run(self, slurm_batch: Union[bool, Dict] = False): + """ """ + if slurm_batch: + slurm.run_in_slurm( + slurm_batch, + self._preprocess_and_save_all_runs, + **self.passed_arguments, + ), + else: + self._preprocess_and_save_all_runs(**self.passed_arguments) + + # -------------------------------------------------------------------------------------- + # Private Functions + # -------------------------------------------------------------------------------------- + + def _preprocess_and_save_all_runs( + self, + preprocess_data: PreprocessingData, + pp_steps_dict: Dict, + handle_existing_data: HandleExisting, + preprocess_by_group: bool, + chunk_size: Optional[int] = None, + log: bool = True, + ) -> None: + """ + Handle the loading of existing preprocessed data. + See `run_preprocessing()` for details. + + This function validates all input arguments and initialises logging. + Then, it will iterate over every run in `preprocess_data` and + check whether preprocessing needs to be run and saved based on the + `handle_existing_data` option. If so, it will fill the relevant run + with the preprocessed spikeinterface recording object and save to disk. + """ + passed_arguments = locals() + validate.check_function_arguments(passed_arguments) + + if log: + logs = logging_sw.get_started_logger( + utils.get_logging_path( + preprocess_data.base_path, preprocess_data.sub_name + ), + "preprocessing", + ) + utils.show_passed_arguments(passed_arguments, "`run_preprocessing`") + + for ses_name, run_name in preprocess_data.flat_sessions_and_runs(): + utils.message_user(f"Preprocessing run {run_name}...") + + to_save, overwrite = _handle_existing_data_options( + preprocess_data, ses_name, run_name, handle_existing_data + ) + + if to_save: + _preprocess_and_save_single_run( + preprocess_data, + ses_name, + run_name, + pp_steps_dict, + overwrite, + preprocess_by_group, + chunk_size, + ) + + if log: + logs.stop_logging() + + def _preprocess_and_save_single_run( + self, + preprocess_data: PreprocessingData, + ses_name: str, + run_name: str, + pp_steps_dict: Dict, + overwrite: bool, + preprocess_by_group: bool, + chunk_size: Optional[int], + ) -> None: + """ + Given a single session and run, fill the entry for this run + on the `preprocess_data` object and write to disk. + """ + _fill_run_data_with_preprocessed_recording( preprocess_data, + ses_name, + run_name, pp_steps_dict, - handle_existing_data, preprocess_by_group, - chunk_size, - log, ) + preprocess_data.save_preprocessed_data( + ses_name, run_name, overwrite, chunk_size + ) + + def _handle_existing_data_options( + self, + preprocess_data: PreprocessingData, + ses_name: str, + run_name: str, + handle_existing_data: HandleExisting, + ) -> Tuple[bool, bool]: + """ + Determine whether preprocesing for this run needs to be performed based + on the `handle_existing_data setting`. If preprocessing does not exist, + preprocessing + is always run. Otherwise, if it already exists, the behaviour depends on + the `handle_existing_data` setting. + + Returns + ------- + + to_save : bool + Whether the preprocessing needs to be run and saved. + + to_overwrite : bool + If saving, set the `overwrite` flag to confirm existing data should + be overwritten. + """ + preprocess_path = preprocess_data.get_preprocessing_path(ses_name, run_name) + + if handle_existing_data == "skip_if_exists": + if preprocess_path.is_dir(): + utils.message_user( + f"\nSkipping preprocessing, using file at " + f"{preprocess_path} for sorting.\n" + ) + to_save = False + overwrite = False + else: + utils.message_user( + f"No data found at {preprocess_path}, saving preprocessed data." + ) + to_save = True + overwrite = False + + elif handle_existing_data == "overwrite": + if preprocess_path.is_dir(): + utils.message_user(f"Removing existing file at {preprocess_path}\n") + + utils.message_user(f"Saving preprocessed data to {preprocess_path}") + to_save = True + overwrite = True + + elif handle_existing_data == "fail_if_exists": + if preprocess_path.is_dir(): + raise FileExistsError( + f"Preprocessed binary already exists at " + f"{preprocess_path}. " + f"To overwrite, set 'existing_preprocessed_data' to 'overwrite'" + ) + to_save = True + overwrite = False + + return to_save, overwrite + + +# -------------------------------------------------------------------------------------- +# Preprocessing Functions +# -------------------------------------------------------------------------------------- + def fill_all_runs_with_preprocessed_recording( preprocess_data: PreprocessingData, @@ -121,145 +239,6 @@ def fill_all_runs_with_preprocessed_recording( ) -# -------------------------------------------------------------------------------------- -# Private Functions -# -------------------------------------------------------------------------------------- - - -def _preprocess_and_save_all_runs( - preprocess_data: PreprocessingData, - pp_steps_dict: Dict, - handle_existing_data: HandleExisting, - preprocess_by_group: bool, - chunk_size: Optional[int] = None, - log: bool = True, -) -> None: - """ - Handle the loading of existing preprocessed data. - See `run_preprocessing()` for details. - - This function validates all input arguments and initialises logging. - Then, it will iterate over every run in `preprocess_data` and - check whether preprocessing needs to be run and saved based on the - `handle_existing_data` option. If so, it will fill the relevant run - with the preprocessed spikeinterface recording object and save to disk. - """ - passed_arguments = locals() - validate.check_function_arguments(passed_arguments) - - if log: - logs = logging_sw.get_started_logger( - utils.get_logging_path(preprocess_data.base_path, preprocess_data.sub_name), - "preprocessing", - ) - utils.show_passed_arguments(passed_arguments, "`run_preprocessing`") - - for ses_name, run_name in preprocess_data.flat_sessions_and_runs(): - utils.message_user(f"Preprocessing run {run_name}...") - - to_save, overwrite = _handle_existing_data_options( - preprocess_data, ses_name, run_name, handle_existing_data - ) - - if to_save: - _preprocess_and_save_single_run( - preprocess_data, - ses_name, - run_name, - pp_steps_dict, - overwrite, - preprocess_by_group, - chunk_size, - ) - - if log: - logs.stop_logging() - - -def _preprocess_and_save_single_run( - preprocess_data: PreprocessingData, - ses_name: str, - run_name: str, - pp_steps_dict: Dict, - overwrite: bool, - preprocess_by_group: bool, - chunk_size: Optional[int], -) -> None: - """ - Given a single session and run, fill the entry for this run - on the `preprocess_data` object and write to disk. - """ - _fill_run_data_with_preprocessed_recording( - preprocess_data, - ses_name, - run_name, - pp_steps_dict, - preprocess_by_group, - ) - - preprocess_data.save_preprocessed_data(ses_name, run_name, overwrite, chunk_size) - - -def _handle_existing_data_options( - preprocess_data: PreprocessingData, - ses_name: str, - run_name: str, - handle_existing_data: HandleExisting, -) -> Tuple[bool, bool]: - """ - Determine whether preprocesing for this run needs to be performed based - on the `handle_existing_data setting`. If preprocessing does not exist, preprocessing - is always run. Otherwise, if it already exists, the behaviour depends on - the `handle_existing_data` setting. - - Returns - ------- - - to_save : bool - Whether the preprocessing needs to be run and saved. - - to_overwrite : bool - If saving, set the `overwrite` flag to confirm existing data should - be overwritten. - """ - preprocess_path = preprocess_data.get_preprocessing_path(ses_name, run_name) - - if handle_existing_data == "skip_if_exists": - if preprocess_path.is_dir(): - utils.message_user( - f"\nSkipping preprocessing, using file at " - f"{preprocess_path} for sorting.\n" - ) - to_save = False - overwrite = False - else: - utils.message_user( - f"No data found at {preprocess_path}, saving preprocessed data." - ) - to_save = True - overwrite = False - - elif handle_existing_data == "overwrite": - if preprocess_path.is_dir(): - utils.message_user(f"Removing existing file at {preprocess_path}\n") - - utils.message_user(f"Saving preprocessed data to {preprocess_path}") - to_save = True - overwrite = True - - elif handle_existing_data == "fail_if_exists": - if preprocess_path.is_dir(): - raise FileExistsError( - f"Preprocessed binary already exists at " - f"{preprocess_path}. " - f"To overwrite, set 'existing_preprocessed_data' to 'overwrite'" - ) - to_save = True - overwrite = False - - return to_save, overwrite - - def _fill_run_data_with_preprocessed_recording( preprocess_data: PreprocessingData, ses_name: str, diff --git a/tests/data/small_toy_data/mountainsort5_output/sorter_output/firings.npz b/tests/data/small_toy_data/mountainsort5_output/sorter_output/firings.npz deleted file mode 100644 index d01359b46fa6623c11b271c251619a551406cf94..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 1574 zcmWIWW@Zs#fB;2?g%U#kEkF(k^Du}ol;&lY#Al`y>*WS8 zkb6))5p>7Y3iZ4ygXQ_2rA4 z1Telp80?>d%&624G66G|_G(jv?dc_%2H2Br5x>03~mkp-fz5lXLt(qgO-@l8+~HGB(07$eUB z6@aiPD7}EgHzzSEH3v0#K?*1C%MmGaBzaX0t4NOmv*hSAh=w_g%31l;N Z1JfGT6cXUg$_7%%3WQET$DRWxJ^&XQ4qyNP diff --git a/tests/data/small_toy_data/mountainsort5_output/spikeinterface_log.json b/tests/data/small_toy_data/mountainsort5_output/spikeinterface_log.json deleted file mode 100644 index 42f4d75..0000000 --- a/tests/data/small_toy_data/mountainsort5_output/spikeinterface_log.json +++ /dev/null @@ -1,8 +0,0 @@ -{ - "sorter_name": "mountainsort5", - "sorter_version": "0.3.0", - "datetime": "2023-12-20T14:07:55.752677", - "runtime_trace": [], - "error": false, - "run_time": 0.2969591000000946 -} diff --git a/tests/data/small_toy_data/mountainsort5_output/spikeinterface_params.json b/tests/data/small_toy_data/mountainsort5_output/spikeinterface_params.json deleted file mode 100644 index 93387ee..0000000 --- a/tests/data/small_toy_data/mountainsort5_output/spikeinterface_params.json +++ /dev/null @@ -1,25 +0,0 @@ -{ - "sorter_name": "mountainsort5", - "sorter_params": { - "scheme": "2", - "detect_threshold": 5.5, - "detect_sign": -1, - "detect_time_radius_msec": 0.5, - "snippet_T1": 20, - "snippet_T2": 20, - "npca_per_channel": 3, - "npca_per_subdivision": 10, - "snippet_mask_radius": 250, - "scheme1_detect_channel_radius": 150, - "scheme2_phase1_detect_channel_radius": 200, - "scheme2_detect_channel_radius": 50, - "scheme2_max_num_snippets_per_training_batch": 200, - "scheme2_training_duration_sec": 300, - "scheme2_training_recording_sampling_mode": "uniform", - "scheme3_block_duration_sec": 1800, - "freq_min": 300, - "freq_max": 6000, - "filter": false, - "whiten": false - } -} diff --git a/tests/data/small_toy_data/mountainsort5_output/spikeinterface_recording.json b/tests/data/small_toy_data/mountainsort5_output/spikeinterface_recording.json deleted file mode 100644 index 47e654b..0000000 --- a/tests/data/small_toy_data/mountainsort5_output/spikeinterface_recording.json +++ /dev/null @@ -1,869 +0,0 @@ -{ - "class": "spikeinterface.core.channelslice.ChannelSliceRecording", - "module": "spikeinterface", - "kwargs": { - "parent_recording": { - "class": "spikeinterface.preprocessing.common_reference.CommonReferenceRecording", - "module": "spikeinterface", - "kwargs": { - "recording": { - "class": "spikeinterface.preprocessing.filter.BandpassFilterRecording", - "module": "spikeinterface", - "kwargs": { - "recording": { - "class": "spikeinterface.preprocessing.phase_shift.PhaseShiftRecording", - "module": "spikeinterface", - "kwargs": { - "recording": { - "class": "spikeinterface.preprocessing.astype.AstypeRecording", - "module": "spikeinterface", - "kwargs": { - "recording": { - "class": "spikeinterface.core.binaryfolder.BinaryFolderRecording", - "module": "spikeinterface", - "kwargs": { - "folder_path": "C:\\fMRIData\\git-repo\\spikewrap\\tests\\data\\small_toy_data\\rawdata\\sub-001_type-test\\ses-003\\ephys\\ses-003_run-002" - }, - "version": "0.100.0.dev0", - "annotations": { - "is_filtered": true, - "probe_0_planar_contour": [ - [ - -20.0, - 620.0 - ], - [ - -20.0, - -20.0 - ], - [ - 20.0, - -20.0 - ], - [ - 20.0, - 620.0 - ] - ], - "probes_info": [ - {} - ] - }, - "properties": { - "group": [ - 0, - 0, - 0, - 0, - 1, - 1, - 1, - 1, - 2, - 2, - 2, - 2, - 3, - 3, - 3, - 3 - ], - "location": [ - [ - 0.0, - 0.0 - ], - [ - 0.0, - 40.0 - ], - [ - 0.0, - 80.0 - ], - [ - 0.0, - 120.0 - ], - [ - 0.0, - 160.0 - ], - [ - 0.0, - 200.0 - ], - [ - 0.0, - 240.0 - ], - [ - 0.0, - 280.0 - ], - [ - 0.0, - 320.0 - ], - [ - 0.0, - 360.0 - ], - [ - 0.0, - 400.0 - ], - [ - 0.0, - 440.0 - ], - [ - 0.0, - 480.0 - ], - [ - 0.0, - 520.0 - ], - [ - 0.0, - 560.0 - ], - [ - 0.0, - 600.0 - ] - ], - "gain_to_uV": [ - 1.0, - 1.0, - 1.0, - 1.0, - 1.0, - 1.0, - 1.0, - 1.0, - 1.0, - 1.0, - 1.0, - 1.0, - 1.0, - 1.0, - 1.0, - 1.0 - ], - "offset_to_uV": [ - 0.0, - 0.0, - 0.0, - 0.0, - 0.0, - 0.0, - 0.0, - 0.0, - 0.0, - 0.0, - 0.0, - 0.0, - 0.0, - 0.0, - 0.0, - 0.0 - ] - }, - "relative_paths": false - }, - "dtype": " Date: Wed, 20 Dec 2023 16:47:54 +0000 Subject: [PATCH 6/6] Small fixes when checking slurm. --- spikewrap/examples/example_preprocess.py | 5 ++-- spikewrap/pipeline/preprocess.py | 33 ++++++++++++------------ 2 files changed, 19 insertions(+), 19 deletions(-) diff --git a/spikewrap/examples/example_preprocess.py b/spikewrap/examples/example_preprocess.py index 4b54741..e811f50 100644 --- a/spikewrap/examples/example_preprocess.py +++ b/spikewrap/examples/example_preprocess.py @@ -4,7 +4,8 @@ from spikewrap.pipeline.preprocess import PreprocessPipeline base_path = Path( - r"C:\fMRIData\git-repo\spikewrap\tests\data\small_toy_data" + "/ceph/neuroinformatics/neuroinformatics/scratch/jziminski/ephys/code/spikewrap/tests/data/small_toy_data" + # r"C:\fMRIData\git-repo\spikewrap\tests\data\small_toy_data" # r"/ceph/neuroinformatics/neuroinformatics/scratch/jziminski/ephys/test_data/steve_multi_run/1119617/time-short-multises" # r"C:\data\ephys\test_data\steve_multi_run\1119617\time-miniscule-multises" ) @@ -42,4 +43,4 @@ preprocess_by_group=True, log=True, ) -preprocess_pipeline.run(slurm_batch=False) +preprocess_pipeline.run(slurm_batch=True) diff --git a/spikewrap/pipeline/preprocess.py b/spikewrap/pipeline/preprocess.py index 310baf5..1b8abe5 100644 --- a/spikewrap/pipeline/preprocess.py +++ b/spikewrap/pipeline/preprocess.py @@ -1,5 +1,4 @@ import json -from types import MappingProxyType from typing import Dict, List, Optional, Tuple, Union import numpy as np @@ -33,19 +32,19 @@ def __init__( pp_steps_dict = pp_steps else: pp_steps_dict, _, _ = configs.get_configs(pp_steps) - pp_steps_dict = MappingProxyType(pp_steps_dict) - - self.passed_arguments = MappingProxyType( - { - "preprocess_data": preprocess_data, - "pp_steps_dict": pp_steps_dict, - "handle_existing_data": handle_existing_data, - "preprocess_by_group": preprocess_by_group, - "chunk_size": chunk_size, - # "slurm_batch": slurm_batch, - "log": log, - } - ) + # pp_steps_dict = MappingProxyType(pp_steps_dict) + + self.passed_arguments = { # MappingProxyType( + # { + "preprocess_data": preprocess_data, + "pp_steps_dict": pp_steps_dict, + "handle_existing_data": handle_existing_data, + "preprocess_by_group": preprocess_by_group, + "chunk_size": chunk_size, + # "slurm_batch": slurm_batch, + "log": log, + } + # ) validate.check_function_arguments(self.passed_arguments) # TODO: do some check the name is valid @@ -55,7 +54,7 @@ def run(self, slurm_batch: Union[bool, Dict] = False): slurm.run_in_slurm( slurm_batch, self._preprocess_and_save_all_runs, - **self.passed_arguments, + self.passed_arguments, ), else: self._preprocess_and_save_all_runs(**self.passed_arguments) @@ -98,12 +97,12 @@ def _preprocess_and_save_all_runs( for ses_name, run_name in preprocess_data.flat_sessions_and_runs(): utils.message_user(f"Preprocessing run {run_name}...") - to_save, overwrite = _handle_existing_data_options( + to_save, overwrite = self._handle_existing_data_options( preprocess_data, ses_name, run_name, handle_existing_data ) if to_save: - _preprocess_and_save_single_run( + self._preprocess_and_save_single_run( preprocess_data, ses_name, run_name,