+from __future__ import annotations
+
+from enum import Enum
+from typing import Annotated, Dict, List, Literal, Optional, Union
+
+import aind_behavior_services.task_logic.distributions as distributions
+from aind_behavior_services.task_logic import AindBehaviorTaskLogicModel
+from aind_behavior_vr_foraging import __version__
+from pydantic import BaseModel, Field, RootModel
+
+
+
+
[docs]
+
def scalar_value(value: float) -> distributions.Scalar:
+
"""
+
Helper function to create a scalar value distribution for a given value.
+
+
Args:
+
value (float): The value of the scalar distribution.
+
+
Returns:
+
distributions.Scalar: The scalar distribution type.
+
"""
+
return distributions.Scalar(distribution_parameters=distributions.ScalarDistributionParameter(value=value))
+
+
+
+
+
[docs]
+
class Size(BaseModel):
+
width: float = Field(default=0, description="Width of the texture")
+
height: float = Field(default=0, description="Height of the texture")
+
+
+
+
+
[docs]
+
class Vector2(BaseModel):
+
x: float = Field(default=0, description="X coordinate of the point")
+
y: float = Field(default=0, description="Y coordinate of the point")
+
+
+
+
+
[docs]
+
class Vector3(BaseModel):
+
x: float = Field(default=0, description="X coordinate of the point")
+
y: float = Field(default=0, description="Y coordinate of the point")
+
z: float = Field(default=0, description="Z coordinate of the point")
+
+
+
+
+
[docs]
+
class Matrix2D(BaseModel):
+
data: List[List[float]] = Field([[1]], description="Defines a 2D matrix")
+
+
+
+# Updaters
+
+
[docs]
+
class NumericalUpdaterOperation(str, Enum):
+
NONE = "None"
+
OFFSET = "Offset"
+
GAIN = "Gain"
+
SET = "Set"
+
OFFSETPERCENTAGE = "OffsetPercentage"
+
+
+
+
+
[docs]
+
class NumericalUpdaterParameters(BaseModel):
+
initial_value: float = Field(default=0.0, description="Initial value of the parameter")
+
increment: float = Field(default=0.0, description="Value to increment the parameter by")
+
decrement: float = Field(default=0.0, description="Value to decrement the parameter by")
+
minimum: float = Field(default=0.0, description="Minimum value of the parameter")
+
maximum: float = Field(default=0.0, description="Minimum value of the parameter")
+
+
+
+
+
[docs]
+
class NumericalUpdater(BaseModel):
+
operation: NumericalUpdaterOperation = Field(
+
default=NumericalUpdaterOperation.NONE, description="Operation to perform on the parameter"
+
)
+
parameters: NumericalUpdaterParameters = Field(
+
NumericalUpdaterParameters(), description="Parameters of the updater"
+
)
+
+
+
+
+
[docs]
+
class Texture(BaseModel):
+
name: str = Field(default="default", description="Name of the texture")
+
size: Size = Field(default=Size(width=40, height=40), description="Size of the texture")
+
+
+
+
+
[docs]
+
class OdorSpecification(BaseModel):
+
index: int = Field(..., ge=0, le=3, description="Index of the odor to be used")
+
concentration: float = Field(default=1, ge=0, le=1, description="Concentration of the odor")
+
+
+
+
+
[docs]
+
class OperantLogic(BaseModel):
+
is_operant: bool = Field(default=True, description="Will the trial implement operant logic")
+
stop_duration: float = Field(
+
default=0, ge=0, description="Duration (s) the animal must stop for to lock its choice"
+
)
+
time_to_collect_reward: float = Field(
+
default=100000, ge=0, description="Time(s) the animal has to collect the reward"
+
)
+
grace_distance_threshold: float = Field(
+
default=10, ge=0, description="Virtual distance (cm) the animal must be within to not abort the current choice"
+
)
+
+
+
+
+
[docs]
+
class PowerFunction(BaseModel):
+
function_type: Literal["PowerFunction"] = "PowerFunction"
+
mininum: float = Field(default=0, description="Minimum value of the function")
+
maximum: float = Field(default=1, description="Maximum value of the function")
+
a: float = Field(default=1, description="Coefficient a of the function: value = a * pow(b, c * x) + d")
+
b: float = Field(
+
default=2.718281828459045, description="Coefficient b of the function: value = a * pow(b, c * x) + d"
+
)
+
c: float = Field(default=-1, description="Coefficient c of the function: value = a * pow(b, c * x) + d")
+
d: float = Field(default=0, description="Coefficient d of the function: value = a * pow(b, c * x) + d")
+
+
+
+
+
[docs]
+
class LinearFunction(BaseModel):
+
function_type: Literal["LinearFunction"] = "LinearFunction"
+
mininum: float = Field(default=0, description="Minimum value of the function")
+
maximum: float = Field(default=9999, description="Maximum value of the function")
+
a: float = Field(default=1, description="Coefficient a of the function: value = a * x + b")
+
b: float = Field(default=0, description="Coefficient b of the function: value = a * x + b")
+
+
+
+
+
[docs]
+
class ConstantFunction(BaseModel):
+
function_type: Literal["ConstantFunction"] = "ConstantFunction"
+
value: float = Field(default=1, description="Value of the function")
+
+
+
+
+
[docs]
+
class RewardFunction(RootModel):
+
root: Annotated[Union[ConstantFunction, LinearFunction, PowerFunction], Field(discriminator="function_type")]
+
+
+
+
+
[docs]
+
class DepletionRule(str, Enum):
+
ON_REWARD = ("OnReward",)
+
ON_CHOICE = ("OnChoice",)
+
ON_TIME = ("OnTime",)
+
ON_DISTANCE = "OnDistance"
+
+
+
+
+
[docs]
+
class PatchRewardFunction(BaseModel):
+
amount: RewardFunction = Field(
+
default=ConstantFunction(value=1),
+
description="Determines the amount of reward to be delivered. The value is in microliters",
+
validate_default=True,
+
)
+
probability: RewardFunction = Field(
+
default=ConstantFunction(value=1),
+
description="Determines the probability that a reward will be delivered",
+
validate_default=True,
+
)
+
available: RewardFunction = Field(
+
default=LinearFunction(mininum=0, a=-1, b=5),
+
description="Determines the total amount of reward available left in the patch. The value is in microliters",
+
validate_default=True,
+
)
+
depletion_rule: DepletionRule = Field(default=DepletionRule.ON_CHOICE, description="Depletion")
+
+
+
+
+
[docs]
+
class RewardSpecification(BaseModel):
+
operant_logic: Optional[OperantLogic] = Field(None, description="The optional operant logic of the reward")
+
delay: distributions.Distribution = Field(
+
default=scalar_value(0),
+
description="The optional distribution where the delay to reward will be drawn from",
+
validate_default=True,
+
)
+
reward_function: PatchRewardFunction = Field(
+
default=PatchRewardFunction(), description="Reward function of the patch."
+
)
+
+
+
+
+
[docs]
+
class VirtualSiteLabels(str, Enum):
+
UNSPECIFIED = "Unspecified"
+
INTERPATCH = "InterPatch"
+
REWARDSITE = "RewardSite"
+
INTERSITE = "InterSite"
+
+
+
+
+
[docs]
+
class RenderSpecification(BaseModel):
+
contrast: Optional[float] = Field(default=None, ge=0, le=1, description="Contrast of the texture")
+
+
+
+
+
[docs]
+
class VirtualSiteGenerator(BaseModel):
+
render_specification: RenderSpecification = Field(
+
default=RenderSpecification(), description="Contrast of the environment"
+
)
+
label: VirtualSiteLabels = Field(default=VirtualSiteLabels.UNSPECIFIED, description="Label of the virtual site")
+
length_distribution: distributions.Distribution = Field(
+
default=scalar_value(20), description="Distribution of the length of the virtual site", validate_default=True
+
)
+
+
+
+
+
[docs]
+
class VirtualSiteGeneration(BaseModel):
+
inter_site: VirtualSiteGenerator = Field(
+
VirtualSiteGenerator(), description="Generator of the inter-site virtual sites"
+
)
+
inter_patch: VirtualSiteGenerator = Field(
+
VirtualSiteGenerator(), description="Generator of the inter-patch virtual sites"
+
)
+
reward_site: VirtualSiteGenerator = Field(
+
VirtualSiteGenerator(), description="Generator of the reward-site virtual sites"
+
)
+
+
+
+
+
[docs]
+
class VirtualSite(BaseModel):
+
id: int = Field(default=0, ge=0, description="Id of the virtual site")
+
label: VirtualSiteLabels = Field(VirtualSiteLabels.UNSPECIFIED, description="Label of the virtual site")
+
length: float = Field(20, description="Length of the virtual site (cm)")
+
start_position: float = Field(default=0, ge=0, description="Start position of the virtual site (cm)")
+
odor_specification: Optional[OdorSpecification] = Field(
+
None, description="The optional odor specification of the virtual site"
+
)
+
reward_specification: Optional[RewardSpecification] = Field(
+
None, description="The optional reward specification of the virtual site"
+
)
+
render_specification: RenderSpecification = Field(
+
RenderSpecification(), description="The optional render specification of the virtual site"
+
)
+
+
+
+
+
[docs]
+
class PatchStatistics(BaseModel):
+
label: str = Field(default="", description="Label of the patch")
+
state_index: int = Field(default=0, ge=0, description="Index of the state")
+
odor_specification: Optional[OdorSpecification] = Field(
+
default=None, description="The optional odor specification of the patch"
+
)
+
reward_specification: Optional[RewardSpecification] = Field(
+
default=None, description="The optional reward specification of the patch"
+
)
+
virtual_site_generation: VirtualSiteGeneration = Field(
+
VirtualSiteGeneration(), description="Virtual site generation specification"
+
)
+
+
+
+
+
[docs]
+
class WallTextures(BaseModel):
+
floor: Texture = Field(..., description="The texture of the floor")
+
ceiling: Texture = Field(..., description="The texture of the ceiling")
+
left: Texture = Field(..., description="The texture of the left")
+
right: Texture = Field(..., description="The texture of the right")
+
+
+
+
+
[docs]
+
class VisualCorridor(BaseModel):
+
id: int = Field(default=0, ge=0, description="Id of the visual corridor object")
+
size: Size = Field(default=Size(width=40, height=40), description="Size of the corridor (cm)")
+
start_position: float = Field(default=0, ge=0, description="Start position of the corridor (cm)")
+
length: float = Field(default=120, ge=0, description="Length of the corridor site (cm)")
+
textures: WallTextures = Field(..., description="The textures of the corridor")
+
+
+
+
+
[docs]
+
class EnvironmentStatistics(BaseModel):
+
patches: List[PatchStatistics] = Field(default_factory=list, description="List of patches")
+
transition_matrix: Matrix2D = Field(default=Matrix2D(), description="Transition matrix between patches")
+
first_state: Optional[int] = Field(
+
default=None, ge=0, description="The first state to be visited. If None, it will be randomly drawn."
+
)
+
+
+
+
+
[docs]
+
class ServoMotor(BaseModel):
+
period: int = Field(default=20000, ge=1, description="Period", units="us")
+
min_pulse_duration: int = Field(default=1000, ge=1, description="Minimum pulse duration", units="us")
+
max_pulse_duration: int = Field(default=2000, ge=1, description="Maximum pulse duration", units="us")
+
default_pulse_duration: int = Field(default=2000, ge=1, description="Default pulse duration", units="us")
+
+
+
+
+
[docs]
+
class MovableSpoutControl(BaseModel):
+
enabled: bool = Field(default=False, description="Whether the movable spout is enabled")
+
time_to_collect_after_reward: float = Field(default=1, ge=0, description="Time (s) to collect after reward")
+
servo_motor: ServoMotor = Field(default=ServoMotor(), description="Servo motor settings")
+
+
+
+
+
[docs]
+
class OdorControl(BaseModel):
+
valve_max_open_time: float = Field(
+
default=10, ge=0, description="Maximum time (s) the valve can be open continuously"
+
)
+
target_total_flow: float = Field(
+
default=1000, ge=100, le=1000, description="Target total flow (ml/s) of the odor mixture"
+
)
+
use_channel_3_as_carrier: bool = Field(default=True, description="Whether to use channel 3 as carrier")
+
target_odor_flow: float = Field(
+
default=100, ge=0, le=100, description="Target odor flow (ml/s) in the odor mixture"
+
)
+
+
+
+
+
[docs]
+
class PositionControl(BaseModel):
+
gain: Vector3 = Field(default=Vector3(x=1, y=1, z=1), description="Gain of the position control.")
+
initial_position: Vector3 = Field(default=Vector3(x=0, y=2.56, z=0), description="Gain of the position control.")
+
frequency_filter_cutoff: float = Field(
+
default=0.5,
+
ge=0,
+
le=100,
+
description="Cutoff frequency (Hz) of the low-pass filter used to filter the velocity signal.",
+
)
+
velocity_threshold: float = Field(
+
default=1, ge=0, description="Threshold (cm/s) of the velocity signal used to detect when the animal is moving."
+
)
+
+
+
+
+
[docs]
+
class AudioControl(BaseModel):
+
duration: float = Field(default=0.2, ge=0, description="Duration", units="s")
+
frequency: float = Field(default=1000, ge=100, description="Frequency", units="Hz")
+
+
+
+
+
[docs]
+
class OperationControl(BaseModel):
+
movable_spout_control: MovableSpoutControl = Field(
+
default=MovableSpoutControl(), description="Control of the movable spout"
+
)
+
odor_control: OdorControl = Field(default=OdorControl(), description="Control of the odor", validate_default=True)
+
position_control: PositionControl = Field(
+
default=PositionControl(), description="Control of the position", validate_default=True
+
)
+
audio_control: AudioControl = Field(
+
default=AudioControl(), description="Control of the audio", validate_default=True
+
)
+
+
+
+
+
[docs]
+
class TaskMode(str, Enum):
+
DEBUG = "DEBUG"
+
HABITUATION = "HABITUATION"
+
FORAGING = "FORAGING"
+
+
+
+
+
[docs]
+
class TaskModeSettingsBase(BaseModel):
+
task_mode: TaskMode = Field(default=TaskMode.FORAGING, description="Stage of the task")
+
+
+
+
+
[docs]
+
class HabituationSettings(TaskModeSettingsBase):
+
task_mode: Literal[TaskMode.HABITUATION] = TaskMode.HABITUATION
+
distance_to_reward: distributions.Distribution = Field(..., description="Distance (cm) to the reward")
+
render_specification: RenderSpecification = Field(
+
RenderSpecification(), description="The optional render specification of the virtual site", validate_default=True
+
)
+
+
+
+
+
[docs]
+
class DebugSettings(TaskModeSettingsBase):
+
"""This class is not currently implemented"""
+
+
task_mode: Literal[TaskMode.DEBUG] = TaskMode.DEBUG
+
visual_corridors: List[VisualCorridor]
+
virtual_sites: List[VirtualSite]
+
+
+
+
+
[docs]
+
class ForagingSettings(TaskModeSettingsBase):
+
task_mode: Literal[TaskMode.FORAGING] = TaskMode.FORAGING
+
+
+
+
+
[docs]
+
class TaskModeSettings(RootModel):
+
root: Annotated[Union[HabituationSettings, ForagingSettings, DebugSettings], Field(discriminator="task_mode")]
+
+
+
+
+
[docs]
+
class AindVrForagingTaskLogic(AindBehaviorTaskLogicModel):
+
schema_version: Literal[__version__] = __version__
+
updaters: Dict[str, NumericalUpdater] = Field(default_factory=dict, description="List of numerical updaters")
+
environment_statistics: EnvironmentStatistics = Field(..., description="Statistics of the environment")
+
task_mode_settings: TaskModeSettings = Field(
+
default=ForagingSettings(), description="Settings of the task stage", validate_default=True
+
)
+
operation_control: OperationControl = Field(..., description="Control of the operation")
+
+
+
+
+
[docs]
+
def schema() -> BaseModel:
+
return AindVrForagingTaskLogic
+
+
+