Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Task protocol model #773

Draft
wants to merge 6 commits into
base: dev
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions alyx/actions/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,18 +477,18 @@ def _pass_narrative_templates(context):

class SessionAdmin(BaseActionAdmin):
list_display = ['subject_l', 'start_time', 'number', 'lab', 'dataset_count',
'task_protocol', 'qc', 'user_list', 'project_']
'task_protocol_', 'qc', 'user_list', 'project_']
list_display_links = ['start_time']
fields = BaseActionAdmin.fields + [
'repo_url', 'qc', 'extended_qc', 'projects', ('type', 'task_protocol', ), 'number',
'repo_url', 'qc', 'extended_qc', 'projects', ('type', 'task_protocols', ), 'number',
'n_correct_trials', 'n_trials', 'weighing', 'auto_datetime']
list_filter = [('users', RelatedDropdownFilter),
('start_time', DateRangeFilter),
('projects', RelatedDropdownFilter),
('lab', RelatedDropdownFilter),
]
search_fields = ('subject__nickname', 'lab__name', 'projects__name', 'users__username',
'task_protocol', 'pk')
'task_protocol__name', 'pk')
ordering = ('-start_time', 'task_protocol', 'lab')
inlines = [WaterAdminInline, DatasetInline, NoteInline]
readonly_fields = ['repo_url', 'task_protocol', 'weighing', 'qc', 'extended_qc',
Expand Down Expand Up @@ -520,6 +520,9 @@ def add_view(self, request, extra_context=None):
def project_(self, obj):
return [getattr(p, 'name', None) for p in obj.projects.all()]

def task_protocol_(self, obj):
return [getattr(p, 'name', None) for p in obj.task_protocols.all()]

def repo_url(self, obj):
url = settings.SESSION_REPO_URL.format(
lab=obj.subject.lab.name,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Generated by Django 4.1.5 on 2023-02-03 10:16

from django.db import migrations, models


class Migration(migrations.Migration):

dependencies = [
('experiments', '0012_taskprotocol'),
('actions', '0018_session_projects_alter_session_project'),
]

operations = [
migrations.AddField(
model_name='session',
name='task_protocols',
field=models.ManyToManyField(blank=True, to='experiments.taskprotocol', verbose_name='Session task protocols'),
),
migrations.AlterField(
model_name='session',
name='task_protocol',
field=models.CharField(blank=True, default='old task protocol', max_length=1023),
),
]
4 changes: 3 additions & 1 deletion alyx/actions/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,9 @@ class Session(BaseAction):
help_text="User-defined session type (e.g. Base, Experiment)")
number = models.IntegerField(null=True, blank=True,
help_text="Optional session number for this level")
task_protocol = models.CharField(max_length=1023, blank=True, default='')
task_protocol = models.CharField(max_length=1023, blank=True, default='old task protocol')
task_protocols = models.ManyToManyField('experiments.TaskProtocol', blank=True,
verbose_name='Session task protocols')
n_trials = models.IntegerField(blank=True, null=True)
n_correct_trials = models.IntegerField(blank=True, null=True)

Expand Down
13 changes: 11 additions & 2 deletions alyx/actions/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,12 @@
from data.models import Dataset, DatasetType
from misc.models import LabLocation, Lab
from experiments.serializers import ProbeInsertionListSerializer, FilterDatasetSerializer
from experiments.models import TaskProtocol
from misc.serializers import NoteSerializer


SESSION_FIELDS = ('subject', 'users', 'location', 'procedures', 'lab', 'projects', 'type',
'task_protocol', 'number', 'start_time', 'end_time', 'narrative',
'task_protocols', 'number', 'start_time', 'end_time', 'narrative',
'parent_session', 'n_correct_trials', 'n_trials', 'url', 'extended_qc', 'qc',
'wateradmin_session_related', 'data_dataset_session_related',
'auto_datetime')
Expand Down Expand Up @@ -121,18 +122,23 @@ class SessionListSerializer(BaseActionSerializer):
slug_field='name',
queryset=Project.objects.all(),
many=True)
task_protocols = serializers.SlugRelatedField(read_only=False,
slug_field='name',
queryset=TaskProtocol.objects.all(),
many=True)

