Skip to content

Commit

Permalink
Optim waveforms (#44)
Browse files Browse the repository at this point in the history
* modify wiggle plot to allow filling of positive and/or negative peaks

* double wiggle: double trouble. And tests.

* pre-processing should have time samples contiguous in memory

* WIP compress before

* instantiate memmap only once

* waveformextraction: remove AGC for the CAR option

* Write to output is parallelized

* add decompression in waveform extraction

* waveform extract: add removal of temporary file

* flatten waveforms

* allow for cluster indices not starting at 0 and flake8 fixes

* starting on test w/ new format

* wip waveforms loader version 1 and 2

* load w/ indices

* some more fixes

* add colours to the channel detection plots

* remove the Windows tests from CI

* fix bug when cluster ids do not start at 0

* fix last set of tests for the waveform extraction

* remove neurodsp after grace period expired

* Make sure the waveform loader is compatible with the notebook

* add bad channel plot helper and default LF parameters

* add some splitting / splicing function on window generator

* add lfp pre-processing

* fix channel labels bug

* fix default lfp filter parameters

* fix waveform extractor conflict

* add compatible syntax for 3.9

---------

Co-authored-by: chris-langfield <[email protected]>
  • Loading branch information
oliche and chris-langfield authored Oct 5, 2024
1 parent 2bf5aea commit 97c4824
Show file tree
Hide file tree
Showing 14 changed files with 349 additions and 286 deletions.
Binary file removed .DS_Store
Binary file not shown.
2 changes: 1 addition & 1 deletion .github/workflows/tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ jobs:
strategy:
fail-fast: false
matrix:
os: ["ubuntu-latest", "windows-latest"]
os: ["ubuntu-latest"]
python-version: ["3.9", "3.10"]
steps:
- name: Checkout ibl-neuropixel repo
Expand Down
10 changes: 10 additions & 0 deletions release_notes.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,13 @@
# 1.4

## 1.4.0 2024-10-05
- Waveform extraction:
- Optimization of the waveform extractor, outputs flattened waveforms
- Refactoring ot the waveform loader with back compability
- Bad channel detector:
- The bad channel detector has a plot option to visualize the bad channels and thresholds
- The default low-cut filters are set to 300Hz for AP band and 2 Hz for LF band

# 1.3

## 1.3.2 2024-09-18
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

setuptools.setup(
name="ibl-neuropixel",
version="1.3.2",
version="1.4.0",
author="The International Brain Laboratory",
description="Collection of tools for Neuropixel 1.0 and 2.0 probes data",
long_description=long_description,
Expand Down
Binary file removed src/.DS_Store
Binary file not shown.
23 changes: 10 additions & 13 deletions src/ibldsp/plots.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,33 @@
import numpy as np
import matplotlib.pyplot as plt
import scipy.signal


def show_channels_labels(raw, fs, channel_labels, xfeats):
def show_channels_labels(raw, fs, channel_labels, xfeats, similarity_threshold, psd_hf_threshold=0.02):
"""
Shows the features side by side a snippet of raw data
:param sr:
:return:
"""
nc, ns = raw.shape
ns_plot = np.minimum(ns, 3000)
vaxis_uv = 75
sos_hp = scipy.signal.butter(**{"N": 3, "Wn": 300 / fs * 2, "btype": "highpass"}, output="sos")
butt = scipy.signal.sosfiltfilt(sos_hp, raw)
vaxis_uv = 250 if fs < 2600 else 75
fig, ax = plt.subplots(1, 5, figsize=(18, 6), gridspec_kw={'width_ratios': [1, 1, 1, 8, .2]})
ax[0].plot(xfeats['xcor_hf'], np.arange(nc))
ax[0].plot(xfeats['xcor_hf'][(iko := channel_labels == 1)], np.arange(nc)[iko], 'r*')
ax[0].plot([- .5, -.5], [0, nc], 'r--')
ax[0].plot(xfeats['xcor_hf'][(iko := channel_labels == 1)], np.arange(nc)[iko], 'k*')
ax[0].plot(similarity_threshold[0] * np.ones(2), [0, nc], 'k--')
ax[0].plot(similarity_threshold[1] * np.ones(2), [0, nc], 'r--')
ax[0].set(ylabel='channel #', xlabel='high coherence', ylim=[0, nc], title='a) dead channel')
ax[1].plot(xfeats['psd_hf'], np.arange(nc))
ax[1].plot(xfeats['psd_hf'][(iko := channel_labels == 2)], np.arange(nc)[iko], 'r*')
ax[1].plot([.02, .02], [0, nc], 'r--')

ax[1].plot(psd_hf_threshold * np.array([1, 1]), [0, nc], 'r--')
ax[1].set(yticklabels=[], xlabel='PSD', ylim=[0, nc], title='b) noisy channel')
ax[1].sharey(ax[0])
ax[2].plot(xfeats['xcor_lf'], np.arange(nc))
ax[2].plot(xfeats['xcor_lf'][(iko := channel_labels == 3)], np.arange(nc)[iko], 'r*')
ax[2].plot([-.75, -.75], [0, nc], 'r--')
ax[2].set(yticklabels=[], xlabel='low coherence', ylim=[0, nc], title='c) outside')
ax[2].plot(xfeats['xcor_lf'][(iko := channel_labels == 3)], np.arange(nc)[iko], 'y*')
ax[2].plot([-.75, -.75], [0, nc], 'y--')
ax[2].set(yticklabels=[], xlabel='LF coherence', ylim=[0, nc], title='c) outside')
ax[2].sharey(ax[0])
im = ax[3].imshow(butt[:, :ns_plot] * 1e6, origin='lower', cmap='PuOr', aspect='auto',
im = ax[3].imshow(raw[:, :ns_plot] * 1e6, origin='lower', cmap='PuOr', aspect='auto',
vmin=-vaxis_uv, vmax=vaxis_uv, extent=[0, ns_plot / fs * 1e3, 0, nc])
ax[3].set(yticklabels=[], title='d) Raw data', xlabel='time (ms)', ylim=[0, nc])
ax[3].grid(False)
Expand Down
34 changes: 33 additions & 1 deletion src/ibldsp/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,39 @@ def __init__(self, ns, nswin, overlap):
self.iw = None

@property
def firstlast(self):
def firstlast_splicing(self):
"""
Generator that yields the indices as well as an amplitude function that can be used
to splice the windows together.
In the overlap, the amplitude function gradually transitions the amplitude from one window
to the next. The amplitudes always sum to one (ie. windows are symmetrical)
:return: tuple of (first_index, last_index, amplitude_vector]
"""
w = scipy.signal.windows.hann((self.overlap + 1) * 2 + 1, sym=True)[1:self.overlap + 1]
assert np.all(np.isclose(w + np.flipud(w), 1))

for first, last in self.firstlast:
amp = np.ones(last - first)
amp[:self.overlap] = 1 if first == 0 else w
amp[-self.overlap:] = 1 if last == self.ns else np.flipud(w)
yield (first, last, amp)

@property
def firstlast_valid(self):
"""
Generator that yields a tuple of first, last, first_valid, last_valid index of windows
The valid indices span up to half of the overlap
:return:
"""
assert self.overlap % 2 == 0, "Overlap must be even"
for first, last in self.firstlast:
first_valid = 0 if first == 0 else first + self.overlap // 2
last_valid = last if last == self.ns else last - self.overlap // 2
yield (first, last, first_valid, last_valid)

@property
def firstlast(self, return_valid=False):
"""
Generator that yields first and last index of windows
Expand Down
79 changes: 33 additions & 46 deletions src/ibldsp/voltage.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,40 +142,28 @@ def fk(
return xf * gain


def car(x, collection=None, lagc=300, butter_kwargs=None, **kwargs):
def car(x, collection=None, operator='median', **kwargs):
"""
Applies common average referencing with optional automatic gain control
:param x: the input array to be filtered. dimension, the filtering is considering
:param x: np.array(nc, ns) the input array to be de-referenced. dimension, the filtering is considering
axis=0: spatial dimension, axis=1 temporal dimension. (ntraces, ns)
:param collection:
:param lagc: window size for time domain automatic gain control (no agc otherwise)
:param butter_kwargs: filtering parameters: defaults: {'N': 3, 'Wn': 0.1, 'btype': 'highpass'}
:param collection: vector length ntraces. Each unique value set of traces is a collection and will be handled
separately. Useful for shanks.
:param operator: 'median' or 'average'
:return:
"""
if butter_kwargs is None:
butter_kwargs = {"N": 3, "Wn": 0.1, "btype": "highpass"}
if collection is not None:
xout = np.zeros_like(x)
for c in np.unique(collection):
sel = collection == c
xout[sel, :] = kfilt(
x=x[sel, :],
ntr_pad=0,
ntr_tap=None,
collection=None,
butter_kwargs=butter_kwargs,
)
xout[sel, :] = car(x=x[sel, :], collection=None, **kwargs)
return xout

# apply agc and keep the gain in handy
if not lagc:
xf = np.copy(x)
gain = 1
else:
xf, gain = agc(x, wl=lagc, si=1.0)
# apply CAR and then un-apply the gain
xf = xf - np.median(xf, axis=0)
return xf * gain
if operator == 'median':
x = x - np.median(x, axis=0)
elif operator == 'average':
x = x - np.mean(x, axis=0)
return x


def kfilt(
Expand Down Expand Up @@ -390,21 +378,18 @@ def destripe(
return x


def destripe_lfp(x, fs, channel_labels=None, **kwargs):
def destripe_lfp(x, fs, channel_labels=None, butter_kwargs=None, k_filter=False):
"""
Wrapper around the destipe function with some default parameters to destripe the LFP band
Wrapper around the destripe function with some default parameters to destripe the LFP band
See help destripe function for documentation
:param x:
:param fs:
:return:
:param x: demultiplexed array (nc, ns)
:param fs: sampling frequency
:param channel_labels: see destripe
"""
kwargs["butter_kwargs"] = {"N": 3, "Wn": 2 / fs * 2, "btype": "highpass"}
kwargs["k_filter"] = False
butter_kwargs = {"N": 3, "Wn": [0.5, 300], "btype": "bandpass", "fs": fs} if butter_kwargs is None else butter_kwargs
if channel_labels is True:
kwargs["channel_labels"], _ = detect_bad_channels(
x, fs=fs, psd_hf_threshold=1.4
)
return destripe(x, fs, **kwargs)
channel_labels, _ = detect_bad_channels(x, fs=fs, psd_hf_threshold=1.4)
return destripe(x, fs, butter_kwargs=butter_kwargs, k_filter=k_filter, channel_labels=channel_labels)


def decompress_destripe_cbin(
Expand Down Expand Up @@ -632,11 +617,9 @@ def my_function(i_chunk, n_chunk):
saturation_data = np.load(file_saturation)
assert rms_data.shape[0] == time_data.shape[0] * ncv
rms_data = rms_data.reshape(time_data.shape[0], ncv)
output_qc_path = output_qc_path or output_file.parent
output_qc_path = output_file.parent if output_qc_path is None else output_qc_path
np.save(output_qc_path.joinpath("_iblqc_ephysTimeRmsAP.rms.npy"), rms_data)
np.save(
output_qc_path.joinpath("_iblqc_ephysTimeRmsAP.timestamps.npy"), time_data
)
np.save(output_qc_path.joinpath("_iblqc_ephysTimeRmsAP.timestamps.npy"), time_data)
np.save(output_qc_path.joinpath("_iblqc_ephysSaturation.samples.npy"), saturation_data)


Expand Down Expand Up @@ -715,15 +698,18 @@ def nxcor(x, ref):
raw = raw - np.mean(raw, axis=-1)[:, np.newaxis] # removes DC offset
xcor = channels_similarity(raw)
fscale, psd = scipy.signal.welch(raw * 1e6, fs=fs) # units; uV ** 2 / Hz
if psd_hf_threshold is None:
# the LFP band data is obviously much stronger so auto-adjust the default threshold
psd_hf_threshold = 1.4 if fs < 5000 else 0.02
sos_hp = scipy.signal.butter(
**{"N": 3, "Wn": 300 / fs * 2, "btype": "highpass"}, output="sos"
)
# auto-detection of the band with which we are working
band = 'ap' if fs > 2600 else 'lf'
# the LFP band data is obviously much stronger so auto-adjust the default threshold
if band == 'ap':
psd_hf_threshold = 0.02 if psd_hf_threshold is None else psd_hf_threshold
filter_kwargs = {"N": 3, "Wn": 300 / fs * 2, "btype": "highpass"}
elif band == 'lf':
psd_hf_threshold = 1.4 if psd_hf_threshold is None else psd_hf_threshold
filter_kwargs = {"N": 3, "Wn": 1 / fs * 2, "btype": "highpass"}
sos_hp = scipy.signal.butter(**filter_kwargs, output="sos")
hf = scipy.signal.sosfiltfilt(sos_hp, raw)
xcorf = channels_similarity(hf)

xfeats = {
"ind": np.arange(nc),
"rms_raw": utils.rms(raw), # very similar to the rms avfter butterworth filter
Expand Down Expand Up @@ -754,7 +740,8 @@ def nxcor(x, ref):
# from ibllib.plots.figures import ephys_bad_channels
# ephys_bad_channels(x, 30000, ichannels, xfeats)
if display:
ibldsp.plots.show_channels_labels(raw, fs, ichannels, xfeats)
ibldsp.plots.show_channels_labels(
raw, fs, ichannels, xfeats, similarity_threshold=similarity_threshold, psd_hf_threshold=psd_hf_threshold)
return ichannels, xfeats


Expand Down
Loading

0 comments on commit 97c4824

Please sign in to comment.