Skip to content

Commit

Permalink
Merge branch 'release/2.40.0'
Browse files Browse the repository at this point in the history
  • Loading branch information
oliche committed Oct 28, 2024
2 parents c4418cb + b235446 commit b1dd4e1
Show file tree
Hide file tree
Showing 33 changed files with 454 additions and 220 deletions.
21 changes: 14 additions & 7 deletions brainbox/io/one.py
Original file line number Diff line number Diff line change
Expand Up @@ -866,13 +866,21 @@ def _get_attributes(dataset_types):
waveform_attributes = list(set(WAVEFORMS_ATTRIBUTES + waveform_attributes))
return {'spikes': spike_attributes, 'clusters': cluster_attributes, 'waveforms': waveform_attributes}

def _get_spike_sorting_collection(self, spike_sorter='pykilosort'):
def _get_spike_sorting_collection(self, spike_sorter=None):
"""
Filters a list or array of collections to get the relevant spike sorting dataset
if there is a pykilosort, load it
"""
collection = next(filter(lambda c: c == f'alf/{self.pname}/{spike_sorter}', self.collections), None)
# otherwise, prefers the shortest
for sorter in list([spike_sorter, 'iblsorter', 'pykilosort']):
if sorter is None:
continue
if sorter == "":
collection = next(filter(lambda c: c == f'alf/{self.pname}', self.collections), None)
else:
collection = next(filter(lambda c: c == f'alf/{self.pname}/{sorter}', self.collections), None)
if collection is not None:
return collection
# if none is found amongst the defaults, prefers the shortest
collection = collection or next(iter(sorted(filter(lambda c: f'alf/{self.pname}' in c, self.collections), key=len)), None)
_logger.debug(f"selecting: {collection} to load amongst candidates: {self.collections}")
return collection
Expand Down Expand Up @@ -982,14 +990,13 @@ def download_raw_waveforms(self, **kwargs):
"""
_logger.debug(f"loading waveforms from {self.collection}")
return self.one.load_object(
self.eid, "waveforms",
attribute=["traces", "templates", "table", "channels"],
id=self.eid, obj="waveforms", attribute=["traces", "templates", "table", "channels"],
collection=self._get_spike_sorting_collection("pykilosort"), download_only=True, **kwargs
)

def raw_waveforms(self, **kwargs):
wf_paths = self.download_raw_waveforms(**kwargs)
return WaveformsLoader(wf_paths[0].parent, wfs_dtype=np.float16)
return WaveformsLoader(wf_paths[0].parent)

def load_channels(self, **kwargs):
"""
Expand Down Expand Up @@ -1022,7 +1029,7 @@ def load_channels(self, **kwargs):
self.histology = 'alf'
return Bunch(channels)

def load_spike_sorting(self, spike_sorter='pykilosort', revision=None, enforce_version=True, good_units=False, **kwargs):
def load_spike_sorting(self, spike_sorter='iblsorter', revision=None, enforce_version=False, good_units=False, **kwargs):
"""
Loads spikes, clusters and channels
Expand Down
1 change: 1 addition & 0 deletions brainbox/io/spikeglx.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ def __init__(self, pid, one, typ='ap', cache_folder=None, remove_cached=False):
self.file_chunks = self.one.load_dataset(self.eid, f'*.{typ}.ch', collection=f"*{self.pname}")
meta_file = self.one.load_dataset(self.eid, f'*.{typ}.meta', collection=f"*{self.pname}")
cbin_rec = self.one.list_datasets(self.eid, collection=f"*{self.pname}", filename=f'*{typ}.*bin', details=True)
cbin_rec.index = cbin_rec.index.map(lambda x: (self.eid, x))
self.url_cbin = self.one.record2url(cbin_rec)[0]
with open(self.file_chunks, 'r') as f:
self.chunks = json.load(f)
Expand Down
2 changes: 1 addition & 1 deletion ibllib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import logging
import warnings

__version__ = '2.39.1'
__version__ = '2.40.0'
warnings.filterwarnings('always', category=DeprecationWarning, module='ibllib')

# if this becomes a full-blown library we should let the logging configuration to the discretion of the dev
Expand Down
2 changes: 1 addition & 1 deletion ibllib/ephys/sync_probes.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def sync(ses_path, **kwargs):
return version3B(ses_path, **kwargs)


def version3A(ses_path, display=True, type='smooth', tol=2.1):
def version3A(ses_path, display=True, type='smooth', tol=2.1, probe_names=None):
"""
From a session path with _spikeglx_sync arrays extracted, locate ephys files for 3A and
outputs one sync.timestamps.probeN.npy file per acquired probe. By convention the reference
Expand Down
16 changes: 8 additions & 8 deletions ibllib/io/extractors/biased_trials.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ class TrialsTableBiased(BaseBpodTrialsExtractor):
save_names = ('_ibl_trials.table.pqt', None, None, '_ibl_wheel.timestamps.npy', '_ibl_wheel.position.npy',
'_ibl_wheelMoves.intervals.npy', '_ibl_wheelMoves.peakAmplitude.npy', None, None)
var_names = ('table', 'stimOff_times', 'stimFreeze_times', 'wheel_timestamps', 'wheel_position', 'wheelMoves_intervals',
'wheelMoves_peakAmplitude', 'peakVelocity_times', 'is_final_movement')
'wheelMoves_peakAmplitude', 'wheelMoves_peakVelocity_times', 'is_final_movement')

