Skip to content

Commit

Permalink
Print the arguments passed to main public methods, tidy up postproces…
Browse files Browse the repository at this point in the history
…sing logs. (#86)
  • Loading branch information
JoeZiminski authored Aug 11, 2023
1 parent e2926ab commit 1c2f422
Show file tree
Hide file tree
Showing 8 changed files with 47 additions and 18 deletions.
2 changes: 1 addition & 1 deletion spikewrap/data_classes/postprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ def __init__(self, sorting_path):
self.sorting_info["preprocessing_run_names"],
self.sorting_info["sorter"],
self.sorting_info["concat_for_sorting"],
print_messages=False,
)
self.sorting_data.load_preprocessed_binary()

self.sorted_run_name = self.sorting_info["sorted_run_name"]
self.preprocessing_info = self.sorting_info["preprocessing"]
Expand Down
29 changes: 19 additions & 10 deletions spikewrap/data_classes/sorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,19 @@ class SortingData(BaseUserDict):
"""

def __init__(
self, base_path, sub_name, run_names, sorter: str, concat_for_sorting: bool
self,
base_path,
sub_name,
run_names,
sorter: str,
concat_for_sorting: bool,
print_messages: bool = True,
):
super(SortingData, self).__init__(base_path, sub_name, run_names)

self.concat_for_sorting = concat_for_sorting
self.sorter = sorter
self.print_messages = print_messages

self._check_preprocessing_exists()

Expand Down Expand Up @@ -190,21 +197,23 @@ def _concatenate_si_recording(self, recordings: Dict) -> si.BaseRecording:

concatenated_recording = concatenate_recordings(recordings_list)

# Perform some checks before returning.
assert loaded_prepro_run_names == tuple(
self.preprocessing_run_names
), "Something has gone wrong in the `run_names` ordering."

if not self._run_names_are_in_datetime_order("creation"):
warnings.warn(
"The runs provided are not in creation datetime order.\n"
"They will be concatenated in the order provided."
if self.print_messages:
if not self._run_names_are_in_datetime_order("creation"):
warnings.warn(
"The runs provided are not in creation datetime order.\n"
"They will be concatenated in the order provided."
)

utils.message_user(
f"Preprocessed data loaded prior to sorting. "
f"Runs were concatenated runs in the order: "
f"{loaded_prepro_run_names}"
)

utils.message_user(
f"Concatenating runs in the order: " f"{loaded_prepro_run_names}"
)

return concatenated_recording

# Paths
Expand Down
4 changes: 2 additions & 2 deletions spikewrap/examples/example_full_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@
config_name,
sorter,
concat_for_sorting=True,
existing_preprocessed_data="load_if_exists",
existing_sorting_output="load_if_exists",
existing_preprocessed_data="overwrite",
existing_sorting_output="overwrite",
overwrite_postprocessing=True,
delete_intermediate_files=(
"recording.dat",
Expand Down
6 changes: 5 additions & 1 deletion spikewrap/pipeline/full_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,8 @@ def run_full_pipeline(
If True, the pipeline will be run in a SLURM job. Set False
if running on an interactive job, or locally.
"""
passed_arguments = locals()

if slurm_batch:
local_args = copy.deepcopy(locals())
slurm.run_full_pipeline_slurm(**local_args)
Expand All @@ -133,8 +135,10 @@ def run_full_pipeline(
pp_steps, sorter_options, waveform_options = get_configs(config_name)

logs = logging_sw.get_started_logger(
utils.get_logging_path(base_path, sub_name), "full_pipeline"
utils.get_logging_path(base_path, sub_name),
"full_pipeline",
)
utils.show_passed_arguments(passed_arguments, "`run_full pipeline`")

loaded_data = load_data(base_path, sub_name, run_names, data_format="spikeglx")

Expand Down
5 changes: 4 additions & 1 deletion spikewrap/pipeline/postprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,15 +69,18 @@ def run_postprocess(
A dictionary containing options passed to SpikeInterface's
`extract_waveforms()` function as kwargs.
"""
passed_arguments = locals()

postprocess_data = PostprocessingData(sorting_path)

logs = logging_sw.get_started_logger(
utils.get_logging_path(
postprocess_data.sorting_info["base_path"],
postprocess_data.sorting_info["sub_name"],
),
"full_pipeline",
"postprocess",
)
utils.show_passed_arguments(passed_arguments, "`run_postprocess`")

utils.message_user(f"Postprocessing run: {postprocess_data.sorted_run_name}...")

Expand Down
5 changes: 4 additions & 1 deletion spikewrap/pipeline/sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,12 @@ def run_sorting(
if running on an interactive job, or locally.
"""
passed_arguments = locals()
logs = logging_sw.get_started_logger(
utils.get_logging_path(base_path, sub_name), "full_pipeline"
utils.get_logging_path(base_path, sub_name),
"sorting",
)
utils.show_passed_arguments(passed_arguments, "`run_sorting`")

sorting_data = SortingData(
base_path,
Expand Down
7 changes: 5 additions & 2 deletions spikewrap/utils/logging_sw.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,15 +171,18 @@ def flush(self):

def get_started_logger(
log_filepath: Path,
run_name: Literal["full_pipeline", "preprocess", "sorting", "postprocess"],
run_name: Literal["full_pipeline", "sorting", "postprocess"],
) -> HandleLogging:
"""
Convenience function that creates logger name and stars a
HandleLogging() instance.
HandleLogging() instance. Note that this may be called
even when the logging does not log, see HandleLogging()
docs for details.
"""
format_datetime = utils.get_formatted_datetime()
log_name = f"{format_datetime}_{run_name}.log"

logs = HandleLogging()
logs.start_logging(log_filepath / log_name)

return logs
7 changes: 7 additions & 0 deletions spikewrap/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,13 @@ def get_logging_path(base_path: Union[str, Path], sub_name: str) -> Path:
return Path(base_path) / "derivatives" / sub_name / "logs"


def show_passed_arguments(passed_arguments, function_name):
message_user(
f"Running {function_name}. The function was called "
f"with the arguments {passed_arguments}.",
)


def message_user(message: str, verbose: bool = True) -> None:
"""
Method to interact with user.
Expand Down

0 comments on commit 1c2f422

Please sign in to comment.