@staticmethod
def setup_eager_loading(queryset):
""" Perform necessary eager loading of data to avoid horrible performance."""
queryset = queryset.select_related('subject', 'lab')
queryset = queryset.prefetch_related('projects')
queryset = queryset.prefetch_related('task_protocols')
return queryset.order_by('-start_time')

class Meta:
model = Session
fields = ('id', 'subject', 'start_time', 'number', 'lab', 'projects', 'url',
'task_protocol')
'task_protocols')


class SessionDetailSerializer(BaseActionSerializer):
Expand All @@ -142,6 +148,9 @@ class SessionDetailSerializer(BaseActionSerializer):
probe_insertion = ProbeInsertionListSerializer(read_only=True, many=True)
projects = serializers.SlugRelatedField(read_only=False, slug_field='name', many=True,
queryset=Project.objects.all(), required=False)
task_protocols = serializers.SlugRelatedField(
read_only=False, slug_field='name', many=True,
queryset=TaskProtocol.objects.all(), required=False)
notes = NoteSerializer(read_only=True, many=True)
qc = BaseSerializerEnumField(required=False)

Expand Down
44 changes: 43 additions & 1 deletion alyx/actions/tests_rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from alyx import base
from alyx.base import BaseTests
from subjects.models import Subject, Project
from experiments.models import TaskProtocol
from misc.models import Lab, Note, ContentType
from actions.models import Session, WaterType, WaterAdministration

Expand All @@ -21,6 +22,8 @@ def setUp(self):
self.lab02 = Lab.objects.create(name='awesomelab')
self.projectX = Project.objects.create(name='projectX')
self.projectY = Project.objects.create(name='projectY')
self.protocolX = TaskProtocol.objects.create(name='ephysChoiceWorld')
self.protocolY = TaskProtocol.objects.create(name='passiveChoiceWorld')
# Set an implant weight.
self.subject.implant_weight = 4.56
self.subject.save()
Expand Down Expand Up @@ -187,6 +190,45 @@ def test_sessions_projects(self):
d = self.ar(self.client.get(reverse('session-list') + f'?projects={self.projectY.name}'))
self.assertEqual(len(d), 1)

def test_sessions_protocols(self):
ses1dict = {'subject': self.subject.nickname,
'users': [self.superuser.username],
'projects': [self.projectX.name],
'start_time': '2020-07-09T12:34:56',
'end_time': '2020-07-09T12:34:57',
'type': 'Base',
'number': '1',
'lab': self.lab01.name,
'task_protocol': [self.protocolX]
}
ses2dict = {'subject': self.subject.nickname,
'users': [self.superuser.username, self.superuser2.username],
'projects': [self.projectX.name],
'start_time': '2020-07-09T12:34:56',
'end_time': '2020-07-09T12:34:57',
'type': 'Base',
'number': '2',
'lab': self.lab01.name,
'task_protocol': [self.protocolX, self.protocolY]
}
self.ar(self.post(reverse('session-list'), data=ses1dict), 201)
self.ar(self.post(reverse('session-list'), data=ses2dict), 201)
# Test the user filter, this should return 2 sessions
q = f'?task_protocols={self.protocolX.name}'
d = self.ar(self.client.get(reverse('session-list') + q))
self.assertEqual(len(d), 2)
# This should return only one session
q = f'?task_protocols={self.protocolY.name}'
d = self.ar(self.client.get(reverse('session-list') + q))
self.assertEqual(len(d), 1)
# test the legacy filter that should act in the same way
q = f'?task_protocol={self.protocolX.name}'
d = self.ar(self.client.get(reverse('session-list') + q))
self.assertEqual(len(d), 2)
q = f'?task_protocols={self.protocolY.name}'
d = self.ar(self.client.get(reverse('session-list') + q))
self.assertEqual(len(d), 1)

