Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Major refactor pipelines #173

Open
wants to merge 6 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion spikewrap/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
pass

from .pipeline.full_pipeline import run_full_pipeline
from .pipeline.preprocess import _preprocess_and_save_all_runs

# from .pipeline.preprocess import _preprocess_and_save_all_runs
from .pipeline.sort import run_sorting
from .pipeline.postprocess import run_postprocess

Expand Down
5 changes: 2 additions & 3 deletions spikewrap/data_classes/postprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ class PostprocessingData:

def __init__(self, sorting_path: Union[str, Path]) -> None:
self.sorting_path = Path(sorting_path)

self.sorter_output_path = self.sorting_path / "sorter_output"
self.sorting_info_path = self.sorting_path / utils.canonical_names(
"sorting_yaml"
Expand Down Expand Up @@ -151,9 +152,7 @@ def get_sorting_extractor_object(self) -> si.SortingExtractor:
return sorting_without_excess_spikes

def get_postprocessing_path(self) -> Path:
return self.sorting_data.get_postprocessing_path(
self.sorted_ses_name, self.sorted_run_name
)
return utils.make_postprocessing_path(self.sorting_path)

def get_quality_metrics_path(self) -> Path:
return self.get_postprocessing_path() / "quality_metrics.csv"
Expand Down
46 changes: 30 additions & 16 deletions spikewrap/data_classes/sorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,19 +186,31 @@ def check_ses_or_run_folders_in_datetime_order(
# Paths
# ----------------------------------------------------------------------------------

def get_sorting_path(self, ses_name: str, run_name: Optional[str]) -> Path:
return self._get_base_sorting_path(ses_name, run_name) / "sorting"

def get_sorter_output_path(self, ses_name: str, run_name: Optional[str]) -> Path:
return self.get_sorting_path(ses_name, run_name) / "sorter_output"
def get_sorting_path(
self, ses_name: str, run_name: Optional[str], group_idx: Optional[int] = None
) -> Path:
if group_idx is None:
format_group_name = ""
else:
format_group_name = f"group-{group_idx}"

def _get_sorting_info_path(self, ses_name: str, run_name: Optional[str]) -> Path:
return self.get_sorting_path(ses_name, run_name) / utils.canonical_names(
"sorting_yaml"
return (
self.get_base_sorting_path(ses_name, run_name)
/ format_group_name
/ "sorting"
)

def get_postprocessing_path(self, ses_name: str, run_name: Optional[str]) -> Path:
return self._get_base_sorting_path(ses_name, run_name) / "postprocessing"
def get_sorter_output_path(
self, ses_name: str, run_name: Optional[str], group_idx: Optional[int] = None
) -> Path:
return self.get_sorting_path(ses_name, run_name, group_idx) / "sorter_output"

def _get_sorting_info_path(
self, ses_name: str, run_name: Optional[str], group_idx: Optional[int] = None
) -> Path:
return self.get_sorting_path(
ses_name, run_name, group_idx
) / utils.canonical_names("sorting_yaml")

def _validate_derivatives_inputs(self):
self._validate_inputs(
Expand Down Expand Up @@ -247,7 +259,9 @@ def _make_run_name_from_multiple_run_names(self, run_names: List[str]) -> str:
# Sorting info
# ----------------------------------------------------------------------------------

def save_sorting_info(self, ses_name: str, run_name: str) -> None:
def save_sorting_info(
self, ses_name: str, run_name: str, group_idx: Optional[int] = None
) -> None:
"""
Save a sorting_info.yaml file containing a dictionary holding
important information on the sorting. This is for provenance.
Expand Down Expand Up @@ -289,7 +303,7 @@ def save_sorting_info(self, ses_name: str, run_name: str) -> None:
sorting_info["datetime_created"] = utils.get_formatted_datetime()

utils.dump_dict_to_yaml(
self._get_sorting_info_path(ses_name, run_name), sorting_info
self._get_sorting_info_path(ses_name, run_name, group_idx), sorting_info
)

@property
Expand Down Expand Up @@ -325,7 +339,7 @@ def get_preprocessed_recordings(
raise NotImplementedError

@abstractmethod
def _get_base_sorting_path(self, ses_name: str, run_name: Optional[str]) -> Path:
def get_base_sorting_path(self, ses_name: str, run_name: Optional[str]) -> Path:
raise NotImplementedError


Expand Down Expand Up @@ -364,7 +378,7 @@ def get_preprocessed_recordings(
self.assert_names(ses_name, run_name)
return self[self.concat_ses_name()]

def _get_base_sorting_path(self, ses_name: str, run_name: Optional[str]) -> Path:
def get_base_sorting_path(self, ses_name: str, run_name: Optional[str]) -> Path:
""""""
self.assert_names(ses_name, run_name)

Expand Down Expand Up @@ -447,7 +461,7 @@ def get_preprocessed_recordings(

return self[ses_name][run_name]

def _get_base_sorting_path(self, ses_name: str, run_name: Optional[str]) -> Path:
def get_base_sorting_path(self, ses_name: str, run_name: Optional[str]) -> Path:
assert run_name == self.concat_run_name(ses_name)
assert run_name is not None

Expand Down Expand Up @@ -501,7 +515,7 @@ def get_preprocessed_recordings(
assert run_name is not None
return self[ses_name][run_name]

def _get_base_sorting_path(self, ses_name: str, run_name: Optional[str]) -> Path:
def get_base_sorting_path(self, ses_name: str, run_name: Optional[str]) -> Path:
assert run_name is not None
# TODO: centralise paths!!# TODO: centralise paths!!# TODO: centralise paths!!
return (
Expand Down
15 changes: 8 additions & 7 deletions spikewrap/examples/example_full_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,19 @@
from spikewrap.pipeline.full_pipeline import run_full_pipeline

base_path = Path(
r"C:\fMRIData\git-repo\spikewrap\tests\data\small_toy_data"
# r"/ceph/neuroinformatics/neuroinformatics/scratch/jziminski/ephys/test_data/steve_multi_run/1119617/time-long_origdata"
r"C:\fMRIData\git-repo\spikewrap\tests\data\small_toy_data",
)

sub_name = "sub-001_type-test"
sessions_and_runs = {"all": ["all"]}

sessions_and_runs = {
"all": ["all"],
}

# sub_name = "1119617"
# sessions_and_runs = {
# "ses-001": ["1119617_LSE1_shank12_g0"],
# "ses-001": ["1119617_LSE1_shank12_g0"],
# }

config_name = "test_default"
Expand All @@ -28,16 +32,13 @@
"spikeinterface",
config_name,
sorter,
sort_by_group=True,
save_preprocessing_chunk_size=30000,
existing_preprocessed_data="overwrite",
existing_sorting_output="overwrite",
overwrite_postprocessing=True,
concat_sessions_for_sorting=False, # TODO: validate this at the start, in `run_full_pipeline`
concat_runs_for_sorting=False,
# existing_preprocessed_data="skip_if_exists", # this is kind of confusing...
# existing_sorting_output="overwrite",
# overwrite_postprocessing=True,
# slurm_batch=False,
)

print(f"TOOK {time.time() - t}")
44 changes: 28 additions & 16 deletions spikewrap/examples/example_preprocess.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,46 @@
from pathlib import Path

from spikewrap.pipeline.load_data import load_data
from spikewrap.pipeline.preprocess import run_preprocessing
from spikewrap.pipeline.preprocess import PreprocessPipeline

base_path = Path(
r"/ceph/neuroinformatics/neuroinformatics/scratch/jziminski/ephys/test_data/steve_multi_run/1119617/time-short-multises"
"/ceph/neuroinformatics/neuroinformatics/scratch/jziminski/ephys/code/spikewrap/tests/data/small_toy_data"
# r"C:\fMRIData\git-repo\spikewrap\tests\data\small_toy_data"
# r"/ceph/neuroinformatics/neuroinformatics/scratch/jziminski/ephys/test_data/steve_multi_run/1119617/time-short-multises"
# r"C:\data\ephys\test_data\steve_multi_run\1119617\time-miniscule-multises"
)

sub_name = "sub-1119617"
sub_name = "sub-001_type-test"
# sub_name = "sub-1119617"

sessions_and_runs = {
"ses-001": [
"run-001_1119617_LSE1_shank12_g0",
"run-002_made_up_g0",
],
"ses-002": [
"run-001_1119617_pretest1_shank12_g0",
],
"ses-003": [
"run-002_1119617_pretest1_shank12_g0",
],
"ses-001": ["all"],
"ses-002": ["all"],
}

loaded_data = load_data(base_path, sub_name, sessions_and_runs, data_format="spikeglx")
if False:
sessions_and_runs = {
"ses-001": [
"run-001_1119617_LSE1_shank12_g0",
"run-002_made_up_g0",
],
"ses-002": [
"run-001_1119617_pretest1_shank12_g0",
],
"ses-003": [
"run-002_1119617_pretest1_shank12_g0",
],
}

loaded_data = load_data(
base_path, sub_name, sessions_and_runs, data_format="spikeinterface"
)

run_preprocessing(
preprocess_pipeline = PreprocessPipeline(
loaded_data,
pp_steps="default",
handle_existing_data="overwrite",
preprocess_by_group=True,
log=True,
slurm_batch=False,
)
preprocess_pipeline.run(slurm_batch=True)
79 changes: 52 additions & 27 deletions spikewrap/pipeline/full_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
from spikewrap.configs.configs import get_configs
from spikewrap.pipeline.load_data import load_data
from spikewrap.pipeline.postprocess import run_postprocess
from spikewrap.pipeline.preprocess import run_preprocessing

# from spikewrap.pipeline.preprocess import run_preprocessing
from spikewrap.pipeline.sort import run_sorting
from spikewrap.utils import logging_sw, slurm, utils, validate

Expand All @@ -26,6 +27,7 @@ def run_full_pipeline(
config_name: str = "default",
sorter: str = "kilosort2_5",
preprocess_by_group: bool = False,
sort_by_group: bool = False,
concat_sessions_for_sorting: bool = False,
concat_runs_for_sorting: bool = False,
existing_preprocessed_data: HandleExisting = "fail_if_exists",
Expand Down Expand Up @@ -53,6 +55,7 @@ def run_full_pipeline(
"config_name": config_name,
"sorter": sorter,
"preprocess_by_group": preprocess_by_group,
"sort_by_group": sort_by_group,
"concat_sessions_for_sorting": concat_sessions_for_sorting,
"concat_runs_for_sorting": concat_runs_for_sorting,
"existing_preprocessed_data": existing_preprocessed_data,
Expand All @@ -72,6 +75,7 @@ def run_full_pipeline(
config_name,
sorter,
preprocess_by_group,
sort_by_group,
concat_sessions_for_sorting,
concat_runs_for_sorting,
existing_preprocessed_data,
Expand All @@ -91,6 +95,7 @@ def _run_full_pipeline(
config_name: str = "default",
sorter: str = "kilosort2_5",
preprocess_by_group: bool = False,
sort_by_group: bool = False,
concat_sessions_for_sorting: bool = False,
concat_runs_for_sorting: bool = False,
existing_preprocessed_data: HandleExisting = "fail_if_exists",
Expand Down Expand Up @@ -231,6 +236,7 @@ def _run_full_pipeline(
sub_name,
sessions_and_runs,
sorter,
sort_by_group,
concat_sessions_for_sorting,
concat_runs_for_sorting,
sorter_options,
Expand All @@ -240,20 +246,26 @@ def _run_full_pipeline(

# Run Postprocessing
for ses_name, run_name in sorting_data.get_sorting_sessions_and_runs():
sorting_path = sorting_data.get_sorting_path(ses_name, run_name)

postprocess_data = run_postprocess(
sorting_path,
overwrite_postprocessing=overwrite_postprocessing,
existing_waveform_data="fail_if_exists",
waveform_options=waveform_options,
)
for sorting_path in _get_sorting_paths(
sorting_data, ses_name, run_name, sort_by_group
):
postprocess_data = run_postprocess(
sorting_path,
overwrite_postprocessing=overwrite_postprocessing,
existing_waveform_data="fail_if_exists",
waveform_options=waveform_options,
)

# Delete intermediate files
for ses_name, run_name in sorting_data.get_sorting_sessions_and_runs():
handle_delete_intermediate_files(
ses_name, run_name, sorting_data, delete_intermediate_files
)
for sorting_path in _get_sorting_paths(
sorting_data, ses_name, run_name, sort_by_group
):
postprocessing_path = utils.make_postprocessing_path(sorting_path)

handle_delete_intermediate_files(
sorting_path, postprocessing_path, delete_intermediate_files
)
logs.stop_logging()

return (
Expand All @@ -262,15 +274,37 @@ def _run_full_pipeline(
)


def _get_sorting_paths(
sorting_data: SortingData, ses_name: str, run_name: str, sort_by_group: bool
) -> List[Path]:
""" """
if sort_by_group:
all_group_paths = sorting_data.get_base_sorting_path(ses_name, run_name).glob(
"group-*"
)
group_indexes = [
int(group.name.split("group-")[1])
for group in all_group_paths
if group.is_dir()
] # TODO: kind of hacky
all_sorting_paths = [
sorting_data.get_sorting_path(ses_name, run_name, idx)
for idx in group_indexes
]
else:
all_sorting_paths = [sorting_data.get_sorting_path(ses_name, run_name)]

return all_sorting_paths


# --------------------------------------------------------------------------------------
# Remove Intermediate Files
# --------------------------------------------------------------------------------------


def handle_delete_intermediate_files(
ses_name: str,
run_name: Optional[str],
sorting_data: SortingData,
sorting_path: Path,
postprocessing_path: Path,
delete_intermediate_files: DeleteIntermediate,
):
"""
Expand All @@ -279,22 +313,13 @@ def handle_delete_intermediate_files(
for Kilosort). See `run_full_pipeline` for inputs
"""
if "recording.dat" in delete_intermediate_files:
if (
recording_file := sorting_data.get_sorter_output_path(ses_name, run_name)
/ "recording.dat"
).is_file():
if (recording_file := sorting_path / "recording.dat").is_file():
recording_file.unlink()

if "temp_wh.dat" in delete_intermediate_files:
if (
recording_file := sorting_data.get_sorter_output_path(ses_name, run_name)
/ "temp_wh.dat"
).is_file():
if (recording_file := sorting_path / "temp_wh.dat").is_file():
recording_file.unlink()

if "waveforms" in delete_intermediate_files:
if (
waveforms_path := sorting_data.get_postprocessing_path(ses_name, run_name)
/ "waveforms"
).is_dir():
if (waveforms_path := postprocessing_path / "waveforms").is_dir():
shutil.rmtree(waveforms_path)
Loading
Loading