def _extract(self, extractor_classes=None, **kwargs):
extractor_classes = extractor_classes or []
Expand Down Expand Up @@ -125,7 +125,7 @@ class TrialsTableEphys(BaseBpodTrialsExtractor):
'_ibl_wheelMoves.intervals.npy', '_ibl_wheelMoves.peakAmplitude.npy', None,
None, None, None, '_ibl_trials.quiescencePeriod.npy')
var_names = ('table', 'stimOff_times', 'stimFreeze_times', 'wheel_timestamps', 'wheel_position', 'wheelMoves_intervals',
'wheelMoves_peakAmplitude', 'peakVelocity_times', 'is_final_movement',
'wheelMoves_peakAmplitude', 'wheelMoves_peakVelocity_times', 'is_final_movement',
'phase', 'position', 'quiescence')

def _extract(self, extractor_classes=None, **kwargs):
Expand All @@ -152,12 +152,12 @@ class BiasedTrials(BaseBpodTrialsExtractor):
save_names = ('_ibl_trials.goCueTrigger_times.npy', '_ibl_trials.stimOnTrigger_times.npy', None,
'_ibl_trials.stimOffTrigger_times.npy', None, None, '_ibl_trials.table.pqt',
'_ibl_trials.stimOff_times.npy', None, '_ibl_wheel.timestamps.npy', '_ibl_wheel.position.npy',
'_ibl_wheelMoves.intervals.npy', '_ibl_wheelMoves.peakAmplitude.npy', None, None, '_ibl_trials.included.npy',
None, None, '_ibl_trials.quiescencePeriod.npy')
'_ibl_wheelMoves.intervals.npy', '_ibl_wheelMoves.peakAmplitude.npy', None, None,
'_ibl_trials.included.npy', None, None, '_ibl_trials.quiescencePeriod.npy')
var_names = ('goCueTrigger_times', 'stimOnTrigger_times', 'itiIn_times', 'stimOffTrigger_times', 'stimFreezeTrigger_times',
'errorCueTrigger_times', 'table', 'stimOff_times', 'stimFreeze_times', 'wheel_timestamps', 'wheel_position',
'wheelMoves_intervals', 'wheelMoves_peakAmplitude', 'peakVelocity_times', 'is_final_movement', 'included',
'phase', 'position', 'quiescence')
'wheelMoves_intervals', 'wheelMoves_peakAmplitude', 'wheelMoves_peakVelocity_times', 'is_final_movement',
'included', 'phase', 'position', 'quiescence')

def _extract(self, extractor_classes=None, **kwargs) -> dict:
extractor_classes = extractor_classes or []
Expand All @@ -182,8 +182,8 @@ class EphysTrials(BaseBpodTrialsExtractor):
'_ibl_trials.included.npy', None, None, '_ibl_trials.quiescencePeriod.npy')
var_names = ('goCueTrigger_times', 'stimOnTrigger_times', 'itiIn_times', 'stimOffTrigger_times', 'stimFreezeTrigger_times',
'errorCueTrigger_times', 'table', 'stimOff_times', 'stimFreeze_times', 'wheel_timestamps', 'wheel_position',
'wheelMoves_intervals', 'wheelMoves_peakAmplitude', 'peakVelocity_times', 'is_final_movement', 'included',
'phase', 'position', 'quiescence')
'wheelMoves_intervals', 'wheelMoves_peakAmplitude', 'wheelMoves_peakVelocity_times', 'is_final_movement',
'included', 'phase', 'position', 'quiescence')