def test_sessions(self):
a_dict4json = {'String': 'this is not a JSON', 'Integer': 4, 'List': ['titi', 4]}
ses_dict = {'subject': self.subject.nickname,
Expand All @@ -201,7 +243,7 @@ def test_sessions(self):
'lab': self.lab01.name,
'n_trials': 100,
'n_correct_trials': 75,
'task_protocol': self.test_protocol,
'task_protocol': [self.protocolX],
'json': a_dict4json}
# Test the session creation
r = self.post(reverse('session-list'), data=ses_dict)
Expand Down
2 changes: 2 additions & 0 deletions alyx/actions/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,8 @@ class SessionFilter(BaseFilterSet):
date_range = django_filters.CharFilter(field_name='date_range', method=('filter_date_range'))
type = django_filters.CharFilter(field_name='type', lookup_expr=('iexact'))
lab = django_filters.CharFilter(field_name='lab__name', lookup_expr=('iexact'))
task_protocols = django_filters.CharFilter(field_name='task_protocols__name',
lookup_expr=('icontains'))
task_protocol = django_filters.CharFilter(field_name='task_protocol',
lookup_expr=('icontains'))
qc = django_filters.CharFilter(method='enum_field_filter')
Expand Down
27 changes: 27 additions & 0 deletions alyx/experiments/migrations/0012_taskprotocol.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Generated by Django 4.1.5 on 2023-02-03 10:16

from django.db import migrations, models
import uuid


class Migration(migrations.Migration):

dependencies = [
('experiments', '0011_chronic_insertion'),
]

operations = [
migrations.CreateModel(
name='TaskProtocol',
fields=[
('id', models.UUIDField(default=uuid.uuid4, editable=False, primary_key=True, serialize=False)),
('json', models.JSONField(blank=True, help_text='Structured data, formatted in a user-defined way', null=True)),
('name', models.CharField(max_length=255)),
('version', models.CharField(help_text='The major version of the task protocol', max_length=255)),
('description', models.CharField(blank=True, help_text='Description of the task protocol', max_length=1023)),
],
options={
'unique_together': {('name', 'version')},
},
),
]
13 changes: 13 additions & 0 deletions alyx/experiments/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,3 +235,16 @@ class Meta:
def save(self, *args, **kwargs):
super(Channel, self).save(*args, **kwargs)
self.trajectory_estimate.save() # this will bump the datetime auto-update of trajectory


class TaskProtocol(BaseModel):
name = models.CharField(max_length=255)
version = models.CharField(max_length=255, help_text='The major version of the task protocol')
description = models.CharField(
max_length=1023, blank=True, help_text='Description of the task protocol')

class Meta:
unique_together = (('name', 'version'),)