def _extract(self, extractor_classes=None, **kwargs) -> dict:
extractor_classes = extractor_classes or []
Expand Down
30 changes: 16 additions & 14 deletions ibllib/io/extractors/bpod_trials.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,14 @@
This module will extract the Bpod trials and wheel data based on the task protocol,
i.e. habituation, training or biased.
"""
import logging
import importlib

from ibllib.io.extractors.base import get_bpod_extractor_class, protocol2extractor
from ibllib.io.extractors.base import get_bpod_extractor_class, protocol2extractor, BaseExtractor
from ibllib.io.extractors.habituation_trials import HabituationTrials
from ibllib.io.extractors.training_trials import TrainingTrials
from ibllib.io.extractors.biased_trials import BiasedTrials, EphysTrials
from ibllib.io.extractors.base import BaseBpodTrialsExtractor

_logger = logging.getLogger(__name__)


def get_bpod_extractor(session_path, protocol=None, task_collection='raw_behavior_data') -> BaseBpodTrialsExtractor:
"""
Expand All @@ -39,20 +36,25 @@ def get_bpod_extractor(session_path, protocol=None, task_collection='raw_behavio
'BiasedTrials': BiasedTrials,
'EphysTrials': EphysTrials
}

if protocol:
class_name = protocol2extractor(protocol)
extractor_class_name = protocol2extractor(protocol)
else:
class_name = get_bpod_extractor_class(session_path, task_collection=task_collection)
if class_name in builtins:
return builtins[class_name](session_path)
extractor_class_name = get_bpod_extractor_class(session_path, task_collection=task_collection)
if extractor_class_name in builtins:
return builtins[extractor_class_name](session_path)

# look if there are custom extractor types in the personal projects repo
if not class_name.startswith('projects.'):
class_name = 'projects.' + class_name
module, class_name = class_name.rsplit('.', 1)
if not extractor_class_name.startswith('projects.'):
extractor_class_name = 'projects.' + extractor_class_name
module, extractor_class_name = extractor_class_name.rsplit('.', 1)
mdl = importlib.import_module(module)
extractor_class = getattr(mdl, class_name, None)
extractor_class = getattr(mdl, extractor_class_name, None)
if extractor_class:
return extractor_class(session_path)
my_extractor = extractor_class(session_path)
if not isinstance(my_extractor, BaseExtractor):
raise ValueError(
f"{my_extractor} should be an Extractor class inheriting from ibllib.io.extractors.base.BaseExtractor")
return my_extractor
else:
raise ValueError(f'extractor {class_name} not found')
raise ValueError(f'extractor {extractor_class_name} not found')
4 changes: 2 additions & 2 deletions ibllib/io/extractors/ephys_fpga.py
Original file line number Diff line number Diff line change
Expand Up @@ -585,13 +585,13 @@ class FpgaTrials(extractors_base.BaseExtractor):
'_ibl_trials.stimOff_times.npy', None, None, None, '_ibl_trials.quiescencePeriod.npy',
'_ibl_trials.table.pqt', '_ibl_wheel.timestamps.npy',
'_ibl_wheel.position.npy', '_ibl_wheelMoves.intervals.npy',
'_ibl_wheelMoves.peakAmplitude.npy')
'_ibl_wheelMoves.peakAmplitude.npy', None)
var_names = ('goCueTrigger_times', 'stimOnTrigger_times',
'stimOffTrigger_times', 'stimFreezeTrigger_times', 'errorCueTrigger_times',
'errorCue_times', 'itiIn_times', 'stimFreeze_times', 'stimOff_times',
'valveOpen_times', 'phase', 'position', 'quiescence', 'table',
'wheel_timestamps', 'wheel_position',
'wheelMoves_intervals', 'wheelMoves_peakAmplitude')
'wheelMoves_intervals', 'wheelMoves_peakAmplitude', 'wheelMoves_peakVelocity_times')