def __str__(self):
return "<TaskProtocol %s>" % self.name
11 changes: 6 additions & 5 deletions alyx/jobs/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ class TaskAdmin(BaseAdmin):
exclude = ['json']
readonly_fields = ['session', 'log', 'parents']
list_display = ['name', 'graph', 'status', 'version_str', 'level', 'datetime',
'session_str', 'session_task_protocol', 'session_projects']
'session_str', 'session_task_protocols', 'session_projects']
search_fields = ('session__id', 'session__lab__name', 'session__subject__nickname',
'log', 'version', 'session__task_protocol', 'session__projects__name')
'log', 'version', 'session__task_protocols__name', 'session__projects__name')
ordering = ('-session__start_time', 'level')
list_editable = ('status', )
list_filter = [('name', DropdownFilter),
Expand All @@ -32,9 +32,10 @@ def session_projects(self, obj):
return obj.session.projects.name
session_projects.short_description = 'projects'

def session_task_protocol(self, obj):
return obj.session.task_protocol
session_task_protocol.short_description = 'task_protocol'
def session_task_protocols(self, obj):
if obj.session.task_protocols is not None:
return obj.session.task_protocols.name
session_task_protocols.short_description = 'task_protocols'

def session_str(self, obj):
url = get_admin_url(obj.session)
Expand Down
13 changes: 9 additions & 4 deletions alyx/misc/management/commands/one_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,12 +311,15 @@ def generate_sessions_frame(int_id=True, tags=None) -> pd.DataFrame:
)
"""
fields = ('id', 'lab__name', 'subject__nickname', 'start_time__date',
'number', 'task_protocol', 'all_projects')
'number', 'all_protocols', 'all_projects')
projects = ArrayAgg('projects__name')
protocols = ArrayAgg('task_protocols__name')
query = (Session
.objects
.select_related('subject', 'lab')
.prefetch_related('projects')
.annotate(all_projects=ArrayAgg('projects__name'))
.prefetch_related('task_protocols')
.annotate(all_projects=projects, all_protocols=protocols)
.order_by('-start_time', 'subject__nickname', '-number')) # FIXME Ignores nickname :(
if tags:
if not isinstance(tags, str):
Expand All @@ -327,16 +330,18 @@ def generate_sessions_frame(int_id=True, tags=None) -> pd.DataFrame:
logger.debug(f'Raw session frame = {getsizeof(df) / 1024**2} MiB')
# Rename, sort fields
df['all_projects'] = df['all_projects'].map(lambda x: ','.join(filter(None, set(x))))
df['all_protocols'] = df['all_protocols'].map(lambda x: ','.join(filter(None, set(x))))
renames = {'start_time': 'date', 'all_projects': 'projects', 'all_protocols': 'task_protocols'}
df = (
(df
.rename(lambda x: x.split('__')[0], axis=1)
.rename({'start_time': 'date', 'all_projects': 'projects'}, axis=1)
.rename(renames, axis=1)
.dropna(subset=['number', 'date', 'subject', 'lab']) # Remove dud or base sessions
.sort_values(['date', 'subject', 'number'], ascending=False))
)
df['number'] = df['number'].astype(int) # After dropping nans we can convert number to int
# These columns may be empty; ensure None -> ''
for col in ('task_protocol', 'projects'):
for col in ('task_protocols', 'projects'):
df[col] = df[col].astype(str)

if int_id:
Expand Down
6 changes: 5 additions & 1 deletion alyx/subjects/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -630,7 +630,11 @@ def new_litter_autoname(self):
def new_subject_autoname(self):
self.subject_autoname_index = self.subject_autoname_index + 1
self.save()
return '%s_%04d' % (self.nickname, self.subject_autoname_index)
new_name = '%s_%04d' % (self.nickname, self.subject_autoname_index)
if Subject.objects.filter(nickname=new_name).count() > 0:
return self.new_subject_autoname()
assert Subject.objects.filter(nickname=new_name).count() == 0
return new_name

def set_autoname(self, obj):
if isinstance(obj, BreedingPair):
Expand Down
10 changes: 5 additions & 5 deletions requirements_frozen.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
asgiref==3.6.0
backports.zoneinfo==0.2.1
boto3==1.26.51
botocore==1.29.51
boto3==1.26.56
botocore==1.29.56
certifi==2022.12.7
cffi==1.15.1
charset-normalizer==3.0.1
Expand Down Expand Up @@ -36,7 +36,7 @@ flake8==6.0.0
fonttools==4.38.0
globus-cli==3.10.1
globus-sdk==3.15.0
iblutil==1.4.0
iblutil==1.5.0
idna==3.4
importlib-metadata==6.0.0
itypes==1.2.0
Expand All @@ -51,9 +51,9 @@ matplotlib==3.6.3
mccabe==0.7.0
numba==0.56.4
numpy==1.23.5
ONE-api==1.18.0
ONE-api==1.19.0
packaging==23.0
pandas==1.5.2
pandas==1.5.3
Pillow==9.4.0
psycopg2-binary==2.9.5
pyarrow==10.0.1
Expand Down