bpod_rsync_fields = ('intervals', 'response_times', 'goCueTrigger_times',
'stimOnTrigger_times', 'stimOffTrigger_times',
Expand Down
2 changes: 1 addition & 1 deletion ibllib/io/extractors/fibrephotometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ def _extract(self, light_source_map=None, collection=None, regions=None, **kwarg
regions = regions or [k for k in fp_data['raw'].keys() if 'Region' in k]
out_df = fp_data['raw'].filter(items=regions, axis=1).sort_index(axis=1)
out_df['times'] = ts
out_df['wavelength'] = np.NaN
out_df['wavelength'] = np.nan
out_df['name'] = ''
out_df['color'] = ''
# Extract channel index
Expand Down
2 changes: 1 addition & 1 deletion ibllib/io/extractors/mesoscope.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
import numpy as np
from scipy.signal import find_peaks
import one.alf.io as alfio
from one.util import ensure_list
from one.alf.files import session_path_parts
from iblutil.util import ensure_list
import matplotlib.pyplot as plt
from packaging import version

Expand Down
8 changes: 4 additions & 4 deletions ibllib/io/extractors/opto_trials.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ class LaserBool(BaseBpodTrialsExtractor):
def _extract(self, **kwargs):
_logger.info('Extracting laser datasets')
# reference pybpod implementation
lstim = np.array([float(t.get('laser_stimulation', np.NaN)) for t in self.bpod_trials])
lprob = np.array([float(t.get('laser_probability', np.NaN)) for t in self.bpod_trials])
lstim = np.array([float(t.get('laser_stimulation', np.nan)) for t in self.bpod_trials])
lprob = np.array([float(t.get('laser_probability', np.nan)) for t in self.bpod_trials])

# Karolina's choice world legacy implementation - from Slack message:
# it is possible that some versions I have used:
Expand All @@ -30,9 +30,9 @@ def _extract(self, **kwargs):
# laserOFF_trials=(optoOUT ==0);
if 'PROBABILITY_OPTO' in self.settings.keys() and np.all(np.isnan(lstim)):
lprob = np.zeros_like(lprob) + self.settings['PROBABILITY_OPTO']
lstim = np.array([float(t.get('opto_ON_time', np.NaN)) for t in self.bpod_trials])
lstim = np.array([float(t.get('opto_ON_time', np.nan)) for t in self.bpod_trials])
if np.all(np.isnan(lstim)):
lstim = np.array([float(t.get('optoOUT', np.NaN)) for t in self.bpod_trials])
lstim = np.array([float(t.get('optoOUT', np.nan)) for t in self.bpod_trials])
lstim[lstim == 255] = 1
else:
lstim[~np.isnan(lstim)] = 1
Expand Down
6 changes: 3 additions & 3 deletions ibllib/io/extractors/training_trials.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def _extract(self):
feedbackType = np.zeros(len(self.bpod_trials), np.int64)
for i, t in enumerate(self.bpod_trials):
state_names = ['correct', 'error', 'no_go', 'omit_correct', 'omit_error', 'omit_no_go']
outcome = {sn: ~np.isnan(t['behavior_data']['States timestamps'].get(sn, [[np.NaN]])[0][0]) for sn in state_names}
outcome = {sn: ~np.isnan(t['behavior_data']['States timestamps'].get(sn, [[np.nan]])[0][0]) for sn in state_names}
assert np.sum(list(outcome.values())) == 1
outcome = next(k for k in outcome if outcome[k])
if outcome == 'correct':
Expand Down Expand Up @@ -709,7 +709,7 @@ class TrialsTable(BaseBpodTrialsExtractor):
save_names = ('_ibl_trials.table.pqt', None, None, '_ibl_wheel.timestamps.npy', '_ibl_wheel.position.npy',
'_ibl_wheelMoves.intervals.npy', '_ibl_wheelMoves.peakAmplitude.npy', None, None)
var_names = ('table', 'stimOff_times', 'stimFreeze_times', 'wheel_timestamps', 'wheel_position', 'wheelMoves_intervals',
'wheelMoves_peakAmplitude', 'peakVelocity_times', 'is_final_movement')
'wheelMoves_peakAmplitude', 'wheelMoves_peakVelocity_times', 'is_final_movement')

def _extract(self, extractor_classes=None, **kwargs):
base = [Intervals, GoCueTimes, ResponseTimes, Choice, StimOnOffFreezeTimes, ContrastLR, FeedbackTimes, FeedbackType,
Expand All @@ -732,7 +732,7 @@ class TrainingTrials(BaseBpodTrialsExtractor):
var_names = ('repNum', 'goCueTrigger_times', 'stimOnTrigger_times', 'itiIn_times', 'stimOffTrigger_times',
'stimFreezeTrigger_times', 'errorCueTrigger_times', 'table', 'stimOff_times', 'stimFreeze_times',
'wheel_timestamps', 'wheel_position', 'wheelMoves_intervals', 'wheelMoves_peakAmplitude',
'peakVelocity_times', 'is_final_movement', 'phase', 'position', 'quiescence', 'pause_duration')
'wheelMoves_peakVelocity_times', 'is_final_movement', 'phase', 'position', 'quiescence', 'pause_duration')

def _extract(self) -> dict:
base = [RepNum, GoCueTriggerTimes, StimOnTriggerTimes, ItiInTimes, StimOffTriggerTimes, StimFreezeTriggerTimes,
Expand Down
14 changes: 4 additions & 10 deletions ibllib/io/extractors/training_wheel.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,11 +330,6 @@ def extract_first_movement_times(wheel_moves, trials, min_qt=None):
gap between quiescence end and cue start, or during the quiescence period but sub-
threshold). The movement is sufficiently large if it is greater than or equal to THRESH.
:param wheel_moves:
:param trials: dictionary of trial data
:param min_qt:
:return: numpy array of
Parameters
----------
wheel_moves : dict
Expand Down Expand Up @@ -407,9 +402,9 @@ class Wheel(BaseBpodTrialsExtractor):
save_names = ('_ibl_wheel.timestamps.npy', '_ibl_wheel.position.npy',
'_ibl_wheelMoves.intervals.npy', '_ibl_wheelMoves.peakAmplitude.npy', None,
'_ibl_trials.firstMovement_times.npy', None)
var_names = ('wheel_timestamps', 'wheel_position', 'wheelMoves_intervals',
'wheelMoves_peakAmplitude', 'peakVelocity_times', 'firstMovement_times',
'is_final_movement')
var_names = ('wheel_timestamps', 'wheel_position',
'wheelMoves_intervals', 'wheelMoves_peakAmplitude', 'wheelMoves_peakVelocity_times',
'firstMovement_times', 'is_final_movement')

def _extract(self):
ts, pos = get_wheel_position(self.session_path, self.bpod_trials, task_collection=self.task_collection)
Expand All @@ -425,6 +420,5 @@ def _extract(self):
min_qt = self.settings.get('QUIESCENT_PERIOD', None)

first_moves, is_final, _ = extract_first_movement_times(moves, trials, min_qt=min_qt)
output = (ts, pos, moves['intervals'], moves['peakAmplitude'],
moves['peakVelocity_times'], first_moves, is_final)
output = (ts, pos, moves['intervals'], moves['peakAmplitude'], moves['peakVelocity_times'], first_moves, is_final)
return output
19 changes: 17 additions & 2 deletions ibllib/io/session_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import logging
import socket
from pathlib import Path
from itertools import chain
from copy import deepcopy

from one.converters import ConversionMixin
Expand Down Expand Up @@ -77,6 +78,9 @@ def _patch_file(data: dict) -> dict:
if 'tasks' in data and isinstance(data['tasks'], dict):
data['tasks'] = [{k: v} for k, v in data['tasks'].copy().items()]
data['version'] = SPEC_VERSION
# Ensure all items in tasks list are single value dicts
if 'tasks' in data:
data['tasks'] = [{k: v} for k, v in chain.from_iterable(map(dict.items, data['tasks']))]
return data


Expand Down Expand Up @@ -168,8 +172,19 @@ def merge_params(a, b, copy=False):
assert k not in a or a[k] == b[k], 'multiple sync fields defined'
if isinstance(b[k], list):
prev = list(a.get(k, []))
# For procedures and projects, remove duplicates
to_add = b[k] if k == 'tasks' else set(b[k]) - set(prev)
if k == 'tasks':
# For tasks, keep order and skip duplicates
# Assert tasks is a list of single value dicts
assert (not prev or set(map(len, prev)) == {1}) and set(map(len, b[k])) == {1}
# Convert protocol -> dict map to hashable tuple of protocol + sorted key value pairs
to_hashable = lambda itm: (itm[0], *chain.from_iterable(sorted(itm[1].items()))) # noqa
# Get the set of previous tasks
prev_tasks = set(map(to_hashable, chain.from_iterable(map(dict.items, prev))))
tasks = chain.from_iterable(map(dict.items, b[k]))
to_add = [dict([itm]) for itm in tasks if to_hashable(itm) not in prev_tasks]
else:
# For procedures and projects, remove duplicates
to_add = set(b[k]) - set(prev)
a[k] = prev + list(to_add)
elif isinstance(b[k], dict):
a[k] = {**a.get(k, {}), **b[k]}
Expand Down
Loading

0 comments on commit b1dd4e1

Please sign in to comment.